Skip to main content

citadel_sql/executor/
ann_persist.rs

1//! Persisted ANN segments: the storage envelope around
2//! [`citadel_vector::segment`]. Each indexed table owns one hidden storage
3//! tree `__annseg_{table}` (never registered in the schema manager, invisible
4//! to SQL) holding a header row plus body chunks, encrypted like every tree.
5//!
6//! Three independent layers refuse a stale segment (any failure falls through
7//! to a rebuild):
8//! - transactional: every DML/DDL site that marks a table dirty drops its
9//!   segment in the same write txn (shadow paging keeps "table changed but
10//!   segment survived" unrepresentable for those paths);
11//! - content fingerprint: BLAKE3 over the scan-order row content
12//!   (domain-separated, length-framed) at persist time, recomputed by the
13//!   load-time rehydration scan;
14//! - header checks: format/config/shape pins compared before the scan.
15
16use citadel_vector::segment;
17use citadel_vector::PrismConfig;
18use rustc_hash::FxHashMap;
19
20use crate::error::{Result, SqlError};
21
22/// Bump on ANY layout change of the header or the segment body.
23pub const ANNSEG_FORMAT_VERSION: u16 = 2;
24
25const MAGIC: &[u8; 7] = b"ANNSEG\0";
26
27/// Body chunk size. Chunking bounds the peak memory of a single value
28/// read/write; storage chains overflow pages above ~2 KB anyway, so smaller
29/// chunks cost only a few hundred point-gets per attach while keeping buffers
30/// modest.
31pub const CHUNK_BYTES: usize = 1024 * 1024;
32
33/// The hidden storage tree for a table's segment.
34pub fn segment_table_name(table: &str) -> Vec<u8> {
35    format!("__annseg_{table}").into_bytes()
36}
37
38/// Key 0 is the header; chunks are 1..=chunk_count (big-endian for scan order).
39pub fn segment_key(chunk_no: u32) -> [u8; 4] {
40    chunk_no.to_be_bytes()
41}
42
43/// Everything the loader must verify BEFORE paying for the rehydration scan,
44/// plus the two content hashes it verifies during/after it.
45#[derive(Debug, Clone, PartialEq, Eq)]
46pub struct SegmentHeader {
47    pub format_version: u16,
48    /// BLAKE3 of the canonical encoding of the builder's `PrismConfig`.
49    pub prism_config_hash: [u8; 32],
50    pub dim: u16,
51    pub metric_tag: u8,
52    /// Indexed (non-null) row count - compared as `n <= live rows` pre-scan
53    /// (NULL vectors are unindexed), exactly via the fingerprint scan.
54    pub n: u64,
55    pub snapshot_max: u64,
56    /// The table's catalog root at persist - a differing live root means stale (CoW gate).
57    pub table_root: u64,
58    /// The indexed column and the filter columns, IN ATTRIBUTE ORDER - an
59    /// index re-created over different columns must be refused explicitly,
60    /// never discovered via fingerprint luck.
61    pub col_idx: u32,
62    pub filter_cols: Vec<u32>,
63    /// Per attribute dim: encoded filter value -> PRISM code, in scan order.
64    pub dicts: Vec<Vec<(Vec<u8>, u32)>>,
65    pub content_fingerprint: [u8; 32],
66    /// BLAKE3 of the concatenated body chunks (the segment.rs payload).
67    pub segment_b3: [u8; 32],
68    pub chunk_count: u32,
69    /// Forensics only - never compared.
70    pub writer: String,
71}
72
73impl SegmentHeader {
74    pub fn encode(&self) -> Vec<u8> {
75        let mut b = Vec::new();
76        b.extend_from_slice(MAGIC);
77        b.extend_from_slice(&self.format_version.to_le_bytes());
78        b.extend_from_slice(&self.prism_config_hash);
79        b.extend_from_slice(&self.dim.to_le_bytes());
80        b.push(self.metric_tag);
81        b.extend_from_slice(&self.n.to_le_bytes());
82        b.extend_from_slice(&self.snapshot_max.to_le_bytes());
83        b.extend_from_slice(&self.table_root.to_le_bytes());
84        b.extend_from_slice(&self.col_idx.to_le_bytes());
85        b.extend_from_slice(&(self.filter_cols.len() as u32).to_le_bytes());
86        for &c in &self.filter_cols {
87            b.extend_from_slice(&c.to_le_bytes());
88        }
89        b.extend_from_slice(&(self.dicts.len() as u32).to_le_bytes());
90        for dict in &self.dicts {
91            b.extend_from_slice(&(dict.len() as u64).to_le_bytes());
92            for (k, v) in dict {
93                b.extend_from_slice(&(k.len() as u64).to_le_bytes());
94                b.extend_from_slice(k);
95                b.extend_from_slice(&v.to_le_bytes());
96            }
97        }
98        b.extend_from_slice(&self.content_fingerprint);
99        b.extend_from_slice(&self.segment_b3);
100        b.extend_from_slice(&self.chunk_count.to_le_bytes());
101        b.extend_from_slice(&(self.writer.len() as u32).to_le_bytes());
102        b.extend_from_slice(self.writer.as_bytes());
103        // Self-hash binds header fields beyond page-level HMAC (cheap
104        // hardening: a header is never accepted with internal bit-rot).
105        let self_hash = blake3::hash(&b);
106        b.extend_from_slice(self_hash.as_bytes());
107        b
108    }
109
110    pub fn decode(bytes: &[u8]) -> Result<Self> {
111        let fail = |what: &str| SqlError::InvalidValue(format!("ANN segment header: {what}"));
112        if bytes.len() < 32 {
113            return Err(fail("truncated"));
114        }
115        let (body, hash) = bytes.split_at(bytes.len() - 32);
116        if blake3::hash(body).as_bytes() != hash {
117            return Err(fail("self-hash mismatch (corrupt)"));
118        }
119        let mut at = 0usize;
120        let mut take = |n: usize| -> Result<&[u8]> {
121            let end = at.checked_add(n).filter(|&e| e <= body.len());
122            let end = end.ok_or_else(|| fail("truncated"))?;
123            let s = &body[at..end];
124            at = end;
125            Ok(s)
126        };
127        if take(7)? != MAGIC {
128            return Err(fail("bad magic"));
129        }
130        let format_version = u16::from_le_bytes(take(2)?.try_into().unwrap());
131        let prism_config_hash: [u8; 32] = take(32)?.try_into().unwrap();
132        let dim = u16::from_le_bytes(take(2)?.try_into().unwrap());
133        let metric_tag = take(1)?[0];
134        let n = u64::from_le_bytes(take(8)?.try_into().unwrap());
135        let snapshot_max = u64::from_le_bytes(take(8)?.try_into().unwrap());
136        let table_root = u64::from_le_bytes(take(8)?.try_into().unwrap());
137        let col_idx = u32::from_le_bytes(take(4)?.try_into().unwrap());
138        let fc_len = u32::from_le_bytes(take(4)?.try_into().unwrap()) as usize;
139        let mut filter_cols = Vec::with_capacity(fc_len);
140        for _ in 0..fc_len {
141            filter_cols.push(u32::from_le_bytes(take(4)?.try_into().unwrap()));
142        }
143        let dicts_len = u32::from_le_bytes(take(4)?.try_into().unwrap()) as usize;
144        let mut dicts = Vec::with_capacity(dicts_len);
145        for _ in 0..dicts_len {
146            let entries = u64::from_le_bytes(take(8)?.try_into().unwrap()) as usize;
147            let mut dict = Vec::with_capacity(entries);
148            for _ in 0..entries {
149                let klen = u64::from_le_bytes(take(8)?.try_into().unwrap()) as usize;
150                let k = take(klen)?.to_vec();
151                let v = u32::from_le_bytes(take(4)?.try_into().unwrap());
152                dict.push((k, v));
153            }
154            dicts.push(dict);
155        }
156        let content_fingerprint: [u8; 32] = take(32)?.try_into().unwrap();
157        let segment_b3: [u8; 32] = take(32)?.try_into().unwrap();
158        let chunk_count = u32::from_le_bytes(take(4)?.try_into().unwrap());
159        let wlen = u32::from_le_bytes(take(4)?.try_into().unwrap()) as usize;
160        let writer = String::from_utf8_lossy(take(wlen)?).into_owned();
161        if at != body.len() {
162            return Err(fail("trailing bytes"));
163        }
164        Ok(Self {
165            format_version,
166            prism_config_hash,
167            dim,
168            metric_tag,
169            n,
170            snapshot_max,
171            table_root,
172            col_idx,
173            filter_cols,
174            dicts,
175            content_fingerprint,
176            segment_b3,
177            chunk_count,
178            writer,
179        })
180    }
181
182    /// The dicts as the runtime maps the filter pushdown uses.
183    pub fn dict_maps(&self) -> Vec<FxHashMap<Vec<u8>, u32>> {
184        self.dicts
185            .iter()
186            .map(|d| d.iter().cloned().collect())
187            .collect()
188    }
189}
190
191/// The INJECTIVE content fingerprint: domain-separated, every component
192/// length-framed (unframed concatenation admits boundary-shift collisions),
193/// bound to the table/column/filter identity, fed rows IN SCAN ORDER. Persist
194/// and load MUST construct it identically - both go through this one type.
195pub struct FingerprintHasher {
196    h: blake3::Hasher,
197}
198
199impl FingerprintHasher {
200    pub fn new(table: &str, col_idx: u32, filter_cols: &[u32], dim: u16, metric_tag: u8) -> Self {
201        let mut h = blake3::Hasher::new();
202        h.update(b"citadel-annseg-fp-v1");
203        h.update(&(table.len() as u64).to_le_bytes());
204        h.update(table.as_bytes());
205        h.update(&col_idx.to_le_bytes());
206        h.update(&(filter_cols.len() as u32).to_le_bytes());
207        for &c in filter_cols {
208            h.update(&c.to_le_bytes());
209        }
210        h.update(&dim.to_le_bytes());
211        h.update(&[metric_tag]);
212        Self { h }
213    }
214
215    /// One scanned row: its key, the RAW encoded vector-column bytes (null =
216    /// empty, still framed - unindexed rows are part of the content), and each
217    /// filter column's encoded bytes.
218    pub fn row(&mut self, key: &[u8], vector_raw: &[u8], filter_encoded: &[&[u8]]) {
219        self.h.update(&(key.len() as u64).to_le_bytes());
220        self.h.update(key);
221        self.h.update(&(vector_raw.len() as u64).to_le_bytes());
222        self.h.update(vector_raw);
223        for f in filter_encoded {
224            self.h.update(&(f.len() as u64).to_le_bytes());
225            self.h.update(f);
226        }
227    }
228
229    pub fn finish(self) -> [u8; 32] {
230        *self.h.finalize().as_bytes()
231    }
232}
233
234/// The active config's hash for `metric` - what persist writes and the loader
235/// requires (a binary with a different geometry must rebuild, not load).
236pub fn active_config_hash(metric: citadel_vector::Metric) -> [u8; 32] {
237    let cfg: PrismConfig = citadel_vector::AnnIndex::active_config(metric);
238    segment::prism_config_hash(&cfg)
239}
240
241/// What `persist_ann_index` returns for the caller's manifest: the hashes a
242/// later attach verifies against, and the shape for the record.
243#[derive(Debug, Clone, PartialEq, Eq)]
244pub struct AnnSegmentInfo {
245    pub segment_b3: [u8; 32],
246    pub content_fingerprint: [u8; 32],
247    pub n: u64,
248    pub dim: u16,
249    pub metric_tag: u8,
250    pub chunk_count: u32,
251}
252
253/// Drop a table's persisted segment INSIDE the caller's write txn (the
254/// transactional staleness layer). Absent segment = nothing to do; savepoint
255/// rollback restores a dropped one automatically.
256pub(crate) fn purge_segment(
257    wtx: &mut citadel_txn::write_txn::WriteTxn<'_>,
258    table_name: &str,
259) -> Result<()> {
260    match wtx.drop_table(&segment_table_name(table_name)) {
261        Ok(()) => Ok(()),
262        Err(citadel_core::Error::TableNotFound(_)) => Ok(()),
263        Err(e) => Err(SqlError::Storage(e)),
264    }
265}
266
267/// Split a segment body into storage chunks (chunk 0 is the header's key).
268pub fn chunks(body: &[u8]) -> impl Iterator<Item = (u32, &[u8])> {
269    body.chunks(CHUNK_BYTES)
270        .enumerate()
271        .map(|(i, c)| ((i + 1) as u32, c))
272}
273
274#[cfg(test)]
275mod tests {
276    use super::*;
277
278    fn header_fixture() -> SegmentHeader {
279        SegmentHeader {
280            format_version: ANNSEG_FORMAT_VERSION,
281            prism_config_hash: [7; 32],
282            dim: 768,
283            metric_tag: 2,
284            n: 311_592,
285            snapshot_max: 99,
286            table_root: 1234,
287            col_idx: 3,
288            filter_cols: vec![1, 2],
289            dicts: vec![
290                vec![(b"region".to_vec(), 0), (b"other".to_vec(), 1)],
291                vec![(b"kind".to_vec(), 0)],
292            ],
293            content_fingerprint: [9; 32],
294            segment_b3: [4; 32],
295            chunk_count: 41,
296            writer: "citadel-test".into(),
297        }
298    }
299
300    #[test]
301    fn header_roundtrips_exactly() {
302        let h = header_fixture();
303        assert_eq!(SegmentHeader::decode(&h.encode()).unwrap(), h);
304    }
305
306    #[test]
307    fn header_corruption_is_refused() {
308        let bytes = header_fixture().encode();
309        for spot in [0, 9, 45, bytes.len() / 2, bytes.len() - 40] {
310            let mut corrupt = bytes.clone();
311            corrupt[spot] ^= 0xFF;
312            assert!(
313                SegmentHeader::decode(&corrupt).is_err(),
314                "corruption at {spot} must refuse"
315            );
316        }
317    }
318
319    #[test]
320    fn fingerprint_is_framed_against_boundary_shifts() {
321        // Same concatenated bytes, different row framing -> different hashes.
322        let mut a = FingerprintHasher::new("t", 0, &[], 4, 2);
323        a.row(b"ab", b"cd", &[]);
324        let mut b = FingerprintHasher::new("t", 0, &[], 4, 2);
325        b.row(b"abc", b"d", &[]);
326        assert_ne!(a.finish(), b.finish());
327
328        // Identity changes perturb it too.
329        let mut c = FingerprintHasher::new("t", 1, &[], 4, 2);
330        c.row(b"ab", b"cd", &[]);
331        let mut d = FingerprintHasher::new("t", 0, &[2], 4, 2);
332        d.row(b"ab", b"cd", &[]);
333        let mut base = FingerprintHasher::new("t", 0, &[], 4, 2);
334        base.row(b"ab", b"cd", &[]);
335        let base = base.finish();
336        assert_ne!(c.finish(), base);
337        assert_ne!(d.finish(), base);
338    }
339
340    #[test]
341    fn chunking_covers_the_body_in_order() {
342        let body = vec![0xABu8; CHUNK_BYTES + 17];
343        let parts: Vec<(u32, usize)> = chunks(&body).map(|(no, c)| (no, c.len())).collect();
344        assert_eq!(parts, vec![(1, CHUNK_BYTES), (2, 17)]);
345    }
346}