aedb 0.2.3

Embedded Rust storage engine with transactional commits, WAL durability, and snapshot-consistent reads
Documentation
use crate::error::AedbError;
use crate::wal::segment::{SEGMENT_HEADER_SIZE, SegmentHeader};
use std::fs;
use std::io::Read;
use std::path::{Path, PathBuf};

pub fn scan_segments(data_dir: &Path) -> Result<Vec<PathBuf>, AedbError> {
    let mut segments: Vec<(u64, PathBuf)> = fs::read_dir(data_dir)?
        .filter_map(|entry| entry.ok())
        .filter_map(|entry| {
            let path = entry.path();
            let name = entry.file_name().to_string_lossy().to_string();
            parse_seq(&name).map(|seq| (seq, path))
        })
        .collect();
    segments.sort_by_key(|(seq, _)| *seq);
    Ok(segments.into_iter().map(|(_, path)| path).collect())
}

pub fn verify_hash_chain(paths: &[PathBuf]) -> Result<(), AedbError> {
    let mut prev_hash = [0u8; 32];
    for path in paths {
        let mut file = fs::File::open(path)?;
        let mut header = [0u8; SEGMENT_HEADER_SIZE];
        file.read_exact(&mut header)?;
        let parsed = SegmentHeader::from_bytes(&header)
            .map_err(|e| AedbError::Validation(format!("bad segment header: {e}")))?;
        if parsed.prev_segment_hash != prev_hash {
            return Err(AedbError::Validation("segment hash chain mismatch".into()));
        }

        let mut hasher = blake3::Hasher::new();
        hasher.update(&header);
        let mut buffer = [0u8; 64 * 1024];
        loop {
            let n = file.read(&mut buffer)?;
            if n == 0 {
                break;
            }
            hasher.update(&buffer[..n]);
        }
        prev_hash = *blake3::Hasher::finalize(&hasher).as_bytes();
    }
    Ok(())
}

pub fn verify_hash_chain_if_required(paths: &[PathBuf], required: bool) -> Result<(), AedbError> {
    if !required {
        return Ok(());
    }
    if paths.iter().any(|path| {
        fs::metadata(path).map_or(true, |metadata| metadata.len() < SEGMENT_HEADER_SIZE as u64)
    }) {
        return Err(AedbError::Decode("segment too small".into()));
    }
    verify_hash_chain(paths)
}

/// Return the number of prefix segments that pass hash-chain validation.
///
/// - If `required=false`, all segments are considered valid.
/// - If `required=true` and `strict=true`, any validation problem returns an error.
/// - If `required=true` and `strict=false`, validation stops at the first bad segment and
///   returns the validated prefix length.
pub fn validated_hash_chain_prefix_len(
    paths: &[PathBuf],
    required: bool,
    strict: bool,
) -> Result<usize, AedbError> {
    validated_hash_chain_prefix_len_from_checkpoint(paths, required, strict, false)
}

pub fn validated_hash_chain_prefix_len_from_checkpoint(
    paths: &[PathBuf],
    required: bool,
    strict: bool,
    allow_checkpoint_tail_anchor: bool,
) -> Result<usize, AedbError> {
    if !required {
        return Ok(paths.len());
    }

    let mut prev_hash = [0u8; 32];
    let mut valid_segment_count = 0usize;

    for path in paths {
        let segment_metadata = match fs::metadata(path) {
            Ok(metadata) => metadata,
            Err(e) => {
                if strict {
                    return Err(AedbError::Io(e));
                }
                break;
            }
        };
        let segment_size_bytes = segment_metadata.len();
        if segment_size_bytes < SEGMENT_HEADER_SIZE as u64 {
            if strict {
                return Err(AedbError::Decode("segment too small".into()));
            }
            break;
        }

        let mut file = match fs::File::open(path) {
            Ok(f) => f,
            Err(e) => {
                if strict {
                    return Err(AedbError::Io(e));
                }
                break;
            }
        };
        let mut header = [0u8; SEGMENT_HEADER_SIZE];
        if let Err(e) = file.read_exact(&mut header) {
            if strict {
                return Err(AedbError::Io(e));
            }
            break;
        }
        let parsed = match SegmentHeader::from_bytes(&header) {
            Ok(h) => h,
            Err(e) => {
                if strict {
                    return Err(AedbError::Validation(format!("bad segment header: {e}")));
                }
                break;
            }
        };
        if parsed.prev_segment_hash != prev_hash
            && !(allow_checkpoint_tail_anchor && valid_segment_count == 0 && parsed.segment_seq > 1)
        {
            if strict {
                return Err(AedbError::Validation("segment hash chain mismatch".into()));
            }
            break;
        }

        let mut hasher = blake3::Hasher::new();
        hasher.update(&header);
        let mut buffer = [0u8; 64 * 1024];
        loop {
            match file.read(&mut buffer) {
                Ok(0) => break,
                Ok(n) => {
                    hasher.update(&buffer[..n]);
                }
                Err(e) => {
                    if strict {
                        return Err(AedbError::Io(e));
                    }
                    return Ok(valid_segment_count);
                }
            }
        }
        prev_hash = *blake3::Hasher::finalize(&hasher).as_bytes();
        valid_segment_count += 1;
    }

    debug_assert!(valid_segment_count <= paths.len());
    Ok(valid_segment_count)
}

fn parse_seq(name: &str) -> Option<u64> {
    if !name.starts_with("segment_") || !name.ends_with(".aedbwal") {
        return None;
    }
    let middle = name
        .trim_start_matches("segment_")
        .trim_end_matches(".aedbwal");
    middle.parse::<u64>().ok()
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::wal::segment::SegmentHeader;
    use std::io::Write;

    fn write_segment(path: &Path, header: SegmentHeader, payload: &[u8]) {
        let mut file = std::fs::File::create(path).expect("create segment");
        file.write_all(&header.to_bytes()).expect("write header");
        file.write_all(payload).expect("write payload");
        file.sync_all().expect("sync");
    }

    fn segment_hash(path: &Path) -> [u8; 32] {
        let bytes = std::fs::read(path).expect("read segment");
        *blake3::hash(&bytes).as_bytes()
    }

    #[test]
    fn permissive_mode_trims_invalid_chain_tail() {
        let dir = tempfile::tempdir().expect("tempdir");
        let seg1 = dir.path().join("segment_0000000000000001.aedbwal");
        let seg2 = dir.path().join("segment_0000000000000002.aedbwal");

        let h1 = SegmentHeader::new(1, 1, [0u8; 32]);
        write_segment(&seg1, h1, b"frame-data-1");
        let hash1 = segment_hash(&seg1);

        // Valid successor header for seq=2 would use hash1; intentionally break it.
        let h2_bad = SegmentHeader::new(1, 2, [9u8; 32]);
        write_segment(&seg2, h2_bad, b"frame-data-2");

        let paths = vec![seg1.clone(), seg2.clone()];
        assert_eq!(
            validated_hash_chain_prefix_len(&paths, true, false).expect("permissive"),
            1
        );

        let strict_err = validated_hash_chain_prefix_len(&paths, true, true).expect_err("strict");
        assert!(
            strict_err
                .to_string()
                .contains("segment hash chain mismatch")
        );

        let h2_good = SegmentHeader::new(1, 2, hash1);
        write_segment(&seg2, h2_good, b"frame-data-2");
        assert_eq!(
            validated_hash_chain_prefix_len(&paths, true, true).expect("strict valid"),
            2
        );
    }

    #[test]
    fn checkpoint_tail_anchor_allows_first_retained_segment_to_reference_reclaimed_segment() {
        let dir = tempfile::tempdir().expect("tempdir");
        let seg2 = dir.path().join("segment_0000000000000002.aedbwal");
        let seg3 = dir.path().join("segment_0000000000000003.aedbwal");

        let reclaimed_hash = [7u8; 32];
        let h2 = SegmentHeader::new(1, 2, reclaimed_hash);
        write_segment(&seg2, h2, b"frame-data-2");
        let hash2 = segment_hash(&seg2);

        let h3 = SegmentHeader::new(1, 3, hash2);
        write_segment(&seg3, h3, b"frame-data-3");

        let paths = vec![seg2.clone(), seg3.clone()];
        let err = validated_hash_chain_prefix_len(&paths, true, true).expect_err("strict");
        assert!(err.to_string().contains("segment hash chain mismatch"));
        assert_eq!(
            validated_hash_chain_prefix_len_from_checkpoint(&paths, true, true, true)
                .expect("checkpoint anchored tail"),
            2
        );

        let h3_bad = SegmentHeader::new(1, 3, [9u8; 32]);
        write_segment(&seg3, h3_bad, b"frame-data-3");
        let err = validated_hash_chain_prefix_len_from_checkpoint(&paths, true, true, true)
            .expect_err("broken retained tail must still fail");
        assert!(err.to_string().contains("segment hash chain mismatch"));
    }

    #[test]
    fn verify_hash_chain_accepts_valid_segment_sequence() {
        let dir = tempfile::tempdir().expect("tempdir");
        let seg1 = dir.path().join("segment_0000000000000001.aedbwal");
        let seg2 = dir.path().join("segment_0000000000000002.aedbwal");

        let h1 = SegmentHeader::new(7, 1, [0u8; 32]);
        write_segment(&seg1, h1, b"frame-a");
        let hash1 = segment_hash(&seg1);

        let h2 = SegmentHeader::new(7, 2, hash1);
        write_segment(&seg2, h2, b"frame-b");

        verify_hash_chain(&[seg1, seg2]).expect("valid hash chain");
    }
}