1use std::any::Any;
6use std::cmp::Ordering;
7use std::collections::BinaryHeap;
8use std::sync::Arc;
9
10use citadel_txn::read_txn::ReadTxn;
11use citadel_txn::write_txn::WriteTxn;
12use citadel_vector::{AnnIndex, Filter, Metric};
13use rustc_hash::FxHashMap;
14
15use crate::encoding::{
16 decode_column_raw, decode_pk_integer, encode_int_key_into, encode_key_value,
17};
18use crate::error::{Result, SqlError};
19use crate::eval::{eval_expr, is_truthy, ColumnMap, EvalCtx};
20use crate::parser::*;
21use crate::schema::SchemaManager;
22use crate::types::*;
23
24use super::aggregate::is_aggregate_expr;
25use super::ann_persist;
26use super::helpers::{decode_full_row, eval_const_expr, eval_const_int, project_rows};
27use super::window::has_any_window_function;
28
29type StorageResult<T> = std::result::Result<T, citadel_core::Error>;
30type ScanRow<'a> = dyn FnMut(&[u8], &[u8]) -> Result<bool> + 'a;
31type RawScanRow<'a> = dyn FnMut(&[u8], &[u8]) -> StorageResult<bool> + 'a;
32
33pub(super) trait AnnScan {
35 fn ann_scan(&mut self, table: &[u8], f: &mut ScanRow<'_>) -> Result<()>;
36 fn ann_get(&mut self, table: &[u8], key: &[u8]) -> Result<Option<Vec<u8>>>;
37 fn cache_generation(&self) -> Option<u64>;
41}
42
43fn bridge_scan(
45 scan: impl FnOnce(&mut RawScanRow<'_>) -> StorageResult<()>,
46 f: &mut ScanRow<'_>,
47) -> Result<()> {
48 let mut cb_err: Option<SqlError> = None;
49 scan(&mut |key, value| match f(key, value) {
50 Ok(go) => Ok(go),
51 Err(e) => {
52 cb_err = Some(e);
53 Ok(false)
54 }
55 })
56 .map_err(SqlError::Storage)?;
57 match cb_err {
58 Some(e) => Err(e),
59 None => Ok(()),
60 }
61}
62
63impl AnnScan for ReadTxn<'_> {
64 fn ann_scan(&mut self, table: &[u8], f: &mut ScanRow<'_>) -> Result<()> {
65 bridge_scan(|cb| self.table_scan_from(table, b"", cb), f)
66 }
67
68 fn ann_get(&mut self, table: &[u8], key: &[u8]) -> Result<Option<Vec<u8>>> {
69 self.table_get(table, key).map_err(SqlError::Storage)
70 }
71
72 fn cache_generation(&self) -> Option<u64> {
73 Some(self.commit_generation())
74 }
75}
76
77impl AnnScan for WriteTxn<'_> {
78 fn ann_scan(&mut self, table: &[u8], f: &mut ScanRow<'_>) -> Result<()> {
79 bridge_scan(|cb| self.table_scan_from(table, b"", cb), f)
80 }
81
82 fn ann_get(&mut self, table: &[u8], key: &[u8]) -> Result<Option<Vec<u8>>> {
83 self.table_get(table, key).map_err(SqlError::Storage)
84 }
85
86 fn cache_generation(&self) -> Option<u64> {
87 None
88 }
89}
90
91#[derive(Debug, Clone, PartialEq, Eq)]
95pub enum AnnIndexSource {
96 Built { refusal: Option<String> },
99 Loaded { segment_b3: [u8; 32] },
102}
103
104struct CachedAnnIndex {
106 index: AnnIndex,
107 dicts: Vec<FxHashMap<Vec<u8>, u32>>,
109 source: AnnIndexSource,
110 cached_gen: u64,
114}
115
116pub(super) struct AnnTopKPlan {
117 col_idx: usize,
118 dim: u16,
119 metric: AnnMetric,
120 query_vec: Vec<f32>,
121 k: usize,
122 offset: usize,
123 filter_cols: Vec<u16>,
125 pushable: Vec<(usize, Vec<Value>)>,
127 residual: Option<Expr>,
129}
130
131fn topk_shape_ok(stmt: &SelectStmt) -> bool {
133 stmt.order_by.len() == 1
134 && !stmt.order_by[0].descending
135 && stmt.limit.is_some()
136 && stmt.group_by.is_empty()
137 && stmt.having.is_none()
138 && stmt.joins.is_empty()
139 && !stmt.distinct
140 && !has_any_window_function(stmt)
141 && !stmt
142 .columns
143 .iter()
144 .any(|c| matches!(c, SelectColumn::Expr { expr, .. } if is_aggregate_expr(expr)))
145}
146
147impl AnnTopKPlan {
148 pub(super) fn try_new(stmt: &SelectStmt, table_schema: &TableSchema) -> Result<Option<Self>> {
149 if !topk_shape_ok(stmt) {
150 return Ok(None);
151 }
152 let ob = &stmt.order_by[0];
153
154 let (col_idx, dim, op_metric, query_vec) = match &ob.expr {
155 Expr::BinaryOp { left, op, right } => {
156 let op_metric = match op {
157 BinOp::VectorL2 => AnnMetric::L2,
158 BinOp::VectorInner => AnnMetric::Inner,
159 BinOp::VectorCosine => AnnMetric::Cosine,
160 _ => return Ok(None),
161 };
162 let col_name = match left.as_ref() {
163 Expr::Column(name) => name.to_ascii_lowercase(),
164 _ => return Ok(None),
165 };
166 let (col_idx, dim) = match table_schema
167 .columns
168 .iter()
169 .enumerate()
170 .find(|(_, c)| c.name.to_ascii_lowercase() == col_name)
171 {
172 Some((i, c)) => match c.data_type {
173 DataType::Vector { dim } => (i, dim),
174 _ => return Ok(None),
175 },
176 None => return Ok(None),
177 };
178 let col_map = ColumnMap::new(&table_schema.columns);
179 let ctx = EvalCtx::new(&col_map, &[]);
180 let v = match eval_expr(right, &ctx) {
181 Ok(Value::Vector(v)) => v,
182 _ => return Ok(None),
183 };
184 if v.len() != dim as usize {
185 return Err(SqlError::InvalidValue(format!(
186 "ANN query vector dim {} does not match column dim {}",
187 v.len(),
188 dim
189 )));
190 }
191 (col_idx, dim, op_metric, v.to_vec())
192 }
193 _ => return Ok(None),
194 };
195
196 let ann_index = table_schema.indices.iter().find(|ix| {
197 matches!(ix.kind,
198 IndexKind::Inverted(InvertedKind::Ann { metric }) if metric == op_metric
199 ) && ix.keys.len() == 1
200 && matches!(ix.keys[0],
201 IndexKey::Column { idx, .. } if idx as usize == col_idx
202 )
203 });
204 let Some(ann_index) = ann_index else {
205 return Ok(None);
206 };
207 let filter_cols = ann_index.ann_filter_cols.clone();
208
209 if table_schema.primary_key_columns.len() != 1 {
210 return Ok(None);
211 }
212 let pk_col = &table_schema.columns[table_schema.primary_key_columns[0] as usize];
213 if !matches!(pk_col.data_type, DataType::Integer) {
214 return Ok(None);
215 }
216
217 let mut pushable: Vec<(usize, Vec<Value>)> = Vec::new();
220 let mut residual_leaves: Vec<Expr> = Vec::new();
221 if let Some(w) = &stmt.where_clause {
222 split_where(
223 w,
224 &filter_cols,
225 table_schema,
226 &mut pushable,
227 &mut residual_leaves,
228 );
229 if pushable.is_empty() {
230 return Ok(None);
231 }
232 }
233 let residual = fold_and(residual_leaves);
234
235 let k_limit = eval_const_int(stmt.limit.as_ref().unwrap())?.max(0) as usize;
236 let offset = stmt
237 .offset
238 .as_ref()
239 .map(eval_const_int)
240 .transpose()?
241 .unwrap_or(0)
242 .max(0) as usize;
243 if k_limit == 0 {
244 return Ok(None);
245 }
246
247 Ok(Some(Self {
248 col_idx,
249 dim,
250 metric: op_metric,
251 query_vec,
252 k: k_limit,
253 offset,
254 filter_cols,
255 pushable,
256 residual,
257 }))
258 }
259
260 pub(super) fn execute_with_read(
261 &self,
262 rtx: &mut ReadTxn<'_>,
263 schema: &SchemaManager,
264 stmt: &SelectStmt,
265 table_schema: &TableSchema,
266 ) -> Result<ExecutionResult> {
267 let cache_key = cache_key(&table_schema.name, self.col_idx, self.metric);
268 let Some(cached) = self.load_or_build_index(rtx, schema, &cache_key, table_schema)? else {
270 return empty_result(table_schema, stmt);
271 };
272 self.run_query(rtx, &cached, stmt, table_schema)
273 }
274
275 fn run_query(
277 &self,
278 txn: &mut dyn AnnScan,
279 cached: &CachedAnnIndex,
280 stmt: &SelectStmt,
281 table_schema: &TableSchema,
282 ) -> Result<ExecutionResult> {
283 let mut constraints: Vec<(usize, Vec<u32>)> = Vec::with_capacity(self.pushable.len());
285 for (dim, values) in &self.pushable {
286 let dict = &cached.dicts[*dim];
287 let mut codes = Vec::with_capacity(values.len());
288 for v in values {
289 if let Some(&code) = dict.get(encode_key_value(v).as_slice()) {
290 codes.push(code);
291 }
292 }
293 if codes.is_empty() {
294 return empty_result(table_schema, stmt);
295 }
296 constraints.push((*dim, codes));
297 }
298 let filter = if constraints.is_empty() {
299 Filter::none()
300 } else {
301 Filter::new(constraints)
302 };
303
304 let want = self.k.saturating_add(self.offset).max(1);
305 let mut rows = self.collect_survivors(txn, &cached.index, &filter, table_schema, want)?;
306
307 if self.offset >= rows.len() {
308 rows.clear();
309 } else if self.offset > 0 {
310 rows = rows.split_off(self.offset);
311 }
312 rows.truncate(self.k);
313
314 let (col_names, projected) = project_rows(&table_schema.columns, &stmt.columns, rows)?;
315 Ok(ExecutionResult::Query(QueryResult {
316 columns: col_names,
317 rows: projected,
318 }))
319 }
320
321 fn collect_survivors(
325 &self,
326 txn: &mut dyn AnnScan,
327 index: &AnnIndex,
328 filter: &Filter,
329 table_schema: &TableSchema,
330 want: usize,
331 ) -> Result<Vec<Vec<Value>>> {
332 let col_map = ColumnMap::new(&table_schema.columns);
333 let max_target = index.indexed_len().max(1);
334 let mut key_buf: Vec<u8> = Vec::with_capacity(10);
335 let mut target = want;
336 loop {
337 target = target.min(max_target);
338 let hits = index.search_filtered_default_ef(&self.query_vec, target, filter);
339 let mut survivors: Vec<Vec<Value>> = Vec::with_capacity(want);
340 for (id, _dist) in &hits {
341 encode_int_key_into(*id as i64, &mut key_buf);
342 let Some(row_bytes) = txn.ann_get(table_schema.name.as_bytes(), &key_buf)? else {
343 continue;
344 };
345 let row = decode_full_row(table_schema, &key_buf, &row_bytes)?;
346 let keep = match &self.residual {
347 None => true,
348 Some(expr) => {
349 let ctx = EvalCtx::new(&col_map, &row);
350 is_truthy(&eval_expr(expr, &ctx)?)
351 }
352 };
353 if keep {
354 survivors.push(row);
355 if survivors.len() >= want {
356 break;
357 }
358 }
359 }
360 if survivors.len() >= want || target >= max_target || hits.len() < target {
363 return Ok(survivors);
364 }
365 target = target.saturating_mul(2);
366 }
367 }
368
369 fn load_or_build_index(
370 &self,
371 txn: &mut dyn AnnScan,
372 schema: &SchemaManager,
373 cache_key: &str,
374 table_schema: &TableSchema,
375 ) -> Result<Option<Arc<CachedAnnIndex>>> {
376 if let Some(existing) = lookup_cached(schema, cache_key, &table_schema.name)? {
377 return Ok(Some(existing));
378 }
379 let spec = AnnSpec {
380 col_idx: self.col_idx,
381 dim: self.dim,
382 metric: self.metric,
383 filter_cols: self.filter_cols.clone(),
384 };
385 load_or_build(txn, schema, cache_key, table_schema, &spec)
386 }
387}
388
389pub(super) struct AnnSpec {
393 pub col_idx: usize,
394 pub dim: u16,
395 pub metric: AnnMetric,
396 pub filter_cols: Vec<u16>,
397}
398
399impl AnnSpec {
400 fn metric_tag(&self) -> u8 {
401 citadel_vector::segment::metric_tag(ann_metric_to_prism(self.metric))
402 }
403}
404
405struct ScanOutcome {
409 rows: Vec<(u64, Vec<f32>, Vec<u32>)>,
410 dicts: Vec<FxHashMap<Vec<u8>, u32>>,
411 fingerprint: [u8; 32],
412}
413
414fn scan_rows(
415 txn: &mut dyn AnnScan,
416 table_schema: &TableSchema,
417 spec: &AnnSpec,
418) -> Result<ScanOutcome> {
419 let non_pk = table_schema.non_pk_indices();
420 let enc_pos = table_schema.encoding_positions();
421 let nonpk_order = non_pk
422 .iter()
423 .position(|&i| i == spec.col_idx)
424 .ok_or_else(|| {
425 SqlError::InvalidValue("vector column must be non-PK for ANN build".into())
426 })?;
427 let enc_idx = enc_pos[nonpk_order] as usize;
428
429 let num_attrs = spec.filter_cols.len();
430 let extracts: Vec<Extract> = spec
431 .filter_cols
432 .iter()
433 .map(|&c| extract_plan(c, table_schema, non_pk, enc_pos))
434 .collect::<Result<_>>()?;
435 let mut dicts: Vec<FxHashMap<Vec<u8>, u32>> = vec![FxHashMap::default(); num_attrs];
436 let mut fp = ann_persist::FingerprintHasher::new(
437 &table_schema.name,
438 spec.col_idx as u32,
439 &spec
440 .filter_cols
441 .iter()
442 .map(|&c| c as u32)
443 .collect::<Vec<_>>(),
444 spec.dim,
445 spec.metric_tag(),
446 );
447 let mut rows: Vec<(u64, Vec<f32>, Vec<u32>)> = Vec::new();
448
449 txn.ann_scan(table_schema.name.as_bytes(), &mut |key, value| {
450 let vector = match decode_column_raw(value, enc_idx)?.to_value() {
451 Value::Vector(arr) => Some(arr.to_vec()),
452 Value::Null => None, _ => {
454 return Err(SqlError::InvalidValue(
455 "ANN column produced non-vector value".into(),
456 ))
457 }
458 };
459 let mut encoded_filters: Vec<Vec<u8>> = Vec::with_capacity(num_attrs);
460 for ex in &extracts {
461 encoded_filters.push(encode_key_value(&ex.extract(key, value)?));
462 }
463 let vec_bytes: Vec<u8> = vector
464 .as_deref()
465 .unwrap_or(&[])
466 .iter()
467 .flat_map(|f| f.to_le_bytes())
468 .collect();
469 fp.row(
470 key,
471 &vec_bytes,
472 &encoded_filters
473 .iter()
474 .map(Vec::as_slice)
475 .collect::<Vec<_>>(),
476 );
477 let Some(vector) = vector else {
478 return Ok(true);
479 };
480 let id = decode_pk_integer(key)? as u64;
481 let mut codes: Vec<u32> = Vec::with_capacity(num_attrs);
482 for (j, encoded) in encoded_filters.into_iter().enumerate() {
483 let next = dicts[j].len() as u32;
484 codes.push(*dicts[j].entry(encoded).or_insert(next));
485 }
486 rows.push((id, vector, codes));
487 Ok(true)
488 })?;
489
490 Ok(ScanOutcome {
491 rows,
492 dicts,
493 fingerprint: fp.finish(),
494 })
495}
496
497fn build_index(
499 txn: &mut dyn AnnScan,
500 table_schema: &TableSchema,
501 spec: &AnnSpec,
502 refusal: Option<String>,
503 cached_gen: u64,
504) -> Result<Option<CachedAnnIndex>> {
505 let outcome = scan_rows(txn, table_schema, spec)?;
506 if outcome.rows.is_empty() {
507 return Ok(None);
508 }
509 let index = AnnIndex::build_with_attrs(
510 outcome.rows,
511 spec.filter_cols.len(),
512 ann_metric_to_prism(spec.metric),
513 spec.dim,
514 )
515 .map_err(|e| SqlError::InvalidValue(format!("ANN build failed: {e}")))?;
516 Ok(Some(CachedAnnIndex {
517 index,
518 dicts: outcome.dicts,
519 source: AnnIndexSource::Built { refusal },
520 cached_gen,
521 }))
522}
523
524enum LoadOutcome {
528 Loaded(Box<CachedAnnIndex>),
529 NoSegment,
530 Refused { reason: String, corrupt: bool },
531}
532
533fn try_load_segment(
537 txn: &mut dyn AnnScan,
538 table_schema: &TableSchema,
539 spec: &AnnSpec,
540 cached_gen: u64,
541) -> Result<LoadOutcome> {
542 let seg_table = ann_persist::segment_table_name(&table_schema.name);
543 let header_bytes = match txn.ann_get(&seg_table, &ann_persist::segment_key(0)) {
544 Ok(Some(b)) => b,
545 Ok(None) | Err(_) => return Ok(LoadOutcome::NoSegment),
547 };
548 let refuse = |reason: String, corrupt: bool| Ok(LoadOutcome::Refused { reason, corrupt });
549 let header = match ann_persist::SegmentHeader::decode(&header_bytes) {
550 Ok(h) => h,
551 Err(e) => return refuse(format!("header: {e}"), true),
552 };
553 if header.format_version != ann_persist::ANNSEG_FORMAT_VERSION {
554 return refuse(
555 format!("format v{} (this binary reads v1)", header.format_version),
556 false,
557 );
558 }
559 let active_cfg = citadel_vector::segment::prism_config_hash(&AnnIndex::active_config(
560 ann_metric_to_prism(spec.metric),
561 ));
562 if header.prism_config_hash != active_cfg {
563 return refuse(
564 "PRISM config drift (segment built by another geometry)".into(),
565 false,
566 );
567 }
568 if header.dim != spec.dim
569 || header.metric_tag != spec.metric_tag()
570 || header.col_idx != spec.col_idx as u32
571 || header.filter_cols
572 != spec
573 .filter_cols
574 .iter()
575 .map(|&c| c as u32)
576 .collect::<Vec<_>>()
577 {
578 return refuse(
579 "index identity mismatch (column/metric/filter set)".into(),
580 false,
581 );
582 }
583
584 let mut body = Vec::new();
585 for chunk_no in 1..=header.chunk_count {
586 match txn.ann_get(&seg_table, &ann_persist::segment_key(chunk_no)) {
587 Ok(Some(c)) => body.extend_from_slice(&c),
588 _ => return refuse(format!("missing chunk {chunk_no}"), true),
589 }
590 }
591 if *blake3::hash(&body).as_bytes() != header.segment_b3 {
592 return refuse("segment body BLAKE3 mismatch (corrupt)".into(), true);
593 }
594 let parts = match citadel_vector::segment::decode(&body) {
595 Ok(p) => p,
596 Err(e) => return refuse(format!("segment decode: {e}"), true),
597 };
598 if parts.n() as u64 != header.n || parts.dim() != header.dim {
599 return refuse("segment body disagrees with header counts".into(), true);
600 }
601
602 let slot_of = parts.internal_of_row();
605 let dim = spec.dim as usize;
606 let mut vectors = vec![0.0f32; parts.n() * dim];
607 let mut filled = 0usize;
608 let outcome = scan_rows_rehydrate(txn, table_schema, spec, &mut |row_id, vector| {
609 let Some(&slot) = slot_of.get(&row_id) else {
610 return false; };
612 vectors[slot as usize * dim..(slot as usize + 1) * dim].copy_from_slice(vector);
613 filled += 1;
614 true
615 })?;
616 let Some(fingerprint) = outcome else {
617 return refuse(
618 "a scanned row is unknown to the segment (stale)".into(),
619 false,
620 );
621 };
622 if fingerprint != header.content_fingerprint {
623 return refuse("content fingerprint mismatch (stale)".into(), false);
624 }
625 let index = match parts.into_index(vectors, filled) {
626 Ok(i) => i,
627 Err(e) => return refuse(format!("rehydration: {e}"), true),
628 };
629 Ok(LoadOutcome::Loaded(Box::new(CachedAnnIndex {
630 index,
631 dicts: header.dict_maps(),
632 source: AnnIndexSource::Loaded {
633 segment_b3: header.segment_b3,
634 },
635 cached_gen,
636 })))
637}
638
639fn scan_rows_rehydrate(
643 txn: &mut dyn AnnScan,
644 table_schema: &TableSchema,
645 spec: &AnnSpec,
646 place: &mut dyn FnMut(u64, &[f32]) -> bool,
647) -> Result<Option<[u8; 32]>> {
648 let non_pk = table_schema.non_pk_indices();
649 let enc_pos = table_schema.encoding_positions();
650 let nonpk_order = non_pk
651 .iter()
652 .position(|&i| i == spec.col_idx)
653 .ok_or_else(|| {
654 SqlError::InvalidValue("vector column must be non-PK for ANN build".into())
655 })?;
656 let enc_idx = enc_pos[nonpk_order] as usize;
657 let extracts: Vec<Extract> = spec
658 .filter_cols
659 .iter()
660 .map(|&c| extract_plan(c, table_schema, non_pk, enc_pos))
661 .collect::<Result<_>>()?;
662 let mut fp = ann_persist::FingerprintHasher::new(
663 &table_schema.name,
664 spec.col_idx as u32,
665 &spec
666 .filter_cols
667 .iter()
668 .map(|&c| c as u32)
669 .collect::<Vec<_>>(),
670 spec.dim,
671 spec.metric_tag(),
672 );
673 let mut unknown_row = false;
674
675 txn.ann_scan(table_schema.name.as_bytes(), &mut |key, value| {
676 let vector = match decode_column_raw(value, enc_idx)?.to_value() {
677 Value::Vector(arr) => Some(arr.to_vec()),
678 Value::Null => None,
679 _ => {
680 return Err(SqlError::InvalidValue(
681 "ANN column produced non-vector value".into(),
682 ))
683 }
684 };
685 let mut encoded_filters: Vec<Vec<u8>> = Vec::with_capacity(extracts.len());
686 for ex in &extracts {
687 encoded_filters.push(encode_key_value(&ex.extract(key, value)?));
688 }
689 let vec_bytes: Vec<u8> = vector
690 .as_deref()
691 .unwrap_or(&[])
692 .iter()
693 .flat_map(|f| f.to_le_bytes())
694 .collect();
695 fp.row(
696 key,
697 &vec_bytes,
698 &encoded_filters
699 .iter()
700 .map(Vec::as_slice)
701 .collect::<Vec<_>>(),
702 );
703 if let Some(vector) = vector {
704 let id = decode_pk_integer(key)? as u64;
705 if !place(id, &vector) {
706 unknown_row = true;
707 return Ok(false);
708 }
709 }
710 Ok(true)
711 })?;
712
713 Ok(if unknown_row { None } else { Some(fp.finish()) })
714}
715
716fn load_or_build(
721 txn: &mut dyn AnnScan,
722 schema: &SchemaManager,
723 cache_key: &str,
724 table_schema: &TableSchema,
725 spec: &AnnSpec,
726) -> Result<Option<Arc<CachedAnnIndex>>> {
727 let gen = txn.cache_generation();
728 let cached_gen = gen.unwrap_or(u64::MAX);
729 let loaded = match try_load_segment(txn, table_schema, spec, cached_gen)? {
730 LoadOutcome::Loaded(c) => Some(*c),
731 LoadOutcome::NoSegment => None,
732 LoadOutcome::Refused { reason, corrupt } => {
733 if corrupt {
734 eprintln!(
735 "citadel-sql: ANN segment for `{}` REFUSED as corrupt ({reason}); \
736 rebuilding from scan - investigate before re-persisting",
737 table_schema.name
738 );
739 }
740 match build_index(txn, table_schema, spec, Some(reason), cached_gen)? {
743 Some(c) => Some(c),
744 None => return Ok(None),
745 }
746 }
747 };
748 let built = match loaded {
749 Some(c) => c,
750 None => match build_index(txn, table_schema, spec, None, cached_gen)? {
751 Some(c) => c,
752 None => return Ok(None),
753 },
754 };
755 let arc: Arc<CachedAnnIndex> = Arc::new(built);
756 if gen.is_none() {
757 return Ok(Some(arc));
759 }
760 let mut guard = schema.sql_caches.lock();
761 if let Some(existing) = guard.get(cache_key) {
762 return Arc::clone(existing)
764 .downcast::<CachedAnnIndex>()
765 .map(Some)
766 .map_err(|_| {
767 SqlError::InvalidValue(format!("ANN cache type mismatch for {cache_key}"))
768 });
769 }
770 let marker = marker_gen_locked(&guard, &table_schema.name);
771 if marker.is_some_and(|g| arc.cached_gen < g) {
772 return Ok(Some(arc));
775 }
776 let as_any: Arc<dyn Any + Send + Sync> = arc.clone();
777 guard.insert(cache_key.to_string(), as_any);
778 Ok(Some(arc))
779}
780
781pub(super) struct VectorTopKPlan {
784 order_expr: Expr,
785 where_clause: Option<Expr>,
786 k: usize,
787 offset: usize,
788 nulls_first: bool,
789}
790
791struct Ranked {
794 dist: f64,
795 seq: u64,
796 row: Vec<Value>,
797}
798
799impl PartialEq for Ranked {
800 fn eq(&self, other: &Self) -> bool {
801 self.cmp(other) == Ordering::Equal
802 }
803}
804impl Eq for Ranked {}
805impl PartialOrd for Ranked {
806 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
807 Some(self.cmp(other))
808 }
809}
810impl Ord for Ranked {
811 fn cmp(&self, other: &Self) -> Ordering {
812 self.dist
813 .total_cmp(&other.dist)
814 .then_with(|| self.seq.cmp(&other.seq))
815 }
816}
817
818impl VectorTopKPlan {
819 pub(super) fn try_new(stmt: &SelectStmt, table_schema: &TableSchema) -> Result<Option<Self>> {
820 if !topk_shape_ok(stmt) {
821 return Ok(None);
822 }
823 let ob = &stmt.order_by[0];
824 let Expr::BinaryOp { left, op, .. } = &ob.expr else {
825 return Ok(None);
826 };
827 if !matches!(
828 op,
829 BinOp::VectorL2 | BinOp::VectorInner | BinOp::VectorCosine
830 ) {
831 return Ok(None);
832 }
833 let Expr::Column(name) = left.as_ref() else {
835 return Ok(None);
836 };
837 let name = name.to_ascii_lowercase();
838 let is_vector_col = table_schema.columns.iter().any(|c| {
839 c.name.to_ascii_lowercase() == name && matches!(c.data_type, DataType::Vector { .. })
840 });
841 if !is_vector_col {
842 return Ok(None);
843 }
844
845 let k = eval_const_int(stmt.limit.as_ref().unwrap())?.max(0) as usize;
846 if k == 0 {
847 return Ok(None);
848 }
849 let offset = stmt
850 .offset
851 .as_ref()
852 .map(eval_const_int)
853 .transpose()?
854 .unwrap_or(0)
855 .max(0) as usize;
856
857 Ok(Some(Self {
858 order_expr: ob.expr.clone(),
859 where_clause: stmt.where_clause.clone(),
860 k,
861 offset,
862 nulls_first: ob.nulls_first.unwrap_or(true),
864 }))
865 }
866
867 pub(super) fn execute(
868 &self,
869 txn: &mut dyn AnnScan,
870 table_schema: &TableSchema,
871 stmt: &SelectStmt,
872 ) -> Result<ExecutionResult> {
873 let want = self.k.saturating_add(self.offset);
874 let col_map = ColumnMap::new(&table_schema.columns);
875 let null_dist = if self.nulls_first {
877 f64::NEG_INFINITY
878 } else {
879 f64::INFINITY
880 };
881 let mut heap: BinaryHeap<Ranked> = BinaryHeap::new();
882 let mut seq: u64 = 0;
883
884 txn.ann_scan(table_schema.name.as_bytes(), &mut |key, value| {
885 let row = decode_full_row(table_schema, key, value)?;
886 let ctx = EvalCtx::new(&col_map, &row);
887 if let Some(w) = &self.where_clause {
888 if !is_truthy(&eval_expr(w, &ctx)?) {
889 return Ok(true);
890 }
891 }
892 let dist = match eval_expr(&self.order_expr, &ctx)? {
893 Value::Real(d) => d,
894 Value::Integer(i) => i as f64,
895 Value::Null => null_dist,
896 other => {
897 return Err(SqlError::InvalidValue(format!(
898 "ORDER BY vector distance produced a non-numeric {}",
899 other.data_type()
900 )))
901 }
902 };
903 let cand = Ranked { dist, seq, row };
904 seq += 1;
905 if heap.len() < want {
907 heap.push(cand);
908 } else if heap.peek().is_some_and(|top| cand < *top) {
909 heap.pop();
910 heap.push(cand);
911 }
912 Ok(true)
913 })?;
914
915 let mut rows: Vec<Vec<Value>> = heap.into_sorted_vec().into_iter().map(|r| r.row).collect();
916 if self.offset >= rows.len() {
917 rows.clear();
918 } else if self.offset > 0 {
919 rows = rows.split_off(self.offset);
920 }
921 rows.truncate(self.k);
922
923 let (col_names, projected) = project_rows(&table_schema.columns, &stmt.columns, rows)?;
924 Ok(ExecutionResult::Query(QueryResult {
925 columns: col_names,
926 rows: projected,
927 }))
928 }
929}
930
931enum Extract {
933 Pk,
935 NonPk(usize),
937}
938
939impl Extract {
940 fn extract(&self, key: &[u8], value: &[u8]) -> Result<Value> {
941 match self {
942 Extract::Pk => Ok(Value::Integer(decode_pk_integer(key)?)),
943 Extract::NonPk(ei) => Ok(decode_column_raw(value, *ei)?.to_value()),
944 }
945 }
946}
947
948fn extract_plan(
949 col: u16,
950 table_schema: &TableSchema,
951 non_pk: &[usize],
952 enc_pos: &[u16],
953) -> Result<Extract> {
954 if table_schema.primary_key_columns.contains(&col) {
955 return Ok(Extract::Pk);
956 }
957 let order = non_pk
958 .iter()
959 .position(|&i| i == col as usize)
960 .ok_or_else(|| SqlError::InvalidValue("ANN filter column not found in row".into()))?;
961 Ok(Extract::NonPk(enc_pos[order] as usize))
962}
963
964fn split_where(
967 expr: &Expr,
968 filter_cols: &[u16],
969 table_schema: &TableSchema,
970 pushable: &mut Vec<(usize, Vec<Value>)>,
971 residual: &mut Vec<Expr>,
972) {
973 if let Expr::BinaryOp {
974 left,
975 op: BinOp::And,
976 right,
977 } = expr
978 {
979 split_where(left, filter_cols, table_schema, pushable, residual);
980 split_where(right, filter_cols, table_schema, pushable, residual);
981 return;
982 }
983 match classify_leaf(expr, filter_cols, table_schema) {
984 Some(constraint) => pushable.push(constraint),
985 None => residual.push(expr.clone()),
986 }
987}
988
989fn classify_leaf(
992 leaf: &Expr,
993 filter_cols: &[u16],
994 table_schema: &TableSchema,
995) -> Option<(usize, Vec<Value>)> {
996 match leaf {
997 Expr::BinaryOp {
998 left,
999 op: BinOp::Eq,
1000 right,
1001 } => {
1002 let dim = filter_dim(left, filter_cols, table_schema)?;
1003 let val = eval_const_expr(right).ok()?;
1004 Some((dim, vec![val]))
1005 }
1006 Expr::InList {
1007 expr,
1008 list,
1009 negated: false,
1010 } => {
1011 let dim = filter_dim(expr, filter_cols, table_schema)?;
1012 let mut vals = Vec::with_capacity(list.len());
1013 for e in list {
1014 vals.push(eval_const_expr(e).ok()?);
1015 }
1016 Some((dim, vals))
1017 }
1018 _ => None,
1019 }
1020}
1021
1022fn filter_dim(expr: &Expr, filter_cols: &[u16], table_schema: &TableSchema) -> Option<usize> {
1025 let name = match expr {
1026 Expr::Column(c) => c.to_ascii_lowercase(),
1027 Expr::QualifiedColumn { column, .. } => column.to_ascii_lowercase(),
1028 _ => return None,
1029 };
1030 let col_idx = table_schema
1031 .columns
1032 .iter()
1033 .position(|c| c.name.to_ascii_lowercase() == name)? as u16;
1034 filter_cols.iter().position(|&c| c == col_idx)
1035}
1036
1037fn fold_and(mut leaves: Vec<Expr>) -> Option<Expr> {
1038 if leaves.is_empty() {
1039 return None;
1040 }
1041 let first = leaves.remove(0);
1042 Some(leaves.into_iter().fold(first, |acc, e| Expr::BinaryOp {
1043 left: Box::new(acc),
1044 op: BinOp::And,
1045 right: Box::new(e),
1046 }))
1047}
1048
1049fn empty_result(table_schema: &TableSchema, stmt: &SelectStmt) -> Result<ExecutionResult> {
1050 let (col_names, projected) = project_rows(&table_schema.columns, &stmt.columns, Vec::new())?;
1051 Ok(ExecutionResult::Query(QueryResult {
1052 columns: col_names,
1053 rows: projected,
1054 }))
1055}
1056
1057pub(crate) fn persist_ann_index(
1065 db: &citadel::Database,
1066 schema: &SchemaManager,
1067 table_schema: &TableSchema,
1068 column: &str,
1069) -> Result<ann_persist::AnnSegmentInfo> {
1070 let col_lower = column.to_ascii_lowercase();
1071 let col_idx = table_schema
1072 .columns
1073 .iter()
1074 .position(|c| c.name == col_lower)
1075 .ok_or_else(|| SqlError::ColumnNotFound(column.to_string()))?;
1076 let DataType::Vector { dim } = table_schema.columns[col_idx].data_type else {
1077 return Err(SqlError::InvalidValue(format!(
1078 "column `{column}` is not VECTOR(N)"
1079 )));
1080 };
1081 if table_schema.primary_key_columns.len() != 1
1085 || !matches!(
1086 table_schema.columns[table_schema.primary_key_columns[0] as usize].data_type,
1087 DataType::Integer
1088 )
1089 {
1090 return Err(SqlError::InvalidValue(
1091 "ANN persistence requires a single INTEGER primary key (same rule as the \
1092 ANN query plan)"
1093 .into(),
1094 ));
1095 }
1096 let ann_index = table_schema
1097 .indices
1098 .iter()
1099 .find(|ix| {
1100 matches!(ix.kind, IndexKind::Inverted(InvertedKind::Ann { .. }))
1101 && ix.keys.len() == 1
1102 && matches!(ix.keys[0], IndexKey::Column { idx, .. } if idx as usize == col_idx)
1103 })
1104 .ok_or_else(|| SqlError::InvalidValue(format!("no ANN index declared on `{column}`")))?;
1105 let IndexKind::Inverted(InvertedKind::Ann { metric }) = ann_index.kind else {
1106 unreachable!("matched above");
1107 };
1108 let spec = AnnSpec {
1109 col_idx,
1110 dim,
1111 metric,
1112 filter_cols: ann_index.ann_filter_cols.clone(),
1113 };
1114
1115 let mut wtx = db.begin_write().map_err(SqlError::Storage)?;
1116 let outcome = scan_rows(&mut wtx, table_schema, &spec)?;
1117 if outcome.rows.is_empty() {
1118 return Err(SqlError::InvalidValue(
1119 "nothing to persist: the table has no indexable (non-NULL) vectors".into(),
1120 ));
1121 }
1122 let n = outcome.rows.len() as u64;
1123 let index = AnnIndex::build_with_attrs(
1124 outcome.rows,
1125 spec.filter_cols.len(),
1126 ann_metric_to_prism(spec.metric),
1127 spec.dim,
1128 )
1129 .map_err(|e| SqlError::InvalidValue(format!("ANN build failed: {e}")))?;
1130
1131 let body = citadel_vector::segment::encode(&index);
1132 let segment_b3 = *blake3::hash(&body).as_bytes();
1133 let dicts_ordered: Vec<Vec<(Vec<u8>, u32)>> = outcome
1136 .dicts
1137 .iter()
1138 .map(|d| {
1139 let mut entries: Vec<(Vec<u8>, u32)> = d.iter().map(|(k, &v)| (k.clone(), v)).collect();
1140 entries.sort_by_key(|&(_, code)| code);
1141 entries
1142 })
1143 .collect();
1144 let header = ann_persist::SegmentHeader {
1145 format_version: ann_persist::ANNSEG_FORMAT_VERSION,
1146 prism_config_hash: ann_persist::active_config_hash(ann_metric_to_prism(spec.metric)),
1147 dim: spec.dim,
1148 metric_tag: spec.metric_tag(),
1149 n,
1150 snapshot_max: index.snapshot_max,
1151 col_idx: spec.col_idx as u32,
1152 filter_cols: spec.filter_cols.iter().map(|&c| c as u32).collect(),
1153 dicts: dicts_ordered,
1154 content_fingerprint: outcome.fingerprint,
1155 segment_b3,
1156 chunk_count: body.len().div_ceil(ann_persist::CHUNK_BYTES) as u32,
1157 writer: format!("citadel-sql {}", env!("CARGO_PKG_VERSION")),
1158 };
1159
1160 let seg_table = ann_persist::segment_table_name(&table_schema.name);
1161 ann_persist::purge_segment(&mut wtx, &table_schema.name)?;
1162 wtx.create_table(&seg_table).map_err(SqlError::Storage)?;
1163 wtx.table_insert(&seg_table, &ann_persist::segment_key(0), &header.encode())
1164 .map_err(SqlError::Storage)?;
1165 for (chunk_no, chunk) in ann_persist::chunks(&body) {
1166 wtx.table_insert(&seg_table, &ann_persist::segment_key(chunk_no), chunk)
1167 .map_err(SqlError::Storage)?;
1168 }
1169 wtx.commit().map_err(SqlError::Storage)?;
1170
1171 let cached = CachedAnnIndex {
1174 index,
1175 dicts: outcome.dicts,
1176 source: AnnIndexSource::Built { refusal: None },
1177 cached_gen: db.manager().commit_generation(),
1178 };
1179 let key = cache_key(&table_schema.name, spec.col_idx, spec.metric);
1180 let as_any: Arc<dyn Any + Send + Sync> = Arc::new(cached);
1181 schema.sql_caches.lock().insert(key, as_any);
1182
1183 Ok(ann_persist::AnnSegmentInfo {
1184 segment_b3,
1185 content_fingerprint: header.content_fingerprint,
1186 n,
1187 dim: spec.dim,
1188 metric_tag: header.metric_tag,
1189 chunk_count: header.chunk_count,
1190 })
1191}
1192
1193pub(crate) fn ann_cache_status(
1196 schema: &SchemaManager,
1197 table_schema: &TableSchema,
1198 column: &str,
1199) -> Result<Option<(AnnIndexSource, u64)>> {
1200 let col_lower = column.to_ascii_lowercase();
1201 let col_idx = table_schema
1202 .columns
1203 .iter()
1204 .position(|c| c.name == col_lower)
1205 .ok_or_else(|| SqlError::ColumnNotFound(column.to_string()))?;
1206 let guard = schema.sql_caches.lock();
1207 for metric in [AnnMetric::L2, AnnMetric::Inner, AnnMetric::Cosine] {
1208 let key = cache_key(&table_schema.name, col_idx, metric);
1209 if let Some(entry) = guard.get(&key) {
1210 if let Ok(c) = Arc::clone(entry).downcast::<CachedAnnIndex>() {
1211 return Ok(Some((c.source.clone(), c.cached_gen)));
1212 }
1213 }
1214 }
1215 Ok(None)
1216}
1217
1218pub(crate) fn ann_dml_gen_key(table_name: &str) -> String {
1222 format!("ann_dml_gen:{table_name}")
1223}
1224
1225fn marker_gen_locked(
1227 entries: &FxHashMap<String, Arc<dyn Any + Send + Sync>>,
1228 table_name: &str,
1229) -> Option<u64> {
1230 entries
1231 .get(&ann_dml_gen_key(table_name))
1232 .and_then(|e| e.downcast_ref::<u64>())
1233 .copied()
1234}
1235
1236fn lookup_cached(
1237 schema: &SchemaManager,
1238 cache_key: &str,
1239 table_name: &str,
1240) -> Result<Option<Arc<CachedAnnIndex>>> {
1241 let mut guard = schema.sql_caches.lock();
1242 let Some(entry) = guard.get(cache_key) else {
1243 return Ok(None);
1244 };
1245 let entry = Arc::clone(entry)
1246 .downcast::<CachedAnnIndex>()
1247 .map_err(|_| SqlError::InvalidValue(format!("ANN cache type mismatch for {cache_key}")))?;
1248 if marker_gen_locked(&guard, table_name).is_some_and(|g| entry.cached_gen < g) {
1249 guard.remove(cache_key);
1252 return Ok(None);
1253 }
1254 Ok(Some(entry))
1255}
1256
1257pub(super) fn cache_key(table_name: &str, col_idx: usize, metric: AnnMetric) -> String {
1258 let tag = match metric {
1259 AnnMetric::L2 => "l2",
1260 AnnMetric::Inner => "inner",
1261 AnnMetric::Cosine => "cosine",
1262 };
1263 format!(
1264 "ann:{}:{}:{}",
1265 table_name.to_ascii_lowercase(),
1266 col_idx,
1267 tag
1268 )
1269}
1270
1271fn ann_metric_to_prism(m: AnnMetric) -> Metric {
1272 match m {
1273 AnnMetric::L2 => Metric::L2,
1274 AnnMetric::Inner => Metric::InnerProduct,
1275 AnnMetric::Cosine => Metric::Cosine,
1276 }
1277}