1use std::any::Any;
6use std::cmp::Ordering;
7use std::collections::BinaryHeap;
8use std::sync::Arc;
9
10use citadel_txn::read_txn::ReadTxn;
11use citadel_txn::write_txn::WriteTxn;
12use citadel_vector::{AnnIndex, Filter, Metric};
13use rustc_hash::FxHashMap;
14
15use crate::encoding::{
16 decode_column_raw, decode_pk_integer, encode_int_key_into, encode_key_value,
17 encode_key_value_collated_into,
18};
19use crate::error::{Result, SqlError};
20use crate::eval::{eval_expr, is_truthy, ColumnMap, EvalCtx};
21use crate::parser::*;
22use crate::schema::SchemaManager;
23use crate::types::*;
24
25use super::aggregate::is_aggregate_expr;
26use super::ann_persist;
27use super::helpers::{decode_full_row, eval_const_expr, eval_const_int, project_rows};
28use super::window::has_any_window_function;
29
30type StorageResult<T> = std::result::Result<T, citadel_core::Error>;
31type ScanRow<'a> = dyn FnMut(&[u8], &[u8]) -> Result<bool> + 'a;
32type RawScanRow<'a> = dyn FnMut(&[u8], &[u8]) -> StorageResult<bool> + 'a;
33type RankedRow = (f64, i64, Vec<Value>);
35
36pub(super) trait AnnScan {
38 fn ann_scan(&mut self, table: &[u8], f: &mut ScanRow<'_>) -> Result<()>;
39 fn ann_scan_from(&mut self, table: &[u8], start_key: &[u8], f: &mut ScanRow<'_>) -> Result<()>;
41 fn ann_get(&mut self, table: &[u8], key: &[u8]) -> Result<Option<Vec<u8>>>;
42 fn cache_generation(&self) -> Option<u64>;
45 fn ann_table_root(&self, table: &[u8]) -> Option<u64>;
47}
48
49fn bridge_scan(
51 scan: impl FnOnce(&mut RawScanRow<'_>) -> StorageResult<()>,
52 f: &mut ScanRow<'_>,
53) -> Result<()> {
54 let mut cb_err: Option<SqlError> = None;
55 scan(&mut |key, value| match f(key, value) {
56 Ok(go) => Ok(go),
57 Err(e) => {
58 cb_err = Some(e);
59 Ok(false)
60 }
61 })
62 .map_err(SqlError::Storage)?;
63 match cb_err {
64 Some(e) => Err(e),
65 None => Ok(()),
66 }
67}
68
69impl AnnScan for ReadTxn<'_> {
70 fn ann_scan(&mut self, table: &[u8], f: &mut ScanRow<'_>) -> Result<()> {
71 bridge_scan(|cb| self.table_scan_from(table, b"", cb), f)
72 }
73
74 fn ann_scan_from(&mut self, table: &[u8], start_key: &[u8], f: &mut ScanRow<'_>) -> Result<()> {
75 bridge_scan(|cb| self.table_scan_from(table, start_key, cb), f)
76 }
77
78 fn ann_get(&mut self, table: &[u8], key: &[u8]) -> Result<Option<Vec<u8>>> {
79 self.table_get(table, key).map_err(SqlError::Storage)
80 }
81
82 fn cache_generation(&self) -> Option<u64> {
83 Some(self.commit_generation())
84 }
85
86 fn ann_table_root(&self, table: &[u8]) -> Option<u64> {
87 self.table_root_page(table)
88 .ok()
89 .flatten()
90 .map(|p| u64::from(p.0))
91 }
92}
93
94impl AnnScan for WriteTxn<'_> {
95 fn ann_scan(&mut self, table: &[u8], f: &mut ScanRow<'_>) -> Result<()> {
96 bridge_scan(|cb| self.table_scan_from(table, b"", cb), f)
97 }
98
99 fn ann_scan_from(&mut self, table: &[u8], start_key: &[u8], f: &mut ScanRow<'_>) -> Result<()> {
100 bridge_scan(|cb| self.table_scan_from(table, start_key, cb), f)
101 }
102
103 fn ann_get(&mut self, table: &[u8], key: &[u8]) -> Result<Option<Vec<u8>>> {
104 self.table_get(table, key).map_err(SqlError::Storage)
105 }
106
107 fn cache_generation(&self) -> Option<u64> {
108 None
109 }
110
111 fn ann_table_root(&self, table: &[u8]) -> Option<u64> {
112 self.table_root_page(table)
113 .ok()
114 .flatten()
115 .map(|p| u64::from(p.0))
116 }
117}
118
119#[derive(Debug, Clone, PartialEq, Eq)]
122pub enum AnnIndexSource {
123 Built { refusal: Option<String> },
126 Loaded { segment_b3: [u8; 32] },
128}
129
130struct CachedAnnIndex {
132 index: AnnIndex,
133 dicts: Vec<FxHashMap<Vec<u8>, u32>>,
135 source: AnnIndexSource,
136 cached_gen: u64,
139}
140
141pub(super) struct AnnTopKPlan {
142 col_idx: usize,
143 dim: u16,
144 metric: AnnMetric,
145 query_vec: Vec<f32>,
146 k: usize,
147 offset: usize,
148 filter_cols: Vec<u16>,
150 pushable: Vec<(usize, Vec<Value>)>,
152 residual: Option<Expr>,
154}
155
156fn topk_shape_ok(stmt: &SelectStmt) -> bool {
158 stmt.order_by.len() == 1
159 && !stmt.order_by[0].descending
160 && stmt.limit.is_some()
161 && stmt.group_by.is_empty()
162 && stmt.having.is_none()
163 && stmt.joins.is_empty()
164 && !stmt.distinct
165 && !has_any_window_function(stmt)
166 && !stmt
167 .columns
168 .iter()
169 .any(|c| matches!(c, SelectColumn::Expr { expr, .. } if is_aggregate_expr(expr)))
170}
171
172enum RunOutcome {
174 Done(ExecutionResult),
175 Rebuild,
176}
177
178fn tail_distance(metric: AnnMetric, q: &[f32], v: &[f32]) -> Option<f64> {
180 let d = match metric {
181 AnnMetric::L2 => {
182 let mut sum = 0.0f64;
183 for (x, y) in q.iter().zip(v.iter()) {
184 let diff = (*x as f64) - (*y as f64);
185 sum += diff * diff;
186 }
187 sum.sqrt()
188 }
189 AnnMetric::Inner => {
190 let mut sum = 0.0f64;
191 for (x, y) in q.iter().zip(v.iter()) {
192 sum += (*x as f64) * (*y as f64);
193 }
194 -sum
195 }
196 AnnMetric::Cosine => {
197 let mut dot = 0.0f64;
198 let mut nq = 0.0f64;
199 let mut nv = 0.0f64;
200 for (x, y) in q.iter().zip(v.iter()) {
201 let xf = *x as f64;
202 let yf = *y as f64;
203 dot += xf * yf;
204 nq += xf * xf;
205 nv += yf * yf;
206 }
207 let denom = nq.sqrt() * nv.sqrt();
208 if denom == 0.0 {
209 return None;
210 }
211 1.0 - dot / denom
212 }
213 };
214 Some(d)
215}
216
217impl AnnTopKPlan {
218 pub(super) fn try_new(stmt: &SelectStmt, table_schema: &TableSchema) -> Result<Option<Self>> {
219 if !topk_shape_ok(stmt) {
220 return Ok(None);
221 }
222 let ob = &stmt.order_by[0];
223
224 let (col_idx, dim, op_metric, query_vec) = match &ob.expr {
225 Expr::BinaryOp { left, op, right } => {
226 let op_metric = match op {
227 BinOp::VectorL2 => AnnMetric::L2,
228 BinOp::VectorInner => AnnMetric::Inner,
229 BinOp::VectorCosine => AnnMetric::Cosine,
230 _ => return Ok(None),
231 };
232 let col_name = match left.as_ref() {
233 Expr::Column(name) => name.to_ascii_lowercase(),
234 _ => return Ok(None),
235 };
236 let (col_idx, dim) = match table_schema
237 .columns
238 .iter()
239 .enumerate()
240 .find(|(_, c)| c.name.to_ascii_lowercase() == col_name)
241 {
242 Some((i, c)) => match c.data_type {
243 DataType::Vector { dim } => (i, dim),
244 _ => return Ok(None),
245 },
246 None => return Ok(None),
247 };
248 let col_map = ColumnMap::new(&table_schema.columns);
249 let ctx = EvalCtx::new(&col_map, &[]);
250 let v = match eval_expr(right, &ctx) {
251 Ok(Value::Vector(v)) => v,
252 _ => return Ok(None),
253 };
254 if v.len() != dim as usize {
255 return Err(SqlError::InvalidValue(format!(
256 "ANN query vector dim {} does not match column dim {}",
257 v.len(),
258 dim
259 )));
260 }
261 (col_idx, dim, op_metric, v.to_vec())
262 }
263 _ => return Ok(None),
264 };
265
266 let ann_index = table_schema.indices.iter().find(|ix| {
267 matches!(ix.kind,
268 IndexKind::Inverted(InvertedKind::Ann { metric }) if metric == op_metric
269 ) && ix.keys.len() == 1
270 && matches!(ix.keys[0],
271 IndexKey::Column { idx, .. } if idx as usize == col_idx
272 )
273 });
274 let Some(ann_index) = ann_index else {
275 return Ok(None);
276 };
277 let filter_cols = ann_index.ann_filter_cols.clone();
278
279 if table_schema.primary_key_columns.len() != 1 {
280 return Ok(None);
281 }
282 let pk_col = &table_schema.columns[table_schema.primary_key_columns[0] as usize];
283 if !matches!(pk_col.data_type, DataType::Integer) {
284 return Ok(None);
285 }
286
287 let mut pushable: Vec<(usize, Vec<Value>)> = Vec::new();
289 let mut residual_leaves: Vec<Expr> = Vec::new();
290 if let Some(w) = &stmt.where_clause {
291 split_where(
292 w,
293 &filter_cols,
294 table_schema,
295 &mut pushable,
296 &mut residual_leaves,
297 );
298 if pushable.is_empty() {
299 return Ok(None);
300 }
301 }
302 let residual = fold_and(residual_leaves);
303
304 let k_limit = eval_const_int(stmt.limit.as_ref().unwrap())?.max(0) as usize;
305 let offset = stmt
306 .offset
307 .as_ref()
308 .map(eval_const_int)
309 .transpose()?
310 .unwrap_or(0)
311 .max(0) as usize;
312 if k_limit == 0 {
313 return Ok(None);
314 }
315
316 Ok(Some(Self {
317 col_idx,
318 dim,
319 metric: op_metric,
320 query_vec,
321 k: k_limit,
322 offset,
323 filter_cols,
324 pushable,
325 residual,
326 }))
327 }
328
329 pub(super) fn execute_with_read(
330 &self,
331 rtx: &mut ReadTxn<'_>,
332 schema: &SchemaManager,
333 stmt: &SelectStmt,
334 table_schema: &TableSchema,
335 ) -> Result<ExecutionResult> {
336 let cache_key = cache_key(&table_schema.name, self.col_idx, self.metric);
337 let mut force_rebuild = false;
339 loop {
340 if force_rebuild {
341 schema.sql_caches.lock().remove(&cache_key);
342 }
343 let Some(cached) = self.load_or_build_index(rtx, schema, &cache_key, table_schema)?
344 else {
345 return empty_result(table_schema, stmt);
346 };
347 match self.run_query(rtx, &cached, stmt, table_schema, !force_rebuild)? {
348 RunOutcome::Done(result) => return Ok(result),
349 RunOutcome::Rebuild => force_rebuild = true,
350 }
351 }
352 }
353
354 fn run_query(
356 &self,
357 txn: &mut dyn AnnScan,
358 cached: &CachedAnnIndex,
359 stmt: &SelectStmt,
360 table_schema: &TableSchema,
361 allow_rebuild: bool,
362 ) -> Result<RunOutcome> {
363 let mut constraints: Vec<(usize, Vec<u32>)> = Vec::with_capacity(self.pushable.len());
366 let mut index_unsat = false;
367 for (dim, values) in &self.pushable {
368 let dict = &cached.dicts[*dim];
369 let coll = table_schema.columns[self.filter_cols[*dim] as usize].collation;
370 let mut codes = Vec::with_capacity(values.len());
371 let mut canon = Vec::with_capacity(16);
372 for v in values {
373 canon.clear();
374 encode_key_value_collated_into(v, coll, &mut canon);
375 if let Some(&code) = dict.get(canon.as_slice()) {
376 codes.push(code);
377 }
378 }
379 if codes.is_empty() {
380 index_unsat = true;
381 }
382 constraints.push((*dim, codes));
383 }
384
385 let want = self.k.saturating_add(self.offset).max(1);
386 let mut merged: Vec<RankedRow> = if index_unsat {
387 Vec::new()
388 } else {
389 let filter = if constraints.is_empty() {
390 Filter::none()
391 } else {
392 Filter::new(constraints)
393 };
394 self.collect_survivors(txn, &cached.index, &filter, table_schema, want)?
395 };
396
397 match self.collect_tail(txn, &cached.index, table_schema, allow_rebuild)? {
398 Some(tail) => merged.extend(tail),
399 None => return Ok(RunOutcome::Rebuild),
400 }
401
402 merged.sort_by(|a, b| a.0.total_cmp(&b.0).then_with(|| a.1.cmp(&b.1)));
404 let mut rows: Vec<Vec<Value>> = merged.into_iter().map(|(_, _, row)| row).collect();
405
406 if self.offset >= rows.len() {
407 rows.clear();
408 } else if self.offset > 0 {
409 rows = rows.split_off(self.offset);
410 }
411 rows.truncate(self.k);
412
413 let (col_names, projected) = project_rows(&table_schema.columns, &stmt.columns, rows)?;
414 Ok(RunOutcome::Done(ExecutionResult::Query(QueryResult {
415 columns: col_names,
416 rows: projected,
417 })))
418 }
419
420 fn collect_survivors(
422 &self,
423 txn: &mut dyn AnnScan,
424 index: &AnnIndex,
425 filter: &Filter,
426 table_schema: &TableSchema,
427 want: usize,
428 ) -> Result<Vec<RankedRow>> {
429 let col_map = ColumnMap::new(&table_schema.columns);
430 let max_target = index.indexed_len().max(1);
431 let mut key_buf: Vec<u8> = Vec::with_capacity(10);
432 let mut target = want;
433 loop {
434 target = target.min(max_target);
435 let hits = index.search_filtered_default_ef(&self.query_vec, target, filter);
436 let mut survivors: Vec<RankedRow> = Vec::with_capacity(want);
437 for (id, dist) in &hits {
438 encode_int_key_into(*id as i64, &mut key_buf);
439 let Some(row_bytes) = txn.ann_get(table_schema.name.as_bytes(), &key_buf)? else {
440 continue;
441 };
442 let row = decode_full_row(table_schema, &key_buf, &row_bytes)?;
443 let keep = match &self.residual {
444 None => true,
445 Some(expr) => {
446 let ctx = EvalCtx::new(&col_map, &row);
447 is_truthy(&eval_expr(expr, &ctx)?)
448 }
449 };
450 if keep {
451 survivors.push((*dist as f64, *id as i64, row));
452 if survivors.len() >= want {
453 break;
454 }
455 }
456 }
457 if survivors.len() >= want || target >= max_target || hits.len() < target {
459 return Ok(survivors);
460 }
461 target = target.saturating_mul(2);
462 }
463 }
464
465 fn collect_tail(
467 &self,
468 txn: &mut dyn AnnScan,
469 index: &AnnIndex,
470 table_schema: &TableSchema,
471 allow_rebuild: bool,
472 ) -> Result<Option<Vec<RankedRow>>> {
473 let snapshot_max = index.snapshot_max;
474 let first_tail_pk = match (snapshot_max as i64).checked_add(1) {
477 Some(pk) if (snapshot_max as i64) >= 0 => pk,
478 _ => return Ok(Some(Vec::new())),
479 };
480 let mut start_key: Vec<u8> = Vec::with_capacity(10);
481 encode_int_key_into(first_tail_pk, &mut start_key);
482
483 let col_map = ColumnMap::new(&table_schema.columns);
484 let mut out: Vec<RankedRow> = Vec::new();
485 let mut seen: u64 = 0;
486 let mut over_threshold = false;
487
488 txn.ann_scan_from(
489 table_schema.name.as_bytes(),
490 &start_key,
491 &mut |key, value| {
492 seen += 1;
493 if allow_rebuild && index.tail_is_stale(snapshot_max.saturating_add(seen)) {
494 over_threshold = true;
495 return Ok(false);
496 }
497 let row = decode_full_row(table_schema, key, value)?;
498 if !self.tail_passes_pushable(&row, table_schema) {
499 return Ok(true);
500 }
501 if let Some(expr) = &self.residual {
502 let ctx = EvalCtx::new(&col_map, &row);
503 if !is_truthy(&eval_expr(expr, &ctx)?) {
504 return Ok(true);
505 }
506 }
507 let dist = match &row[self.col_idx] {
508 Value::Vector(v) => match tail_distance(self.metric, &self.query_vec, v) {
509 Some(d) => d,
510 None => return Ok(true), },
512 Value::Null => return Ok(true), _ => {
514 return Err(SqlError::InvalidValue(
515 "ANN column produced non-vector value".into(),
516 ))
517 }
518 };
519 out.push((dist, decode_pk_integer(key)?, row));
520 Ok(true)
521 },
522 )?;
523
524 if over_threshold {
525 return Ok(None);
526 }
527 Ok(Some(out))
528 }
529
530 fn tail_passes_pushable(&self, row: &[Value], table_schema: &TableSchema) -> bool {
532 for (dim, values) in &self.pushable {
533 let col = self.filter_cols[*dim] as usize;
534 let coll = table_schema.columns[col].collation;
535 let mut canon_row = Vec::with_capacity(16);
536 encode_key_value_collated_into(&row[col], coll, &mut canon_row);
537 let matched = values.iter().any(|v| {
538 let mut canon_v = Vec::with_capacity(16);
539 encode_key_value_collated_into(v, coll, &mut canon_v);
540 canon_v == canon_row
541 });
542 if !matched {
543 return false;
544 }
545 }
546 true
547 }
548
549 fn load_or_build_index(
550 &self,
551 txn: &mut dyn AnnScan,
552 schema: &SchemaManager,
553 cache_key: &str,
554 table_schema: &TableSchema,
555 ) -> Result<Option<Arc<CachedAnnIndex>>> {
556 if let Some(existing) = lookup_cached(schema, cache_key, &table_schema.name)? {
557 return Ok(Some(existing));
558 }
559 let spec = AnnSpec {
560 col_idx: self.col_idx,
561 dim: self.dim,
562 metric: self.metric,
563 filter_cols: self.filter_cols.clone(),
564 };
565 load_or_build(txn, schema, cache_key, table_schema, &spec)
566 }
567}
568
569pub(super) struct AnnSpec {
572 pub col_idx: usize,
573 pub dim: u16,
574 pub metric: AnnMetric,
575 pub filter_cols: Vec<u16>,
576}
577
578impl AnnSpec {
579 fn metric_tag(&self) -> u8 {
580 citadel_vector::segment::metric_tag(ann_metric_to_prism(self.metric))
581 }
582}
583
584struct ScanOutcome {
587 rows: Vec<(u64, Vec<f32>, Vec<u32>)>,
588 dicts: Vec<FxHashMap<Vec<u8>, u32>>,
589 fingerprint: [u8; 32],
590}
591
592fn scan_rows(
593 txn: &mut dyn AnnScan,
594 table_schema: &TableSchema,
595 spec: &AnnSpec,
596) -> Result<ScanOutcome> {
597 let non_pk = table_schema.non_pk_indices();
598 let enc_pos = table_schema.encoding_positions();
599 let nonpk_order = non_pk
600 .iter()
601 .position(|&i| i == spec.col_idx)
602 .ok_or_else(|| {
603 SqlError::InvalidValue("vector column must be non-PK for ANN build".into())
604 })?;
605 let enc_idx = enc_pos[nonpk_order] as usize;
606
607 let num_attrs = spec.filter_cols.len();
608 let extracts: Vec<Extract> = spec
609 .filter_cols
610 .iter()
611 .map(|&c| extract_plan(c, table_schema, non_pk, enc_pos))
612 .collect::<Result<_>>()?;
613 let collations: Vec<Collation> = spec
616 .filter_cols
617 .iter()
618 .map(|&c| table_schema.columns[c as usize].collation)
619 .collect();
620 let mut dicts: Vec<FxHashMap<Vec<u8>, u32>> = vec![FxHashMap::default(); num_attrs];
621 let mut fp = ann_persist::FingerprintHasher::new(
622 &table_schema.name,
623 spec.col_idx as u32,
624 &spec
625 .filter_cols
626 .iter()
627 .map(|&c| c as u32)
628 .collect::<Vec<_>>(),
629 spec.dim,
630 spec.metric_tag(),
631 );
632 let mut rows: Vec<(u64, Vec<f32>, Vec<u32>)> = Vec::new();
633
634 txn.ann_scan(table_schema.name.as_bytes(), &mut |key, value| {
635 let vector = match decode_column_raw(value, enc_idx)?.to_value() {
636 Value::Vector(arr) => Some(arr.to_vec()),
637 Value::Null => None, _ => {
639 return Err(SqlError::InvalidValue(
640 "ANN column produced non-vector value".into(),
641 ))
642 }
643 };
644 let mut filter_vals: Vec<Value> = Vec::with_capacity(num_attrs);
645 for ex in &extracts {
646 filter_vals.push(ex.extract(key, value)?);
647 }
648 let encoded_filters: Vec<Vec<u8>> = filter_vals.iter().map(encode_key_value).collect();
649 let vec_bytes: Vec<u8> = vector
650 .as_deref()
651 .unwrap_or(&[])
652 .iter()
653 .flat_map(|f| f.to_le_bytes())
654 .collect();
655 fp.row(
656 key,
657 &vec_bytes,
658 &encoded_filters
659 .iter()
660 .map(Vec::as_slice)
661 .collect::<Vec<_>>(),
662 );
663 let Some(vector) = vector else {
664 return Ok(true);
665 };
666 let id = decode_pk_integer(key)? as u64;
667 let mut codes: Vec<u32> = Vec::with_capacity(num_attrs);
668 for (j, v) in filter_vals.iter().enumerate() {
669 let mut canon = Vec::with_capacity(16);
670 encode_key_value_collated_into(v, collations[j], &mut canon);
671 let next = dicts[j].len() as u32;
672 codes.push(*dicts[j].entry(canon).or_insert(next));
673 }
674 rows.push((id, vector, codes));
675 Ok(true)
676 })?;
677
678 Ok(ScanOutcome {
679 rows,
680 dicts,
681 fingerprint: fp.finish(),
682 })
683}
684
685#[cfg(test)]
687fn note_ann_rebuild() {
688 ANN_REBUILD_COUNT.with(|c| c.set(c.get() + 1));
689}
690
691#[cfg(test)]
692thread_local! {
693 static ANN_REBUILD_COUNT: std::cell::Cell<u64> = const { std::cell::Cell::new(0) };
694}
695
696#[cfg(test)]
697pub(super) fn take_ann_rebuilds() -> u64 {
698 ANN_REBUILD_COUNT.with(|c| c.replace(0))
699}
700
701fn build_index(
703 txn: &mut dyn AnnScan,
704 table_schema: &TableSchema,
705 spec: &AnnSpec,
706 refusal: Option<String>,
707 cached_gen: u64,
708) -> Result<Option<CachedAnnIndex>> {
709 let outcome = scan_rows(txn, table_schema, spec)?;
710 if outcome.rows.is_empty() {
711 return Ok(None);
712 }
713 let index = AnnIndex::build_with_attrs(
714 outcome.rows,
715 spec.filter_cols.len(),
716 ann_metric_to_prism(spec.metric),
717 spec.dim,
718 )
719 .map_err(|e| SqlError::InvalidValue(format!("ANN build failed: {e}")))?;
720 #[cfg(test)]
721 note_ann_rebuild();
722 Ok(Some(CachedAnnIndex {
723 index,
724 dicts: outcome.dicts,
725 source: AnnIndexSource::Built { refusal },
726 cached_gen,
727 }))
728}
729
730enum LoadOutcome {
733 Loaded(Box<CachedAnnIndex>),
734 NoSegment,
735 Refused { reason: String, corrupt: bool },
736}
737
738fn try_load_segment(
741 txn: &mut dyn AnnScan,
742 table_schema: &TableSchema,
743 spec: &AnnSpec,
744 cached_gen: u64,
745) -> Result<LoadOutcome> {
746 let seg_table = ann_persist::segment_table_name(&table_schema.name);
747 let header_bytes = match txn.ann_get(&seg_table, &ann_persist::segment_key(0)) {
748 Ok(Some(b)) => b,
749 Ok(None) | Err(_) => return Ok(LoadOutcome::NoSegment),
751 };
752 let refuse = |reason: String, corrupt: bool| Ok(LoadOutcome::Refused { reason, corrupt });
753 let header = match ann_persist::SegmentHeader::decode(&header_bytes) {
754 Ok(h) => h,
755 Err(e) => return refuse(format!("header: {e}"), true),
756 };
757 if header.format_version != ann_persist::ANNSEG_FORMAT_VERSION {
758 return refuse(
759 format!("format v{} (this binary reads v2)", header.format_version),
760 false,
761 );
762 }
763 let active_cfg = citadel_vector::segment::prism_config_hash(&AnnIndex::active_config(
764 ann_metric_to_prism(spec.metric),
765 ));
766 if header.prism_config_hash != active_cfg {
767 return refuse(
768 "PRISM config drift (segment built by another geometry)".into(),
769 false,
770 );
771 }
772 if header.dim != spec.dim
773 || header.metric_tag != spec.metric_tag()
774 || header.col_idx != spec.col_idx as u32
775 || header.filter_cols
776 != spec
777 .filter_cols
778 .iter()
779 .map(|&c| c as u32)
780 .collect::<Vec<_>>()
781 {
782 return refuse(
783 "index identity mismatch (column/metric/filter set)".into(),
784 false,
785 );
786 }
787
788 let mut body = Vec::new();
789 for chunk_no in 1..=header.chunk_count {
790 match txn.ann_get(&seg_table, &ann_persist::segment_key(chunk_no)) {
791 Ok(Some(c)) => body.extend_from_slice(&c),
792 _ => return refuse(format!("missing chunk {chunk_no}"), true),
793 }
794 }
795 if *blake3::hash(&body).as_bytes() != header.segment_b3 {
796 return refuse("segment body BLAKE3 mismatch (corrupt)".into(), true);
797 }
798 let parts = match citadel_vector::segment::decode(&body) {
799 Ok(p) => p,
800 Err(e) => return refuse(format!("segment decode: {e}"), true),
801 };
802 if parts.n() as u64 != header.n || parts.dim() != header.dim {
803 return refuse("segment body disagrees with header counts".into(), true);
804 }
805
806 match txn.ann_table_root(table_schema.name.as_bytes()) {
808 Some(live) if live == header.table_root => {}
809 _ => {
810 return refuse(
811 "stale: table root moved since the segment was persisted".into(),
812 false,
813 )
814 }
815 }
816
817 let index = parts.into_index_embedded();
819 Ok(LoadOutcome::Loaded(Box::new(CachedAnnIndex {
820 index,
821 dicts: header.dict_maps(),
822 source: AnnIndexSource::Loaded {
823 segment_b3: header.segment_b3,
824 },
825 cached_gen,
826 })))
827}
828
829fn load_or_build(
832 txn: &mut dyn AnnScan,
833 schema: &SchemaManager,
834 cache_key: &str,
835 table_schema: &TableSchema,
836 spec: &AnnSpec,
837) -> Result<Option<Arc<CachedAnnIndex>>> {
838 let gen = txn.cache_generation();
839 let cached_gen = gen.unwrap_or(u64::MAX);
840 let loaded = match try_load_segment(txn, table_schema, spec, cached_gen)? {
841 LoadOutcome::Loaded(c) => Some(*c),
842 LoadOutcome::NoSegment => None,
843 LoadOutcome::Refused { reason, corrupt } => {
844 if corrupt {
845 eprintln!(
846 "citadel-sql: ANN segment for `{}` REFUSED as corrupt ({reason}); \
847 rebuilding from scan - investigate before re-persisting",
848 table_schema.name
849 );
850 }
851 match build_index(txn, table_schema, spec, Some(reason), cached_gen)? {
853 Some(c) => Some(c),
854 None => return Ok(None),
855 }
856 }
857 };
858 let built = match loaded {
859 Some(c) => c,
860 None => match build_index(txn, table_schema, spec, None, cached_gen)? {
861 Some(c) => c,
862 None => return Ok(None),
863 },
864 };
865 let arc: Arc<CachedAnnIndex> = Arc::new(built);
866 if gen.is_none() {
867 return Ok(Some(arc));
869 }
870 let mut guard = schema.sql_caches.lock();
871 if let Some(existing) = guard.get(cache_key) {
872 return Arc::clone(existing)
874 .downcast::<CachedAnnIndex>()
875 .map(Some)
876 .map_err(|_| {
877 SqlError::InvalidValue(format!("ANN cache type mismatch for {cache_key}"))
878 });
879 }
880 let marker = marker_gen_locked(&guard, &table_schema.name);
881 if marker.is_some_and(|g| arc.cached_gen < g) {
882 return Ok(Some(arc));
884 }
885 let as_any: Arc<dyn Any + Send + Sync> = arc.clone();
886 guard.insert(cache_key.to_string(), as_any);
887 Ok(Some(arc))
888}
889
890pub(super) struct VectorTopKPlan {
893 order_expr: Expr,
894 where_clause: Option<Expr>,
895 k: usize,
896 offset: usize,
897 nulls_first: bool,
898}
899
900struct Ranked {
903 dist: f64,
904 seq: u64,
905 row: Vec<Value>,
906}
907
908impl PartialEq for Ranked {
909 fn eq(&self, other: &Self) -> bool {
910 self.cmp(other) == Ordering::Equal
911 }
912}
913impl Eq for Ranked {}
914impl PartialOrd for Ranked {
915 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
916 Some(self.cmp(other))
917 }
918}
919impl Ord for Ranked {
920 fn cmp(&self, other: &Self) -> Ordering {
921 self.dist
922 .total_cmp(&other.dist)
923 .then_with(|| self.seq.cmp(&other.seq))
924 }
925}
926
927impl VectorTopKPlan {
928 pub(super) fn try_new(stmt: &SelectStmt, table_schema: &TableSchema) -> Result<Option<Self>> {
929 if !topk_shape_ok(stmt) {
930 return Ok(None);
931 }
932 let ob = &stmt.order_by[0];
933 let Expr::BinaryOp { left, op, .. } = &ob.expr else {
934 return Ok(None);
935 };
936 if !matches!(
937 op,
938 BinOp::VectorL2 | BinOp::VectorInner | BinOp::VectorCosine
939 ) {
940 return Ok(None);
941 }
942 let Expr::Column(name) = left.as_ref() else {
944 return Ok(None);
945 };
946 let name = name.to_ascii_lowercase();
947 let is_vector_col = table_schema.columns.iter().any(|c| {
948 c.name.to_ascii_lowercase() == name && matches!(c.data_type, DataType::Vector { .. })
949 });
950 if !is_vector_col {
951 return Ok(None);
952 }
953
954 let k = eval_const_int(stmt.limit.as_ref().unwrap())?.max(0) as usize;
955 if k == 0 {
956 return Ok(None);
957 }
958 let offset = stmt
959 .offset
960 .as_ref()
961 .map(eval_const_int)
962 .transpose()?
963 .unwrap_or(0)
964 .max(0) as usize;
965
966 Ok(Some(Self {
967 order_expr: ob.expr.clone(),
968 where_clause: stmt.where_clause.clone(),
969 k,
970 offset,
971 nulls_first: ob.nulls_first.unwrap_or(true),
973 }))
974 }
975
976 pub(super) fn execute(
977 &self,
978 txn: &mut dyn AnnScan,
979 table_schema: &TableSchema,
980 stmt: &SelectStmt,
981 ) -> Result<ExecutionResult> {
982 let want = self.k.saturating_add(self.offset);
983 let col_map = ColumnMap::new(&table_schema.columns);
984 let null_dist = if self.nulls_first {
986 f64::NEG_INFINITY
987 } else {
988 f64::INFINITY
989 };
990 let mut heap: BinaryHeap<Ranked> = BinaryHeap::new();
991 let mut seq: u64 = 0;
992
993 txn.ann_scan(table_schema.name.as_bytes(), &mut |key, value| {
994 let row = decode_full_row(table_schema, key, value)?;
995 let ctx = EvalCtx::new(&col_map, &row);
996 if let Some(w) = &self.where_clause {
997 if !is_truthy(&eval_expr(w, &ctx)?) {
998 return Ok(true);
999 }
1000 }
1001 let dist = match eval_expr(&self.order_expr, &ctx)? {
1002 Value::Real(d) => d,
1003 Value::Integer(i) => i as f64,
1004 Value::Null => null_dist,
1005 other => {
1006 return Err(SqlError::InvalidValue(format!(
1007 "ORDER BY vector distance produced a non-numeric {}",
1008 other.data_type()
1009 )))
1010 }
1011 };
1012 let cand = Ranked { dist, seq, row };
1013 seq += 1;
1014 if heap.len() < want {
1016 heap.push(cand);
1017 } else if heap.peek().is_some_and(|top| cand < *top) {
1018 heap.pop();
1019 heap.push(cand);
1020 }
1021 Ok(true)
1022 })?;
1023
1024 let mut rows: Vec<Vec<Value>> = heap.into_sorted_vec().into_iter().map(|r| r.row).collect();
1025 if self.offset >= rows.len() {
1026 rows.clear();
1027 } else if self.offset > 0 {
1028 rows = rows.split_off(self.offset);
1029 }
1030 rows.truncate(self.k);
1031
1032 let (col_names, projected) = project_rows(&table_schema.columns, &stmt.columns, rows)?;
1033 Ok(ExecutionResult::Query(QueryResult {
1034 columns: col_names,
1035 rows: projected,
1036 }))
1037 }
1038}
1039
1040enum Extract {
1042 Pk,
1044 NonPk(usize),
1046}
1047
1048impl Extract {
1049 fn extract(&self, key: &[u8], value: &[u8]) -> Result<Value> {
1050 match self {
1051 Extract::Pk => Ok(Value::Integer(decode_pk_integer(key)?)),
1052 Extract::NonPk(ei) => Ok(decode_column_raw(value, *ei)?.to_value()),
1053 }
1054 }
1055}
1056
1057fn extract_plan(
1058 col: u16,
1059 table_schema: &TableSchema,
1060 non_pk: &[usize],
1061 enc_pos: &[u16],
1062) -> Result<Extract> {
1063 if table_schema.primary_key_columns.contains(&col) {
1064 return Ok(Extract::Pk);
1065 }
1066 let order = non_pk
1067 .iter()
1068 .position(|&i| i == col as usize)
1069 .ok_or_else(|| SqlError::InvalidValue("ANN filter column not found in row".into()))?;
1070 Ok(Extract::NonPk(enc_pos[order] as usize))
1071}
1072
1073fn split_where(
1076 expr: &Expr,
1077 filter_cols: &[u16],
1078 table_schema: &TableSchema,
1079 pushable: &mut Vec<(usize, Vec<Value>)>,
1080 residual: &mut Vec<Expr>,
1081) {
1082 if let Expr::BinaryOp {
1083 left,
1084 op: BinOp::And,
1085 right,
1086 } = expr
1087 {
1088 split_where(left, filter_cols, table_schema, pushable, residual);
1089 split_where(right, filter_cols, table_schema, pushable, residual);
1090 return;
1091 }
1092 match classify_leaf(expr, filter_cols, table_schema) {
1093 Some(constraint) => pushable.push(constraint),
1094 None => residual.push(expr.clone()),
1095 }
1096}
1097
1098enum Coerced {
1100 Exact(Value),
1102 NeverMatches,
1105 Residual,
1109}
1110
1111fn coerce_pushdown_literal(val: Value, col_type: DataType) -> Coerced {
1112 const EXACT_F64_INT: f64 = 9_007_199_254_740_992.0;
1114 if val.is_null() {
1115 return Coerced::Residual;
1116 }
1117 if val.data_type() == col_type {
1118 return Coerced::Exact(val);
1119 }
1120 match (val, col_type) {
1121 (Value::Real(r), DataType::Integer) => {
1122 if r.is_nan() || r.is_infinite() {
1123 Coerced::NeverMatches
1124 } else if r.abs() > EXACT_F64_INT {
1125 Coerced::Residual
1126 } else if r.fract() == 0.0 {
1127 Coerced::Exact(Value::Integer(r as i64))
1128 } else {
1129 Coerced::NeverMatches
1130 }
1131 }
1132 (Value::Integer(i), DataType::Real) => {
1133 if i.unsigned_abs() <= EXACT_F64_INT as u64 {
1134 Coerced::Exact(Value::Real(i as f64))
1135 } else {
1136 Coerced::Residual
1137 }
1138 }
1139 _ => Coerced::Residual,
1140 }
1141}
1142
1143fn classify_leaf(
1148 leaf: &Expr,
1149 filter_cols: &[u16],
1150 table_schema: &TableSchema,
1151) -> Option<(usize, Vec<Value>)> {
1152 let (col_expr, rhs): (&Expr, Vec<&Expr>) = match leaf {
1153 Expr::BinaryOp {
1154 left,
1155 op: BinOp::Eq,
1156 right,
1157 } => (left, vec![right.as_ref()]),
1158 Expr::InList {
1159 expr,
1160 list,
1161 negated: false,
1162 } => (expr, list.iter().collect()),
1163 _ => return None,
1164 };
1165 let dim = filter_dim(col_expr, filter_cols, table_schema)?;
1166 let col_type = table_schema.columns[filter_cols[dim] as usize].data_type;
1167 let mut vals = Vec::with_capacity(rhs.len());
1168 for e in rhs {
1169 match coerce_pushdown_literal(eval_const_expr(e).ok()?, col_type) {
1170 Coerced::Exact(v) => vals.push(v),
1171 Coerced::NeverMatches => {}
1172 Coerced::Residual => return None,
1173 }
1174 }
1175 Some((dim, vals))
1176}
1177
1178fn filter_dim(expr: &Expr, filter_cols: &[u16], table_schema: &TableSchema) -> Option<usize> {
1181 let name = match expr {
1182 Expr::Column(c) => c.to_ascii_lowercase(),
1183 Expr::QualifiedColumn { column, .. } => column.to_ascii_lowercase(),
1184 _ => return None,
1185 };
1186 let col_idx = table_schema
1187 .columns
1188 .iter()
1189 .position(|c| c.name.to_ascii_lowercase() == name)? as u16;
1190 filter_cols.iter().position(|&c| c == col_idx)
1191}
1192
1193fn fold_and(mut leaves: Vec<Expr>) -> Option<Expr> {
1194 if leaves.is_empty() {
1195 return None;
1196 }
1197 let first = leaves.remove(0);
1198 Some(leaves.into_iter().fold(first, |acc, e| Expr::BinaryOp {
1199 left: Box::new(acc),
1200 op: BinOp::And,
1201 right: Box::new(e),
1202 }))
1203}
1204
1205fn empty_result(table_schema: &TableSchema, stmt: &SelectStmt) -> Result<ExecutionResult> {
1206 let (col_names, projected) = project_rows(&table_schema.columns, &stmt.columns, Vec::new())?;
1207 Ok(ExecutionResult::Query(QueryResult {
1208 columns: col_names,
1209 rows: projected,
1210 }))
1211}
1212
1213pub(crate) fn persist_ann_index(
1219 db: &citadel::Database,
1220 schema: &SchemaManager,
1221 table_schema: &TableSchema,
1222 column: &str,
1223) -> Result<ann_persist::AnnSegmentInfo> {
1224 let col_lower = column.to_ascii_lowercase();
1225 let col_idx = table_schema
1226 .columns
1227 .iter()
1228 .position(|c| c.name == col_lower)
1229 .ok_or_else(|| SqlError::ColumnNotFound(column.to_string()))?;
1230 let DataType::Vector { dim } = table_schema.columns[col_idx].data_type else {
1231 return Err(SqlError::InvalidValue(format!(
1232 "column `{column}` is not VECTOR(N)"
1233 )));
1234 };
1235 if table_schema.primary_key_columns.len() != 1
1238 || !matches!(
1239 table_schema.columns[table_schema.primary_key_columns[0] as usize].data_type,
1240 DataType::Integer
1241 )
1242 {
1243 return Err(SqlError::InvalidValue(
1244 "ANN persistence requires a single INTEGER primary key (same rule as the \
1245 ANN query plan)"
1246 .into(),
1247 ));
1248 }
1249 let ann_index = table_schema
1250 .indices
1251 .iter()
1252 .find(|ix| {
1253 matches!(ix.kind, IndexKind::Inverted(InvertedKind::Ann { .. }))
1254 && ix.keys.len() == 1
1255 && matches!(ix.keys[0], IndexKey::Column { idx, .. } if idx as usize == col_idx)
1256 })
1257 .ok_or_else(|| SqlError::InvalidValue(format!("no ANN index declared on `{column}`")))?;
1258 let IndexKind::Inverted(InvertedKind::Ann { metric }) = ann_index.kind else {
1259 unreachable!("matched above");
1260 };
1261 let spec = AnnSpec {
1262 col_idx,
1263 dim,
1264 metric,
1265 filter_cols: ann_index.ann_filter_cols.clone(),
1266 };
1267
1268 let mut wtx = db.begin_write().map_err(SqlError::Storage)?;
1269 let outcome = scan_rows(&mut wtx, table_schema, &spec)?;
1270 if outcome.rows.is_empty() {
1271 return Err(SqlError::InvalidValue(
1272 "nothing to persist: the table has no indexable (non-NULL) vectors".into(),
1273 ));
1274 }
1275 let n = outcome.rows.len() as u64;
1276 let index = AnnIndex::build_with_attrs(
1277 outcome.rows,
1278 spec.filter_cols.len(),
1279 ann_metric_to_prism(spec.metric),
1280 spec.dim,
1281 )
1282 .map_err(|e| SqlError::InvalidValue(format!("ANN build failed: {e}")))?;
1283
1284 let body = citadel_vector::segment::encode(&index);
1285 let segment_b3 = *blake3::hash(&body).as_bytes();
1286 let dicts_ordered: Vec<Vec<(Vec<u8>, u32)>> = outcome
1288 .dicts
1289 .iter()
1290 .map(|d| {
1291 let mut entries: Vec<(Vec<u8>, u32)> = d.iter().map(|(k, &v)| (k.clone(), v)).collect();
1292 entries.sort_by_key(|&(_, code)| code);
1293 entries
1294 })
1295 .collect();
1296 let table_root = wtx
1298 .table_root_page(table_schema.name.as_bytes())
1299 .map_err(SqlError::Storage)?
1300 .map(|p| u64::from(p.0))
1301 .ok_or_else(|| SqlError::InvalidValue("table vanished during ANN persist".into()))?;
1302 let header = ann_persist::SegmentHeader {
1303 format_version: ann_persist::ANNSEG_FORMAT_VERSION,
1304 prism_config_hash: ann_persist::active_config_hash(ann_metric_to_prism(spec.metric)),
1305 dim: spec.dim,
1306 metric_tag: spec.metric_tag(),
1307 n,
1308 snapshot_max: index.snapshot_max,
1309 table_root,
1310 col_idx: spec.col_idx as u32,
1311 filter_cols: spec.filter_cols.iter().map(|&c| c as u32).collect(),
1312 dicts: dicts_ordered,
1313 content_fingerprint: outcome.fingerprint,
1314 segment_b3,
1315 chunk_count: body.len().div_ceil(ann_persist::CHUNK_BYTES) as u32,
1316 writer: format!("citadel-sql {}", env!("CARGO_PKG_VERSION")),
1317 };
1318
1319 let seg_table = ann_persist::segment_table_name(&table_schema.name);
1320 ann_persist::purge_segment(&mut wtx, &table_schema.name)?;
1321 wtx.create_table(&seg_table).map_err(SqlError::Storage)?;
1322 wtx.table_insert(&seg_table, &ann_persist::segment_key(0), &header.encode())
1323 .map_err(SqlError::Storage)?;
1324 for (chunk_no, chunk) in ann_persist::chunks(&body) {
1325 wtx.table_insert(&seg_table, &ann_persist::segment_key(chunk_no), chunk)
1326 .map_err(SqlError::Storage)?;
1327 }
1328 wtx.commit().map_err(SqlError::Storage)?;
1329
1330 let cached = CachedAnnIndex {
1333 index,
1334 dicts: outcome.dicts,
1335 source: AnnIndexSource::Built { refusal: None },
1336 cached_gen: db.manager().commit_generation(),
1337 };
1338 let key = cache_key(&table_schema.name, spec.col_idx, spec.metric);
1339 let as_any: Arc<dyn Any + Send + Sync> = Arc::new(cached);
1340 schema.sql_caches.lock().insert(key, as_any);
1341
1342 Ok(ann_persist::AnnSegmentInfo {
1343 segment_b3,
1344 content_fingerprint: header.content_fingerprint,
1345 n,
1346 dim: spec.dim,
1347 metric_tag: header.metric_tag,
1348 chunk_count: header.chunk_count,
1349 })
1350}
1351
1352pub(crate) fn ann_cache_status(
1355 schema: &SchemaManager,
1356 table_schema: &TableSchema,
1357 column: &str,
1358) -> Result<Option<(AnnIndexSource, u64)>> {
1359 let col_lower = column.to_ascii_lowercase();
1360 let col_idx = table_schema
1361 .columns
1362 .iter()
1363 .position(|c| c.name == col_lower)
1364 .ok_or_else(|| SqlError::ColumnNotFound(column.to_string()))?;
1365 let guard = schema.sql_caches.lock();
1366 for metric in [AnnMetric::L2, AnnMetric::Inner, AnnMetric::Cosine] {
1367 let key = cache_key(&table_schema.name, col_idx, metric);
1368 if let Some(entry) = guard.get(&key) {
1369 if let Ok(c) = Arc::clone(entry).downcast::<CachedAnnIndex>() {
1370 return Ok(Some((c.source.clone(), c.cached_gen)));
1371 }
1372 }
1373 }
1374 Ok(None)
1375}
1376
1377pub(crate) fn ann_dml_gen_key(table_name: &str) -> String {
1381 format!("ann_dml_gen:{table_name}")
1382}
1383
1384pub(crate) fn ann_appends_safe(schema: &SchemaManager, table: &str, min_pk: i64) -> bool {
1387 let prefix = format!("ann:{}:", table.to_ascii_lowercase());
1388 let guard = schema.sql_caches.lock();
1389 for (key, val) in guard.iter() {
1390 if !key.starts_with(&prefix) {
1391 continue;
1392 }
1393 if let Some(cached) = val.downcast_ref::<CachedAnnIndex>() {
1394 let snap = cached.index.snapshot_max as i64;
1395 if snap < 0 || min_pk <= snap {
1396 return false;
1397 }
1398 }
1399 }
1400 true
1401}
1402
1403fn marker_gen_locked(
1405 entries: &FxHashMap<String, Arc<dyn Any + Send + Sync>>,
1406 table_name: &str,
1407) -> Option<u64> {
1408 entries
1409 .get(&ann_dml_gen_key(table_name))
1410 .and_then(|e| e.downcast_ref::<u64>())
1411 .copied()
1412}
1413
1414fn lookup_cached(
1415 schema: &SchemaManager,
1416 cache_key: &str,
1417 table_name: &str,
1418) -> Result<Option<Arc<CachedAnnIndex>>> {
1419 let mut guard = schema.sql_caches.lock();
1420 let Some(entry) = guard.get(cache_key) else {
1421 return Ok(None);
1422 };
1423 let entry = Arc::clone(entry)
1424 .downcast::<CachedAnnIndex>()
1425 .map_err(|_| SqlError::InvalidValue(format!("ANN cache type mismatch for {cache_key}")))?;
1426 if marker_gen_locked(&guard, table_name).is_some_and(|g| entry.cached_gen < g) {
1427 guard.remove(cache_key);
1429 return Ok(None);
1430 }
1431 Ok(Some(entry))
1432}
1433
1434pub(super) fn cache_key(table_name: &str, col_idx: usize, metric: AnnMetric) -> String {
1435 let tag = match metric {
1436 AnnMetric::L2 => "l2",
1437 AnnMetric::Inner => "inner",
1438 AnnMetric::Cosine => "cosine",
1439 };
1440 format!(
1441 "ann:{}:{}:{}",
1442 table_name.to_ascii_lowercase(),
1443 col_idx,
1444 tag
1445 )
1446}
1447
1448fn ann_metric_to_prism(m: AnnMetric) -> Metric {
1449 match m {
1450 AnnMetric::L2 => Metric::L2,
1451 AnnMetric::Inner => Metric::InnerProduct,
1452 AnnMetric::Cosine => Metric::Cosine,
1453 }
1454}
1455
1456#[cfg(test)]
1457mod thrash_tests {
1458 use super::take_ann_rebuilds;
1459 use crate::{Connection, ExecutionResult, Value};
1460 use citadel::{Argon2Profile, DatabaseBuilder};
1461
1462 const DIM: usize = 8;
1463
1464 fn vec_for(i: u64) -> Vec<f32> {
1465 (0..DIM)
1466 .map(|d| {
1467 let x = (i.wrapping_mul(2654435761).wrapping_add(d as u64 * 40503) % 1000) as f32;
1468 x / 1000.0
1469 })
1470 .collect()
1471 }
1472
1473 fn vec_literal(v: &[f32]) -> String {
1474 let parts: Vec<String> = v.iter().map(|x| format!("{x}")).collect();
1475 format!("'[{}]'::VECTOR({})", parts.join(", "), DIM)
1476 }
1477
1478 fn recall_ids(conn: &Connection<'_>, qvec: &[f32], k: usize) -> Vec<i64> {
1479 let sql = format!(
1480 "SELECT id FROM t WHERE category = 0 ORDER BY v <-> {} LIMIT {k}",
1481 vec_literal(qvec)
1482 );
1483 match conn.execute(&sql).unwrap() {
1484 ExecutionResult::Query(qr) => qr
1485 .rows
1486 .iter()
1487 .map(|r| match &r[0] {
1488 Value::Integer(i) => *i,
1489 other => panic!("expected Integer id, got {other:?}"),
1490 })
1491 .collect(),
1492 _ => panic!("expected query result"),
1493 }
1494 }
1495
1496 #[test]
1498 fn interleaved_append_recall_does_not_thrash() {
1499 let dir = tempfile::tempdir().unwrap();
1500 let db = DatabaseBuilder::new(dir.path().join("test.db"))
1501 .passphrase(b"test-passphrase")
1502 .argon2_profile(Argon2Profile::Iot)
1503 .create()
1504 .unwrap();
1505 let conn = Connection::open(&db).unwrap();
1506 conn.execute(
1507 "CREATE TABLE t (id INTEGER PRIMARY KEY, category INTEGER, score REAL, v VECTOR(8))",
1508 )
1509 .unwrap();
1510 let base = 200u64;
1512 for i in 1..=base {
1513 conn.execute(&format!(
1514 "INSERT INTO t VALUES ({i}, 0, 1.0, {})",
1515 vec_literal(&vec_for(i))
1516 ))
1517 .unwrap();
1518 }
1519 conn.execute(
1520 "CREATE INDEX ix_v ON t USING ann (v) WITH (metric = 'l2', filters = 'category')",
1521 )
1522 .unwrap();
1523
1524 let _ = recall_ids(&conn, &vec_for(7), 5);
1526 let _ = take_ann_rebuilds(); let appends = 10u64;
1530 let mut total_rebuilds = 0u64;
1531 for j in 0..appends {
1532 let new_id = base + 1 + j;
1533 let qvec = vec![0.50005f32 + (j as f32) * 0.0001; DIM];
1534 conn.execute(&format!(
1535 "INSERT INTO t VALUES ({new_id}, 0, 1.0, {})",
1536 vec_literal(&qvec)
1537 ))
1538 .unwrap();
1539 let ids = recall_ids(&conn, &qvec, 5);
1540 total_rebuilds += take_ann_rebuilds();
1541 assert_eq!(
1542 ids.first().copied(),
1543 Some(new_id as i64),
1544 "freshly appended exact-match row must rank #0 (I1 fresh-visibility)"
1545 );
1546 }
1547 assert_eq!(
1548 total_rebuilds, 0,
1549 "appends must not trigger PRISM rebuilds (got {total_rebuilds} over {appends} recalls = thrash)"
1550 );
1551 }
1552
1553 fn fresh_db(dir: &std::path::Path) -> citadel::Database {
1554 DatabaseBuilder::new(dir.join("t.db"))
1555 .passphrase(b"test-passphrase")
1556 .argon2_profile(Argon2Profile::Iot)
1557 .create()
1558 .unwrap()
1559 }
1560
1561 fn setup(conn: &Connection<'_>) {
1562 conn.execute(
1563 "CREATE TABLE t (id INTEGER PRIMARY KEY, category INTEGER, score REAL, v VECTOR(8))",
1564 )
1565 .unwrap();
1566 }
1567
1568 fn insert(conn: &Connection<'_>, id: u64, v: &[f32]) {
1569 conn.execute(&format!(
1570 "INSERT INTO t VALUES ({id}, 0, 1.0, {})",
1571 vec_literal(v)
1572 ))
1573 .unwrap();
1574 }
1575
1576 fn build_index(conn: &Connection<'_>) {
1577 conn.execute(
1578 "CREATE INDEX ix_v ON t USING ann (v) WITH (metric = 'l2', filters = 'category')",
1579 )
1580 .unwrap();
1581 }
1582
1583 #[test]
1585 fn inplace_vector_update_is_reflected() {
1586 let dir = tempfile::tempdir().unwrap();
1587 let db = fresh_db(dir.path());
1588 let conn = Connection::open(&db).unwrap();
1589 setup(&conn);
1590 for i in 1..=200 {
1591 insert(&conn, i, &vec_for(i));
1592 }
1593 build_index(&conn);
1594 let qvec = vec![0.50007f32; DIM];
1595 let _ = recall_ids(&conn, &vec_for(7), 5); let _ = take_ann_rebuilds();
1597
1598 conn.execute(&format!(
1599 "UPDATE t SET v = {} WHERE id = 50",
1600 vec_literal(&qvec)
1601 ))
1602 .unwrap();
1603 let ids = recall_ids(&conn, &qvec, 5);
1604 assert!(
1605 take_ann_rebuilds() >= 1,
1606 "an in-place vector UPDATE must invalidate the cached index"
1607 );
1608 assert_eq!(ids.first().copied(), Some(50), "updated row must rank #0");
1609 }
1610
1611 #[test]
1613 fn delete_indexed_row_disappears() {
1614 let dir = tempfile::tempdir().unwrap();
1615 let db = fresh_db(dir.path());
1616 let conn = Connection::open(&db).unwrap();
1617 setup(&conn);
1618 for i in 1..=200 {
1619 insert(&conn, i, &vec_for(i));
1620 }
1621 build_index(&conn);
1622 let q = vec_for(7);
1623 let before = recall_ids(&conn, &q, 5);
1624 assert_eq!(before.first().copied(), Some(7), "id 7 is the exact match");
1625 let _ = take_ann_rebuilds();
1626
1627 conn.execute("DELETE FROM t WHERE id = 7").unwrap();
1628 let after = recall_ids(&conn, &q, 5);
1629 assert!(
1630 take_ann_rebuilds() >= 1,
1631 "a DELETE must invalidate the cached index"
1632 );
1633 assert!(
1634 !after.contains(&7),
1635 "deleted row must not appear: {after:?}"
1636 );
1637 }
1638
1639 #[test]
1641 fn gap_fill_below_snapshot_is_visible() {
1642 let dir = tempfile::tempdir().unwrap();
1643 let db = fresh_db(dir.path());
1644 let conn = Connection::open(&db).unwrap();
1645 setup(&conn);
1646 for i in 1..=50 {
1648 insert(&conn, i, &vec_for(i));
1649 }
1650 for i in 60..=100 {
1651 insert(&conn, i, &vec_for(i));
1652 }
1653 build_index(&conn);
1654 let _ = recall_ids(&conn, &vec_for(7), 5); let _ = take_ann_rebuilds();
1656
1657 let qvec = vec![0.50009f32; DIM];
1658 insert(&conn, 55, &qvec); let ids = recall_ids(&conn, &qvec, 5);
1660 assert!(
1661 take_ann_rebuilds() >= 1,
1662 "a gap-fill insert below snapshot must invalidate, not tail-merge"
1663 );
1664 assert_eq!(
1665 ids.first().copied(),
1666 Some(55),
1667 "gap-fill row must be visible at rank #0: {ids:?}"
1668 );
1669 }
1670
1671 #[test]
1673 fn long_tail_triggers_single_rebuild() {
1674 let dir = tempfile::tempdir().unwrap();
1675 let db = fresh_db(dir.path());
1676 let conn = Connection::open(&db).unwrap();
1677 setup(&conn);
1678 for i in 1..=40 {
1679 insert(&conn, i, &vec_for(i));
1680 }
1681 build_index(&conn);
1682 let _ = recall_ids(&conn, &vec_for(7), 5); let _ = take_ann_rebuilds();
1684
1685 let qvec = vec![0.50011f32; DIM];
1687 for i in 41..=55u64 {
1688 let v = if i == 55 {
1689 qvec.clone()
1690 } else {
1691 vec_for(i + 1000)
1692 };
1693 insert(&conn, i, &v);
1694 }
1695 assert_eq!(
1696 take_ann_rebuilds(),
1697 0,
1698 "appends alone must not rebuild (retained for tail merge)"
1699 );
1700
1701 let ids = recall_ids(&conn, &qvec, 5);
1702 assert_eq!(
1703 take_ann_rebuilds(),
1704 1,
1705 "a tail past the threshold must trigger exactly one rebuild on recall"
1706 );
1707 assert_eq!(
1708 ids.first().copied(),
1709 Some(55),
1710 "post-rebuild result correct"
1711 );
1712 }
1713}