1use std::any::Any;
6use std::cmp::Ordering;
7use std::collections::BinaryHeap;
8use std::sync::Arc;
9
10use citadel_txn::read_txn::ReadTxn;
11use citadel_txn::write_txn::WriteTxn;
12use citadel_vector::{AnnIndex, Filter, Metric};
13use rustc_hash::FxHashMap;
14
15use crate::encoding::{
16 decode_column_raw, decode_pk_integer, encode_int_key_into, encode_key_value,
17 encode_key_value_collated_into,
18};
19use crate::error::{Result, SqlError};
20use crate::eval::{eval_expr, is_truthy, ColumnMap, EvalCtx};
21use crate::parser::*;
22use crate::schema::SchemaManager;
23use crate::types::*;
24
25use super::aggregate::is_aggregate_expr;
26use super::ann_persist;
27use super::helpers::{decode_full_row, eval_const_expr, eval_const_int, project_rows};
28use super::window::has_any_window_function;
29
30type StorageResult<T> = std::result::Result<T, citadel_core::Error>;
31type ScanRow<'a> = dyn FnMut(&[u8], &[u8]) -> Result<bool> + 'a;
32type RawScanRow<'a> = dyn FnMut(&[u8], &[u8]) -> StorageResult<bool> + 'a;
33
34pub(super) trait AnnScan {
36 fn ann_scan(&mut self, table: &[u8], f: &mut ScanRow<'_>) -> Result<()>;
37 fn ann_get(&mut self, table: &[u8], key: &[u8]) -> Result<Option<Vec<u8>>>;
38 fn cache_generation(&self) -> Option<u64>;
42}
43
44fn bridge_scan(
46 scan: impl FnOnce(&mut RawScanRow<'_>) -> StorageResult<()>,
47 f: &mut ScanRow<'_>,
48) -> Result<()> {
49 let mut cb_err: Option<SqlError> = None;
50 scan(&mut |key, value| match f(key, value) {
51 Ok(go) => Ok(go),
52 Err(e) => {
53 cb_err = Some(e);
54 Ok(false)
55 }
56 })
57 .map_err(SqlError::Storage)?;
58 match cb_err {
59 Some(e) => Err(e),
60 None => Ok(()),
61 }
62}
63
64impl AnnScan for ReadTxn<'_> {
65 fn ann_scan(&mut self, table: &[u8], f: &mut ScanRow<'_>) -> Result<()> {
66 bridge_scan(|cb| self.table_scan_from(table, b"", cb), f)
67 }
68
69 fn ann_get(&mut self, table: &[u8], key: &[u8]) -> Result<Option<Vec<u8>>> {
70 self.table_get(table, key).map_err(SqlError::Storage)
71 }
72
73 fn cache_generation(&self) -> Option<u64> {
74 Some(self.commit_generation())
75 }
76}
77
78impl AnnScan for WriteTxn<'_> {
79 fn ann_scan(&mut self, table: &[u8], f: &mut ScanRow<'_>) -> Result<()> {
80 bridge_scan(|cb| self.table_scan_from(table, b"", cb), f)
81 }
82
83 fn ann_get(&mut self, table: &[u8], key: &[u8]) -> Result<Option<Vec<u8>>> {
84 self.table_get(table, key).map_err(SqlError::Storage)
85 }
86
87 fn cache_generation(&self) -> Option<u64> {
88 None
89 }
90}
91
92#[derive(Debug, Clone, PartialEq, Eq)]
96pub enum AnnIndexSource {
97 Built { refusal: Option<String> },
100 Loaded { segment_b3: [u8; 32] },
103}
104
105struct CachedAnnIndex {
107 index: AnnIndex,
108 dicts: Vec<FxHashMap<Vec<u8>, u32>>,
110 source: AnnIndexSource,
111 cached_gen: u64,
115}
116
117pub(super) struct AnnTopKPlan {
118 col_idx: usize,
119 dim: u16,
120 metric: AnnMetric,
121 query_vec: Vec<f32>,
122 k: usize,
123 offset: usize,
124 filter_cols: Vec<u16>,
126 pushable: Vec<(usize, Vec<Value>)>,
128 residual: Option<Expr>,
130}
131
132fn topk_shape_ok(stmt: &SelectStmt) -> bool {
134 stmt.order_by.len() == 1
135 && !stmt.order_by[0].descending
136 && stmt.limit.is_some()
137 && stmt.group_by.is_empty()
138 && stmt.having.is_none()
139 && stmt.joins.is_empty()
140 && !stmt.distinct
141 && !has_any_window_function(stmt)
142 && !stmt
143 .columns
144 .iter()
145 .any(|c| matches!(c, SelectColumn::Expr { expr, .. } if is_aggregate_expr(expr)))
146}
147
148impl AnnTopKPlan {
149 pub(super) fn try_new(stmt: &SelectStmt, table_schema: &TableSchema) -> Result<Option<Self>> {
150 if !topk_shape_ok(stmt) {
151 return Ok(None);
152 }
153 let ob = &stmt.order_by[0];
154
155 let (col_idx, dim, op_metric, query_vec) = match &ob.expr {
156 Expr::BinaryOp { left, op, right } => {
157 let op_metric = match op {
158 BinOp::VectorL2 => AnnMetric::L2,
159 BinOp::VectorInner => AnnMetric::Inner,
160 BinOp::VectorCosine => AnnMetric::Cosine,
161 _ => return Ok(None),
162 };
163 let col_name = match left.as_ref() {
164 Expr::Column(name) => name.to_ascii_lowercase(),
165 _ => return Ok(None),
166 };
167 let (col_idx, dim) = match table_schema
168 .columns
169 .iter()
170 .enumerate()
171 .find(|(_, c)| c.name.to_ascii_lowercase() == col_name)
172 {
173 Some((i, c)) => match c.data_type {
174 DataType::Vector { dim } => (i, dim),
175 _ => return Ok(None),
176 },
177 None => return Ok(None),
178 };
179 let col_map = ColumnMap::new(&table_schema.columns);
180 let ctx = EvalCtx::new(&col_map, &[]);
181 let v = match eval_expr(right, &ctx) {
182 Ok(Value::Vector(v)) => v,
183 _ => return Ok(None),
184 };
185 if v.len() != dim as usize {
186 return Err(SqlError::InvalidValue(format!(
187 "ANN query vector dim {} does not match column dim {}",
188 v.len(),
189 dim
190 )));
191 }
192 (col_idx, dim, op_metric, v.to_vec())
193 }
194 _ => return Ok(None),
195 };
196
197 let ann_index = table_schema.indices.iter().find(|ix| {
198 matches!(ix.kind,
199 IndexKind::Inverted(InvertedKind::Ann { metric }) if metric == op_metric
200 ) && ix.keys.len() == 1
201 && matches!(ix.keys[0],
202 IndexKey::Column { idx, .. } if idx as usize == col_idx
203 )
204 });
205 let Some(ann_index) = ann_index else {
206 return Ok(None);
207 };
208 let filter_cols = ann_index.ann_filter_cols.clone();
209
210 if table_schema.primary_key_columns.len() != 1 {
211 return Ok(None);
212 }
213 let pk_col = &table_schema.columns[table_schema.primary_key_columns[0] as usize];
214 if !matches!(pk_col.data_type, DataType::Integer) {
215 return Ok(None);
216 }
217
218 let mut pushable: Vec<(usize, Vec<Value>)> = Vec::new();
221 let mut residual_leaves: Vec<Expr> = Vec::new();
222 if let Some(w) = &stmt.where_clause {
223 split_where(
224 w,
225 &filter_cols,
226 table_schema,
227 &mut pushable,
228 &mut residual_leaves,
229 );
230 if pushable.is_empty() {
231 return Ok(None);
232 }
233 }
234 let residual = fold_and(residual_leaves);
235
236 let k_limit = eval_const_int(stmt.limit.as_ref().unwrap())?.max(0) as usize;
237 let offset = stmt
238 .offset
239 .as_ref()
240 .map(eval_const_int)
241 .transpose()?
242 .unwrap_or(0)
243 .max(0) as usize;
244 if k_limit == 0 {
245 return Ok(None);
246 }
247
248 Ok(Some(Self {
249 col_idx,
250 dim,
251 metric: op_metric,
252 query_vec,
253 k: k_limit,
254 offset,
255 filter_cols,
256 pushable,
257 residual,
258 }))
259 }
260
261 pub(super) fn execute_with_read(
262 &self,
263 rtx: &mut ReadTxn<'_>,
264 schema: &SchemaManager,
265 stmt: &SelectStmt,
266 table_schema: &TableSchema,
267 ) -> Result<ExecutionResult> {
268 let cache_key = cache_key(&table_schema.name, self.col_idx, self.metric);
269 let Some(cached) = self.load_or_build_index(rtx, schema, &cache_key, table_schema)? else {
271 return empty_result(table_schema, stmt);
272 };
273 self.run_query(rtx, &cached, stmt, table_schema)
274 }
275
276 fn run_query(
278 &self,
279 txn: &mut dyn AnnScan,
280 cached: &CachedAnnIndex,
281 stmt: &SelectStmt,
282 table_schema: &TableSchema,
283 ) -> Result<ExecutionResult> {
284 let mut constraints: Vec<(usize, Vec<u32>)> = Vec::with_capacity(self.pushable.len());
287 for (dim, values) in &self.pushable {
288 let dict = &cached.dicts[*dim];
289 let coll = table_schema.columns[self.filter_cols[*dim] as usize].collation;
290 let mut codes = Vec::with_capacity(values.len());
291 let mut canon = Vec::with_capacity(16);
292 for v in values {
293 canon.clear();
294 encode_key_value_collated_into(v, coll, &mut canon);
295 if let Some(&code) = dict.get(canon.as_slice()) {
296 codes.push(code);
297 }
298 }
299 if codes.is_empty() {
300 return empty_result(table_schema, stmt);
301 }
302 constraints.push((*dim, codes));
303 }
304 let filter = if constraints.is_empty() {
305 Filter::none()
306 } else {
307 Filter::new(constraints)
308 };
309
310 let want = self.k.saturating_add(self.offset).max(1);
311 let mut rows = self.collect_survivors(txn, &cached.index, &filter, table_schema, want)?;
312
313 if self.offset >= rows.len() {
314 rows.clear();
315 } else if self.offset > 0 {
316 rows = rows.split_off(self.offset);
317 }
318 rows.truncate(self.k);
319
320 let (col_names, projected) = project_rows(&table_schema.columns, &stmt.columns, rows)?;
321 Ok(ExecutionResult::Query(QueryResult {
322 columns: col_names,
323 rows: projected,
324 }))
325 }
326
327 fn collect_survivors(
331 &self,
332 txn: &mut dyn AnnScan,
333 index: &AnnIndex,
334 filter: &Filter,
335 table_schema: &TableSchema,
336 want: usize,
337 ) -> Result<Vec<Vec<Value>>> {
338 let col_map = ColumnMap::new(&table_schema.columns);
339 let max_target = index.indexed_len().max(1);
340 let mut key_buf: Vec<u8> = Vec::with_capacity(10);
341 let mut target = want;
342 loop {
343 target = target.min(max_target);
344 let hits = index.search_filtered_default_ef(&self.query_vec, target, filter);
345 let mut survivors: Vec<Vec<Value>> = Vec::with_capacity(want);
346 for (id, _dist) in &hits {
347 encode_int_key_into(*id as i64, &mut key_buf);
348 let Some(row_bytes) = txn.ann_get(table_schema.name.as_bytes(), &key_buf)? else {
349 continue;
350 };
351 let row = decode_full_row(table_schema, &key_buf, &row_bytes)?;
352 let keep = match &self.residual {
353 None => true,
354 Some(expr) => {
355 let ctx = EvalCtx::new(&col_map, &row);
356 is_truthy(&eval_expr(expr, &ctx)?)
357 }
358 };
359 if keep {
360 survivors.push(row);
361 if survivors.len() >= want {
362 break;
363 }
364 }
365 }
366 if survivors.len() >= want || target >= max_target || hits.len() < target {
369 return Ok(survivors);
370 }
371 target = target.saturating_mul(2);
372 }
373 }
374
375 fn load_or_build_index(
376 &self,
377 txn: &mut dyn AnnScan,
378 schema: &SchemaManager,
379 cache_key: &str,
380 table_schema: &TableSchema,
381 ) -> Result<Option<Arc<CachedAnnIndex>>> {
382 if let Some(existing) = lookup_cached(schema, cache_key, &table_schema.name)? {
383 return Ok(Some(existing));
384 }
385 let spec = AnnSpec {
386 col_idx: self.col_idx,
387 dim: self.dim,
388 metric: self.metric,
389 filter_cols: self.filter_cols.clone(),
390 };
391 load_or_build(txn, schema, cache_key, table_schema, &spec)
392 }
393}
394
395pub(super) struct AnnSpec {
399 pub col_idx: usize,
400 pub dim: u16,
401 pub metric: AnnMetric,
402 pub filter_cols: Vec<u16>,
403}
404
405impl AnnSpec {
406 fn metric_tag(&self) -> u8 {
407 citadel_vector::segment::metric_tag(ann_metric_to_prism(self.metric))
408 }
409}
410
411struct ScanOutcome {
415 rows: Vec<(u64, Vec<f32>, Vec<u32>)>,
416 dicts: Vec<FxHashMap<Vec<u8>, u32>>,
417 fingerprint: [u8; 32],
418}
419
420fn scan_rows(
421 txn: &mut dyn AnnScan,
422 table_schema: &TableSchema,
423 spec: &AnnSpec,
424) -> Result<ScanOutcome> {
425 let non_pk = table_schema.non_pk_indices();
426 let enc_pos = table_schema.encoding_positions();
427 let nonpk_order = non_pk
428 .iter()
429 .position(|&i| i == spec.col_idx)
430 .ok_or_else(|| {
431 SqlError::InvalidValue("vector column must be non-PK for ANN build".into())
432 })?;
433 let enc_idx = enc_pos[nonpk_order] as usize;
434
435 let num_attrs = spec.filter_cols.len();
436 let extracts: Vec<Extract> = spec
437 .filter_cols
438 .iter()
439 .map(|&c| extract_plan(c, table_schema, non_pk, enc_pos))
440 .collect::<Result<_>>()?;
441 let collations: Vec<Collation> = spec
446 .filter_cols
447 .iter()
448 .map(|&c| table_schema.columns[c as usize].collation)
449 .collect();
450 let mut dicts: Vec<FxHashMap<Vec<u8>, u32>> = vec![FxHashMap::default(); num_attrs];
451 let mut fp = ann_persist::FingerprintHasher::new(
452 &table_schema.name,
453 spec.col_idx as u32,
454 &spec
455 .filter_cols
456 .iter()
457 .map(|&c| c as u32)
458 .collect::<Vec<_>>(),
459 spec.dim,
460 spec.metric_tag(),
461 );
462 let mut rows: Vec<(u64, Vec<f32>, Vec<u32>)> = Vec::new();
463
464 txn.ann_scan(table_schema.name.as_bytes(), &mut |key, value| {
465 let vector = match decode_column_raw(value, enc_idx)?.to_value() {
466 Value::Vector(arr) => Some(arr.to_vec()),
467 Value::Null => None, _ => {
469 return Err(SqlError::InvalidValue(
470 "ANN column produced non-vector value".into(),
471 ))
472 }
473 };
474 let mut filter_vals: Vec<Value> = Vec::with_capacity(num_attrs);
475 for ex in &extracts {
476 filter_vals.push(ex.extract(key, value)?);
477 }
478 let encoded_filters: Vec<Vec<u8>> = filter_vals.iter().map(encode_key_value).collect();
479 let vec_bytes: Vec<u8> = vector
480 .as_deref()
481 .unwrap_or(&[])
482 .iter()
483 .flat_map(|f| f.to_le_bytes())
484 .collect();
485 fp.row(
486 key,
487 &vec_bytes,
488 &encoded_filters
489 .iter()
490 .map(Vec::as_slice)
491 .collect::<Vec<_>>(),
492 );
493 let Some(vector) = vector else {
494 return Ok(true);
495 };
496 let id = decode_pk_integer(key)? as u64;
497 let mut codes: Vec<u32> = Vec::with_capacity(num_attrs);
498 for (j, v) in filter_vals.iter().enumerate() {
499 let mut canon = Vec::with_capacity(16);
500 encode_key_value_collated_into(v, collations[j], &mut canon);
501 let next = dicts[j].len() as u32;
502 codes.push(*dicts[j].entry(canon).or_insert(next));
503 }
504 rows.push((id, vector, codes));
505 Ok(true)
506 })?;
507
508 Ok(ScanOutcome {
509 rows,
510 dicts,
511 fingerprint: fp.finish(),
512 })
513}
514
515fn build_index(
517 txn: &mut dyn AnnScan,
518 table_schema: &TableSchema,
519 spec: &AnnSpec,
520 refusal: Option<String>,
521 cached_gen: u64,
522) -> Result<Option<CachedAnnIndex>> {
523 let outcome = scan_rows(txn, table_schema, spec)?;
524 if outcome.rows.is_empty() {
525 return Ok(None);
526 }
527 let index = AnnIndex::build_with_attrs(
528 outcome.rows,
529 spec.filter_cols.len(),
530 ann_metric_to_prism(spec.metric),
531 spec.dim,
532 )
533 .map_err(|e| SqlError::InvalidValue(format!("ANN build failed: {e}")))?;
534 Ok(Some(CachedAnnIndex {
535 index,
536 dicts: outcome.dicts,
537 source: AnnIndexSource::Built { refusal },
538 cached_gen,
539 }))
540}
541
542enum LoadOutcome {
546 Loaded(Box<CachedAnnIndex>),
547 NoSegment,
548 Refused { reason: String, corrupt: bool },
549}
550
551fn try_load_segment(
555 txn: &mut dyn AnnScan,
556 table_schema: &TableSchema,
557 spec: &AnnSpec,
558 cached_gen: u64,
559) -> Result<LoadOutcome> {
560 let seg_table = ann_persist::segment_table_name(&table_schema.name);
561 let header_bytes = match txn.ann_get(&seg_table, &ann_persist::segment_key(0)) {
562 Ok(Some(b)) => b,
563 Ok(None) | Err(_) => return Ok(LoadOutcome::NoSegment),
565 };
566 let refuse = |reason: String, corrupt: bool| Ok(LoadOutcome::Refused { reason, corrupt });
567 let header = match ann_persist::SegmentHeader::decode(&header_bytes) {
568 Ok(h) => h,
569 Err(e) => return refuse(format!("header: {e}"), true),
570 };
571 if header.format_version != ann_persist::ANNSEG_FORMAT_VERSION {
572 return refuse(
573 format!("format v{} (this binary reads v1)", header.format_version),
574 false,
575 );
576 }
577 let active_cfg = citadel_vector::segment::prism_config_hash(&AnnIndex::active_config(
578 ann_metric_to_prism(spec.metric),
579 ));
580 if header.prism_config_hash != active_cfg {
581 return refuse(
582 "PRISM config drift (segment built by another geometry)".into(),
583 false,
584 );
585 }
586 if header.dim != spec.dim
587 || header.metric_tag != spec.metric_tag()
588 || header.col_idx != spec.col_idx as u32
589 || header.filter_cols
590 != spec
591 .filter_cols
592 .iter()
593 .map(|&c| c as u32)
594 .collect::<Vec<_>>()
595 {
596 return refuse(
597 "index identity mismatch (column/metric/filter set)".into(),
598 false,
599 );
600 }
601
602 let mut body = Vec::new();
603 for chunk_no in 1..=header.chunk_count {
604 match txn.ann_get(&seg_table, &ann_persist::segment_key(chunk_no)) {
605 Ok(Some(c)) => body.extend_from_slice(&c),
606 _ => return refuse(format!("missing chunk {chunk_no}"), true),
607 }
608 }
609 if *blake3::hash(&body).as_bytes() != header.segment_b3 {
610 return refuse("segment body BLAKE3 mismatch (corrupt)".into(), true);
611 }
612 let parts = match citadel_vector::segment::decode(&body) {
613 Ok(p) => p,
614 Err(e) => return refuse(format!("segment decode: {e}"), true),
615 };
616 if parts.n() as u64 != header.n || parts.dim() != header.dim {
617 return refuse("segment body disagrees with header counts".into(), true);
618 }
619
620 let slot_of = parts.internal_of_row();
623 let dim = spec.dim as usize;
624 let mut vectors = vec![0.0f32; parts.n() * dim];
625 let mut filled = 0usize;
626 let outcome = scan_rows_rehydrate(txn, table_schema, spec, &mut |row_id, vector| {
627 let Some(&slot) = slot_of.get(&row_id) else {
628 return false; };
630 vectors[slot as usize * dim..(slot as usize + 1) * dim].copy_from_slice(vector);
631 filled += 1;
632 true
633 })?;
634 let Some(fingerprint) = outcome else {
635 return refuse(
636 "a scanned row is unknown to the segment (stale)".into(),
637 false,
638 );
639 };
640 if fingerprint != header.content_fingerprint {
641 return refuse("content fingerprint mismatch (stale)".into(), false);
642 }
643 let index = match parts.into_index(vectors, filled) {
644 Ok(i) => i,
645 Err(e) => return refuse(format!("rehydration: {e}"), true),
646 };
647 Ok(LoadOutcome::Loaded(Box::new(CachedAnnIndex {
648 index,
649 dicts: header.dict_maps(),
650 source: AnnIndexSource::Loaded {
651 segment_b3: header.segment_b3,
652 },
653 cached_gen,
654 })))
655}
656
657fn scan_rows_rehydrate(
661 txn: &mut dyn AnnScan,
662 table_schema: &TableSchema,
663 spec: &AnnSpec,
664 place: &mut dyn FnMut(u64, &[f32]) -> bool,
665) -> Result<Option<[u8; 32]>> {
666 let non_pk = table_schema.non_pk_indices();
667 let enc_pos = table_schema.encoding_positions();
668 let nonpk_order = non_pk
669 .iter()
670 .position(|&i| i == spec.col_idx)
671 .ok_or_else(|| {
672 SqlError::InvalidValue("vector column must be non-PK for ANN build".into())
673 })?;
674 let enc_idx = enc_pos[nonpk_order] as usize;
675 let extracts: Vec<Extract> = spec
676 .filter_cols
677 .iter()
678 .map(|&c| extract_plan(c, table_schema, non_pk, enc_pos))
679 .collect::<Result<_>>()?;
680 let mut fp = ann_persist::FingerprintHasher::new(
681 &table_schema.name,
682 spec.col_idx as u32,
683 &spec
684 .filter_cols
685 .iter()
686 .map(|&c| c as u32)
687 .collect::<Vec<_>>(),
688 spec.dim,
689 spec.metric_tag(),
690 );
691 let mut unknown_row = false;
692
693 txn.ann_scan(table_schema.name.as_bytes(), &mut |key, value| {
694 let vector = match decode_column_raw(value, enc_idx)?.to_value() {
695 Value::Vector(arr) => Some(arr.to_vec()),
696 Value::Null => None,
697 _ => {
698 return Err(SqlError::InvalidValue(
699 "ANN column produced non-vector value".into(),
700 ))
701 }
702 };
703 let mut encoded_filters: Vec<Vec<u8>> = Vec::with_capacity(extracts.len());
704 for ex in &extracts {
705 encoded_filters.push(encode_key_value(&ex.extract(key, value)?));
706 }
707 let vec_bytes: Vec<u8> = vector
708 .as_deref()
709 .unwrap_or(&[])
710 .iter()
711 .flat_map(|f| f.to_le_bytes())
712 .collect();
713 fp.row(
714 key,
715 &vec_bytes,
716 &encoded_filters
717 .iter()
718 .map(Vec::as_slice)
719 .collect::<Vec<_>>(),
720 );
721 if let Some(vector) = vector {
722 let id = decode_pk_integer(key)? as u64;
723 if !place(id, &vector) {
724 unknown_row = true;
725 return Ok(false);
726 }
727 }
728 Ok(true)
729 })?;
730
731 Ok(if unknown_row { None } else { Some(fp.finish()) })
732}
733
734fn load_or_build(
739 txn: &mut dyn AnnScan,
740 schema: &SchemaManager,
741 cache_key: &str,
742 table_schema: &TableSchema,
743 spec: &AnnSpec,
744) -> Result<Option<Arc<CachedAnnIndex>>> {
745 let gen = txn.cache_generation();
746 let cached_gen = gen.unwrap_or(u64::MAX);
747 let loaded = match try_load_segment(txn, table_schema, spec, cached_gen)? {
748 LoadOutcome::Loaded(c) => Some(*c),
749 LoadOutcome::NoSegment => None,
750 LoadOutcome::Refused { reason, corrupt } => {
751 if corrupt {
752 eprintln!(
753 "citadel-sql: ANN segment for `{}` REFUSED as corrupt ({reason}); \
754 rebuilding from scan - investigate before re-persisting",
755 table_schema.name
756 );
757 }
758 match build_index(txn, table_schema, spec, Some(reason), cached_gen)? {
761 Some(c) => Some(c),
762 None => return Ok(None),
763 }
764 }
765 };
766 let built = match loaded {
767 Some(c) => c,
768 None => match build_index(txn, table_schema, spec, None, cached_gen)? {
769 Some(c) => c,
770 None => return Ok(None),
771 },
772 };
773 let arc: Arc<CachedAnnIndex> = Arc::new(built);
774 if gen.is_none() {
775 return Ok(Some(arc));
777 }
778 let mut guard = schema.sql_caches.lock();
779 if let Some(existing) = guard.get(cache_key) {
780 return Arc::clone(existing)
782 .downcast::<CachedAnnIndex>()
783 .map(Some)
784 .map_err(|_| {
785 SqlError::InvalidValue(format!("ANN cache type mismatch for {cache_key}"))
786 });
787 }
788 let marker = marker_gen_locked(&guard, &table_schema.name);
789 if marker.is_some_and(|g| arc.cached_gen < g) {
790 return Ok(Some(arc));
793 }
794 let as_any: Arc<dyn Any + Send + Sync> = arc.clone();
795 guard.insert(cache_key.to_string(), as_any);
796 Ok(Some(arc))
797}
798
799pub(super) struct VectorTopKPlan {
802 order_expr: Expr,
803 where_clause: Option<Expr>,
804 k: usize,
805 offset: usize,
806 nulls_first: bool,
807}
808
809struct Ranked {
812 dist: f64,
813 seq: u64,
814 row: Vec<Value>,
815}
816
817impl PartialEq for Ranked {
818 fn eq(&self, other: &Self) -> bool {
819 self.cmp(other) == Ordering::Equal
820 }
821}
822impl Eq for Ranked {}
823impl PartialOrd for Ranked {
824 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
825 Some(self.cmp(other))
826 }
827}
828impl Ord for Ranked {
829 fn cmp(&self, other: &Self) -> Ordering {
830 self.dist
831 .total_cmp(&other.dist)
832 .then_with(|| self.seq.cmp(&other.seq))
833 }
834}
835
836impl VectorTopKPlan {
837 pub(super) fn try_new(stmt: &SelectStmt, table_schema: &TableSchema) -> Result<Option<Self>> {
838 if !topk_shape_ok(stmt) {
839 return Ok(None);
840 }
841 let ob = &stmt.order_by[0];
842 let Expr::BinaryOp { left, op, .. } = &ob.expr else {
843 return Ok(None);
844 };
845 if !matches!(
846 op,
847 BinOp::VectorL2 | BinOp::VectorInner | BinOp::VectorCosine
848 ) {
849 return Ok(None);
850 }
851 let Expr::Column(name) = left.as_ref() else {
853 return Ok(None);
854 };
855 let name = name.to_ascii_lowercase();
856 let is_vector_col = table_schema.columns.iter().any(|c| {
857 c.name.to_ascii_lowercase() == name && matches!(c.data_type, DataType::Vector { .. })
858 });
859 if !is_vector_col {
860 return Ok(None);
861 }
862
863 let k = eval_const_int(stmt.limit.as_ref().unwrap())?.max(0) as usize;
864 if k == 0 {
865 return Ok(None);
866 }
867 let offset = stmt
868 .offset
869 .as_ref()
870 .map(eval_const_int)
871 .transpose()?
872 .unwrap_or(0)
873 .max(0) as usize;
874
875 Ok(Some(Self {
876 order_expr: ob.expr.clone(),
877 where_clause: stmt.where_clause.clone(),
878 k,
879 offset,
880 nulls_first: ob.nulls_first.unwrap_or(true),
882 }))
883 }
884
885 pub(super) fn execute(
886 &self,
887 txn: &mut dyn AnnScan,
888 table_schema: &TableSchema,
889 stmt: &SelectStmt,
890 ) -> Result<ExecutionResult> {
891 let want = self.k.saturating_add(self.offset);
892 let col_map = ColumnMap::new(&table_schema.columns);
893 let null_dist = if self.nulls_first {
895 f64::NEG_INFINITY
896 } else {
897 f64::INFINITY
898 };
899 let mut heap: BinaryHeap<Ranked> = BinaryHeap::new();
900 let mut seq: u64 = 0;
901
902 txn.ann_scan(table_schema.name.as_bytes(), &mut |key, value| {
903 let row = decode_full_row(table_schema, key, value)?;
904 let ctx = EvalCtx::new(&col_map, &row);
905 if let Some(w) = &self.where_clause {
906 if !is_truthy(&eval_expr(w, &ctx)?) {
907 return Ok(true);
908 }
909 }
910 let dist = match eval_expr(&self.order_expr, &ctx)? {
911 Value::Real(d) => d,
912 Value::Integer(i) => i as f64,
913 Value::Null => null_dist,
914 other => {
915 return Err(SqlError::InvalidValue(format!(
916 "ORDER BY vector distance produced a non-numeric {}",
917 other.data_type()
918 )))
919 }
920 };
921 let cand = Ranked { dist, seq, row };
922 seq += 1;
923 if heap.len() < want {
925 heap.push(cand);
926 } else if heap.peek().is_some_and(|top| cand < *top) {
927 heap.pop();
928 heap.push(cand);
929 }
930 Ok(true)
931 })?;
932
933 let mut rows: Vec<Vec<Value>> = heap.into_sorted_vec().into_iter().map(|r| r.row).collect();
934 if self.offset >= rows.len() {
935 rows.clear();
936 } else if self.offset > 0 {
937 rows = rows.split_off(self.offset);
938 }
939 rows.truncate(self.k);
940
941 let (col_names, projected) = project_rows(&table_schema.columns, &stmt.columns, rows)?;
942 Ok(ExecutionResult::Query(QueryResult {
943 columns: col_names,
944 rows: projected,
945 }))
946 }
947}
948
949enum Extract {
951 Pk,
953 NonPk(usize),
955}
956
957impl Extract {
958 fn extract(&self, key: &[u8], value: &[u8]) -> Result<Value> {
959 match self {
960 Extract::Pk => Ok(Value::Integer(decode_pk_integer(key)?)),
961 Extract::NonPk(ei) => Ok(decode_column_raw(value, *ei)?.to_value()),
962 }
963 }
964}
965
966fn extract_plan(
967 col: u16,
968 table_schema: &TableSchema,
969 non_pk: &[usize],
970 enc_pos: &[u16],
971) -> Result<Extract> {
972 if table_schema.primary_key_columns.contains(&col) {
973 return Ok(Extract::Pk);
974 }
975 let order = non_pk
976 .iter()
977 .position(|&i| i == col as usize)
978 .ok_or_else(|| SqlError::InvalidValue("ANN filter column not found in row".into()))?;
979 Ok(Extract::NonPk(enc_pos[order] as usize))
980}
981
982fn split_where(
985 expr: &Expr,
986 filter_cols: &[u16],
987 table_schema: &TableSchema,
988 pushable: &mut Vec<(usize, Vec<Value>)>,
989 residual: &mut Vec<Expr>,
990) {
991 if let Expr::BinaryOp {
992 left,
993 op: BinOp::And,
994 right,
995 } = expr
996 {
997 split_where(left, filter_cols, table_schema, pushable, residual);
998 split_where(right, filter_cols, table_schema, pushable, residual);
999 return;
1000 }
1001 match classify_leaf(expr, filter_cols, table_schema) {
1002 Some(constraint) => pushable.push(constraint),
1003 None => residual.push(expr.clone()),
1004 }
1005}
1006
1007enum Coerced {
1009 Exact(Value),
1011 NeverMatches,
1014 Residual,
1018}
1019
1020fn coerce_pushdown_literal(val: Value, col_type: DataType) -> Coerced {
1021 const EXACT_F64_INT: f64 = 9_007_199_254_740_992.0;
1024 if val.is_null() {
1025 return Coerced::Residual;
1026 }
1027 if val.data_type() == col_type {
1028 return Coerced::Exact(val);
1029 }
1030 match (val, col_type) {
1031 (Value::Real(r), DataType::Integer) => {
1032 if r.is_nan() || r.is_infinite() {
1033 Coerced::NeverMatches
1034 } else if r.abs() > EXACT_F64_INT {
1035 Coerced::Residual
1036 } else if r.fract() == 0.0 {
1037 Coerced::Exact(Value::Integer(r as i64))
1038 } else {
1039 Coerced::NeverMatches
1040 }
1041 }
1042 (Value::Integer(i), DataType::Real) => {
1043 if i.unsigned_abs() <= EXACT_F64_INT as u64 {
1044 Coerced::Exact(Value::Real(i as f64))
1045 } else {
1046 Coerced::Residual
1047 }
1048 }
1049 _ => Coerced::Residual,
1050 }
1051}
1052
1053fn classify_leaf(
1058 leaf: &Expr,
1059 filter_cols: &[u16],
1060 table_schema: &TableSchema,
1061) -> Option<(usize, Vec<Value>)> {
1062 let (col_expr, rhs): (&Expr, Vec<&Expr>) = match leaf {
1063 Expr::BinaryOp {
1064 left,
1065 op: BinOp::Eq,
1066 right,
1067 } => (left, vec![right.as_ref()]),
1068 Expr::InList {
1069 expr,
1070 list,
1071 negated: false,
1072 } => (expr, list.iter().collect()),
1073 _ => return None,
1074 };
1075 let dim = filter_dim(col_expr, filter_cols, table_schema)?;
1076 let col_type = table_schema.columns[filter_cols[dim] as usize].data_type;
1077 let mut vals = Vec::with_capacity(rhs.len());
1078 for e in rhs {
1079 match coerce_pushdown_literal(eval_const_expr(e).ok()?, col_type) {
1080 Coerced::Exact(v) => vals.push(v),
1081 Coerced::NeverMatches => {}
1082 Coerced::Residual => return None,
1083 }
1084 }
1085 Some((dim, vals))
1086}
1087
1088fn filter_dim(expr: &Expr, filter_cols: &[u16], table_schema: &TableSchema) -> Option<usize> {
1091 let name = match expr {
1092 Expr::Column(c) => c.to_ascii_lowercase(),
1093 Expr::QualifiedColumn { column, .. } => column.to_ascii_lowercase(),
1094 _ => return None,
1095 };
1096 let col_idx = table_schema
1097 .columns
1098 .iter()
1099 .position(|c| c.name.to_ascii_lowercase() == name)? as u16;
1100 filter_cols.iter().position(|&c| c == col_idx)
1101}
1102
1103fn fold_and(mut leaves: Vec<Expr>) -> Option<Expr> {
1104 if leaves.is_empty() {
1105 return None;
1106 }
1107 let first = leaves.remove(0);
1108 Some(leaves.into_iter().fold(first, |acc, e| Expr::BinaryOp {
1109 left: Box::new(acc),
1110 op: BinOp::And,
1111 right: Box::new(e),
1112 }))
1113}
1114
1115fn empty_result(table_schema: &TableSchema, stmt: &SelectStmt) -> Result<ExecutionResult> {
1116 let (col_names, projected) = project_rows(&table_schema.columns, &stmt.columns, Vec::new())?;
1117 Ok(ExecutionResult::Query(QueryResult {
1118 columns: col_names,
1119 rows: projected,
1120 }))
1121}
1122
1123pub(crate) fn persist_ann_index(
1131 db: &citadel::Database,
1132 schema: &SchemaManager,
1133 table_schema: &TableSchema,
1134 column: &str,
1135) -> Result<ann_persist::AnnSegmentInfo> {
1136 let col_lower = column.to_ascii_lowercase();
1137 let col_idx = table_schema
1138 .columns
1139 .iter()
1140 .position(|c| c.name == col_lower)
1141 .ok_or_else(|| SqlError::ColumnNotFound(column.to_string()))?;
1142 let DataType::Vector { dim } = table_schema.columns[col_idx].data_type else {
1143 return Err(SqlError::InvalidValue(format!(
1144 "column `{column}` is not VECTOR(N)"
1145 )));
1146 };
1147 if table_schema.primary_key_columns.len() != 1
1151 || !matches!(
1152 table_schema.columns[table_schema.primary_key_columns[0] as usize].data_type,
1153 DataType::Integer
1154 )
1155 {
1156 return Err(SqlError::InvalidValue(
1157 "ANN persistence requires a single INTEGER primary key (same rule as the \
1158 ANN query plan)"
1159 .into(),
1160 ));
1161 }
1162 let ann_index = table_schema
1163 .indices
1164 .iter()
1165 .find(|ix| {
1166 matches!(ix.kind, IndexKind::Inverted(InvertedKind::Ann { .. }))
1167 && ix.keys.len() == 1
1168 && matches!(ix.keys[0], IndexKey::Column { idx, .. } if idx as usize == col_idx)
1169 })
1170 .ok_or_else(|| SqlError::InvalidValue(format!("no ANN index declared on `{column}`")))?;
1171 let IndexKind::Inverted(InvertedKind::Ann { metric }) = ann_index.kind else {
1172 unreachable!("matched above");
1173 };
1174 let spec = AnnSpec {
1175 col_idx,
1176 dim,
1177 metric,
1178 filter_cols: ann_index.ann_filter_cols.clone(),
1179 };
1180
1181 let mut wtx = db.begin_write().map_err(SqlError::Storage)?;
1182 let outcome = scan_rows(&mut wtx, table_schema, &spec)?;
1183 if outcome.rows.is_empty() {
1184 return Err(SqlError::InvalidValue(
1185 "nothing to persist: the table has no indexable (non-NULL) vectors".into(),
1186 ));
1187 }
1188 let n = outcome.rows.len() as u64;
1189 let index = AnnIndex::build_with_attrs(
1190 outcome.rows,
1191 spec.filter_cols.len(),
1192 ann_metric_to_prism(spec.metric),
1193 spec.dim,
1194 )
1195 .map_err(|e| SqlError::InvalidValue(format!("ANN build failed: {e}")))?;
1196
1197 let body = citadel_vector::segment::encode(&index);
1198 let segment_b3 = *blake3::hash(&body).as_bytes();
1199 let dicts_ordered: Vec<Vec<(Vec<u8>, u32)>> = outcome
1202 .dicts
1203 .iter()
1204 .map(|d| {
1205 let mut entries: Vec<(Vec<u8>, u32)> = d.iter().map(|(k, &v)| (k.clone(), v)).collect();
1206 entries.sort_by_key(|&(_, code)| code);
1207 entries
1208 })
1209 .collect();
1210 let header = ann_persist::SegmentHeader {
1211 format_version: ann_persist::ANNSEG_FORMAT_VERSION,
1212 prism_config_hash: ann_persist::active_config_hash(ann_metric_to_prism(spec.metric)),
1213 dim: spec.dim,
1214 metric_tag: spec.metric_tag(),
1215 n,
1216 snapshot_max: index.snapshot_max,
1217 col_idx: spec.col_idx as u32,
1218 filter_cols: spec.filter_cols.iter().map(|&c| c as u32).collect(),
1219 dicts: dicts_ordered,
1220 content_fingerprint: outcome.fingerprint,
1221 segment_b3,
1222 chunk_count: body.len().div_ceil(ann_persist::CHUNK_BYTES) as u32,
1223 writer: format!("citadel-sql {}", env!("CARGO_PKG_VERSION")),
1224 };
1225
1226 let seg_table = ann_persist::segment_table_name(&table_schema.name);
1227 ann_persist::purge_segment(&mut wtx, &table_schema.name)?;
1228 wtx.create_table(&seg_table).map_err(SqlError::Storage)?;
1229 wtx.table_insert(&seg_table, &ann_persist::segment_key(0), &header.encode())
1230 .map_err(SqlError::Storage)?;
1231 for (chunk_no, chunk) in ann_persist::chunks(&body) {
1232 wtx.table_insert(&seg_table, &ann_persist::segment_key(chunk_no), chunk)
1233 .map_err(SqlError::Storage)?;
1234 }
1235 wtx.commit().map_err(SqlError::Storage)?;
1236
1237 let cached = CachedAnnIndex {
1240 index,
1241 dicts: outcome.dicts,
1242 source: AnnIndexSource::Built { refusal: None },
1243 cached_gen: db.manager().commit_generation(),
1244 };
1245 let key = cache_key(&table_schema.name, spec.col_idx, spec.metric);
1246 let as_any: Arc<dyn Any + Send + Sync> = Arc::new(cached);
1247 schema.sql_caches.lock().insert(key, as_any);
1248
1249 Ok(ann_persist::AnnSegmentInfo {
1250 segment_b3,
1251 content_fingerprint: header.content_fingerprint,
1252 n,
1253 dim: spec.dim,
1254 metric_tag: header.metric_tag,
1255 chunk_count: header.chunk_count,
1256 })
1257}
1258
1259pub(crate) fn ann_cache_status(
1262 schema: &SchemaManager,
1263 table_schema: &TableSchema,
1264 column: &str,
1265) -> Result<Option<(AnnIndexSource, u64)>> {
1266 let col_lower = column.to_ascii_lowercase();
1267 let col_idx = table_schema
1268 .columns
1269 .iter()
1270 .position(|c| c.name == col_lower)
1271 .ok_or_else(|| SqlError::ColumnNotFound(column.to_string()))?;
1272 let guard = schema.sql_caches.lock();
1273 for metric in [AnnMetric::L2, AnnMetric::Inner, AnnMetric::Cosine] {
1274 let key = cache_key(&table_schema.name, col_idx, metric);
1275 if let Some(entry) = guard.get(&key) {
1276 if let Ok(c) = Arc::clone(entry).downcast::<CachedAnnIndex>() {
1277 return Ok(Some((c.source.clone(), c.cached_gen)));
1278 }
1279 }
1280 }
1281 Ok(None)
1282}
1283
1284pub(crate) fn ann_dml_gen_key(table_name: &str) -> String {
1288 format!("ann_dml_gen:{table_name}")
1289}
1290
1291fn marker_gen_locked(
1293 entries: &FxHashMap<String, Arc<dyn Any + Send + Sync>>,
1294 table_name: &str,
1295) -> Option<u64> {
1296 entries
1297 .get(&ann_dml_gen_key(table_name))
1298 .and_then(|e| e.downcast_ref::<u64>())
1299 .copied()
1300}
1301
1302fn lookup_cached(
1303 schema: &SchemaManager,
1304 cache_key: &str,
1305 table_name: &str,
1306) -> Result<Option<Arc<CachedAnnIndex>>> {
1307 let mut guard = schema.sql_caches.lock();
1308 let Some(entry) = guard.get(cache_key) else {
1309 return Ok(None);
1310 };
1311 let entry = Arc::clone(entry)
1312 .downcast::<CachedAnnIndex>()
1313 .map_err(|_| SqlError::InvalidValue(format!("ANN cache type mismatch for {cache_key}")))?;
1314 if marker_gen_locked(&guard, table_name).is_some_and(|g| entry.cached_gen < g) {
1315 guard.remove(cache_key);
1318 return Ok(None);
1319 }
1320 Ok(Some(entry))
1321}
1322
1323pub(super) fn cache_key(table_name: &str, col_idx: usize, metric: AnnMetric) -> String {
1324 let tag = match metric {
1325 AnnMetric::L2 => "l2",
1326 AnnMetric::Inner => "inner",
1327 AnnMetric::Cosine => "cosine",
1328 };
1329 format!(
1330 "ann:{}:{}:{}",
1331 table_name.to_ascii_lowercase(),
1332 col_idx,
1333 tag
1334 )
1335}
1336
1337fn ann_metric_to_prism(m: AnnMetric) -> Metric {
1338 match m {
1339 AnnMetric::L2 => Metric::L2,
1340 AnnMetric::Inner => Metric::InnerProduct,
1341 AnnMetric::Cosine => Metric::Cosine,
1342 }
1343}