1use citadel_vector::segment;
17use citadel_vector::PrismConfig;
18use rustc_hash::FxHashMap;
19
20use crate::error::{Result, SqlError};
21
22pub const ANNSEG_FORMAT_VERSION: u16 = 2;
24
25const MAGIC: &[u8; 7] = b"ANNSEG\0";
26
27pub const CHUNK_BYTES: usize = 1024 * 1024;
32
33pub fn segment_table_name(table: &str) -> Vec<u8> {
35 format!("__annseg_{table}").into_bytes()
36}
37
38pub fn segment_key(chunk_no: u32) -> [u8; 4] {
40 chunk_no.to_be_bytes()
41}
42
43#[derive(Debug, Clone, PartialEq, Eq)]
46pub struct SegmentHeader {
47 pub format_version: u16,
48 pub prism_config_hash: [u8; 32],
50 pub dim: u16,
51 pub metric_tag: u8,
52 pub n: u64,
55 pub snapshot_max: u64,
56 pub table_root: u64,
58 pub col_idx: u32,
62 pub filter_cols: Vec<u32>,
63 pub dicts: Vec<Vec<(Vec<u8>, u32)>>,
65 pub content_fingerprint: [u8; 32],
66 pub segment_b3: [u8; 32],
68 pub chunk_count: u32,
69 pub writer: String,
71}
72
73impl SegmentHeader {
74 pub fn encode(&self) -> Vec<u8> {
75 let mut b = Vec::new();
76 b.extend_from_slice(MAGIC);
77 b.extend_from_slice(&self.format_version.to_le_bytes());
78 b.extend_from_slice(&self.prism_config_hash);
79 b.extend_from_slice(&self.dim.to_le_bytes());
80 b.push(self.metric_tag);
81 b.extend_from_slice(&self.n.to_le_bytes());
82 b.extend_from_slice(&self.snapshot_max.to_le_bytes());
83 b.extend_from_slice(&self.table_root.to_le_bytes());
84 b.extend_from_slice(&self.col_idx.to_le_bytes());
85 b.extend_from_slice(&(self.filter_cols.len() as u32).to_le_bytes());
86 for &c in &self.filter_cols {
87 b.extend_from_slice(&c.to_le_bytes());
88 }
89 b.extend_from_slice(&(self.dicts.len() as u32).to_le_bytes());
90 for dict in &self.dicts {
91 b.extend_from_slice(&(dict.len() as u64).to_le_bytes());
92 for (k, v) in dict {
93 b.extend_from_slice(&(k.len() as u64).to_le_bytes());
94 b.extend_from_slice(k);
95 b.extend_from_slice(&v.to_le_bytes());
96 }
97 }
98 b.extend_from_slice(&self.content_fingerprint);
99 b.extend_from_slice(&self.segment_b3);
100 b.extend_from_slice(&self.chunk_count.to_le_bytes());
101 b.extend_from_slice(&(self.writer.len() as u32).to_le_bytes());
102 b.extend_from_slice(self.writer.as_bytes());
103 let self_hash = blake3::hash(&b);
106 b.extend_from_slice(self_hash.as_bytes());
107 b
108 }
109
110 pub fn decode(bytes: &[u8]) -> Result<Self> {
111 let fail = |what: &str| SqlError::InvalidValue(format!("ANN segment header: {what}"));
112 if bytes.len() < 32 {
113 return Err(fail("truncated"));
114 }
115 let (body, hash) = bytes.split_at(bytes.len() - 32);
116 if blake3::hash(body).as_bytes() != hash {
117 return Err(fail("self-hash mismatch (corrupt)"));
118 }
119 let mut at = 0usize;
120 let mut take = |n: usize| -> Result<&[u8]> {
121 let end = at.checked_add(n).filter(|&e| e <= body.len());
122 let end = end.ok_or_else(|| fail("truncated"))?;
123 let s = &body[at..end];
124 at = end;
125 Ok(s)
126 };
127 if take(7)? != MAGIC {
128 return Err(fail("bad magic"));
129 }
130 let format_version = u16::from_le_bytes(take(2)?.try_into().unwrap());
131 let prism_config_hash: [u8; 32] = take(32)?.try_into().unwrap();
132 let dim = u16::from_le_bytes(take(2)?.try_into().unwrap());
133 let metric_tag = take(1)?[0];
134 let n = u64::from_le_bytes(take(8)?.try_into().unwrap());
135 let snapshot_max = u64::from_le_bytes(take(8)?.try_into().unwrap());
136 let table_root = u64::from_le_bytes(take(8)?.try_into().unwrap());
137 let col_idx = u32::from_le_bytes(take(4)?.try_into().unwrap());
138 let fc_len = u32::from_le_bytes(take(4)?.try_into().unwrap()) as usize;
139 let mut filter_cols = Vec::with_capacity(fc_len);
140 for _ in 0..fc_len {
141 filter_cols.push(u32::from_le_bytes(take(4)?.try_into().unwrap()));
142 }
143 let dicts_len = u32::from_le_bytes(take(4)?.try_into().unwrap()) as usize;
144 let mut dicts = Vec::with_capacity(dicts_len);
145 for _ in 0..dicts_len {
146 let entries = u64::from_le_bytes(take(8)?.try_into().unwrap()) as usize;
147 let mut dict = Vec::with_capacity(entries);
148 for _ in 0..entries {
149 let klen = u64::from_le_bytes(take(8)?.try_into().unwrap()) as usize;
150 let k = take(klen)?.to_vec();
151 let v = u32::from_le_bytes(take(4)?.try_into().unwrap());
152 dict.push((k, v));
153 }
154 dicts.push(dict);
155 }
156 let content_fingerprint: [u8; 32] = take(32)?.try_into().unwrap();
157 let segment_b3: [u8; 32] = take(32)?.try_into().unwrap();
158 let chunk_count = u32::from_le_bytes(take(4)?.try_into().unwrap());
159 let wlen = u32::from_le_bytes(take(4)?.try_into().unwrap()) as usize;
160 let writer = String::from_utf8_lossy(take(wlen)?).into_owned();
161 if at != body.len() {
162 return Err(fail("trailing bytes"));
163 }
164 Ok(Self {
165 format_version,
166 prism_config_hash,
167 dim,
168 metric_tag,
169 n,
170 snapshot_max,
171 table_root,
172 col_idx,
173 filter_cols,
174 dicts,
175 content_fingerprint,
176 segment_b3,
177 chunk_count,
178 writer,
179 })
180 }
181
182 pub fn dict_maps(&self) -> Vec<FxHashMap<Vec<u8>, u32>> {
184 self.dicts
185 .iter()
186 .map(|d| d.iter().cloned().collect())
187 .collect()
188 }
189}
190
191pub struct FingerprintHasher {
196 h: blake3::Hasher,
197}
198
199impl FingerprintHasher {
200 pub fn new(table: &str, col_idx: u32, filter_cols: &[u32], dim: u16, metric_tag: u8) -> Self {
201 let mut h = blake3::Hasher::new();
202 h.update(b"citadel-annseg-fp-v1");
203 h.update(&(table.len() as u64).to_le_bytes());
204 h.update(table.as_bytes());
205 h.update(&col_idx.to_le_bytes());
206 h.update(&(filter_cols.len() as u32).to_le_bytes());
207 for &c in filter_cols {
208 h.update(&c.to_le_bytes());
209 }
210 h.update(&dim.to_le_bytes());
211 h.update(&[metric_tag]);
212 Self { h }
213 }
214
215 pub fn row(&mut self, key: &[u8], vector_raw: &[u8], filter_encoded: &[&[u8]]) {
219 self.h.update(&(key.len() as u64).to_le_bytes());
220 self.h.update(key);
221 self.h.update(&(vector_raw.len() as u64).to_le_bytes());
222 self.h.update(vector_raw);
223 for f in filter_encoded {
224 self.h.update(&(f.len() as u64).to_le_bytes());
225 self.h.update(f);
226 }
227 }
228
229 pub fn finish(self) -> [u8; 32] {
230 *self.h.finalize().as_bytes()
231 }
232}
233
234pub fn active_config_hash(metric: citadel_vector::Metric) -> [u8; 32] {
237 let cfg: PrismConfig = citadel_vector::AnnIndex::active_config(metric);
238 segment::prism_config_hash(&cfg)
239}
240
241#[derive(Debug, Clone, PartialEq, Eq)]
244pub struct AnnSegmentInfo {
245 pub segment_b3: [u8; 32],
246 pub content_fingerprint: [u8; 32],
247 pub n: u64,
248 pub dim: u16,
249 pub metric_tag: u8,
250 pub chunk_count: u32,
251}
252
253pub(crate) fn purge_segment(
257 wtx: &mut citadel_txn::write_txn::WriteTxn<'_>,
258 table_name: &str,
259) -> Result<()> {
260 match wtx.drop_table(&segment_table_name(table_name)) {
261 Ok(()) => Ok(()),
262 Err(citadel_core::Error::TableNotFound(_)) => Ok(()),
263 Err(e) => Err(SqlError::Storage(e)),
264 }
265}
266
267pub fn chunks(body: &[u8]) -> impl Iterator<Item = (u32, &[u8])> {
269 body.chunks(CHUNK_BYTES)
270 .enumerate()
271 .map(|(i, c)| ((i + 1) as u32, c))
272}
273
274#[cfg(test)]
275mod tests {
276 use super::*;
277
278 fn header_fixture() -> SegmentHeader {
279 SegmentHeader {
280 format_version: ANNSEG_FORMAT_VERSION,
281 prism_config_hash: [7; 32],
282 dim: 768,
283 metric_tag: 2,
284 n: 311_592,
285 snapshot_max: 99,
286 table_root: 1234,
287 col_idx: 3,
288 filter_cols: vec![1, 2],
289 dicts: vec![
290 vec![(b"region".to_vec(), 0), (b"other".to_vec(), 1)],
291 vec![(b"kind".to_vec(), 0)],
292 ],
293 content_fingerprint: [9; 32],
294 segment_b3: [4; 32],
295 chunk_count: 41,
296 writer: "citadel-test".into(),
297 }
298 }
299
300 #[test]
301 fn header_roundtrips_exactly() {
302 let h = header_fixture();
303 assert_eq!(SegmentHeader::decode(&h.encode()).unwrap(), h);
304 }
305
306 #[test]
307 fn header_corruption_is_refused() {
308 let bytes = header_fixture().encode();
309 for spot in [0, 9, 45, bytes.len() / 2, bytes.len() - 40] {
310 let mut corrupt = bytes.clone();
311 corrupt[spot] ^= 0xFF;
312 assert!(
313 SegmentHeader::decode(&corrupt).is_err(),
314 "corruption at {spot} must refuse"
315 );
316 }
317 }
318
319 #[test]
320 fn fingerprint_is_framed_against_boundary_shifts() {
321 let mut a = FingerprintHasher::new("t", 0, &[], 4, 2);
323 a.row(b"ab", b"cd", &[]);
324 let mut b = FingerprintHasher::new("t", 0, &[], 4, 2);
325 b.row(b"abc", b"d", &[]);
326 assert_ne!(a.finish(), b.finish());
327
328 let mut c = FingerprintHasher::new("t", 1, &[], 4, 2);
330 c.row(b"ab", b"cd", &[]);
331 let mut d = FingerprintHasher::new("t", 0, &[2], 4, 2);
332 d.row(b"ab", b"cd", &[]);
333 let mut base = FingerprintHasher::new("t", 0, &[], 4, 2);
334 base.row(b"ab", b"cd", &[]);
335 let base = base.finish();
336 assert_ne!(c.finish(), base);
337 assert_ne!(d.finish(), base);
338 }
339
340 #[test]
341 fn chunking_covers_the_body_in_order() {
342 let body = vec![0xABu8; CHUNK_BYTES + 17];
343 let parts: Vec<(u32, usize)> = chunks(&body).map(|(no, c)| (no, c.len())).collect();
344 assert_eq!(parts, vec![(1, CHUNK_BYTES), (2, 17)]);
345 }
346}