engram-core 0.21.1

AI Memory Infrastructure - Persistent memory for AI agents with semantic search
Documentation
//! Merkle tree for attestation proofs
//!
//! Builds a binary Merkle tree over a set of attestation records.
//! Leaves are the `record_hash` values; parent nodes are the SHA-256 hash
//! of the concatenation of their left and right children.
//! When the number of leaves is odd the last leaf is duplicated.

use sha2::{Digest, Sha256};
use subtle::ConstantTimeEq;

use super::types::{AttestationRecord, MerkleProof};

/// Binary Merkle tree built from attestation record hashes
pub struct MerkleTree {
    /// Level 0 = raw leaves (record_hash values), last level = root
    levels: Vec<Vec<String>>,
}

impl MerkleTree {
    /// Build a Merkle tree from a slice of attestation records.
    ///
    /// Returns an empty tree (no levels) when `records` is empty.
    pub fn build(records: &[AttestationRecord]) -> Self {
        if records.is_empty() {
            return Self { levels: Vec::new() };
        }

        // Collect leaf hashes
        let leaves: Vec<String> = records.iter().map(|r| r.record_hash.clone()).collect();

        let mut levels: Vec<Vec<String>> = Vec::new();
        levels.push(leaves);

        // Build levels bottom-up until we reach the root
        while levels.last().map(|l| l.len()).unwrap_or(0) > 1 {
            let current = levels.last().unwrap();
            let mut next: Vec<String> = Vec::new();

            let mut i = 0;
            while i < current.len() {
                let left = &current[i];
                // Duplicate last leaf if count is odd
                let right = if i + 1 < current.len() {
                    &current[i + 1]
                } else {
                    &current[i]
                };
                next.push(hash_pair(left, right));
                i += 2;
            }
            levels.push(next);
        }

        Self { levels }
    }

    /// Return the root hash of the tree, or `None` if the tree is empty.
    pub fn root(&self) -> Option<&str> {
        self.levels.last()?.first().map(String::as_str)
    }

    /// Generate a Merkle proof for the leaf at `leaf_index`.
    ///
    /// Returns `None` if the tree is empty or the index is out of bounds.
    pub fn generate_proof(&self, leaf_index: usize) -> Option<MerkleProof> {
        if self.levels.is_empty() {
            return None;
        }

        let leaves = &self.levels[0];
        if leaf_index >= leaves.len() {
            return None;
        }

        let leaf_hash = leaves[leaf_index].clone();
        let total_leaves = leaves.len();
        let root_hash = self.root()?.to_string();

        let mut proof_hashes: Vec<(String, bool)> = Vec::new();
        let mut current_index = leaf_index;

        for level in &self.levels[..self.levels.len().saturating_sub(1)] {
            // Determine sibling index and position
            let (sibling_index, is_right_sibling) = if current_index.is_multiple_of(2) {
                // We are the left child; sibling is to the right
                let sibling = if current_index + 1 < level.len() {
                    current_index + 1
                } else {
                    current_index // duplicated
                };
                (sibling, true)
            } else {
                // We are the right child; sibling is to the left
                (current_index - 1, false)
            };

            proof_hashes.push((level[sibling_index].clone(), is_right_sibling));
            current_index /= 2;
        }

        Some(MerkleProof {
            leaf_hash,
            leaf_index,
            proof_hashes,
            root_hash,
            total_leaves,
            scheme_version: 2,
        })
    }

    /// Verify a Merkle proof.
    ///
    /// Recomputes the root from the leaf hash and the proof hashes, then
    /// compares against `proof.root_hash`.
    ///
    /// Dispatches on `proof.scheme_version`:
    /// - `1` — uses the original naive `left || right` concatenation
    /// - `2` — uses length-domain-separated `len(left) || left || len(right) || right`
    pub fn verify_proof(proof: &MerkleProof) -> bool {
        let pair_fn = match proof.scheme_version {
            1 => hash_pair_v1,
            _ => hash_pair_v2,
        };

        let mut current = proof.leaf_hash.clone();

        for (sibling_hash, is_right_sibling) in &proof.proof_hashes {
            current = if *is_right_sibling {
                // sibling is to the right → we are the left child
                pair_fn(&current, sibling_hash)
            } else {
                // sibling is to the left → we are the right child
                pair_fn(sibling_hash, &current)
            };
        }

        current.as_bytes().ct_eq(proof.root_hash.as_bytes()).into()
    }
}

/// Current hash pair scheme (scheme v2).
/// Uses length-domain-separation: `len(left) || left || len(right) || right`
/// to prevent second-preimage collisions.
fn hash_pair(left: &str, right: &str) -> String {
    hash_pair_v2(left, right)
}

/// Original hash pair scheme (scheme v1).
/// Naive `left || right` concatenation. Retained for backwards-compatible
/// verification of Merkle proofs generated before June 2026.
fn hash_pair_v1(left: &str, right: &str) -> String {
    let mut hasher = Sha256::new();
    hasher.update(left.as_bytes());
    hasher.update(right.as_bytes());
    format!("sha256:{}", hex::encode(hasher.finalize()))
}

/// Hash pair scheme v2.
/// Length-domain-separated: `len(left) || left || len(right) || right`
/// (lengths as 8-byte big-endian). Prevents the second-preimage collision
/// where hash_pair("ab", "cd") == hash_pair("a", "bcd").
fn hash_pair_v2(left: &str, right: &str) -> String {
    let mut hasher = Sha256::new();
    hasher.update((left.len() as u64).to_be_bytes());
    hasher.update(left.as_bytes());
    hasher.update((right.len() as u64).to_be_bytes());
    hasher.update(right.as_bytes());
    format!("sha256:{}", hex::encode(hasher.finalize()))
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::attestation::chain::AttestationChain;
    use crate::storage::Storage;
    use chrono::Utc;

    fn make_record(hash: &str, name: &str) -> AttestationRecord {
        AttestationRecord {
            id: None,
            document_hash: format!("sha256:{hash}"),
            document_name: name.to_string(),
            document_size: 0,
            ingested_at: Utc::now(),
            agent_id: None,
            memory_ids: vec![],
            previous_hash: "genesis".to_string(),
            record_hash: format!("sha256:{hash}-record"),
            signature: None,
            metadata: serde_json::Value::Object(serde_json::Map::new()),
            created_at: None,
        }
    }

    #[test]
    fn test_empty_tree() {
        let tree = MerkleTree::build(&[]);
        assert!(tree.root().is_none());
        assert!(tree.generate_proof(0).is_none());
    }

    #[test]
    fn test_single_leaf() {
        let rec = make_record("aabbcc", "a.txt");
        let tree = MerkleTree::build(std::slice::from_ref(&rec));
        assert_eq!(tree.root(), Some(rec.record_hash.as_str()));

        let proof = tree.generate_proof(0).unwrap();
        assert_eq!(proof.leaf_hash, rec.record_hash);
        assert_eq!(proof.total_leaves, 1);
        assert!(MerkleTree::verify_proof(&proof));
    }

    #[test]
    fn test_two_leaves() {
        let r1 = make_record("aa", "a.txt");
        let r2 = make_record("bb", "b.txt");
        let tree = MerkleTree::build(&[r1.clone(), r2.clone()]);

        let expected_root = hash_pair(&r1.record_hash, &r2.record_hash);
        assert_eq!(tree.root(), Some(expected_root.as_str()));

        let proof0 = tree.generate_proof(0).unwrap();
        assert_eq!(proof0.scheme_version, 2);
        assert!(MerkleTree::verify_proof(&proof0));

        let proof1 = tree.generate_proof(1).unwrap();
        assert_eq!(proof1.scheme_version, 2);
        assert!(MerkleTree::verify_proof(&proof1));
    }

    #[test]
    fn test_three_leaves_odd() {
        let records: Vec<_> = ["aa", "bb", "cc"]
            .iter()
            .enumerate()
            .map(|(i, h)| make_record(h, &format!("{i}.txt")))
            .collect();

        let tree = MerkleTree::build(&records);
        assert!(tree.root().is_some());

        for i in 0..3 {
            let proof = tree.generate_proof(i).unwrap();
            assert!(MerkleTree::verify_proof(&proof), "proof {i} failed");
        }
    }

    #[test]
    fn test_proof_from_real_chain() {
        let storage = Storage::open_in_memory().unwrap();
        let chain = AttestationChain::new(storage);

        let r1 = chain
            .log_document(b"doc1", "d1.txt", None, &[], None)
            .unwrap();
        let r2 = chain
            .log_document(b"doc2", "d2.txt", None, &[], None)
            .unwrap();
        let r3 = chain
            .log_document(b"doc3", "d3.txt", None, &[], None)
            .unwrap();

        let tree = MerkleTree::build(&[r1, r2, r3]);
        for i in 0..3 {
            let proof = tree.generate_proof(i).unwrap();
            assert!(
                MerkleTree::verify_proof(&proof),
                "real chain proof {i} failed"
            );
        }
    }

    #[test]
    fn test_tampered_proof_fails() {
        let r1 = make_record("aa", "a.txt");
        let r2 = make_record("bb", "b.txt");
        let tree = MerkleTree::build(&[r1, r2]);

        let mut proof = tree.generate_proof(0).unwrap();
        // Tamper with the root
        proof.root_hash =
            "sha256:0000000000000000000000000000000000000000000000000000000000000000".to_string();
        assert!(!MerkleTree::verify_proof(&proof));
    }

    #[test]
    fn test_v1_proof_backwards_compat() {
        // Old-format proof (scheme_version=1, naive concatenation)
        // must still verify correctly.
        let r1 = make_record("aa", "a.txt");
        let r2 = make_record("bb", "b.txt");

        let v1_root = hash_pair_v1(&r1.record_hash, &r2.record_hash);

        let v1_proof = MerkleProof {
            leaf_hash: r1.record_hash.clone(),
            leaf_index: 0,
            proof_hashes: vec![(r2.record_hash.clone(), true)],
            root_hash: v1_root,
            total_leaves: 2,
            scheme_version: 1,
        };
        assert!(MerkleTree::verify_proof(&v1_proof), "v1 proof must verify");

        // Tampered v1 proof must still fail
        let mut bad_v1 = MerkleProof {
            scheme_version: 1,
            ..v1_proof
        };
        bad_v1.root_hash =
            "sha256:0000000000000000000000000000000000000000000000000000000000000000".to_string();
        assert!(!MerkleTree::verify_proof(&bad_v1), "tampered v1 must fail");
    }
}