Skip to main content

citadel_sql/executor/
ann_persist.rs

1//! Persisted ANN segments: the storage envelope around
2//! [`citadel_vector::segment`]. Each indexed table owns one hidden storage
3//! tree `__annseg_{table}` (never registered in the schema manager, invisible
4//! to SQL) holding a header row plus body chunks, encrypted like every tree.
5//!
6//! Three independent layers refuse a stale segment (any failure falls through
7//! to a rebuild):
8//! - transactional: every DML/DDL site that marks a table dirty drops its
9//!   segment in the same write txn (shadow paging keeps "table changed but
10//!   segment survived" unrepresentable for those paths);
11//! - content fingerprint: BLAKE3 over the scan-order row content
12//!   (domain-separated, length-framed) at persist time, recomputed by the
13//!   load-time rehydration scan;
14//! - header checks: format/config/shape pins compared before the scan.
15
16use citadel_vector::segment;
17use citadel_vector::PrismConfig;
18use rustc_hash::FxHashMap;
19
20use crate::error::{Result, SqlError};
21
22/// Bump on ANY layout change of the header or the segment body.
23pub const ANNSEG_FORMAT_VERSION: u16 = 1;
24
25const MAGIC: &[u8; 7] = b"ANNSEG\0";
26
27/// Body chunk size. Chunking bounds the peak memory of a single value
28/// read/write; storage chains overflow pages above ~2 KB anyway, so smaller
29/// chunks cost only a few hundred point-gets per attach while keeping buffers
30/// modest.
31pub const CHUNK_BYTES: usize = 1024 * 1024;
32
33/// The hidden storage tree for a table's segment.
34pub fn segment_table_name(table: &str) -> Vec<u8> {
35    format!("__annseg_{table}").into_bytes()
36}
37
38/// Key 0 is the header; chunks are 1..=chunk_count (big-endian for scan order).
39pub fn segment_key(chunk_no: u32) -> [u8; 4] {
40    chunk_no.to_be_bytes()
41}
42
43/// Everything the loader must verify BEFORE paying for the rehydration scan,
44/// plus the two content hashes it verifies during/after it.
45#[derive(Debug, Clone, PartialEq, Eq)]
46pub struct SegmentHeader {
47    pub format_version: u16,
48    /// BLAKE3 of the canonical encoding of the builder's `PrismConfig`.
49    pub prism_config_hash: [u8; 32],
50    pub dim: u16,
51    pub metric_tag: u8,
52    /// Indexed (non-null) row count - compared as `n <= live rows` pre-scan
53    /// (NULL vectors are unindexed), exactly via the fingerprint scan.
54    pub n: u64,
55    pub snapshot_max: u64,
56    /// The indexed column and the filter columns, IN ATTRIBUTE ORDER - an
57    /// index re-created over different columns must be refused explicitly,
58    /// never discovered via fingerprint luck.
59    pub col_idx: u32,
60    pub filter_cols: Vec<u32>,
61    /// Per attribute dim: encoded filter value -> PRISM code, in scan order.
62    pub dicts: Vec<Vec<(Vec<u8>, u32)>>,
63    pub content_fingerprint: [u8; 32],
64    /// BLAKE3 of the concatenated body chunks (the segment.rs payload).
65    pub segment_b3: [u8; 32],
66    pub chunk_count: u32,
67    /// Forensics only - never compared.
68    pub writer: String,
69}
70
71impl SegmentHeader {
72    pub fn encode(&self) -> Vec<u8> {
73        let mut b = Vec::new();
74        b.extend_from_slice(MAGIC);
75        b.extend_from_slice(&self.format_version.to_le_bytes());
76        b.extend_from_slice(&self.prism_config_hash);
77        b.extend_from_slice(&self.dim.to_le_bytes());
78        b.push(self.metric_tag);
79        b.extend_from_slice(&self.n.to_le_bytes());
80        b.extend_from_slice(&self.snapshot_max.to_le_bytes());
81        b.extend_from_slice(&self.col_idx.to_le_bytes());
82        b.extend_from_slice(&(self.filter_cols.len() as u32).to_le_bytes());
83        for &c in &self.filter_cols {
84            b.extend_from_slice(&c.to_le_bytes());
85        }
86        b.extend_from_slice(&(self.dicts.len() as u32).to_le_bytes());
87        for dict in &self.dicts {
88            b.extend_from_slice(&(dict.len() as u64).to_le_bytes());
89            for (k, v) in dict {
90                b.extend_from_slice(&(k.len() as u64).to_le_bytes());
91                b.extend_from_slice(k);
92                b.extend_from_slice(&v.to_le_bytes());
93            }
94        }
95        b.extend_from_slice(&self.content_fingerprint);
96        b.extend_from_slice(&self.segment_b3);
97        b.extend_from_slice(&self.chunk_count.to_le_bytes());
98        b.extend_from_slice(&(self.writer.len() as u32).to_le_bytes());
99        b.extend_from_slice(self.writer.as_bytes());
100        // Self-hash binds header fields beyond page-level HMAC (cheap
101        // hardening: a header is never accepted with internal bit-rot).
102        let self_hash = blake3::hash(&b);
103        b.extend_from_slice(self_hash.as_bytes());
104        b
105    }
106
107    pub fn decode(bytes: &[u8]) -> Result<Self> {
108        let fail = |what: &str| SqlError::InvalidValue(format!("ANN segment header: {what}"));
109        if bytes.len() < 32 {
110            return Err(fail("truncated"));
111        }
112        let (body, hash) = bytes.split_at(bytes.len() - 32);
113        if blake3::hash(body).as_bytes() != hash {
114            return Err(fail("self-hash mismatch (corrupt)"));
115        }
116        let mut at = 0usize;
117        let mut take = |n: usize| -> Result<&[u8]> {
118            let end = at.checked_add(n).filter(|&e| e <= body.len());
119            let end = end.ok_or_else(|| fail("truncated"))?;
120            let s = &body[at..end];
121            at = end;
122            Ok(s)
123        };
124        if take(7)? != MAGIC {
125            return Err(fail("bad magic"));
126        }
127        let format_version = u16::from_le_bytes(take(2)?.try_into().unwrap());
128        let prism_config_hash: [u8; 32] = take(32)?.try_into().unwrap();
129        let dim = u16::from_le_bytes(take(2)?.try_into().unwrap());
130        let metric_tag = take(1)?[0];
131        let n = u64::from_le_bytes(take(8)?.try_into().unwrap());
132        let snapshot_max = u64::from_le_bytes(take(8)?.try_into().unwrap());
133        let col_idx = u32::from_le_bytes(take(4)?.try_into().unwrap());
134        let fc_len = u32::from_le_bytes(take(4)?.try_into().unwrap()) as usize;
135        let mut filter_cols = Vec::with_capacity(fc_len);
136        for _ in 0..fc_len {
137            filter_cols.push(u32::from_le_bytes(take(4)?.try_into().unwrap()));
138        }
139        let dicts_len = u32::from_le_bytes(take(4)?.try_into().unwrap()) as usize;
140        let mut dicts = Vec::with_capacity(dicts_len);
141        for _ in 0..dicts_len {
142            let entries = u64::from_le_bytes(take(8)?.try_into().unwrap()) as usize;
143            let mut dict = Vec::with_capacity(entries);
144            for _ in 0..entries {
145                let klen = u64::from_le_bytes(take(8)?.try_into().unwrap()) as usize;
146                let k = take(klen)?.to_vec();
147                let v = u32::from_le_bytes(take(4)?.try_into().unwrap());
148                dict.push((k, v));
149            }
150            dicts.push(dict);
151        }
152        let content_fingerprint: [u8; 32] = take(32)?.try_into().unwrap();
153        let segment_b3: [u8; 32] = take(32)?.try_into().unwrap();
154        let chunk_count = u32::from_le_bytes(take(4)?.try_into().unwrap());
155        let wlen = u32::from_le_bytes(take(4)?.try_into().unwrap()) as usize;
156        let writer = String::from_utf8_lossy(take(wlen)?).into_owned();
157        if at != body.len() {
158            return Err(fail("trailing bytes"));
159        }
160        Ok(Self {
161            format_version,
162            prism_config_hash,
163            dim,
164            metric_tag,
165            n,
166            snapshot_max,
167            col_idx,
168            filter_cols,
169            dicts,
170            content_fingerprint,
171            segment_b3,
172            chunk_count,
173            writer,
174        })
175    }
176
177    /// The dicts as the runtime maps the filter pushdown uses.
178    pub fn dict_maps(&self) -> Vec<FxHashMap<Vec<u8>, u32>> {
179        self.dicts
180            .iter()
181            .map(|d| d.iter().cloned().collect())
182            .collect()
183    }
184}
185
186/// The INJECTIVE content fingerprint: domain-separated, every component
187/// length-framed (unframed concatenation admits boundary-shift collisions),
188/// bound to the table/column/filter identity, fed rows IN SCAN ORDER. Persist
189/// and load MUST construct it identically - both go through this one type.
190pub struct FingerprintHasher {
191    h: blake3::Hasher,
192}
193
194impl FingerprintHasher {
195    pub fn new(table: &str, col_idx: u32, filter_cols: &[u32], dim: u16, metric_tag: u8) -> Self {
196        let mut h = blake3::Hasher::new();
197        h.update(b"citadel-annseg-fp-v1");
198        h.update(&(table.len() as u64).to_le_bytes());
199        h.update(table.as_bytes());
200        h.update(&col_idx.to_le_bytes());
201        h.update(&(filter_cols.len() as u32).to_le_bytes());
202        for &c in filter_cols {
203            h.update(&c.to_le_bytes());
204        }
205        h.update(&dim.to_le_bytes());
206        h.update(&[metric_tag]);
207        Self { h }
208    }
209
210    /// One scanned row: its key, the RAW encoded vector-column bytes (null =
211    /// empty, still framed - unindexed rows are part of the content), and each
212    /// filter column's encoded bytes.
213    pub fn row(&mut self, key: &[u8], vector_raw: &[u8], filter_encoded: &[&[u8]]) {
214        self.h.update(&(key.len() as u64).to_le_bytes());
215        self.h.update(key);
216        self.h.update(&(vector_raw.len() as u64).to_le_bytes());
217        self.h.update(vector_raw);
218        for f in filter_encoded {
219            self.h.update(&(f.len() as u64).to_le_bytes());
220            self.h.update(f);
221        }
222    }
223
224    pub fn finish(self) -> [u8; 32] {
225        *self.h.finalize().as_bytes()
226    }
227}
228
229/// The active config's hash for `metric` - what persist writes and the loader
230/// requires (a binary with a different geometry must rebuild, not load).
231pub fn active_config_hash(metric: citadel_vector::Metric) -> [u8; 32] {
232    let cfg: PrismConfig = citadel_vector::AnnIndex::active_config(metric);
233    segment::prism_config_hash(&cfg)
234}
235
236/// What `persist_ann_index` returns for the caller's manifest: the hashes a
237/// later attach verifies against, and the shape for the record.
238#[derive(Debug, Clone, PartialEq, Eq)]
239pub struct AnnSegmentInfo {
240    pub segment_b3: [u8; 32],
241    pub content_fingerprint: [u8; 32],
242    pub n: u64,
243    pub dim: u16,
244    pub metric_tag: u8,
245    pub chunk_count: u32,
246}
247
248/// Drop a table's persisted segment INSIDE the caller's write txn (the
249/// transactional staleness layer). Absent segment = nothing to do; savepoint
250/// rollback restores a dropped one automatically.
251pub(crate) fn purge_segment(
252    wtx: &mut citadel_txn::write_txn::WriteTxn<'_>,
253    table_name: &str,
254) -> Result<()> {
255    match wtx.drop_table(&segment_table_name(table_name)) {
256        Ok(()) => Ok(()),
257        Err(citadel_core::Error::TableNotFound(_)) => Ok(()),
258        Err(e) => Err(SqlError::Storage(e)),
259    }
260}
261
262/// Split a segment body into storage chunks (chunk 0 is the header's key).
263pub fn chunks(body: &[u8]) -> impl Iterator<Item = (u32, &[u8])> {
264    body.chunks(CHUNK_BYTES)
265        .enumerate()
266        .map(|(i, c)| ((i + 1) as u32, c))
267}
268
269#[cfg(test)]
270mod tests {
271    use super::*;
272
273    fn header_fixture() -> SegmentHeader {
274        SegmentHeader {
275            format_version: ANNSEG_FORMAT_VERSION,
276            prism_config_hash: [7; 32],
277            dim: 768,
278            metric_tag: 2,
279            n: 311_592,
280            snapshot_max: 99,
281            col_idx: 3,
282            filter_cols: vec![1, 2],
283            dicts: vec![
284                vec![(b"region".to_vec(), 0), (b"other".to_vec(), 1)],
285                vec![(b"kind".to_vec(), 0)],
286            ],
287            content_fingerprint: [9; 32],
288            segment_b3: [4; 32],
289            chunk_count: 41,
290            writer: "citadel-test".into(),
291        }
292    }
293
294    #[test]
295    fn header_roundtrips_exactly() {
296        let h = header_fixture();
297        assert_eq!(SegmentHeader::decode(&h.encode()).unwrap(), h);
298    }
299
300    #[test]
301    fn header_corruption_is_refused() {
302        let bytes = header_fixture().encode();
303        for spot in [0, 9, 45, bytes.len() / 2, bytes.len() - 40] {
304            let mut corrupt = bytes.clone();
305            corrupt[spot] ^= 0xFF;
306            assert!(
307                SegmentHeader::decode(&corrupt).is_err(),
308                "corruption at {spot} must refuse"
309            );
310        }
311    }
312
313    #[test]
314    fn fingerprint_is_framed_against_boundary_shifts() {
315        // Same concatenated bytes, different row framing -> different hashes.
316        let mut a = FingerprintHasher::new("t", 0, &[], 4, 2);
317        a.row(b"ab", b"cd", &[]);
318        let mut b = FingerprintHasher::new("t", 0, &[], 4, 2);
319        b.row(b"abc", b"d", &[]);
320        assert_ne!(a.finish(), b.finish());
321
322        // Identity changes perturb it too.
323        let mut c = FingerprintHasher::new("t", 1, &[], 4, 2);
324        c.row(b"ab", b"cd", &[]);
325        let mut d = FingerprintHasher::new("t", 0, &[2], 4, 2);
326        d.row(b"ab", b"cd", &[]);
327        let mut base = FingerprintHasher::new("t", 0, &[], 4, 2);
328        base.row(b"ab", b"cd", &[]);
329        let base = base.finish();
330        assert_ne!(c.finish(), base);
331        assert_ne!(d.finish(), base);
332    }
333
334    #[test]
335    fn chunking_covers_the_body_in_order() {
336        let body = vec![0xABu8; CHUNK_BYTES + 17];
337        let parts: Vec<(u32, usize)> = chunks(&body).map(|(no, c)| (no, c.len())).collect();
338        assert_eq!(parts, vec![(1, CHUNK_BYTES), (2, 17)]);
339    }
340}