use exo_core::types::{Did, Hash256, Signature, Timestamp};
use sqlx::PgPool;
use crate::{
dag::DagNode,
error::{DagError, Result},
store::DagStore,
};
fn store_err(e: impl std::fmt::Display) -> DagError {
DagError::StoreError(e.to_string())
}
pub struct PostgresStore {
pool: PgPool,
}
impl PostgresStore {
pub async fn new(pool: PgPool) -> Result<Self> {
Ok(Self { pool })
}
pub async fn migrate(pool: &PgPool) -> Result<()> {
sqlx::migrate!("./migrations")
.run(pool)
.await
.map_err(store_err)?;
Ok(())
}
fn encode_parents(parents: &[Hash256]) -> Vec<Vec<u8>> {
parents.iter().map(|h| h.as_bytes().to_vec()).collect()
}
fn decode_hash256(bytes: &[u8], column: &str) -> Result<Hash256> {
let arr = <[u8; 32]>::try_from(bytes).map_err(|_| {
store_err(format!(
"invalid dag_nodes.{column}: expected 32 bytes, got {}",
bytes.len()
))
})?;
Ok(Hash256::from_bytes(arr))
}
fn decode_parents(raw: &[Vec<u8>]) -> Result<Vec<Hash256>> {
raw.iter()
.enumerate()
.map(|(idx, bytes)| Self::decode_hash256(bytes, &format!("parents[{idx}]")))
.collect()
}
fn decode_tip_hashes(rows: Vec<(Vec<u8>,)>) -> Result<Vec<Hash256>> {
rows.into_iter()
.map(|(bytes,)| Self::decode_hash256(&bytes, "tips.hash"))
.collect()
}
fn encode_signature(sig: &Signature) -> Result<Vec<u8>> {
serde_json::to_vec(sig)
.map_err(|e| store_err(format!("failed to serialize DAG node signature: {e}")))
}
fn decode_signature(bytes: &[u8]) -> Result<Signature> {
serde_json::from_slice(bytes)
.map_err(|e| store_err(format!("invalid DAG node signature encoding: {e}")))
}
fn decode_timestamp(physical_ms: i64, logical: i64) -> Result<Timestamp> {
let physical_ms = u64::try_from(physical_ms).map_err(|_| {
store_err(format!(
"invalid dag_nodes.ts_physical_ms: expected non-negative value, got {physical_ms}"
))
})?;
let logical = u32::try_from(logical).map_err(|_| {
store_err(format!(
"invalid dag_nodes.ts_logical: expected u32-compatible value, got {logical}"
))
})?;
Ok(Timestamp::new(physical_ms, logical))
}
fn encode_timestamp_physical_ms(timestamp: Timestamp) -> Result<i64> {
Self::encode_u64_as_bigint(timestamp.physical_ms, "dag_nodes.ts_physical_ms")
}
fn decode_committed_height(height: i64) -> Result<u64> {
u64::try_from(height).map_err(|_| {
store_err(format!(
"invalid dag_committed.height: expected non-negative value, got {height}"
))
})
}
fn encode_height(height: u64, column: &str) -> Result<i64> {
Self::encode_u64_as_bigint(height, column)
}
fn encode_u64_as_bigint(value: u64, column: &str) -> Result<i64> {
i64::try_from(value).map_err(|_| {
store_err(format!(
"invalid {column}: value {value} exceeds PostgreSQL BIGINT maximum {}",
i64::MAX
))
})
}
}
#[async_trait::async_trait]
impl DagStore for PostgresStore {
async fn get(&self, hash: &Hash256) -> Result<Option<DagNode>> {
let row: Option<(
Vec<u8>, // hash
Vec<Vec<u8>>, // parents
Vec<u8>, // payload_hash
String, // creator_did
i64, // ts_physical_ms
i64, // ts_logical
Vec<u8>, // signature
)> = sqlx::query_as(
"SELECT hash, parents, payload_hash, creator_did, ts_physical_ms, ts_logical, signature
FROM dag_nodes WHERE hash = $1",
)
.bind(hash.as_bytes().as_slice())
.fetch_optional(&self.pool)
.await
.map_err(store_err)?;
match row {
None => Ok(None),
Some((hash_bytes, parents_raw, payload_bytes, did_str, phys, logical, sig_bytes)) => {
let node = DagNode {
hash: Self::decode_hash256(&hash_bytes, "hash")?,
parents: Self::decode_parents(&parents_raw)?,
payload_hash: Self::decode_hash256(&payload_bytes, "payload_hash")?,
creator_did: Did::new(&did_str)
.map_err(|e| store_err(format!("invalid DID: {e}")))?,
timestamp: Self::decode_timestamp(phys, logical)?,
signature: Self::decode_signature(&sig_bytes)?,
};
Ok(Some(node))
}
}
}
async fn put(&mut self, node: DagNode) -> Result<()> {
let parents = Self::encode_parents(&node.parents);
let sig_bytes = Self::encode_signature(&node.signature)?;
let physical_ms = Self::encode_timestamp_physical_ms(node.timestamp)?;
let logical = i64::from(node.timestamp.logical);
sqlx::query(
"INSERT INTO dag_nodes (hash, parents, payload_hash, creator_did, ts_physical_ms, ts_logical, signature)
VALUES ($1, $2, $3, $4, $5, $6, $7)
ON CONFLICT (hash) DO NOTHING",
)
.bind(node.hash.as_bytes().as_slice())
.bind(&parents)
.bind(node.payload_hash.as_bytes().as_slice())
.bind(node.creator_did.as_str())
.bind(physical_ms)
.bind(logical)
.bind(&sig_bytes)
.execute(&self.pool)
.await
.map_err(store_err)?;
Ok(())
}
async fn contains(&self, hash: &Hash256) -> Result<bool> {
let row: (bool,) = sqlx::query_as("SELECT EXISTS(SELECT 1 FROM dag_nodes WHERE hash = $1)")
.bind(hash.as_bytes().as_slice())
.fetch_one(&self.pool)
.await
.map_err(store_err)?;
Ok(row.0)
}
async fn tips(&self) -> Result<Vec<Hash256>> {
let rows: Vec<(Vec<u8>,)> = sqlx::query_as(
"SELECT hash FROM dag_nodes dn
WHERE NOT EXISTS (
SELECT 1 FROM dag_nodes other
WHERE dn.hash = ANY(other.parents)
)
ORDER BY hash",
)
.fetch_all(&self.pool)
.await
.map_err(store_err)?;
Self::decode_tip_hashes(rows)
}
async fn committed_height(&self) -> Result<u64> {
let row: (i64,) = sqlx::query_as("SELECT COALESCE(MAX(height), 0) FROM dag_committed")
.fetch_one(&self.pool)
.await
.map_err(store_err)?;
Self::decode_committed_height(row.0)
}
async fn mark_committed(&mut self, hash: &Hash256, height: u64) -> Result<()> {
let height = Self::encode_height(height, "dag_committed.height")?;
let result = sqlx::query(
"INSERT INTO dag_committed (hash, height)
SELECT $1, $2
WHERE EXISTS (SELECT 1 FROM dag_nodes WHERE hash = $1)
ON CONFLICT (hash) DO UPDATE SET height = EXCLUDED.height",
)
.bind(hash.as_bytes().as_slice())
.bind(height)
.execute(&self.pool)
.await
.map_err(store_err)?;
if result.rows_affected() == 0 {
return Err(DagError::NodeNotFound(*hash));
}
Ok(())
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
use crate::{
dag::{Dag, DeterministicDagClock, append},
store::MemoryStore,
};
type SignFn = Box<dyn Fn(&[u8]) -> Signature>;
fn make_sign_fn() -> SignFn {
Box::new(|data: &[u8]| {
let h = blake3::hash(data);
let mut sig = [0u8; 64];
sig[..32].copy_from_slice(h.as_bytes());
Signature::from_bytes(sig)
})
}
fn make_test_node() -> DagNode {
let mut dag = Dag::new();
let mut clock = DeterministicDagClock::new();
let creator = Did::new("did:exo:test").expect("valid");
let sign_fn = make_sign_fn();
append(&mut dag, &[], b"genesis", &creator, &*sign_fn, &mut clock).unwrap()
}
async fn maybe_pool() -> Option<PgPool> {
let url = std::env::var("DATABASE_URL").ok()?;
let pool = PgPool::connect(&url).await.ok()?;
PostgresStore::migrate(&pool).await.ok()?;
sqlx::query("DELETE FROM dag_committed")
.execute(&pool)
.await
.ok()?;
sqlx::query("DELETE FROM dag_nodes")
.execute(&pool)
.await
.ok()?;
Some(pool)
}
macro_rules! pg_test {
($pool:ident) => {
let Some($pool) = maybe_pool().await else {
eprintln!("Skipping Postgres test: DATABASE_URL not set");
return;
};
};
}
#[tokio::test]
async fn test_pg_put_and_get() {
pg_test!(pool);
let mut store = PostgresStore::new(pool).await.unwrap();
let node = make_test_node();
store.put(node.clone()).await.unwrap();
let retrieved = store.get(&node.hash).await.unwrap();
assert!(retrieved.is_some());
let retrieved = retrieved.unwrap();
assert_eq!(retrieved.hash, node.hash);
assert_eq!(retrieved.parents, node.parents);
assert_eq!(retrieved.payload_hash, node.payload_hash);
assert_eq!(retrieved.creator_did, node.creator_did);
assert_eq!(retrieved.timestamp, node.timestamp);
}
#[test]
fn decode_signature_rejects_invalid_stored_bytes() {
let decoded = PostgresStore::decode_signature(b"not a serialized signature");
assert!(
decoded.is_err(),
"corrupt signature storage must not decode as Signature::Empty"
);
}
#[test]
fn signature_encoding_roundtrips_without_fabrication() {
let signature = Signature::from_bytes([7u8; 64]);
let encoded = PostgresStore::encode_signature(&signature).unwrap();
let decoded = PostgresStore::decode_signature(&encoded).unwrap();
assert_eq!(decoded, signature);
}
#[test]
fn decode_hash256_rejects_wrong_width_storage() {
let err = PostgresStore::decode_hash256(&[1u8; 31], "hash").unwrap_err();
let message = err.to_string();
assert!(message.contains("dag_nodes.hash"));
assert!(message.contains("expected 32 bytes, got 31"));
}
#[test]
fn decode_timestamp_rejects_negative_storage_values() {
let err = PostgresStore::decode_timestamp(-1, 0).unwrap_err();
let message = err.to_string();
assert!(message.contains("ts_physical_ms"));
assert!(message.contains("expected non-negative"));
}
#[test]
fn decode_tip_hashes_rejects_wrong_width_storage() {
let err = PostgresStore::decode_tip_hashes(vec![(vec![1u8; 31],)]).unwrap_err();
let message = err.to_string();
assert!(message.contains("tips.hash"));
assert!(message.contains("expected 32 bytes, got 31"));
}
#[test]
fn decode_tip_hashes_preserves_query_order() {
let first = Hash256::from_bytes([1u8; 32]);
let second = Hash256::from_bytes([2u8; 32]);
let decoded = PostgresStore::decode_tip_hashes(vec![
(first.as_bytes().to_vec(),),
(second.as_bytes().to_vec(),),
])
.unwrap();
assert_eq!(decoded, vec![first, second]);
}
#[test]
fn decode_committed_height_rejects_negative_storage_values() {
let err = PostgresStore::decode_committed_height(-1).unwrap_err();
let message = err.to_string();
assert!(message.contains("dag_committed.height"));
assert!(message.contains("expected non-negative"));
}
#[test]
fn encode_timestamp_rejects_postgres_bigint_overflow() {
let err =
PostgresStore::encode_timestamp_physical_ms(Timestamp::new(u64::MAX, 0)).unwrap_err();
let message = err.to_string();
assert!(message.contains("ts_physical_ms"));
assert!(message.contains("exceeds PostgreSQL BIGINT"));
}
#[test]
fn encode_height_rejects_postgres_bigint_overflow() {
let err = PostgresStore::encode_height(u64::MAX, "dag_committed.height").unwrap_err();
let message = err.to_string();
assert!(message.contains("dag_committed.height"));
assert!(message.contains("exceeds PostgreSQL BIGINT"));
}
#[test]
fn encode_height_accepts_postgres_bigint_maximum() {
let encoded =
PostgresStore::encode_height(u64::try_from(i64::MAX).unwrap(), "dag_committed.height")
.unwrap();
assert_eq!(encoded, i64::MAX);
}
#[test]
fn mark_committed_uses_atomic_insert_without_preflight_contains() {
let source = include_str!("pg_store.rs");
let method = source
.split("async fn mark_committed")
.nth(1)
.expect("mark_committed method must exist");
let method = method
.split("\n }\n}")
.next()
.expect("mark_committed method body must be delimited by impl end");
assert!(
!method.contains("self.contains(hash).await"),
"Postgres mark_committed must not perform a separate contains preflight before INSERT"
);
assert!(
method.contains("INSERT INTO dag_committed")
&& method.contains("SELECT $1, $2")
&& method.contains("WHERE EXISTS"),
"Postgres mark_committed must insert only when the referenced DAG node exists"
);
assert!(
method.contains("rows_affected"),
"Postgres mark_committed must map an unmatched atomic INSERT to NodeNotFound"
);
}
#[test]
fn postgres_schema_is_tracked_by_sqlx_migrations_not_inline_ddl() {
let source = include_str!("pg_store.rs");
let migrate_method = source
.split("pub async fn migrate")
.nth(1)
.expect("PostgresStore::migrate method must exist")
.split(" /// Encode parents")
.next()
.expect("migrate method must end before encode_parents");
assert!(
migrate_method.contains("sqlx::migrate!(\"./migrations\")"),
"PostgresStore schema changes must be tracked by exo-dag sqlx migrations"
);
assert!(
!migrate_method.contains("CREATE TABLE IF NOT EXISTS dag_nodes"),
"PostgresStore::migrate must not carry inline DAG table DDL"
);
assert!(
!migrate_method.contains("CREATE TABLE IF NOT EXISTS dag_committed"),
"PostgresStore::migrate must not carry inline committed-table DDL"
);
}
#[tokio::test]
async fn test_pg_contains() {
pg_test!(pool);
let mut store = PostgresStore::new(pool).await.unwrap();
let node = make_test_node();
assert!(!store.contains(&node.hash).await.unwrap());
store.put(node.clone()).await.unwrap();
assert!(store.contains(&node.hash).await.unwrap());
}
#[tokio::test]
async fn test_pg_tips_single() {
pg_test!(pool);
let mut store = PostgresStore::new(pool).await.unwrap();
let node = make_test_node();
store.put(node.clone()).await.unwrap();
let t = store.tips().await.unwrap();
assert_eq!(t, vec![node.hash]);
}
#[tokio::test]
async fn test_pg_tips_with_children() {
pg_test!(pool);
let mut store = PostgresStore::new(pool).await.unwrap();
let mut dag = Dag::new();
let mut clock = DeterministicDagClock::new();
let creator = Did::new("did:exo:test").expect("valid");
let sign_fn = make_sign_fn();
let genesis = append(&mut dag, &[], b"genesis", &creator, &*sign_fn, &mut clock).unwrap();
let child = append(
&mut dag,
&[genesis.hash],
b"child",
&creator,
&*sign_fn,
&mut clock,
)
.unwrap();
store.put(genesis).await.unwrap();
store.put(child.clone()).await.unwrap();
let t = store.tips().await.unwrap();
assert_eq!(t, vec![child.hash]);
}
#[tokio::test]
async fn test_pg_tips_multiple() {
pg_test!(pool);
let mut store = PostgresStore::new(pool).await.unwrap();
let mut dag = Dag::new();
let mut clock = DeterministicDagClock::new();
let creator = Did::new("did:exo:test").expect("valid");
let sign_fn = make_sign_fn();
let genesis = append(&mut dag, &[], b"genesis", &creator, &*sign_fn, &mut clock).unwrap();
let c1 = append(
&mut dag,
&[genesis.hash],
b"c1",
&creator,
&*sign_fn,
&mut clock,
)
.unwrap();
let c2 = append(
&mut dag,
&[genesis.hash],
b"c2",
&creator,
&*sign_fn,
&mut clock,
)
.unwrap();
store.put(genesis).await.unwrap();
store.put(c1.clone()).await.unwrap();
store.put(c2.clone()).await.unwrap();
let t = store.tips().await.unwrap();
assert_eq!(t.len(), 2);
assert!(t.contains(&c1.hash));
assert!(t.contains(&c2.hash));
}
#[tokio::test]
async fn test_pg_committed_height() {
pg_test!(pool);
let mut store = PostgresStore::new(pool).await.unwrap();
let node = make_test_node();
store.put(node.clone()).await.unwrap();
assert_eq!(store.committed_height().await.unwrap(), 0);
store.mark_committed(&node.hash, 1).await.unwrap();
assert_eq!(store.committed_height().await.unwrap(), 1);
store.mark_committed(&node.hash, 5).await.unwrap();
assert_eq!(store.committed_height().await.unwrap(), 5);
}
#[tokio::test]
async fn test_pg_committed_nonexistent_fails() {
pg_test!(pool);
let mut store = PostgresStore::new(pool).await.unwrap();
let err = store.mark_committed(&Hash256::ZERO, 1).await.unwrap_err();
assert!(matches!(err, DagError::NodeNotFound(_)));
}
#[tokio::test]
async fn test_pg_roundtrip_deterministic() {
pg_test!(pool);
let mut store = PostgresStore::new(pool).await.unwrap();
let node = make_test_node();
store.put(node.clone()).await.unwrap();
let retrieved = store.get(&node.hash).await.unwrap().unwrap();
assert_eq!(retrieved.hash, node.hash);
assert_eq!(retrieved.parents, node.parents);
assert_eq!(retrieved.payload_hash, node.payload_hash);
assert_eq!(retrieved.creator_did, node.creator_did);
assert_eq!(retrieved.timestamp.physical_ms, node.timestamp.physical_ms);
assert_eq!(retrieved.timestamp.logical, node.timestamp.logical);
}
#[tokio::test]
async fn test_pg_parents_ordering() {
pg_test!(pool);
let mut store = PostgresStore::new(pool).await.unwrap();
let mut dag = Dag::new();
let mut clock = DeterministicDagClock::new();
let creator = Did::new("did:exo:test").expect("valid");
let sign_fn = make_sign_fn();
let g = append(&mut dag, &[], b"g", &creator, &*sign_fn, &mut clock).unwrap();
let a = append(&mut dag, &[g.hash], b"a", &creator, &*sign_fn, &mut clock).unwrap();
let b = append(&mut dag, &[g.hash], b"b", &creator, &*sign_fn, &mut clock).unwrap();
let merge = append(
&mut dag,
&[a.hash, b.hash],
b"merge",
&creator,
&*sign_fn,
&mut clock,
)
.unwrap();
store.put(g).await.unwrap();
store.put(a.clone()).await.unwrap();
store.put(b.clone()).await.unwrap();
store.put(merge.clone()).await.unwrap();
let retrieved = store.get(&merge.hash).await.unwrap().unwrap();
assert_eq!(retrieved.parents, merge.parents);
let mut sorted = retrieved.parents.clone();
sorted.sort();
assert_eq!(retrieved.parents, sorted);
}
#[tokio::test]
async fn test_pg_large_payload_hash() {
pg_test!(pool);
let mut store = PostgresStore::new(pool).await.unwrap();
let creator = Did::new("did:exo:test").expect("valid");
let sign_fn = make_sign_fn();
let payload_hash = Hash256::from_bytes([0xFF; 32]);
let timestamp = Timestamp::new(1000, 1);
let hash = crate::dag::compute_node_hash(&[], &payload_hash, &creator, ×tamp).unwrap();
let signature = (*sign_fn)(hash.as_bytes());
let node = DagNode {
hash,
parents: vec![],
payload_hash,
creator_did: creator.clone(),
timestamp,
signature,
};
store.put(node.clone()).await.unwrap();
let retrieved = store.get(&hash).await.unwrap().unwrap();
assert_eq!(retrieved.payload_hash, payload_hash);
assert!(store.get(&Hash256::ZERO).await.unwrap().is_none());
}
#[tokio::test]
async fn test_memory_and_pg_parity() {
pg_test!(pool);
let mut pg_store = PostgresStore::new(pool).await.unwrap();
let mut mem_store = MemoryStore::new();
let mut dag = Dag::new();
let mut clock = DeterministicDagClock::new();
let creator = Did::new("did:exo:test").expect("valid");
let sign_fn = make_sign_fn();
let genesis = append(&mut dag, &[], b"genesis", &creator, &*sign_fn, &mut clock).unwrap();
let c1 = append(
&mut dag,
&[genesis.hash],
b"c1",
&creator,
&*sign_fn,
&mut clock,
)
.unwrap();
let c2 = append(
&mut dag,
&[genesis.hash],
b"c2",
&creator,
&*sign_fn,
&mut clock,
)
.unwrap();
for node in [genesis.clone(), c1.clone(), c2.clone()] {
pg_store.put(node.clone()).await.unwrap();
mem_store.put(node).await.unwrap();
}
let pg_tips = pg_store.tips().await.unwrap();
let mem_tips = mem_store.tips().await.unwrap();
assert_eq!(pg_tips, mem_tips, "tips mismatch between PG and memory");
pg_store.mark_committed(&genesis.hash, 1).await.unwrap();
mem_store.mark_committed(&genesis.hash, 1).await.unwrap();
assert_eq!(
pg_store.committed_height().await.unwrap(),
mem_store.committed_height().await.unwrap(),
"committed height mismatch"
);
for hash in [genesis.hash, c1.hash, c2.hash, Hash256::ZERO] {
assert_eq!(
pg_store.contains(&hash).await.unwrap(),
mem_store.contains(&hash).await.unwrap(),
"contains mismatch for {hash}"
);
}
}
}