Skip to main content

citadel_sql/executor/
ann_topk.rs

1//! Plans for `SELECT ... ORDER BY col <dist> :q LIMIT k`: [`AnnTopKPlan`] uses a
2//! cached PRISM index; [`VectorTopKPlan`] streams a bounded-heap top-k when no
3//! index applies or inside a write txn (uncommitted rows).
4
5use std::any::Any;
6use std::cmp::Ordering;
7use std::collections::BinaryHeap;
8use std::sync::Arc;
9
10use citadel_txn::read_txn::ReadTxn;
11use citadel_txn::write_txn::WriteTxn;
12use citadel_vector::{AnnIndex, Filter, Metric};
13use rustc_hash::FxHashMap;
14
15use crate::encoding::{
16    decode_column_raw, decode_pk_integer, encode_int_key_into, encode_key_value,
17};
18use crate::error::{Result, SqlError};
19use crate::eval::{eval_expr, is_truthy, ColumnMap, EvalCtx};
20use crate::parser::*;
21use crate::schema::SchemaManager;
22use crate::types::*;
23
24use super::aggregate::is_aggregate_expr;
25use super::ann_persist;
26use super::helpers::{decode_full_row, eval_const_expr, eval_const_int, project_rows};
27use super::window::has_any_window_function;
28
29type StorageResult<T> = std::result::Result<T, citadel_core::Error>;
30type ScanRow<'a> = dyn FnMut(&[u8], &[u8]) -> Result<bool> + 'a;
31type RawScanRow<'a> = dyn FnMut(&[u8], &[u8]) -> StorageResult<bool> + 'a;
32
33/// Scan + point-get over a read or write txn, materializing overflow values.
34pub(super) trait AnnScan {
35    fn ann_scan(&mut self, table: &[u8], f: &mut ScanRow<'_>) -> Result<()>;
36    fn ann_get(&mut self, table: &[u8], key: &[u8]) -> Result<Option<Vec<u8>>>;
37    /// The commit generation this txn's snapshot reflects, or `None` when the
38    /// view includes uncommitted writes - an index over such a view must NEVER
39    /// enter the shared cache.
40    fn cache_generation(&self) -> Option<u64>;
41}
42
43/// Adapt a storage-level scan to report `SqlError`, surfacing the first callback error.
44fn bridge_scan(
45    scan: impl FnOnce(&mut RawScanRow<'_>) -> StorageResult<()>,
46    f: &mut ScanRow<'_>,
47) -> Result<()> {
48    let mut cb_err: Option<SqlError> = None;
49    scan(&mut |key, value| match f(key, value) {
50        Ok(go) => Ok(go),
51        Err(e) => {
52            cb_err = Some(e);
53            Ok(false)
54        }
55    })
56    .map_err(SqlError::Storage)?;
57    match cb_err {
58        Some(e) => Err(e),
59        None => Ok(()),
60    }
61}
62
63impl AnnScan for ReadTxn<'_> {
64    fn ann_scan(&mut self, table: &[u8], f: &mut ScanRow<'_>) -> Result<()> {
65        bridge_scan(|cb| self.table_scan_from(table, b"", cb), f)
66    }
67
68    fn ann_get(&mut self, table: &[u8], key: &[u8]) -> Result<Option<Vec<u8>>> {
69        self.table_get(table, key).map_err(SqlError::Storage)
70    }
71
72    fn cache_generation(&self) -> Option<u64> {
73        Some(self.commit_generation())
74    }
75}
76
77impl AnnScan for WriteTxn<'_> {
78    fn ann_scan(&mut self, table: &[u8], f: &mut ScanRow<'_>) -> Result<()> {
79        bridge_scan(|cb| self.table_scan_from(table, b"", cb), f)
80    }
81
82    fn ann_get(&mut self, table: &[u8], key: &[u8]) -> Result<Option<Vec<u8>>> {
83        self.table_get(table, key).map_err(SqlError::Storage)
84    }
85
86    fn cache_generation(&self) -> Option<u64> {
87        None
88    }
89}
90
91/// Where a cached index came from - queryable via `ann_cache_status`, and the
92/// carrier for load-refusal diagnostics (a refused segment degrades to the
93/// slow rebuild, but the refusal reason must stay visible, never a log-only).
94#[derive(Debug, Clone, PartialEq, Eq)]
95pub enum AnnIndexSource {
96    /// Built from a table scan this process. `refusal` records why a persisted
97    /// segment was NOT used, if one existed and was rejected.
98    Built { refusal: Option<String> },
99    /// Loaded from a persisted segment whose body BLAKE3 is `segment_b3` and
100    /// whose content fingerprint matched the rehydration scan.
101    Loaded { segment_b3: [u8; 32] },
102}
103
104/// A cached ANN index plus the metadata needed to push SQL filters into it.
105struct CachedAnnIndex {
106    index: AnnIndex,
107    /// Per attribute dim: maps an encoded filter-column value to its PRISM code.
108    dicts: Vec<FxHashMap<Vec<u8>, u32>>,
109    source: AnnIndexSource,
110    /// The commit generation the index reflects: inserts into the shared cache
111    /// are declined if the database moved past it (the build/load raced a
112    /// commit), so a cached index can never describe a superseded snapshot.
113    cached_gen: u64,
114}
115
116pub(super) struct AnnTopKPlan {
117    col_idx: usize,
118    dim: u16,
119    metric: AnnMetric,
120    query_vec: Vec<f32>,
121    k: usize,
122    offset: usize,
123    /// Schema column indices declared filterable on the index, in attr-dim order.
124    filter_cols: Vec<u16>,
125    /// Pushable conjuncts: `(attr_dim, allowed_values)` from `col = v` / `col IN (...)`.
126    pushable: Vec<(usize, Vec<Value>)>,
127    /// Remaining WHERE predicate evaluated as a recheck on decoded candidates.
128    residual: Option<Expr>,
129}
130
131/// Gate for single-key ascending ORDER BY ... LIMIT k (no group/having/join/distinct/window/agg).
132fn topk_shape_ok(stmt: &SelectStmt) -> bool {
133    stmt.order_by.len() == 1
134        && !stmt.order_by[0].descending
135        && stmt.limit.is_some()
136        && stmt.group_by.is_empty()
137        && stmt.having.is_none()
138        && stmt.joins.is_empty()
139        && !stmt.distinct
140        && !has_any_window_function(stmt)
141        && !stmt
142            .columns
143            .iter()
144            .any(|c| matches!(c, SelectColumn::Expr { expr, .. } if is_aggregate_expr(expr)))
145}
146
147impl AnnTopKPlan {
148    pub(super) fn try_new(stmt: &SelectStmt, table_schema: &TableSchema) -> Result<Option<Self>> {
149        if !topk_shape_ok(stmt) {
150            return Ok(None);
151        }
152        let ob = &stmt.order_by[0];
153
154        let (col_idx, dim, op_metric, query_vec) = match &ob.expr {
155            Expr::BinaryOp { left, op, right } => {
156                let op_metric = match op {
157                    BinOp::VectorL2 => AnnMetric::L2,
158                    BinOp::VectorInner => AnnMetric::Inner,
159                    BinOp::VectorCosine => AnnMetric::Cosine,
160                    _ => return Ok(None),
161                };
162                let col_name = match left.as_ref() {
163                    Expr::Column(name) => name.to_ascii_lowercase(),
164                    _ => return Ok(None),
165                };
166                let (col_idx, dim) = match table_schema
167                    .columns
168                    .iter()
169                    .enumerate()
170                    .find(|(_, c)| c.name.to_ascii_lowercase() == col_name)
171                {
172                    Some((i, c)) => match c.data_type {
173                        DataType::Vector { dim } => (i, dim),
174                        _ => return Ok(None),
175                    },
176                    None => return Ok(None),
177                };
178                let col_map = ColumnMap::new(&table_schema.columns);
179                let ctx = EvalCtx::new(&col_map, &[]);
180                let v = match eval_expr(right, &ctx) {
181                    Ok(Value::Vector(v)) => v,
182                    _ => return Ok(None),
183                };
184                if v.len() != dim as usize {
185                    return Err(SqlError::InvalidValue(format!(
186                        "ANN query vector dim {} does not match column dim {}",
187                        v.len(),
188                        dim
189                    )));
190                }
191                (col_idx, dim, op_metric, v.to_vec())
192            }
193            _ => return Ok(None),
194        };
195
196        let ann_index = table_schema.indices.iter().find(|ix| {
197            matches!(ix.kind,
198                IndexKind::Inverted(InvertedKind::Ann { metric }) if metric == op_metric
199            ) && ix.keys.len() == 1
200                && matches!(ix.keys[0],
201                    IndexKey::Column { idx, .. } if idx as usize == col_idx
202                )
203        });
204        let Some(ann_index) = ann_index else {
205            return Ok(None);
206        };
207        let filter_cols = ann_index.ann_filter_cols.clone();
208
209        if table_schema.primary_key_columns.len() != 1 {
210            return Ok(None);
211        }
212        let pk_col = &table_schema.columns[table_schema.primary_key_columns[0] as usize];
213        if !matches!(pk_col.data_type, DataType::Integer) {
214            return Ok(None);
215        }
216
217        // No pushable predicate means the index gives no leverage; decline so
218        // the exact filtered scan runs instead.
219        let mut pushable: Vec<(usize, Vec<Value>)> = Vec::new();
220        let mut residual_leaves: Vec<Expr> = Vec::new();
221        if let Some(w) = &stmt.where_clause {
222            split_where(
223                w,
224                &filter_cols,
225                table_schema,
226                &mut pushable,
227                &mut residual_leaves,
228            );
229            if pushable.is_empty() {
230                return Ok(None);
231            }
232        }
233        let residual = fold_and(residual_leaves);
234
235        let k_limit = eval_const_int(stmt.limit.as_ref().unwrap())?.max(0) as usize;
236        let offset = stmt
237            .offset
238            .as_ref()
239            .map(eval_const_int)
240            .transpose()?
241            .unwrap_or(0)
242            .max(0) as usize;
243        if k_limit == 0 {
244            return Ok(None);
245        }
246
247        Ok(Some(Self {
248            col_idx,
249            dim,
250            metric: op_metric,
251            query_vec,
252            k: k_limit,
253            offset,
254            filter_cols,
255            pushable,
256            residual,
257        }))
258    }
259
260    pub(super) fn execute_with_read(
261        &self,
262        rtx: &mut ReadTxn<'_>,
263        schema: &SchemaManager,
264        stmt: &SelectStmt,
265        table_schema: &TableSchema,
266    ) -> Result<ExecutionResult> {
267        let cache_key = cache_key(&table_schema.name, self.col_idx, self.metric);
268        // Empty table: nothing to build, ORDER BY ... LIMIT yields no rows.
269        let Some(cached) = self.load_or_build_index(rtx, schema, &cache_key, table_schema)? else {
270            return empty_result(table_schema, stmt);
271        };
272        self.run_query(rtx, &cached, stmt, table_schema)
273    }
274
275    /// Search the index, apply filters and the residual recheck, then page and project.
276    fn run_query(
277        &self,
278        txn: &mut dyn AnnScan,
279        cached: &CachedAnnIndex,
280        stmt: &SelectStmt,
281        table_schema: &TableSchema,
282    ) -> Result<ExecutionResult> {
283        // Map values to codes; a value absent from the dictionary matches no row.
284        let mut constraints: Vec<(usize, Vec<u32>)> = Vec::with_capacity(self.pushable.len());
285        for (dim, values) in &self.pushable {
286            let dict = &cached.dicts[*dim];
287            let mut codes = Vec::with_capacity(values.len());
288            for v in values {
289                if let Some(&code) = dict.get(encode_key_value(v).as_slice()) {
290                    codes.push(code);
291                }
292            }
293            if codes.is_empty() {
294                return empty_result(table_schema, stmt);
295            }
296            constraints.push((*dim, codes));
297        }
298        let filter = if constraints.is_empty() {
299            Filter::none()
300        } else {
301            Filter::new(constraints)
302        };
303
304        let want = self.k.saturating_add(self.offset).max(1);
305        let mut rows = self.collect_survivors(txn, &cached.index, &filter, table_schema, want)?;
306
307        if self.offset >= rows.len() {
308            rows.clear();
309        } else if self.offset > 0 {
310            rows = rows.split_off(self.offset);
311        }
312        rows.truncate(self.k);
313
314        let (col_names, projected) = project_rows(&table_schema.columns, &stmt.columns, rows)?;
315        Ok(ExecutionResult::Query(QueryResult {
316            columns: col_names,
317            rows: projected,
318        }))
319    }
320
321    /// Search the index (with `filter` pushed in) and recheck the residual
322    /// predicate on decoded rows, over-fetching until `want` rows survive or the
323    /// index is exhausted. Distance order is preserved.
324    fn collect_survivors(
325        &self,
326        txn: &mut dyn AnnScan,
327        index: &AnnIndex,
328        filter: &Filter,
329        table_schema: &TableSchema,
330        want: usize,
331    ) -> Result<Vec<Vec<Value>>> {
332        let col_map = ColumnMap::new(&table_schema.columns);
333        let max_target = index.indexed_len().max(1);
334        let mut key_buf: Vec<u8> = Vec::with_capacity(10);
335        let mut target = want;
336        loop {
337            target = target.min(max_target);
338            let hits = index.search_filtered_default_ef(&self.query_vec, target, filter);
339            let mut survivors: Vec<Vec<Value>> = Vec::with_capacity(want);
340            for (id, _dist) in &hits {
341                encode_int_key_into(*id as i64, &mut key_buf);
342                let Some(row_bytes) = txn.ann_get(table_schema.name.as_bytes(), &key_buf)? else {
343                    continue;
344                };
345                let row = decode_full_row(table_schema, &key_buf, &row_bytes)?;
346                let keep = match &self.residual {
347                    None => true,
348                    Some(expr) => {
349                        let ctx = EvalCtx::new(&col_map, &row);
350                        is_truthy(&eval_expr(expr, &ctx)?)
351                    }
352                };
353                if keep {
354                    survivors.push(row);
355                    if survivors.len() >= want {
356                        break;
357                    }
358                }
359            }
360            // Stop when satisfied, when the index is exhausted, or when PRISM
361            // returns fewer candidates than asked (no more to find).
362            if survivors.len() >= want || target >= max_target || hits.len() < target {
363                return Ok(survivors);
364            }
365            target = target.saturating_mul(2);
366        }
367    }
368
369    fn load_or_build_index(
370        &self,
371        txn: &mut dyn AnnScan,
372        schema: &SchemaManager,
373        cache_key: &str,
374        table_schema: &TableSchema,
375    ) -> Result<Option<Arc<CachedAnnIndex>>> {
376        if let Some(existing) = lookup_cached(schema, cache_key, &table_schema.name)? {
377            return Ok(Some(existing));
378        }
379        let spec = AnnSpec {
380            col_idx: self.col_idx,
381            dim: self.dim,
382            metric: self.metric,
383            filter_cols: self.filter_cols.clone(),
384        };
385        load_or_build(txn, schema, cache_key, table_schema, &spec)
386    }
387}
388
389/// The index identity a build/load/persist operates on - what `AnnTopKPlan`
390/// resolves from the statement, and what `persist_ann_index` resolves from the
391/// declared index.
392pub(super) struct AnnSpec {
393    pub col_idx: usize,
394    pub dim: u16,
395    pub metric: AnnMetric,
396    pub filter_cols: Vec<u16>,
397}
398
399impl AnnSpec {
400    fn metric_tag(&self) -> u8 {
401        citadel_vector::segment::metric_tag(ann_metric_to_prism(self.metric))
402    }
403}
404
405/// One scan pass: the build rows, the filter dictionaries (codes in first-seen
406/// order), and the injective content fingerprint. Build, persist, and load all
407/// decode rows through HERE - one decode path, one fingerprint definition.
408struct ScanOutcome {
409    rows: Vec<(u64, Vec<f32>, Vec<u32>)>,
410    dicts: Vec<FxHashMap<Vec<u8>, u32>>,
411    fingerprint: [u8; 32],
412}
413
414fn scan_rows(
415    txn: &mut dyn AnnScan,
416    table_schema: &TableSchema,
417    spec: &AnnSpec,
418) -> Result<ScanOutcome> {
419    let non_pk = table_schema.non_pk_indices();
420    let enc_pos = table_schema.encoding_positions();
421    let nonpk_order = non_pk
422        .iter()
423        .position(|&i| i == spec.col_idx)
424        .ok_or_else(|| {
425            SqlError::InvalidValue("vector column must be non-PK for ANN build".into())
426        })?;
427    let enc_idx = enc_pos[nonpk_order] as usize;
428
429    let num_attrs = spec.filter_cols.len();
430    let extracts: Vec<Extract> = spec
431        .filter_cols
432        .iter()
433        .map(|&c| extract_plan(c, table_schema, non_pk, enc_pos))
434        .collect::<Result<_>>()?;
435    let mut dicts: Vec<FxHashMap<Vec<u8>, u32>> = vec![FxHashMap::default(); num_attrs];
436    let mut fp = ann_persist::FingerprintHasher::new(
437        &table_schema.name,
438        spec.col_idx as u32,
439        &spec
440            .filter_cols
441            .iter()
442            .map(|&c| c as u32)
443            .collect::<Vec<_>>(),
444        spec.dim,
445        spec.metric_tag(),
446    );
447    let mut rows: Vec<(u64, Vec<f32>, Vec<u32>)> = Vec::new();
448
449    txn.ann_scan(table_schema.name.as_bytes(), &mut |key, value| {
450        let vector = match decode_column_raw(value, enc_idx)?.to_value() {
451            Value::Vector(arr) => Some(arr.to_vec()),
452            Value::Null => None, // null vectors are content, but not indexed
453            _ => {
454                return Err(SqlError::InvalidValue(
455                    "ANN column produced non-vector value".into(),
456                ))
457            }
458        };
459        let mut encoded_filters: Vec<Vec<u8>> = Vec::with_capacity(num_attrs);
460        for ex in &extracts {
461            encoded_filters.push(encode_key_value(&ex.extract(key, value)?));
462        }
463        let vec_bytes: Vec<u8> = vector
464            .as_deref()
465            .unwrap_or(&[])
466            .iter()
467            .flat_map(|f| f.to_le_bytes())
468            .collect();
469        fp.row(
470            key,
471            &vec_bytes,
472            &encoded_filters
473                .iter()
474                .map(Vec::as_slice)
475                .collect::<Vec<_>>(),
476        );
477        let Some(vector) = vector else {
478            return Ok(true);
479        };
480        let id = decode_pk_integer(key)? as u64;
481        let mut codes: Vec<u32> = Vec::with_capacity(num_attrs);
482        for (j, encoded) in encoded_filters.into_iter().enumerate() {
483            let next = dicts[j].len() as u32;
484            codes.push(*dicts[j].entry(encoded).or_insert(next));
485        }
486        rows.push((id, vector, codes));
487        Ok(true)
488    })?;
489
490    Ok(ScanOutcome {
491        rows,
492        dicts,
493        fingerprint: fp.finish(),
494    })
495}
496
497/// Build the index from a scan; `None` if there are no indexable rows.
498fn build_index(
499    txn: &mut dyn AnnScan,
500    table_schema: &TableSchema,
501    spec: &AnnSpec,
502    refusal: Option<String>,
503    cached_gen: u64,
504) -> Result<Option<CachedAnnIndex>> {
505    let outcome = scan_rows(txn, table_schema, spec)?;
506    if outcome.rows.is_empty() {
507        return Ok(None);
508    }
509    let index = AnnIndex::build_with_attrs(
510        outcome.rows,
511        spec.filter_cols.len(),
512        ann_metric_to_prism(spec.metric),
513        spec.dim,
514    )
515    .map_err(|e| SqlError::InvalidValue(format!("ANN build failed: {e}")))?;
516    Ok(Some(CachedAnnIndex {
517        index,
518        dicts: outcome.dicts,
519        source: AnnIndexSource::Built { refusal },
520        cached_gen,
521    }))
522}
523
524/// How a persisted segment answered the load attempt. `Refused` always falls
525/// back to a rebuild - corrupt segments additionally report loudly (an
526/// HMAC-authenticated page with a failing BLAKE3 means a writer bug).
527enum LoadOutcome {
528    Loaded(Box<CachedAnnIndex>),
529    NoSegment,
530    Refused { reason: String, corrupt: bool },
531}
532
533/// Try to serve the table's persisted segment: header pins, body decode, and
534/// the rehydration scan whose fingerprint PROVES the index describes exactly
535/// the rows this snapshot holds.
536fn try_load_segment(
537    txn: &mut dyn AnnScan,
538    table_schema: &TableSchema,
539    spec: &AnnSpec,
540    cached_gen: u64,
541) -> Result<LoadOutcome> {
542    let seg_table = ann_persist::segment_table_name(&table_schema.name);
543    let header_bytes = match txn.ann_get(&seg_table, &ann_persist::segment_key(0)) {
544        Ok(Some(b)) => b,
545        // Missing tree and missing header are both "never persisted".
546        Ok(None) | Err(_) => return Ok(LoadOutcome::NoSegment),
547    };
548    let refuse = |reason: String, corrupt: bool| Ok(LoadOutcome::Refused { reason, corrupt });
549    let header = match ann_persist::SegmentHeader::decode(&header_bytes) {
550        Ok(h) => h,
551        Err(e) => return refuse(format!("header: {e}"), true),
552    };
553    if header.format_version != ann_persist::ANNSEG_FORMAT_VERSION {
554        return refuse(
555            format!("format v{} (this binary reads v1)", header.format_version),
556            false,
557        );
558    }
559    let active_cfg = citadel_vector::segment::prism_config_hash(&AnnIndex::active_config(
560        ann_metric_to_prism(spec.metric),
561    ));
562    if header.prism_config_hash != active_cfg {
563        return refuse(
564            "PRISM config drift (segment built by another geometry)".into(),
565            false,
566        );
567    }
568    if header.dim != spec.dim
569        || header.metric_tag != spec.metric_tag()
570        || header.col_idx != spec.col_idx as u32
571        || header.filter_cols
572            != spec
573                .filter_cols
574                .iter()
575                .map(|&c| c as u32)
576                .collect::<Vec<_>>()
577    {
578        return refuse(
579            "index identity mismatch (column/metric/filter set)".into(),
580            false,
581        );
582    }
583
584    let mut body = Vec::new();
585    for chunk_no in 1..=header.chunk_count {
586        match txn.ann_get(&seg_table, &ann_persist::segment_key(chunk_no)) {
587            Ok(Some(c)) => body.extend_from_slice(&c),
588            _ => return refuse(format!("missing chunk {chunk_no}"), true),
589        }
590    }
591    if *blake3::hash(&body).as_bytes() != header.segment_b3 {
592        return refuse("segment body BLAKE3 mismatch (corrupt)".into(), true);
593    }
594    let parts = match citadel_vector::segment::decode(&body) {
595        Ok(p) => p,
596        Err(e) => return refuse(format!("segment decode: {e}"), true),
597    };
598    if parts.n() as u64 != header.n || parts.dim() != header.dim {
599        return refuse("segment body disagrees with header counts".into(), true);
600    }
601
602    // The rehydration scan: vectors placed by the id_map PERMUTATION (scan
603    // order is NOT internal order), fingerprint computed over ALL rows.
604    let slot_of = parts.internal_of_row();
605    let dim = spec.dim as usize;
606    let mut vectors = vec![0.0f32; parts.n() * dim];
607    let mut filled = 0usize;
608    let outcome = scan_rows_rehydrate(txn, table_schema, spec, &mut |row_id, vector| {
609        let Some(&slot) = slot_of.get(&row_id) else {
610            return false; // a row the segment does not know: stale
611        };
612        vectors[slot as usize * dim..(slot as usize + 1) * dim].copy_from_slice(vector);
613        filled += 1;
614        true
615    })?;
616    let Some(fingerprint) = outcome else {
617        return refuse(
618            "a scanned row is unknown to the segment (stale)".into(),
619            false,
620        );
621    };
622    if fingerprint != header.content_fingerprint {
623        return refuse("content fingerprint mismatch (stale)".into(), false);
624    }
625    let index = match parts.into_index(vectors, filled) {
626        Ok(i) => i,
627        Err(e) => return refuse(format!("rehydration: {e}"), true),
628    };
629    Ok(LoadOutcome::Loaded(Box::new(CachedAnnIndex {
630        index,
631        dicts: header.dict_maps(),
632        source: AnnIndexSource::Loaded {
633            segment_b3: header.segment_b3,
634        },
635        cached_gen,
636    })))
637}
638
639/// The rehydration variant of the scan: same decode + fingerprint as
640/// [`scan_rows`], but vectors stream to the placer instead of accumulating.
641/// Returns `None` if the placer rejects a row (unknown to the segment).
642fn scan_rows_rehydrate(
643    txn: &mut dyn AnnScan,
644    table_schema: &TableSchema,
645    spec: &AnnSpec,
646    place: &mut dyn FnMut(u64, &[f32]) -> bool,
647) -> Result<Option<[u8; 32]>> {
648    let non_pk = table_schema.non_pk_indices();
649    let enc_pos = table_schema.encoding_positions();
650    let nonpk_order = non_pk
651        .iter()
652        .position(|&i| i == spec.col_idx)
653        .ok_or_else(|| {
654            SqlError::InvalidValue("vector column must be non-PK for ANN build".into())
655        })?;
656    let enc_idx = enc_pos[nonpk_order] as usize;
657    let extracts: Vec<Extract> = spec
658        .filter_cols
659        .iter()
660        .map(|&c| extract_plan(c, table_schema, non_pk, enc_pos))
661        .collect::<Result<_>>()?;
662    let mut fp = ann_persist::FingerprintHasher::new(
663        &table_schema.name,
664        spec.col_idx as u32,
665        &spec
666            .filter_cols
667            .iter()
668            .map(|&c| c as u32)
669            .collect::<Vec<_>>(),
670        spec.dim,
671        spec.metric_tag(),
672    );
673    let mut unknown_row = false;
674
675    txn.ann_scan(table_schema.name.as_bytes(), &mut |key, value| {
676        let vector = match decode_column_raw(value, enc_idx)?.to_value() {
677            Value::Vector(arr) => Some(arr.to_vec()),
678            Value::Null => None,
679            _ => {
680                return Err(SqlError::InvalidValue(
681                    "ANN column produced non-vector value".into(),
682                ))
683            }
684        };
685        let mut encoded_filters: Vec<Vec<u8>> = Vec::with_capacity(extracts.len());
686        for ex in &extracts {
687            encoded_filters.push(encode_key_value(&ex.extract(key, value)?));
688        }
689        let vec_bytes: Vec<u8> = vector
690            .as_deref()
691            .unwrap_or(&[])
692            .iter()
693            .flat_map(|f| f.to_le_bytes())
694            .collect();
695        fp.row(
696            key,
697            &vec_bytes,
698            &encoded_filters
699                .iter()
700                .map(Vec::as_slice)
701                .collect::<Vec<_>>(),
702        );
703        if let Some(vector) = vector {
704            let id = decode_pk_integer(key)? as u64;
705            if !place(id, &vector) {
706                unknown_row = true;
707                return Ok(false);
708            }
709        }
710        Ok(true)
711    })?;
712
713    Ok(if unknown_row { None } else { Some(fp.finish()) })
714}
715
716/// The shared load-then-build flow: try the persisted segment, fall back to a
717/// scan build carrying the refusal as a diagnostic, and insert into the shared
718/// cache ONLY if no DML on this table committed past the snapshot the index
719/// reflects (and never from a write-txn view).
720fn load_or_build(
721    txn: &mut dyn AnnScan,
722    schema: &SchemaManager,
723    cache_key: &str,
724    table_schema: &TableSchema,
725    spec: &AnnSpec,
726) -> Result<Option<Arc<CachedAnnIndex>>> {
727    let gen = txn.cache_generation();
728    let cached_gen = gen.unwrap_or(u64::MAX);
729    let loaded = match try_load_segment(txn, table_schema, spec, cached_gen)? {
730        LoadOutcome::Loaded(c) => Some(*c),
731        LoadOutcome::NoSegment => None,
732        LoadOutcome::Refused { reason, corrupt } => {
733            if corrupt {
734                eprintln!(
735                    "citadel-sql: ANN segment for `{}` REFUSED as corrupt ({reason}); \
736                     rebuilding from scan - investigate before re-persisting",
737                    table_schema.name
738                );
739            }
740            // Stale/drift refusals are the expected degradation path; the
741            // reason stays queryable on the rebuilt entry either way.
742            match build_index(txn, table_schema, spec, Some(reason), cached_gen)? {
743                Some(c) => Some(c),
744                None => return Ok(None),
745            }
746        }
747    };
748    let built = match loaded {
749        Some(c) => c,
750        None => match build_index(txn, table_schema, spec, None, cached_gen)? {
751            Some(c) => c,
752            None => return Ok(None),
753        },
754    };
755    let arc: Arc<CachedAnnIndex> = Arc::new(built);
756    if gen.is_none() {
757        // A write-txn view may include uncommitted rows: serve, never cache.
758        return Ok(Some(arc));
759    }
760    let mut guard = schema.sql_caches.lock();
761    if let Some(existing) = guard.get(cache_key) {
762        // Another thread won the race; prefer that one and drop ours.
763        return Arc::clone(existing)
764            .downcast::<CachedAnnIndex>()
765            .map(Some)
766            .map_err(|_| {
767                SqlError::InvalidValue(format!("ANN cache type mismatch for {cache_key}"))
768            });
769    }
770    let marker = marker_gen_locked(&guard, &table_schema.name);
771    if marker.is_some_and(|g| arc.cached_gen < g) {
772        // DML committed while the scan/build ran: the index is a superseded
773        // snapshot. Serve it to THIS query, decline the cache.
774        return Ok(Some(arc));
775    }
776    let as_any: Arc<dyn Any + Send + Sync> = arc.clone();
777    guard.insert(cache_key.to_string(), as_any);
778    Ok(Some(arc))
779}
780
781/// Streaming brute-force top-k for `ORDER BY <distance> LIMIT k` when no ANN
782/// index applies (or inside a write txn); bounded heap, O(k) memory.
783pub(super) struct VectorTopKPlan {
784    order_expr: Expr,
785    where_clause: Option<Expr>,
786    k: usize,
787    offset: usize,
788    nulls_first: bool,
789}
790
791/// A candidate keyed by (distance, scan position); `seq` breaks ties by scan
792/// order so the bounded heap matches the stable sort.
793struct Ranked {
794    dist: f64,
795    seq: u64,
796    row: Vec<Value>,
797}
798
799impl PartialEq for Ranked {
800    fn eq(&self, other: &Self) -> bool {
801        self.cmp(other) == Ordering::Equal
802    }
803}
804impl Eq for Ranked {}
805impl PartialOrd for Ranked {
806    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
807        Some(self.cmp(other))
808    }
809}
810impl Ord for Ranked {
811    fn cmp(&self, other: &Self) -> Ordering {
812        self.dist
813            .total_cmp(&other.dist)
814            .then_with(|| self.seq.cmp(&other.seq))
815    }
816}
817
818impl VectorTopKPlan {
819    pub(super) fn try_new(stmt: &SelectStmt, table_schema: &TableSchema) -> Result<Option<Self>> {
820        if !topk_shape_ok(stmt) {
821            return Ok(None);
822        }
823        let ob = &stmt.order_by[0];
824        let Expr::BinaryOp { left, op, .. } = &ob.expr else {
825            return Ok(None);
826        };
827        if !matches!(
828            op,
829            BinOp::VectorL2 | BinOp::VectorInner | BinOp::VectorCosine
830        ) {
831            return Ok(None);
832        }
833        // Only claim a vector-distance sort key; anything else uses the general path.
834        let Expr::Column(name) = left.as_ref() else {
835            return Ok(None);
836        };
837        let name = name.to_ascii_lowercase();
838        let is_vector_col = table_schema.columns.iter().any(|c| {
839            c.name.to_ascii_lowercase() == name && matches!(c.data_type, DataType::Vector { .. })
840        });
841        if !is_vector_col {
842            return Ok(None);
843        }
844
845        let k = eval_const_int(stmt.limit.as_ref().unwrap())?.max(0) as usize;
846        if k == 0 {
847            return Ok(None);
848        }
849        let offset = stmt
850            .offset
851            .as_ref()
852            .map(eval_const_int)
853            .transpose()?
854            .unwrap_or(0)
855            .max(0) as usize;
856
857        Ok(Some(Self {
858            order_expr: ob.expr.clone(),
859            where_clause: stmt.where_clause.clone(),
860            k,
861            offset,
862            // citadel defaults to NULLS FIRST for ascending order.
863            nulls_first: ob.nulls_first.unwrap_or(true),
864        }))
865    }
866
867    pub(super) fn execute(
868        &self,
869        txn: &mut dyn AnnScan,
870        table_schema: &TableSchema,
871        stmt: &SelectStmt,
872    ) -> Result<ExecutionResult> {
873        let want = self.k.saturating_add(self.offset);
874        let col_map = ColumnMap::new(&table_schema.columns);
875        // NULL distances sort like NULLs under the requested ordering.
876        let null_dist = if self.nulls_first {
877            f64::NEG_INFINITY
878        } else {
879            f64::INFINITY
880        };
881        let mut heap: BinaryHeap<Ranked> = BinaryHeap::new();
882        let mut seq: u64 = 0;
883
884        txn.ann_scan(table_schema.name.as_bytes(), &mut |key, value| {
885            let row = decode_full_row(table_schema, key, value)?;
886            let ctx = EvalCtx::new(&col_map, &row);
887            if let Some(w) = &self.where_clause {
888                if !is_truthy(&eval_expr(w, &ctx)?) {
889                    return Ok(true);
890                }
891            }
892            let dist = match eval_expr(&self.order_expr, &ctx)? {
893                Value::Real(d) => d,
894                Value::Integer(i) => i as f64,
895                Value::Null => null_dist,
896                other => {
897                    return Err(SqlError::InvalidValue(format!(
898                        "ORDER BY vector distance produced a non-numeric {}",
899                        other.data_type()
900                    )))
901                }
902            };
903            let cand = Ranked { dist, seq, row };
904            seq += 1;
905            // `seq` only grows, so ties never evict an earlier row (stable-sort order).
906            if heap.len() < want {
907                heap.push(cand);
908            } else if heap.peek().is_some_and(|top| cand < *top) {
909                heap.pop();
910                heap.push(cand);
911            }
912            Ok(true)
913        })?;
914
915        let mut rows: Vec<Vec<Value>> = heap.into_sorted_vec().into_iter().map(|r| r.row).collect();
916        if self.offset >= rows.len() {
917            rows.clear();
918        } else if self.offset > 0 {
919            rows = rows.split_off(self.offset);
920        }
921        rows.truncate(self.k);
922
923        let (col_names, projected) = project_rows(&table_schema.columns, &stmt.columns, rows)?;
924        Ok(ExecutionResult::Query(QueryResult {
925            columns: col_names,
926            rows: projected,
927        }))
928    }
929}
930
931/// How to read a filter column's value out of a raw row during the build scan.
932enum Extract {
933    /// The single integer primary key, read from the row key.
934    Pk,
935    /// A non-PK column at the given encoding position in the row value.
936    NonPk(usize),
937}
938
939impl Extract {
940    fn extract(&self, key: &[u8], value: &[u8]) -> Result<Value> {
941        match self {
942            Extract::Pk => Ok(Value::Integer(decode_pk_integer(key)?)),
943            Extract::NonPk(ei) => Ok(decode_column_raw(value, *ei)?.to_value()),
944        }
945    }
946}
947
948fn extract_plan(
949    col: u16,
950    table_schema: &TableSchema,
951    non_pk: &[usize],
952    enc_pos: &[u16],
953) -> Result<Extract> {
954    if table_schema.primary_key_columns.contains(&col) {
955        return Ok(Extract::Pk);
956    }
957    let order = non_pk
958        .iter()
959        .position(|&i| i == col as usize)
960        .ok_or_else(|| SqlError::InvalidValue("ANN filter column not found in row".into()))?;
961    Ok(Extract::NonPk(enc_pos[order] as usize))
962}
963
964/// Walk the AND-tree, sorting each leaf into a pushable attribute predicate or
965/// the recheck residual.
966fn split_where(
967    expr: &Expr,
968    filter_cols: &[u16],
969    table_schema: &TableSchema,
970    pushable: &mut Vec<(usize, Vec<Value>)>,
971    residual: &mut Vec<Expr>,
972) {
973    if let Expr::BinaryOp {
974        left,
975        op: BinOp::And,
976        right,
977    } = expr
978    {
979        split_where(left, filter_cols, table_schema, pushable, residual);
980        split_where(right, filter_cols, table_schema, pushable, residual);
981        return;
982    }
983    match classify_leaf(expr, filter_cols, table_schema) {
984        Some(constraint) => pushable.push(constraint),
985        None => residual.push(expr.clone()),
986    }
987}
988
989/// A leaf is pushable if it is `col = literal` or `col IN (literal, ...)` on a
990/// declared filter column with all-constant right-hand side.
991fn classify_leaf(
992    leaf: &Expr,
993    filter_cols: &[u16],
994    table_schema: &TableSchema,
995) -> Option<(usize, Vec<Value>)> {
996    match leaf {
997        Expr::BinaryOp {
998            left,
999            op: BinOp::Eq,
1000            right,
1001        } => {
1002            let dim = filter_dim(left, filter_cols, table_schema)?;
1003            let val = eval_const_expr(right).ok()?;
1004            Some((dim, vec![val]))
1005        }
1006        Expr::InList {
1007            expr,
1008            list,
1009            negated: false,
1010        } => {
1011            let dim = filter_dim(expr, filter_cols, table_schema)?;
1012            let mut vals = Vec::with_capacity(list.len());
1013            for e in list {
1014                vals.push(eval_const_expr(e).ok()?);
1015            }
1016            Some((dim, vals))
1017        }
1018        _ => None,
1019    }
1020}
1021
1022/// Resolve a column expression to its attribute-dim index (position in
1023/// `filter_cols`), or `None` if it is not a declared filter column.
1024fn filter_dim(expr: &Expr, filter_cols: &[u16], table_schema: &TableSchema) -> Option<usize> {
1025    let name = match expr {
1026        Expr::Column(c) => c.to_ascii_lowercase(),
1027        Expr::QualifiedColumn { column, .. } => column.to_ascii_lowercase(),
1028        _ => return None,
1029    };
1030    let col_idx = table_schema
1031        .columns
1032        .iter()
1033        .position(|c| c.name.to_ascii_lowercase() == name)? as u16;
1034    filter_cols.iter().position(|&c| c == col_idx)
1035}
1036
1037fn fold_and(mut leaves: Vec<Expr>) -> Option<Expr> {
1038    if leaves.is_empty() {
1039        return None;
1040    }
1041    let first = leaves.remove(0);
1042    Some(leaves.into_iter().fold(first, |acc, e| Expr::BinaryOp {
1043        left: Box::new(acc),
1044        op: BinOp::And,
1045        right: Box::new(e),
1046    }))
1047}
1048
1049fn empty_result(table_schema: &TableSchema, stmt: &SelectStmt) -> Result<ExecutionResult> {
1050    let (col_names, projected) = project_rows(&table_schema.columns, &stmt.columns, Vec::new())?;
1051    Ok(ExecutionResult::Query(QueryResult {
1052        columns: col_names,
1053        rows: projected,
1054    }))
1055}
1056
1057/// The explicit freeze operation behind `Connection::persist_ann_index`: ONE
1058/// write txn scans the table (computing the content fingerprint), pays the
1059/// PRISM build, serializes the segment, replaces any prior one, and commits -
1060/// atomic by shadow paging. The single writer lock is held for the full build
1061/// (minutes on large tables): an offline/builder operation by design. The
1062/// fresh index also warms the shared RAM cache, so the NEXT attach is the
1063/// fast loaded one and THIS process serves queries immediately.
1064pub(crate) fn persist_ann_index(
1065    db: &citadel::Database,
1066    schema: &SchemaManager,
1067    table_schema: &TableSchema,
1068    column: &str,
1069) -> Result<ann_persist::AnnSegmentInfo> {
1070    let col_lower = column.to_ascii_lowercase();
1071    let col_idx = table_schema
1072        .columns
1073        .iter()
1074        .position(|c| c.name == col_lower)
1075        .ok_or_else(|| SqlError::ColumnNotFound(column.to_string()))?;
1076    let DataType::Vector { dim } = table_schema.columns[col_idx].data_type else {
1077        return Err(SqlError::InvalidValue(format!(
1078            "column `{column}` is not VECTOR(N)"
1079        )));
1080    };
1081    // Same admission as AnnTopKPlan::try_new: a table no plan can serve must
1082    // not get a persisted segment (it would be unservable dead weight with
1083    // mis-decoded row ids).
1084    if table_schema.primary_key_columns.len() != 1
1085        || !matches!(
1086            table_schema.columns[table_schema.primary_key_columns[0] as usize].data_type,
1087            DataType::Integer
1088        )
1089    {
1090        return Err(SqlError::InvalidValue(
1091            "ANN persistence requires a single INTEGER primary key (same rule as the \
1092             ANN query plan)"
1093                .into(),
1094        ));
1095    }
1096    let ann_index = table_schema
1097        .indices
1098        .iter()
1099        .find(|ix| {
1100            matches!(ix.kind, IndexKind::Inverted(InvertedKind::Ann { .. }))
1101                && ix.keys.len() == 1
1102                && matches!(ix.keys[0], IndexKey::Column { idx, .. } if idx as usize == col_idx)
1103        })
1104        .ok_or_else(|| SqlError::InvalidValue(format!("no ANN index declared on `{column}`")))?;
1105    let IndexKind::Inverted(InvertedKind::Ann { metric }) = ann_index.kind else {
1106        unreachable!("matched above");
1107    };
1108    let spec = AnnSpec {
1109        col_idx,
1110        dim,
1111        metric,
1112        filter_cols: ann_index.ann_filter_cols.clone(),
1113    };
1114
1115    let mut wtx = db.begin_write().map_err(SqlError::Storage)?;
1116    let outcome = scan_rows(&mut wtx, table_schema, &spec)?;
1117    if outcome.rows.is_empty() {
1118        return Err(SqlError::InvalidValue(
1119            "nothing to persist: the table has no indexable (non-NULL) vectors".into(),
1120        ));
1121    }
1122    let n = outcome.rows.len() as u64;
1123    let index = AnnIndex::build_with_attrs(
1124        outcome.rows,
1125        spec.filter_cols.len(),
1126        ann_metric_to_prism(spec.metric),
1127        spec.dim,
1128    )
1129    .map_err(|e| SqlError::InvalidValue(format!("ANN build failed: {e}")))?;
1130
1131    let body = citadel_vector::segment::encode(&index);
1132    let segment_b3 = *blake3::hash(&body).as_bytes();
1133    // Dict entries ordered by code (codes are assigned in first-seen scan
1134    // order, so by-code IS scan order).
1135    let dicts_ordered: Vec<Vec<(Vec<u8>, u32)>> = outcome
1136        .dicts
1137        .iter()
1138        .map(|d| {
1139            let mut entries: Vec<(Vec<u8>, u32)> = d.iter().map(|(k, &v)| (k.clone(), v)).collect();
1140            entries.sort_by_key(|&(_, code)| code);
1141            entries
1142        })
1143        .collect();
1144    let header = ann_persist::SegmentHeader {
1145        format_version: ann_persist::ANNSEG_FORMAT_VERSION,
1146        prism_config_hash: ann_persist::active_config_hash(ann_metric_to_prism(spec.metric)),
1147        dim: spec.dim,
1148        metric_tag: spec.metric_tag(),
1149        n,
1150        snapshot_max: index.snapshot_max,
1151        col_idx: spec.col_idx as u32,
1152        filter_cols: spec.filter_cols.iter().map(|&c| c as u32).collect(),
1153        dicts: dicts_ordered,
1154        content_fingerprint: outcome.fingerprint,
1155        segment_b3,
1156        chunk_count: body.len().div_ceil(ann_persist::CHUNK_BYTES) as u32,
1157        writer: format!("citadel-sql {}", env!("CARGO_PKG_VERSION")),
1158    };
1159
1160    let seg_table = ann_persist::segment_table_name(&table_schema.name);
1161    ann_persist::purge_segment(&mut wtx, &table_schema.name)?;
1162    wtx.create_table(&seg_table).map_err(SqlError::Storage)?;
1163    wtx.table_insert(&seg_table, &ann_persist::segment_key(0), &header.encode())
1164        .map_err(SqlError::Storage)?;
1165    for (chunk_no, chunk) in ann_persist::chunks(&body) {
1166        wtx.table_insert(&seg_table, &ann_persist::segment_key(chunk_no), chunk)
1167            .map_err(SqlError::Storage)?;
1168    }
1169    wtx.commit().map_err(SqlError::Storage)?;
1170
1171    // Warm the shared cache: this index reflects exactly the just-committed
1172    // state (single writer - our commit is the current generation).
1173    let cached = CachedAnnIndex {
1174        index,
1175        dicts: outcome.dicts,
1176        source: AnnIndexSource::Built { refusal: None },
1177        cached_gen: db.manager().commit_generation(),
1178    };
1179    let key = cache_key(&table_schema.name, spec.col_idx, spec.metric);
1180    let as_any: Arc<dyn Any + Send + Sync> = Arc::new(cached);
1181    schema.sql_caches.lock().insert(key, as_any);
1182
1183    Ok(ann_persist::AnnSegmentInfo {
1184        segment_b3,
1185        content_fingerprint: header.content_fingerprint,
1186        n,
1187        dim: spec.dim,
1188        metric_tag: header.metric_tag,
1189        chunk_count: header.chunk_count,
1190    })
1191}
1192
1193/// The queryable identity of the index currently cached for `table.column`:
1194/// `(source, snapshot generation)`, or `None` when nothing is cached.
1195pub(crate) fn ann_cache_status(
1196    schema: &SchemaManager,
1197    table_schema: &TableSchema,
1198    column: &str,
1199) -> Result<Option<(AnnIndexSource, u64)>> {
1200    let col_lower = column.to_ascii_lowercase();
1201    let col_idx = table_schema
1202        .columns
1203        .iter()
1204        .position(|c| c.name == col_lower)
1205        .ok_or_else(|| SqlError::ColumnNotFound(column.to_string()))?;
1206    let guard = schema.sql_caches.lock();
1207    for metric in [AnnMetric::L2, AnnMetric::Inner, AnnMetric::Cosine] {
1208        let key = cache_key(&table_schema.name, col_idx, metric);
1209        if let Some(entry) = guard.get(&key) {
1210            if let Ok(c) = Arc::clone(entry).downcast::<CachedAnnIndex>() {
1211                return Ok(Some((c.source.clone(), c.cached_gen)));
1212            }
1213        }
1214    }
1215    Ok(None)
1216}
1217
1218/// The per-table last-DML generation marker's cache key. Stamped by the
1219/// commit-time invalidation in `connection.rs`; read here to refuse any index
1220/// whose snapshot predates the most recent DML commit on its table.
1221pub(crate) fn ann_dml_gen_key(table_name: &str) -> String {
1222    format!("ann_dml_gen:{table_name}")
1223}
1224
1225/// Read the marker under an already-held cache lock.
1226fn marker_gen_locked(
1227    entries: &FxHashMap<String, Arc<dyn Any + Send + Sync>>,
1228    table_name: &str,
1229) -> Option<u64> {
1230    entries
1231        .get(&ann_dml_gen_key(table_name))
1232        .and_then(|e| e.downcast_ref::<u64>())
1233        .copied()
1234}
1235
1236fn lookup_cached(
1237    schema: &SchemaManager,
1238    cache_key: &str,
1239    table_name: &str,
1240) -> Result<Option<Arc<CachedAnnIndex>>> {
1241    let mut guard = schema.sql_caches.lock();
1242    let Some(entry) = guard.get(cache_key) else {
1243        return Ok(None);
1244    };
1245    let entry = Arc::clone(entry)
1246        .downcast::<CachedAnnIndex>()
1247        .map_err(|_| SqlError::InvalidValue(format!("ANN cache type mismatch for {cache_key}")))?;
1248    if marker_gen_locked(&guard, table_name).is_some_and(|g| entry.cached_gen < g) {
1249        // The entry predates a DML commit on this table (a build that raced a
1250        // commit slipped past eviction): drop it and rebuild/reload.
1251        guard.remove(cache_key);
1252        return Ok(None);
1253    }
1254    Ok(Some(entry))
1255}
1256
1257pub(super) fn cache_key(table_name: &str, col_idx: usize, metric: AnnMetric) -> String {
1258    let tag = match metric {
1259        AnnMetric::L2 => "l2",
1260        AnnMetric::Inner => "inner",
1261        AnnMetric::Cosine => "cosine",
1262    };
1263    format!(
1264        "ann:{}:{}:{}",
1265        table_name.to_ascii_lowercase(),
1266        col_idx,
1267        tag
1268    )
1269}
1270
1271fn ann_metric_to_prism(m: AnnMetric) -> Metric {
1272    match m {
1273        AnnMetric::L2 => Metric::L2,
1274        AnnMetric::Inner => Metric::InnerProduct,
1275        AnnMetric::Cosine => Metric::Cosine,
1276    }
1277}