use citadel_vector::segment;
use citadel_vector::PrismConfig;
use rustc_hash::FxHashMap;
use crate::error::{Result, SqlError};
pub const ANNSEG_FORMAT_VERSION: u16 = 1;
const MAGIC: &[u8; 7] = b"ANNSEG\0";
pub const CHUNK_BYTES: usize = 1024 * 1024;
pub fn segment_table_name(table: &str) -> Vec<u8> {
format!("__annseg_{table}").into_bytes()
}
pub fn segment_key(chunk_no: u32) -> [u8; 4] {
chunk_no.to_be_bytes()
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct SegmentHeader {
pub format_version: u16,
pub prism_config_hash: [u8; 32],
pub dim: u16,
pub metric_tag: u8,
pub n: u64,
pub snapshot_max: u64,
pub col_idx: u32,
pub filter_cols: Vec<u32>,
pub dicts: Vec<Vec<(Vec<u8>, u32)>>,
pub content_fingerprint: [u8; 32],
pub segment_b3: [u8; 32],
pub chunk_count: u32,
pub writer: String,
}
impl SegmentHeader {
pub fn encode(&self) -> Vec<u8> {
let mut b = Vec::new();
b.extend_from_slice(MAGIC);
b.extend_from_slice(&self.format_version.to_le_bytes());
b.extend_from_slice(&self.prism_config_hash);
b.extend_from_slice(&self.dim.to_le_bytes());
b.push(self.metric_tag);
b.extend_from_slice(&self.n.to_le_bytes());
b.extend_from_slice(&self.snapshot_max.to_le_bytes());
b.extend_from_slice(&self.col_idx.to_le_bytes());
b.extend_from_slice(&(self.filter_cols.len() as u32).to_le_bytes());
for &c in &self.filter_cols {
b.extend_from_slice(&c.to_le_bytes());
}
b.extend_from_slice(&(self.dicts.len() as u32).to_le_bytes());
for dict in &self.dicts {
b.extend_from_slice(&(dict.len() as u64).to_le_bytes());
for (k, v) in dict {
b.extend_from_slice(&(k.len() as u64).to_le_bytes());
b.extend_from_slice(k);
b.extend_from_slice(&v.to_le_bytes());
}
}
b.extend_from_slice(&self.content_fingerprint);
b.extend_from_slice(&self.segment_b3);
b.extend_from_slice(&self.chunk_count.to_le_bytes());
b.extend_from_slice(&(self.writer.len() as u32).to_le_bytes());
b.extend_from_slice(self.writer.as_bytes());
let self_hash = blake3::hash(&b);
b.extend_from_slice(self_hash.as_bytes());
b
}
pub fn decode(bytes: &[u8]) -> Result<Self> {
let fail = |what: &str| SqlError::InvalidValue(format!("ANN segment header: {what}"));
if bytes.len() < 32 {
return Err(fail("truncated"));
}
let (body, hash) = bytes.split_at(bytes.len() - 32);
if blake3::hash(body).as_bytes() != hash {
return Err(fail("self-hash mismatch (corrupt)"));
}
let mut at = 0usize;
let mut take = |n: usize| -> Result<&[u8]> {
let end = at.checked_add(n).filter(|&e| e <= body.len());
let end = end.ok_or_else(|| fail("truncated"))?;
let s = &body[at..end];
at = end;
Ok(s)
};
if take(7)? != MAGIC {
return Err(fail("bad magic"));
}
let format_version = u16::from_le_bytes(take(2)?.try_into().unwrap());
let prism_config_hash: [u8; 32] = take(32)?.try_into().unwrap();
let dim = u16::from_le_bytes(take(2)?.try_into().unwrap());
let metric_tag = take(1)?[0];
let n = u64::from_le_bytes(take(8)?.try_into().unwrap());
let snapshot_max = u64::from_le_bytes(take(8)?.try_into().unwrap());
let col_idx = u32::from_le_bytes(take(4)?.try_into().unwrap());
let fc_len = u32::from_le_bytes(take(4)?.try_into().unwrap()) as usize;
let mut filter_cols = Vec::with_capacity(fc_len);
for _ in 0..fc_len {
filter_cols.push(u32::from_le_bytes(take(4)?.try_into().unwrap()));
}
let dicts_len = u32::from_le_bytes(take(4)?.try_into().unwrap()) as usize;
let mut dicts = Vec::with_capacity(dicts_len);
for _ in 0..dicts_len {
let entries = u64::from_le_bytes(take(8)?.try_into().unwrap()) as usize;
let mut dict = Vec::with_capacity(entries);
for _ in 0..entries {
let klen = u64::from_le_bytes(take(8)?.try_into().unwrap()) as usize;
let k = take(klen)?.to_vec();
let v = u32::from_le_bytes(take(4)?.try_into().unwrap());
dict.push((k, v));
}
dicts.push(dict);
}
let content_fingerprint: [u8; 32] = take(32)?.try_into().unwrap();
let segment_b3: [u8; 32] = take(32)?.try_into().unwrap();
let chunk_count = u32::from_le_bytes(take(4)?.try_into().unwrap());
let wlen = u32::from_le_bytes(take(4)?.try_into().unwrap()) as usize;
let writer = String::from_utf8_lossy(take(wlen)?).into_owned();
if at != body.len() {
return Err(fail("trailing bytes"));
}
Ok(Self {
format_version,
prism_config_hash,
dim,
metric_tag,
n,
snapshot_max,
col_idx,
filter_cols,
dicts,
content_fingerprint,
segment_b3,
chunk_count,
writer,
})
}
pub fn dict_maps(&self) -> Vec<FxHashMap<Vec<u8>, u32>> {
self.dicts
.iter()
.map(|d| d.iter().cloned().collect())
.collect()
}
}
pub struct FingerprintHasher {
h: blake3::Hasher,
}
impl FingerprintHasher {
pub fn new(table: &str, col_idx: u32, filter_cols: &[u32], dim: u16, metric_tag: u8) -> Self {
let mut h = blake3::Hasher::new();
h.update(b"citadel-annseg-fp-v1");
h.update(&(table.len() as u64).to_le_bytes());
h.update(table.as_bytes());
h.update(&col_idx.to_le_bytes());
h.update(&(filter_cols.len() as u32).to_le_bytes());
for &c in filter_cols {
h.update(&c.to_le_bytes());
}
h.update(&dim.to_le_bytes());
h.update(&[metric_tag]);
Self { h }
}
pub fn row(&mut self, key: &[u8], vector_raw: &[u8], filter_encoded: &[&[u8]]) {
self.h.update(&(key.len() as u64).to_le_bytes());
self.h.update(key);
self.h.update(&(vector_raw.len() as u64).to_le_bytes());
self.h.update(vector_raw);
for f in filter_encoded {
self.h.update(&(f.len() as u64).to_le_bytes());
self.h.update(f);
}
}
pub fn finish(self) -> [u8; 32] {
*self.h.finalize().as_bytes()
}
}
pub fn active_config_hash(metric: citadel_vector::Metric) -> [u8; 32] {
let cfg: PrismConfig = citadel_vector::AnnIndex::active_config(metric);
segment::prism_config_hash(&cfg)
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct AnnSegmentInfo {
pub segment_b3: [u8; 32],
pub content_fingerprint: [u8; 32],
pub n: u64,
pub dim: u16,
pub metric_tag: u8,
pub chunk_count: u32,
}
pub(crate) fn purge_segment(
wtx: &mut citadel_txn::write_txn::WriteTxn<'_>,
table_name: &str,
) -> Result<()> {
match wtx.drop_table(&segment_table_name(table_name)) {
Ok(()) => Ok(()),
Err(citadel_core::Error::TableNotFound(_)) => Ok(()),
Err(e) => Err(SqlError::Storage(e)),
}
}
pub fn chunks(body: &[u8]) -> impl Iterator<Item = (u32, &[u8])> {
body.chunks(CHUNK_BYTES)
.enumerate()
.map(|(i, c)| ((i + 1) as u32, c))
}
#[cfg(test)]
mod tests {
use super::*;
fn header_fixture() -> SegmentHeader {
SegmentHeader {
format_version: ANNSEG_FORMAT_VERSION,
prism_config_hash: [7; 32],
dim: 768,
metric_tag: 2,
n: 311_592,
snapshot_max: 99,
col_idx: 3,
filter_cols: vec![1, 2],
dicts: vec![
vec![(b"region".to_vec(), 0), (b"other".to_vec(), 1)],
vec![(b"kind".to_vec(), 0)],
],
content_fingerprint: [9; 32],
segment_b3: [4; 32],
chunk_count: 41,
writer: "citadel-test".into(),
}
}
#[test]
fn header_roundtrips_exactly() {
let h = header_fixture();
assert_eq!(SegmentHeader::decode(&h.encode()).unwrap(), h);
}
#[test]
fn header_corruption_is_refused() {
let bytes = header_fixture().encode();
for spot in [0, 9, 45, bytes.len() / 2, bytes.len() - 40] {
let mut corrupt = bytes.clone();
corrupt[spot] ^= 0xFF;
assert!(
SegmentHeader::decode(&corrupt).is_err(),
"corruption at {spot} must refuse"
);
}
}
#[test]
fn fingerprint_is_framed_against_boundary_shifts() {
let mut a = FingerprintHasher::new("t", 0, &[], 4, 2);
a.row(b"ab", b"cd", &[]);
let mut b = FingerprintHasher::new("t", 0, &[], 4, 2);
b.row(b"abc", b"d", &[]);
assert_ne!(a.finish(), b.finish());
let mut c = FingerprintHasher::new("t", 1, &[], 4, 2);
c.row(b"ab", b"cd", &[]);
let mut d = FingerprintHasher::new("t", 0, &[2], 4, 2);
d.row(b"ab", b"cd", &[]);
let mut base = FingerprintHasher::new("t", 0, &[], 4, 2);
base.row(b"ab", b"cd", &[]);
let base = base.finish();
assert_ne!(c.finish(), base);
assert_ne!(d.finish(), base);
}
#[test]
fn chunking_covers_the_body_in_order() {
let body = vec![0xABu8; CHUNK_BYTES + 17];
let parts: Vec<(u32, usize)> = chunks(&body).map(|(no, c)| (no, c.len())).collect();
assert_eq!(parts, vec![(1, CHUNK_BYTES), (2, 17)]);
}
}