1use citadel_vector::segment;
17use citadel_vector::PrismConfig;
18use rustc_hash::FxHashMap;
19
20use crate::error::{Result, SqlError};
21
22pub const ANNSEG_FORMAT_VERSION: u16 = 1;
24
25const MAGIC: &[u8; 7] = b"ANNSEG\0";
26
27pub const CHUNK_BYTES: usize = 1024 * 1024;
32
33pub fn segment_table_name(table: &str) -> Vec<u8> {
35 format!("__annseg_{table}").into_bytes()
36}
37
38pub fn segment_key(chunk_no: u32) -> [u8; 4] {
40 chunk_no.to_be_bytes()
41}
42
43#[derive(Debug, Clone, PartialEq, Eq)]
46pub struct SegmentHeader {
47 pub format_version: u16,
48 pub prism_config_hash: [u8; 32],
50 pub dim: u16,
51 pub metric_tag: u8,
52 pub n: u64,
55 pub snapshot_max: u64,
56 pub col_idx: u32,
60 pub filter_cols: Vec<u32>,
61 pub dicts: Vec<Vec<(Vec<u8>, u32)>>,
63 pub content_fingerprint: [u8; 32],
64 pub segment_b3: [u8; 32],
66 pub chunk_count: u32,
67 pub writer: String,
69}
70
71impl SegmentHeader {
72 pub fn encode(&self) -> Vec<u8> {
73 let mut b = Vec::new();
74 b.extend_from_slice(MAGIC);
75 b.extend_from_slice(&self.format_version.to_le_bytes());
76 b.extend_from_slice(&self.prism_config_hash);
77 b.extend_from_slice(&self.dim.to_le_bytes());
78 b.push(self.metric_tag);
79 b.extend_from_slice(&self.n.to_le_bytes());
80 b.extend_from_slice(&self.snapshot_max.to_le_bytes());
81 b.extend_from_slice(&self.col_idx.to_le_bytes());
82 b.extend_from_slice(&(self.filter_cols.len() as u32).to_le_bytes());
83 for &c in &self.filter_cols {
84 b.extend_from_slice(&c.to_le_bytes());
85 }
86 b.extend_from_slice(&(self.dicts.len() as u32).to_le_bytes());
87 for dict in &self.dicts {
88 b.extend_from_slice(&(dict.len() as u64).to_le_bytes());
89 for (k, v) in dict {
90 b.extend_from_slice(&(k.len() as u64).to_le_bytes());
91 b.extend_from_slice(k);
92 b.extend_from_slice(&v.to_le_bytes());
93 }
94 }
95 b.extend_from_slice(&self.content_fingerprint);
96 b.extend_from_slice(&self.segment_b3);
97 b.extend_from_slice(&self.chunk_count.to_le_bytes());
98 b.extend_from_slice(&(self.writer.len() as u32).to_le_bytes());
99 b.extend_from_slice(self.writer.as_bytes());
100 let self_hash = blake3::hash(&b);
103 b.extend_from_slice(self_hash.as_bytes());
104 b
105 }
106
107 pub fn decode(bytes: &[u8]) -> Result<Self> {
108 let fail = |what: &str| SqlError::InvalidValue(format!("ANN segment header: {what}"));
109 if bytes.len() < 32 {
110 return Err(fail("truncated"));
111 }
112 let (body, hash) = bytes.split_at(bytes.len() - 32);
113 if blake3::hash(body).as_bytes() != hash {
114 return Err(fail("self-hash mismatch (corrupt)"));
115 }
116 let mut at = 0usize;
117 let mut take = |n: usize| -> Result<&[u8]> {
118 let end = at.checked_add(n).filter(|&e| e <= body.len());
119 let end = end.ok_or_else(|| fail("truncated"))?;
120 let s = &body[at..end];
121 at = end;
122 Ok(s)
123 };
124 if take(7)? != MAGIC {
125 return Err(fail("bad magic"));
126 }
127 let format_version = u16::from_le_bytes(take(2)?.try_into().unwrap());
128 let prism_config_hash: [u8; 32] = take(32)?.try_into().unwrap();
129 let dim = u16::from_le_bytes(take(2)?.try_into().unwrap());
130 let metric_tag = take(1)?[0];
131 let n = u64::from_le_bytes(take(8)?.try_into().unwrap());
132 let snapshot_max = u64::from_le_bytes(take(8)?.try_into().unwrap());
133 let col_idx = u32::from_le_bytes(take(4)?.try_into().unwrap());
134 let fc_len = u32::from_le_bytes(take(4)?.try_into().unwrap()) as usize;
135 let mut filter_cols = Vec::with_capacity(fc_len);
136 for _ in 0..fc_len {
137 filter_cols.push(u32::from_le_bytes(take(4)?.try_into().unwrap()));
138 }
139 let dicts_len = u32::from_le_bytes(take(4)?.try_into().unwrap()) as usize;
140 let mut dicts = Vec::with_capacity(dicts_len);
141 for _ in 0..dicts_len {
142 let entries = u64::from_le_bytes(take(8)?.try_into().unwrap()) as usize;
143 let mut dict = Vec::with_capacity(entries);
144 for _ in 0..entries {
145 let klen = u64::from_le_bytes(take(8)?.try_into().unwrap()) as usize;
146 let k = take(klen)?.to_vec();
147 let v = u32::from_le_bytes(take(4)?.try_into().unwrap());
148 dict.push((k, v));
149 }
150 dicts.push(dict);
151 }
152 let content_fingerprint: [u8; 32] = take(32)?.try_into().unwrap();
153 let segment_b3: [u8; 32] = take(32)?.try_into().unwrap();
154 let chunk_count = u32::from_le_bytes(take(4)?.try_into().unwrap());
155 let wlen = u32::from_le_bytes(take(4)?.try_into().unwrap()) as usize;
156 let writer = String::from_utf8_lossy(take(wlen)?).into_owned();
157 if at != body.len() {
158 return Err(fail("trailing bytes"));
159 }
160 Ok(Self {
161 format_version,
162 prism_config_hash,
163 dim,
164 metric_tag,
165 n,
166 snapshot_max,
167 col_idx,
168 filter_cols,
169 dicts,
170 content_fingerprint,
171 segment_b3,
172 chunk_count,
173 writer,
174 })
175 }
176
177 pub fn dict_maps(&self) -> Vec<FxHashMap<Vec<u8>, u32>> {
179 self.dicts
180 .iter()
181 .map(|d| d.iter().cloned().collect())
182 .collect()
183 }
184}
185
186pub struct FingerprintHasher {
191 h: blake3::Hasher,
192}
193
194impl FingerprintHasher {
195 pub fn new(table: &str, col_idx: u32, filter_cols: &[u32], dim: u16, metric_tag: u8) -> Self {
196 let mut h = blake3::Hasher::new();
197 h.update(b"citadel-annseg-fp-v1");
198 h.update(&(table.len() as u64).to_le_bytes());
199 h.update(table.as_bytes());
200 h.update(&col_idx.to_le_bytes());
201 h.update(&(filter_cols.len() as u32).to_le_bytes());
202 for &c in filter_cols {
203 h.update(&c.to_le_bytes());
204 }
205 h.update(&dim.to_le_bytes());
206 h.update(&[metric_tag]);
207 Self { h }
208 }
209
210 pub fn row(&mut self, key: &[u8], vector_raw: &[u8], filter_encoded: &[&[u8]]) {
214 self.h.update(&(key.len() as u64).to_le_bytes());
215 self.h.update(key);
216 self.h.update(&(vector_raw.len() as u64).to_le_bytes());
217 self.h.update(vector_raw);
218 for f in filter_encoded {
219 self.h.update(&(f.len() as u64).to_le_bytes());
220 self.h.update(f);
221 }
222 }
223
224 pub fn finish(self) -> [u8; 32] {
225 *self.h.finalize().as_bytes()
226 }
227}
228
229pub fn active_config_hash(metric: citadel_vector::Metric) -> [u8; 32] {
232 let cfg: PrismConfig = citadel_vector::AnnIndex::active_config(metric);
233 segment::prism_config_hash(&cfg)
234}
235
236#[derive(Debug, Clone, PartialEq, Eq)]
239pub struct AnnSegmentInfo {
240 pub segment_b3: [u8; 32],
241 pub content_fingerprint: [u8; 32],
242 pub n: u64,
243 pub dim: u16,
244 pub metric_tag: u8,
245 pub chunk_count: u32,
246}
247
248pub(crate) fn purge_segment(
252 wtx: &mut citadel_txn::write_txn::WriteTxn<'_>,
253 table_name: &str,
254) -> Result<()> {
255 match wtx.drop_table(&segment_table_name(table_name)) {
256 Ok(()) => Ok(()),
257 Err(citadel_core::Error::TableNotFound(_)) => Ok(()),
258 Err(e) => Err(SqlError::Storage(e)),
259 }
260}
261
262pub fn chunks(body: &[u8]) -> impl Iterator<Item = (u32, &[u8])> {
264 body.chunks(CHUNK_BYTES)
265 .enumerate()
266 .map(|(i, c)| ((i + 1) as u32, c))
267}
268
269#[cfg(test)]
270mod tests {
271 use super::*;
272
273 fn header_fixture() -> SegmentHeader {
274 SegmentHeader {
275 format_version: ANNSEG_FORMAT_VERSION,
276 prism_config_hash: [7; 32],
277 dim: 768,
278 metric_tag: 2,
279 n: 311_592,
280 snapshot_max: 99,
281 col_idx: 3,
282 filter_cols: vec![1, 2],
283 dicts: vec![
284 vec![(b"region".to_vec(), 0), (b"other".to_vec(), 1)],
285 vec![(b"kind".to_vec(), 0)],
286 ],
287 content_fingerprint: [9; 32],
288 segment_b3: [4; 32],
289 chunk_count: 41,
290 writer: "citadel-test".into(),
291 }
292 }
293
294 #[test]
295 fn header_roundtrips_exactly() {
296 let h = header_fixture();
297 assert_eq!(SegmentHeader::decode(&h.encode()).unwrap(), h);
298 }
299
300 #[test]
301 fn header_corruption_is_refused() {
302 let bytes = header_fixture().encode();
303 for spot in [0, 9, 45, bytes.len() / 2, bytes.len() - 40] {
304 let mut corrupt = bytes.clone();
305 corrupt[spot] ^= 0xFF;
306 assert!(
307 SegmentHeader::decode(&corrupt).is_err(),
308 "corruption at {spot} must refuse"
309 );
310 }
311 }
312
313 #[test]
314 fn fingerprint_is_framed_against_boundary_shifts() {
315 let mut a = FingerprintHasher::new("t", 0, &[], 4, 2);
317 a.row(b"ab", b"cd", &[]);
318 let mut b = FingerprintHasher::new("t", 0, &[], 4, 2);
319 b.row(b"abc", b"d", &[]);
320 assert_ne!(a.finish(), b.finish());
321
322 let mut c = FingerprintHasher::new("t", 1, &[], 4, 2);
324 c.row(b"ab", b"cd", &[]);
325 let mut d = FingerprintHasher::new("t", 0, &[2], 4, 2);
326 d.row(b"ab", b"cd", &[]);
327 let mut base = FingerprintHasher::new("t", 0, &[], 4, 2);
328 base.row(b"ab", b"cd", &[]);
329 let base = base.finish();
330 assert_ne!(c.finish(), base);
331 assert_ne!(d.finish(), base);
332 }
333
334 #[test]
335 fn chunking_covers_the_body_in_order() {
336 let body = vec![0xABu8; CHUNK_BYTES + 17];
337 let parts: Vec<(u32, usize)> = chunks(&body).map(|(no, c)| (no, c.len())).collect();
338 assert_eq!(parts, vec![(1, CHUNK_BYTES), (2, 17)]);
339 }
340}