Skip to main content

citadel_sql/executor/
ann_topk.rs

1//! Plans for `SELECT ... ORDER BY col <dist> :q LIMIT k`: [`AnnTopKPlan`] uses a
2//! cached PRISM index; [`VectorTopKPlan`] streams a bounded-heap top-k when no
3//! index applies or inside a write txn (uncommitted rows).
4
5use std::any::Any;
6use std::cmp::Ordering;
7use std::collections::BinaryHeap;
8use std::sync::Arc;
9
10use citadel_txn::read_txn::ReadTxn;
11use citadel_txn::write_txn::WriteTxn;
12use citadel_vector::{AnnIndex, Filter, Metric};
13use rustc_hash::FxHashMap;
14
15use crate::encoding::{
16    decode_column_raw, decode_pk_integer, encode_int_key_into, encode_key_value,
17    encode_key_value_collated_into,
18};
19use crate::error::{Result, SqlError};
20use crate::eval::{eval_expr, is_truthy, ColumnMap, EvalCtx};
21use crate::parser::*;
22use crate::schema::SchemaManager;
23use crate::types::*;
24
25use super::aggregate::is_aggregate_expr;
26use super::ann_persist;
27use super::helpers::{decode_full_row, eval_const_expr, eval_const_int, project_rows};
28use super::window::has_any_window_function;
29
30type StorageResult<T> = std::result::Result<T, citadel_core::Error>;
31type ScanRow<'a> = dyn FnMut(&[u8], &[u8]) -> Result<bool> + 'a;
32type RawScanRow<'a> = dyn FnMut(&[u8], &[u8]) -> StorageResult<bool> + 'a;
33
34/// Scan + point-get over a read or write txn, materializing overflow values.
35pub(super) trait AnnScan {
36    fn ann_scan(&mut self, table: &[u8], f: &mut ScanRow<'_>) -> Result<()>;
37    fn ann_get(&mut self, table: &[u8], key: &[u8]) -> Result<Option<Vec<u8>>>;
38    /// The commit generation this txn's snapshot reflects, or `None` when the
39    /// view includes uncommitted writes - an index over such a view must NEVER
40    /// enter the shared cache.
41    fn cache_generation(&self) -> Option<u64>;
42}
43
44/// Adapt a storage-level scan to report `SqlError`, surfacing the first callback error.
45fn bridge_scan(
46    scan: impl FnOnce(&mut RawScanRow<'_>) -> StorageResult<()>,
47    f: &mut ScanRow<'_>,
48) -> Result<()> {
49    let mut cb_err: Option<SqlError> = None;
50    scan(&mut |key, value| match f(key, value) {
51        Ok(go) => Ok(go),
52        Err(e) => {
53            cb_err = Some(e);
54            Ok(false)
55        }
56    })
57    .map_err(SqlError::Storage)?;
58    match cb_err {
59        Some(e) => Err(e),
60        None => Ok(()),
61    }
62}
63
64impl AnnScan for ReadTxn<'_> {
65    fn ann_scan(&mut self, table: &[u8], f: &mut ScanRow<'_>) -> Result<()> {
66        bridge_scan(|cb| self.table_scan_from(table, b"", cb), f)
67    }
68
69    fn ann_get(&mut self, table: &[u8], key: &[u8]) -> Result<Option<Vec<u8>>> {
70        self.table_get(table, key).map_err(SqlError::Storage)
71    }
72
73    fn cache_generation(&self) -> Option<u64> {
74        Some(self.commit_generation())
75    }
76}
77
78impl AnnScan for WriteTxn<'_> {
79    fn ann_scan(&mut self, table: &[u8], f: &mut ScanRow<'_>) -> Result<()> {
80        bridge_scan(|cb| self.table_scan_from(table, b"", cb), f)
81    }
82
83    fn ann_get(&mut self, table: &[u8], key: &[u8]) -> Result<Option<Vec<u8>>> {
84        self.table_get(table, key).map_err(SqlError::Storage)
85    }
86
87    fn cache_generation(&self) -> Option<u64> {
88        None
89    }
90}
91
92/// Where a cached index came from - queryable via `ann_cache_status`, and the
93/// carrier for load-refusal diagnostics (a refused segment degrades to the
94/// slow rebuild, but the refusal reason must stay visible, never a log-only).
95#[derive(Debug, Clone, PartialEq, Eq)]
96pub enum AnnIndexSource {
97    /// Built from a table scan this process. `refusal` records why a persisted
98    /// segment was NOT used, if one existed and was rejected.
99    Built { refusal: Option<String> },
100    /// Loaded from a persisted segment whose body BLAKE3 is `segment_b3` and
101    /// whose content fingerprint matched the rehydration scan.
102    Loaded { segment_b3: [u8; 32] },
103}
104
105/// A cached ANN index plus the metadata needed to push SQL filters into it.
106struct CachedAnnIndex {
107    index: AnnIndex,
108    /// Per attribute dim: maps an encoded filter-column value to its PRISM code.
109    dicts: Vec<FxHashMap<Vec<u8>, u32>>,
110    source: AnnIndexSource,
111    /// The commit generation the index reflects: inserts into the shared cache
112    /// are declined if the database moved past it (the build/load raced a
113    /// commit), so a cached index can never describe a superseded snapshot.
114    cached_gen: u64,
115}
116
117pub(super) struct AnnTopKPlan {
118    col_idx: usize,
119    dim: u16,
120    metric: AnnMetric,
121    query_vec: Vec<f32>,
122    k: usize,
123    offset: usize,
124    /// Schema column indices declared filterable on the index, in attr-dim order.
125    filter_cols: Vec<u16>,
126    /// Pushable conjuncts: `(attr_dim, allowed_values)` from `col = v` / `col IN (...)`.
127    pushable: Vec<(usize, Vec<Value>)>,
128    /// Remaining WHERE predicate evaluated as a recheck on decoded candidates.
129    residual: Option<Expr>,
130}
131
132/// Gate for single-key ascending ORDER BY ... LIMIT k (no group/having/join/distinct/window/agg).
133fn topk_shape_ok(stmt: &SelectStmt) -> bool {
134    stmt.order_by.len() == 1
135        && !stmt.order_by[0].descending
136        && stmt.limit.is_some()
137        && stmt.group_by.is_empty()
138        && stmt.having.is_none()
139        && stmt.joins.is_empty()
140        && !stmt.distinct
141        && !has_any_window_function(stmt)
142        && !stmt
143            .columns
144            .iter()
145            .any(|c| matches!(c, SelectColumn::Expr { expr, .. } if is_aggregate_expr(expr)))
146}
147
148impl AnnTopKPlan {
149    pub(super) fn try_new(stmt: &SelectStmt, table_schema: &TableSchema) -> Result<Option<Self>> {
150        if !topk_shape_ok(stmt) {
151            return Ok(None);
152        }
153        let ob = &stmt.order_by[0];
154
155        let (col_idx, dim, op_metric, query_vec) = match &ob.expr {
156            Expr::BinaryOp { left, op, right } => {
157                let op_metric = match op {
158                    BinOp::VectorL2 => AnnMetric::L2,
159                    BinOp::VectorInner => AnnMetric::Inner,
160                    BinOp::VectorCosine => AnnMetric::Cosine,
161                    _ => return Ok(None),
162                };
163                let col_name = match left.as_ref() {
164                    Expr::Column(name) => name.to_ascii_lowercase(),
165                    _ => return Ok(None),
166                };
167                let (col_idx, dim) = match table_schema
168                    .columns
169                    .iter()
170                    .enumerate()
171                    .find(|(_, c)| c.name.to_ascii_lowercase() == col_name)
172                {
173                    Some((i, c)) => match c.data_type {
174                        DataType::Vector { dim } => (i, dim),
175                        _ => return Ok(None),
176                    },
177                    None => return Ok(None),
178                };
179                let col_map = ColumnMap::new(&table_schema.columns);
180                let ctx = EvalCtx::new(&col_map, &[]);
181                let v = match eval_expr(right, &ctx) {
182                    Ok(Value::Vector(v)) => v,
183                    _ => return Ok(None),
184                };
185                if v.len() != dim as usize {
186                    return Err(SqlError::InvalidValue(format!(
187                        "ANN query vector dim {} does not match column dim {}",
188                        v.len(),
189                        dim
190                    )));
191                }
192                (col_idx, dim, op_metric, v.to_vec())
193            }
194            _ => return Ok(None),
195        };
196
197        let ann_index = table_schema.indices.iter().find(|ix| {
198            matches!(ix.kind,
199                IndexKind::Inverted(InvertedKind::Ann { metric }) if metric == op_metric
200            ) && ix.keys.len() == 1
201                && matches!(ix.keys[0],
202                    IndexKey::Column { idx, .. } if idx as usize == col_idx
203                )
204        });
205        let Some(ann_index) = ann_index else {
206            return Ok(None);
207        };
208        let filter_cols = ann_index.ann_filter_cols.clone();
209
210        if table_schema.primary_key_columns.len() != 1 {
211            return Ok(None);
212        }
213        let pk_col = &table_schema.columns[table_schema.primary_key_columns[0] as usize];
214        if !matches!(pk_col.data_type, DataType::Integer) {
215            return Ok(None);
216        }
217
218        // No pushable predicate means the index gives no leverage; decline so
219        // the exact filtered scan runs instead.
220        let mut pushable: Vec<(usize, Vec<Value>)> = Vec::new();
221        let mut residual_leaves: Vec<Expr> = Vec::new();
222        if let Some(w) = &stmt.where_clause {
223            split_where(
224                w,
225                &filter_cols,
226                table_schema,
227                &mut pushable,
228                &mut residual_leaves,
229            );
230            if pushable.is_empty() {
231                return Ok(None);
232            }
233        }
234        let residual = fold_and(residual_leaves);
235
236        let k_limit = eval_const_int(stmt.limit.as_ref().unwrap())?.max(0) as usize;
237        let offset = stmt
238            .offset
239            .as_ref()
240            .map(eval_const_int)
241            .transpose()?
242            .unwrap_or(0)
243            .max(0) as usize;
244        if k_limit == 0 {
245            return Ok(None);
246        }
247
248        Ok(Some(Self {
249            col_idx,
250            dim,
251            metric: op_metric,
252            query_vec,
253            k: k_limit,
254            offset,
255            filter_cols,
256            pushable,
257            residual,
258        }))
259    }
260
261    pub(super) fn execute_with_read(
262        &self,
263        rtx: &mut ReadTxn<'_>,
264        schema: &SchemaManager,
265        stmt: &SelectStmt,
266        table_schema: &TableSchema,
267    ) -> Result<ExecutionResult> {
268        let cache_key = cache_key(&table_schema.name, self.col_idx, self.metric);
269        // Empty table: nothing to build, ORDER BY ... LIMIT yields no rows.
270        let Some(cached) = self.load_or_build_index(rtx, schema, &cache_key, table_schema)? else {
271            return empty_result(table_schema, stmt);
272        };
273        self.run_query(rtx, &cached, stmt, table_schema)
274    }
275
276    /// Search the index, apply filters and the residual recheck, then page and project.
277    fn run_query(
278        &self,
279        txn: &mut dyn AnnScan,
280        cached: &CachedAnnIndex,
281        stmt: &SelectStmt,
282        table_schema: &TableSchema,
283    ) -> Result<ExecutionResult> {
284        // Map values to codes via the same collation-canonical encoding the
285        // dictionary was built with; a value absent from it matches no row.
286        let mut constraints: Vec<(usize, Vec<u32>)> = Vec::with_capacity(self.pushable.len());
287        for (dim, values) in &self.pushable {
288            let dict = &cached.dicts[*dim];
289            let coll = table_schema.columns[self.filter_cols[*dim] as usize].collation;
290            let mut codes = Vec::with_capacity(values.len());
291            let mut canon = Vec::with_capacity(16);
292            for v in values {
293                canon.clear();
294                encode_key_value_collated_into(v, coll, &mut canon);
295                if let Some(&code) = dict.get(canon.as_slice()) {
296                    codes.push(code);
297                }
298            }
299            if codes.is_empty() {
300                return empty_result(table_schema, stmt);
301            }
302            constraints.push((*dim, codes));
303        }
304        let filter = if constraints.is_empty() {
305            Filter::none()
306        } else {
307            Filter::new(constraints)
308        };
309
310        let want = self.k.saturating_add(self.offset).max(1);
311        let mut rows = self.collect_survivors(txn, &cached.index, &filter, table_schema, want)?;
312
313        if self.offset >= rows.len() {
314            rows.clear();
315        } else if self.offset > 0 {
316            rows = rows.split_off(self.offset);
317        }
318        rows.truncate(self.k);
319
320        let (col_names, projected) = project_rows(&table_schema.columns, &stmt.columns, rows)?;
321        Ok(ExecutionResult::Query(QueryResult {
322            columns: col_names,
323            rows: projected,
324        }))
325    }
326
327    /// Search the index (with `filter` pushed in) and recheck the residual
328    /// predicate on decoded rows, over-fetching until `want` rows survive or the
329    /// index is exhausted. Distance order is preserved.
330    fn collect_survivors(
331        &self,
332        txn: &mut dyn AnnScan,
333        index: &AnnIndex,
334        filter: &Filter,
335        table_schema: &TableSchema,
336        want: usize,
337    ) -> Result<Vec<Vec<Value>>> {
338        let col_map = ColumnMap::new(&table_schema.columns);
339        let max_target = index.indexed_len().max(1);
340        let mut key_buf: Vec<u8> = Vec::with_capacity(10);
341        let mut target = want;
342        loop {
343            target = target.min(max_target);
344            let hits = index.search_filtered_default_ef(&self.query_vec, target, filter);
345            let mut survivors: Vec<Vec<Value>> = Vec::with_capacity(want);
346            for (id, _dist) in &hits {
347                encode_int_key_into(*id as i64, &mut key_buf);
348                let Some(row_bytes) = txn.ann_get(table_schema.name.as_bytes(), &key_buf)? else {
349                    continue;
350                };
351                let row = decode_full_row(table_schema, &key_buf, &row_bytes)?;
352                let keep = match &self.residual {
353                    None => true,
354                    Some(expr) => {
355                        let ctx = EvalCtx::new(&col_map, &row);
356                        is_truthy(&eval_expr(expr, &ctx)?)
357                    }
358                };
359                if keep {
360                    survivors.push(row);
361                    if survivors.len() >= want {
362                        break;
363                    }
364                }
365            }
366            // Stop when satisfied, when the index is exhausted, or when PRISM
367            // returns fewer candidates than asked (no more to find).
368            if survivors.len() >= want || target >= max_target || hits.len() < target {
369                return Ok(survivors);
370            }
371            target = target.saturating_mul(2);
372        }
373    }
374
375    fn load_or_build_index(
376        &self,
377        txn: &mut dyn AnnScan,
378        schema: &SchemaManager,
379        cache_key: &str,
380        table_schema: &TableSchema,
381    ) -> Result<Option<Arc<CachedAnnIndex>>> {
382        if let Some(existing) = lookup_cached(schema, cache_key, &table_schema.name)? {
383            return Ok(Some(existing));
384        }
385        let spec = AnnSpec {
386            col_idx: self.col_idx,
387            dim: self.dim,
388            metric: self.metric,
389            filter_cols: self.filter_cols.clone(),
390        };
391        load_or_build(txn, schema, cache_key, table_schema, &spec)
392    }
393}
394
395/// The index identity a build/load/persist operates on - what `AnnTopKPlan`
396/// resolves from the statement, and what `persist_ann_index` resolves from the
397/// declared index.
398pub(super) struct AnnSpec {
399    pub col_idx: usize,
400    pub dim: u16,
401    pub metric: AnnMetric,
402    pub filter_cols: Vec<u16>,
403}
404
405impl AnnSpec {
406    fn metric_tag(&self) -> u8 {
407        citadel_vector::segment::metric_tag(ann_metric_to_prism(self.metric))
408    }
409}
410
411/// One scan pass: the build rows, the filter dictionaries (codes in first-seen
412/// order), and the injective content fingerprint. Build, persist, and load all
413/// decode rows through HERE - one decode path, one fingerprint definition.
414struct ScanOutcome {
415    rows: Vec<(u64, Vec<f32>, Vec<u32>)>,
416    dicts: Vec<FxHashMap<Vec<u8>, u32>>,
417    fingerprint: [u8; 32],
418}
419
420fn scan_rows(
421    txn: &mut dyn AnnScan,
422    table_schema: &TableSchema,
423    spec: &AnnSpec,
424) -> Result<ScanOutcome> {
425    let non_pk = table_schema.non_pk_indices();
426    let enc_pos = table_schema.encoding_positions();
427    let nonpk_order = non_pk
428        .iter()
429        .position(|&i| i == spec.col_idx)
430        .ok_or_else(|| {
431            SqlError::InvalidValue("vector column must be non-PK for ANN build".into())
432        })?;
433    let enc_idx = enc_pos[nonpk_order] as usize;
434
435    let num_attrs = spec.filter_cols.len();
436    let extracts: Vec<Extract> = spec
437        .filter_cols
438        .iter()
439        .map(|&c| extract_plan(c, table_schema, non_pk, enc_pos))
440        .collect::<Result<_>>()?;
441    // Dictionary keys are COLLATION-CANONICAL so collation-equal stored values
442    // share one attr code (matching the eval path's equality); the fingerprint
443    // keeps raw encodings so content edits between collation-equal values are
444    // still detected.
445    let collations: Vec<Collation> = spec
446        .filter_cols
447        .iter()
448        .map(|&c| table_schema.columns[c as usize].collation)
449        .collect();
450    let mut dicts: Vec<FxHashMap<Vec<u8>, u32>> = vec![FxHashMap::default(); num_attrs];
451    let mut fp = ann_persist::FingerprintHasher::new(
452        &table_schema.name,
453        spec.col_idx as u32,
454        &spec
455            .filter_cols
456            .iter()
457            .map(|&c| c as u32)
458            .collect::<Vec<_>>(),
459        spec.dim,
460        spec.metric_tag(),
461    );
462    let mut rows: Vec<(u64, Vec<f32>, Vec<u32>)> = Vec::new();
463
464    txn.ann_scan(table_schema.name.as_bytes(), &mut |key, value| {
465        let vector = match decode_column_raw(value, enc_idx)?.to_value() {
466            Value::Vector(arr) => Some(arr.to_vec()),
467            Value::Null => None, // null vectors are content, but not indexed
468            _ => {
469                return Err(SqlError::InvalidValue(
470                    "ANN column produced non-vector value".into(),
471                ))
472            }
473        };
474        let mut filter_vals: Vec<Value> = Vec::with_capacity(num_attrs);
475        for ex in &extracts {
476            filter_vals.push(ex.extract(key, value)?);
477        }
478        let encoded_filters: Vec<Vec<u8>> = filter_vals.iter().map(encode_key_value).collect();
479        let vec_bytes: Vec<u8> = vector
480            .as_deref()
481            .unwrap_or(&[])
482            .iter()
483            .flat_map(|f| f.to_le_bytes())
484            .collect();
485        fp.row(
486            key,
487            &vec_bytes,
488            &encoded_filters
489                .iter()
490                .map(Vec::as_slice)
491                .collect::<Vec<_>>(),
492        );
493        let Some(vector) = vector else {
494            return Ok(true);
495        };
496        let id = decode_pk_integer(key)? as u64;
497        let mut codes: Vec<u32> = Vec::with_capacity(num_attrs);
498        for (j, v) in filter_vals.iter().enumerate() {
499            let mut canon = Vec::with_capacity(16);
500            encode_key_value_collated_into(v, collations[j], &mut canon);
501            let next = dicts[j].len() as u32;
502            codes.push(*dicts[j].entry(canon).or_insert(next));
503        }
504        rows.push((id, vector, codes));
505        Ok(true)
506    })?;
507
508    Ok(ScanOutcome {
509        rows,
510        dicts,
511        fingerprint: fp.finish(),
512    })
513}
514
515/// Build the index from a scan; `None` if there are no indexable rows.
516fn build_index(
517    txn: &mut dyn AnnScan,
518    table_schema: &TableSchema,
519    spec: &AnnSpec,
520    refusal: Option<String>,
521    cached_gen: u64,
522) -> Result<Option<CachedAnnIndex>> {
523    let outcome = scan_rows(txn, table_schema, spec)?;
524    if outcome.rows.is_empty() {
525        return Ok(None);
526    }
527    let index = AnnIndex::build_with_attrs(
528        outcome.rows,
529        spec.filter_cols.len(),
530        ann_metric_to_prism(spec.metric),
531        spec.dim,
532    )
533    .map_err(|e| SqlError::InvalidValue(format!("ANN build failed: {e}")))?;
534    Ok(Some(CachedAnnIndex {
535        index,
536        dicts: outcome.dicts,
537        source: AnnIndexSource::Built { refusal },
538        cached_gen,
539    }))
540}
541
542/// How a persisted segment answered the load attempt. `Refused` always falls
543/// back to a rebuild - corrupt segments additionally report loudly (an
544/// HMAC-authenticated page with a failing BLAKE3 means a writer bug).
545enum LoadOutcome {
546    Loaded(Box<CachedAnnIndex>),
547    NoSegment,
548    Refused { reason: String, corrupt: bool },
549}
550
551/// Try to serve the table's persisted segment: header pins, body decode, and
552/// the rehydration scan whose fingerprint PROVES the index describes exactly
553/// the rows this snapshot holds.
554fn try_load_segment(
555    txn: &mut dyn AnnScan,
556    table_schema: &TableSchema,
557    spec: &AnnSpec,
558    cached_gen: u64,
559) -> Result<LoadOutcome> {
560    let seg_table = ann_persist::segment_table_name(&table_schema.name);
561    let header_bytes = match txn.ann_get(&seg_table, &ann_persist::segment_key(0)) {
562        Ok(Some(b)) => b,
563        // Missing tree and missing header are both "never persisted".
564        Ok(None) | Err(_) => return Ok(LoadOutcome::NoSegment),
565    };
566    let refuse = |reason: String, corrupt: bool| Ok(LoadOutcome::Refused { reason, corrupt });
567    let header = match ann_persist::SegmentHeader::decode(&header_bytes) {
568        Ok(h) => h,
569        Err(e) => return refuse(format!("header: {e}"), true),
570    };
571    if header.format_version != ann_persist::ANNSEG_FORMAT_VERSION {
572        return refuse(
573            format!("format v{} (this binary reads v1)", header.format_version),
574            false,
575        );
576    }
577    let active_cfg = citadel_vector::segment::prism_config_hash(&AnnIndex::active_config(
578        ann_metric_to_prism(spec.metric),
579    ));
580    if header.prism_config_hash != active_cfg {
581        return refuse(
582            "PRISM config drift (segment built by another geometry)".into(),
583            false,
584        );
585    }
586    if header.dim != spec.dim
587        || header.metric_tag != spec.metric_tag()
588        || header.col_idx != spec.col_idx as u32
589        || header.filter_cols
590            != spec
591                .filter_cols
592                .iter()
593                .map(|&c| c as u32)
594                .collect::<Vec<_>>()
595    {
596        return refuse(
597            "index identity mismatch (column/metric/filter set)".into(),
598            false,
599        );
600    }
601
602    let mut body = Vec::new();
603    for chunk_no in 1..=header.chunk_count {
604        match txn.ann_get(&seg_table, &ann_persist::segment_key(chunk_no)) {
605            Ok(Some(c)) => body.extend_from_slice(&c),
606            _ => return refuse(format!("missing chunk {chunk_no}"), true),
607        }
608    }
609    if *blake3::hash(&body).as_bytes() != header.segment_b3 {
610        return refuse("segment body BLAKE3 mismatch (corrupt)".into(), true);
611    }
612    let parts = match citadel_vector::segment::decode(&body) {
613        Ok(p) => p,
614        Err(e) => return refuse(format!("segment decode: {e}"), true),
615    };
616    if parts.n() as u64 != header.n || parts.dim() != header.dim {
617        return refuse("segment body disagrees with header counts".into(), true);
618    }
619
620    // The rehydration scan: vectors placed by the id_map PERMUTATION (scan
621    // order is NOT internal order), fingerprint computed over ALL rows.
622    let slot_of = parts.internal_of_row();
623    let dim = spec.dim as usize;
624    let mut vectors = vec![0.0f32; parts.n() * dim];
625    let mut filled = 0usize;
626    let outcome = scan_rows_rehydrate(txn, table_schema, spec, &mut |row_id, vector| {
627        let Some(&slot) = slot_of.get(&row_id) else {
628            return false; // a row the segment does not know: stale
629        };
630        vectors[slot as usize * dim..(slot as usize + 1) * dim].copy_from_slice(vector);
631        filled += 1;
632        true
633    })?;
634    let Some(fingerprint) = outcome else {
635        return refuse(
636            "a scanned row is unknown to the segment (stale)".into(),
637            false,
638        );
639    };
640    if fingerprint != header.content_fingerprint {
641        return refuse("content fingerprint mismatch (stale)".into(), false);
642    }
643    let index = match parts.into_index(vectors, filled) {
644        Ok(i) => i,
645        Err(e) => return refuse(format!("rehydration: {e}"), true),
646    };
647    Ok(LoadOutcome::Loaded(Box::new(CachedAnnIndex {
648        index,
649        dicts: header.dict_maps(),
650        source: AnnIndexSource::Loaded {
651            segment_b3: header.segment_b3,
652        },
653        cached_gen,
654    })))
655}
656
657/// The rehydration variant of the scan: same decode + fingerprint as
658/// [`scan_rows`], but vectors stream to the placer instead of accumulating.
659/// Returns `None` if the placer rejects a row (unknown to the segment).
660fn scan_rows_rehydrate(
661    txn: &mut dyn AnnScan,
662    table_schema: &TableSchema,
663    spec: &AnnSpec,
664    place: &mut dyn FnMut(u64, &[f32]) -> bool,
665) -> Result<Option<[u8; 32]>> {
666    let non_pk = table_schema.non_pk_indices();
667    let enc_pos = table_schema.encoding_positions();
668    let nonpk_order = non_pk
669        .iter()
670        .position(|&i| i == spec.col_idx)
671        .ok_or_else(|| {
672            SqlError::InvalidValue("vector column must be non-PK for ANN build".into())
673        })?;
674    let enc_idx = enc_pos[nonpk_order] as usize;
675    let extracts: Vec<Extract> = spec
676        .filter_cols
677        .iter()
678        .map(|&c| extract_plan(c, table_schema, non_pk, enc_pos))
679        .collect::<Result<_>>()?;
680    let mut fp = ann_persist::FingerprintHasher::new(
681        &table_schema.name,
682        spec.col_idx as u32,
683        &spec
684            .filter_cols
685            .iter()
686            .map(|&c| c as u32)
687            .collect::<Vec<_>>(),
688        spec.dim,
689        spec.metric_tag(),
690    );
691    let mut unknown_row = false;
692
693    txn.ann_scan(table_schema.name.as_bytes(), &mut |key, value| {
694        let vector = match decode_column_raw(value, enc_idx)?.to_value() {
695            Value::Vector(arr) => Some(arr.to_vec()),
696            Value::Null => None,
697            _ => {
698                return Err(SqlError::InvalidValue(
699                    "ANN column produced non-vector value".into(),
700                ))
701            }
702        };
703        let mut encoded_filters: Vec<Vec<u8>> = Vec::with_capacity(extracts.len());
704        for ex in &extracts {
705            encoded_filters.push(encode_key_value(&ex.extract(key, value)?));
706        }
707        let vec_bytes: Vec<u8> = vector
708            .as_deref()
709            .unwrap_or(&[])
710            .iter()
711            .flat_map(|f| f.to_le_bytes())
712            .collect();
713        fp.row(
714            key,
715            &vec_bytes,
716            &encoded_filters
717                .iter()
718                .map(Vec::as_slice)
719                .collect::<Vec<_>>(),
720        );
721        if let Some(vector) = vector {
722            let id = decode_pk_integer(key)? as u64;
723            if !place(id, &vector) {
724                unknown_row = true;
725                return Ok(false);
726            }
727        }
728        Ok(true)
729    })?;
730
731    Ok(if unknown_row { None } else { Some(fp.finish()) })
732}
733
734/// The shared load-then-build flow: try the persisted segment, fall back to a
735/// scan build carrying the refusal as a diagnostic, and insert into the shared
736/// cache ONLY if no DML on this table committed past the snapshot the index
737/// reflects (and never from a write-txn view).
738fn load_or_build(
739    txn: &mut dyn AnnScan,
740    schema: &SchemaManager,
741    cache_key: &str,
742    table_schema: &TableSchema,
743    spec: &AnnSpec,
744) -> Result<Option<Arc<CachedAnnIndex>>> {
745    let gen = txn.cache_generation();
746    let cached_gen = gen.unwrap_or(u64::MAX);
747    let loaded = match try_load_segment(txn, table_schema, spec, cached_gen)? {
748        LoadOutcome::Loaded(c) => Some(*c),
749        LoadOutcome::NoSegment => None,
750        LoadOutcome::Refused { reason, corrupt } => {
751            if corrupt {
752                eprintln!(
753                    "citadel-sql: ANN segment for `{}` REFUSED as corrupt ({reason}); \
754                     rebuilding from scan - investigate before re-persisting",
755                    table_schema.name
756                );
757            }
758            // Stale/drift refusals are the expected degradation path; the
759            // reason stays queryable on the rebuilt entry either way.
760            match build_index(txn, table_schema, spec, Some(reason), cached_gen)? {
761                Some(c) => Some(c),
762                None => return Ok(None),
763            }
764        }
765    };
766    let built = match loaded {
767        Some(c) => c,
768        None => match build_index(txn, table_schema, spec, None, cached_gen)? {
769            Some(c) => c,
770            None => return Ok(None),
771        },
772    };
773    let arc: Arc<CachedAnnIndex> = Arc::new(built);
774    if gen.is_none() {
775        // A write-txn view may include uncommitted rows: serve, never cache.
776        return Ok(Some(arc));
777    }
778    let mut guard = schema.sql_caches.lock();
779    if let Some(existing) = guard.get(cache_key) {
780        // Another thread won the race; prefer that one and drop ours.
781        return Arc::clone(existing)
782            .downcast::<CachedAnnIndex>()
783            .map(Some)
784            .map_err(|_| {
785                SqlError::InvalidValue(format!("ANN cache type mismatch for {cache_key}"))
786            });
787    }
788    let marker = marker_gen_locked(&guard, &table_schema.name);
789    if marker.is_some_and(|g| arc.cached_gen < g) {
790        // DML committed while the scan/build ran: the index is a superseded
791        // snapshot. Serve it to THIS query, decline the cache.
792        return Ok(Some(arc));
793    }
794    let as_any: Arc<dyn Any + Send + Sync> = arc.clone();
795    guard.insert(cache_key.to_string(), as_any);
796    Ok(Some(arc))
797}
798
799/// Streaming brute-force top-k for `ORDER BY <distance> LIMIT k` when no ANN
800/// index applies (or inside a write txn); bounded heap, O(k) memory.
801pub(super) struct VectorTopKPlan {
802    order_expr: Expr,
803    where_clause: Option<Expr>,
804    k: usize,
805    offset: usize,
806    nulls_first: bool,
807}
808
809/// A candidate keyed by (distance, scan position); `seq` breaks ties by scan
810/// order so the bounded heap matches the stable sort.
811struct Ranked {
812    dist: f64,
813    seq: u64,
814    row: Vec<Value>,
815}
816
817impl PartialEq for Ranked {
818    fn eq(&self, other: &Self) -> bool {
819        self.cmp(other) == Ordering::Equal
820    }
821}
822impl Eq for Ranked {}
823impl PartialOrd for Ranked {
824    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
825        Some(self.cmp(other))
826    }
827}
828impl Ord for Ranked {
829    fn cmp(&self, other: &Self) -> Ordering {
830        self.dist
831            .total_cmp(&other.dist)
832            .then_with(|| self.seq.cmp(&other.seq))
833    }
834}
835
836impl VectorTopKPlan {
837    pub(super) fn try_new(stmt: &SelectStmt, table_schema: &TableSchema) -> Result<Option<Self>> {
838        if !topk_shape_ok(stmt) {
839            return Ok(None);
840        }
841        let ob = &stmt.order_by[0];
842        let Expr::BinaryOp { left, op, .. } = &ob.expr else {
843            return Ok(None);
844        };
845        if !matches!(
846            op,
847            BinOp::VectorL2 | BinOp::VectorInner | BinOp::VectorCosine
848        ) {
849            return Ok(None);
850        }
851        // Only claim a vector-distance sort key; anything else uses the general path.
852        let Expr::Column(name) = left.as_ref() else {
853            return Ok(None);
854        };
855        let name = name.to_ascii_lowercase();
856        let is_vector_col = table_schema.columns.iter().any(|c| {
857            c.name.to_ascii_lowercase() == name && matches!(c.data_type, DataType::Vector { .. })
858        });
859        if !is_vector_col {
860            return Ok(None);
861        }
862
863        let k = eval_const_int(stmt.limit.as_ref().unwrap())?.max(0) as usize;
864        if k == 0 {
865            return Ok(None);
866        }
867        let offset = stmt
868            .offset
869            .as_ref()
870            .map(eval_const_int)
871            .transpose()?
872            .unwrap_or(0)
873            .max(0) as usize;
874
875        Ok(Some(Self {
876            order_expr: ob.expr.clone(),
877            where_clause: stmt.where_clause.clone(),
878            k,
879            offset,
880            // citadel defaults to NULLS FIRST for ascending order.
881            nulls_first: ob.nulls_first.unwrap_or(true),
882        }))
883    }
884
885    pub(super) fn execute(
886        &self,
887        txn: &mut dyn AnnScan,
888        table_schema: &TableSchema,
889        stmt: &SelectStmt,
890    ) -> Result<ExecutionResult> {
891        let want = self.k.saturating_add(self.offset);
892        let col_map = ColumnMap::new(&table_schema.columns);
893        // NULL distances sort like NULLs under the requested ordering.
894        let null_dist = if self.nulls_first {
895            f64::NEG_INFINITY
896        } else {
897            f64::INFINITY
898        };
899        let mut heap: BinaryHeap<Ranked> = BinaryHeap::new();
900        let mut seq: u64 = 0;
901
902        txn.ann_scan(table_schema.name.as_bytes(), &mut |key, value| {
903            let row = decode_full_row(table_schema, key, value)?;
904            let ctx = EvalCtx::new(&col_map, &row);
905            if let Some(w) = &self.where_clause {
906                if !is_truthy(&eval_expr(w, &ctx)?) {
907                    return Ok(true);
908                }
909            }
910            let dist = match eval_expr(&self.order_expr, &ctx)? {
911                Value::Real(d) => d,
912                Value::Integer(i) => i as f64,
913                Value::Null => null_dist,
914                other => {
915                    return Err(SqlError::InvalidValue(format!(
916                        "ORDER BY vector distance produced a non-numeric {}",
917                        other.data_type()
918                    )))
919                }
920            };
921            let cand = Ranked { dist, seq, row };
922            seq += 1;
923            // `seq` only grows, so ties never evict an earlier row (stable-sort order).
924            if heap.len() < want {
925                heap.push(cand);
926            } else if heap.peek().is_some_and(|top| cand < *top) {
927                heap.pop();
928                heap.push(cand);
929            }
930            Ok(true)
931        })?;
932
933        let mut rows: Vec<Vec<Value>> = heap.into_sorted_vec().into_iter().map(|r| r.row).collect();
934        if self.offset >= rows.len() {
935            rows.clear();
936        } else if self.offset > 0 {
937            rows = rows.split_off(self.offset);
938        }
939        rows.truncate(self.k);
940
941        let (col_names, projected) = project_rows(&table_schema.columns, &stmt.columns, rows)?;
942        Ok(ExecutionResult::Query(QueryResult {
943            columns: col_names,
944            rows: projected,
945        }))
946    }
947}
948
949/// How to read a filter column's value out of a raw row during the build scan.
950enum Extract {
951    /// The single integer primary key, read from the row key.
952    Pk,
953    /// A non-PK column at the given encoding position in the row value.
954    NonPk(usize),
955}
956
957impl Extract {
958    fn extract(&self, key: &[u8], value: &[u8]) -> Result<Value> {
959        match self {
960            Extract::Pk => Ok(Value::Integer(decode_pk_integer(key)?)),
961            Extract::NonPk(ei) => Ok(decode_column_raw(value, *ei)?.to_value()),
962        }
963    }
964}
965
966fn extract_plan(
967    col: u16,
968    table_schema: &TableSchema,
969    non_pk: &[usize],
970    enc_pos: &[u16],
971) -> Result<Extract> {
972    if table_schema.primary_key_columns.contains(&col) {
973        return Ok(Extract::Pk);
974    }
975    let order = non_pk
976        .iter()
977        .position(|&i| i == col as usize)
978        .ok_or_else(|| SqlError::InvalidValue("ANN filter column not found in row".into()))?;
979    Ok(Extract::NonPk(enc_pos[order] as usize))
980}
981
982/// Walk the AND-tree, sorting each leaf into a pushable attribute predicate or
983/// the recheck residual.
984fn split_where(
985    expr: &Expr,
986    filter_cols: &[u16],
987    table_schema: &TableSchema,
988    pushable: &mut Vec<(usize, Vec<Value>)>,
989    residual: &mut Vec<Expr>,
990) {
991    if let Expr::BinaryOp {
992        left,
993        op: BinOp::And,
994        right,
995    } = expr
996    {
997        split_where(left, filter_cols, table_schema, pushable, residual);
998        split_where(right, filter_cols, table_schema, pushable, residual);
999        return;
1000    }
1001    match classify_leaf(expr, filter_cols, table_schema) {
1002        Some(constraint) => pushable.push(constraint),
1003        None => residual.push(expr.clone()),
1004    }
1005}
1006
1007/// Outcome of coercing a pushdown literal to the filter column's stored type.
1008enum Coerced {
1009    /// Encodes exactly like a stored value; safe for the dictionary lookup.
1010    Exact(Value),
1011    /// Can never equal any stored value of this column (e.g. a fractional
1012    /// literal vs INTEGER); contributes no codes.
1013    NeverMatches,
1014    /// Eval equality may diverge from encoded-byte equality (NULL three-valued
1015    /// logic, cross-type comparisons, floats past 2^53); the whole leaf must
1016    /// stay in the residual so the eval path decides.
1017    Residual,
1018}
1019
1020fn coerce_pushdown_literal(val: Value, col_type: DataType) -> Coerced {
1021    // Past 2^53 the int<->f64 mapping is not 1:1, so encoded equality and
1022    // numeric equality diverge.
1023    const EXACT_F64_INT: f64 = 9_007_199_254_740_992.0;
1024    if val.is_null() {
1025        return Coerced::Residual;
1026    }
1027    if val.data_type() == col_type {
1028        return Coerced::Exact(val);
1029    }
1030    match (val, col_type) {
1031        (Value::Real(r), DataType::Integer) => {
1032            if r.is_nan() || r.is_infinite() {
1033                Coerced::NeverMatches
1034            } else if r.abs() > EXACT_F64_INT {
1035                Coerced::Residual
1036            } else if r.fract() == 0.0 {
1037                Coerced::Exact(Value::Integer(r as i64))
1038            } else {
1039                Coerced::NeverMatches
1040            }
1041        }
1042        (Value::Integer(i), DataType::Real) => {
1043            if i.unsigned_abs() <= EXACT_F64_INT as u64 {
1044                Coerced::Exact(Value::Real(i as f64))
1045            } else {
1046                Coerced::Residual
1047            }
1048        }
1049        _ => Coerced::Residual,
1050    }
1051}
1052
1053/// A leaf is pushable if it is `col = literal` or `col IN (literal, ...)` on a
1054/// declared filter column whose constant right-hand side coerces exactly to
1055/// the column's stored type. An empty value list means the leaf is provably
1056/// unsatisfiable (the caller short-circuits to an empty result).
1057fn classify_leaf(
1058    leaf: &Expr,
1059    filter_cols: &[u16],
1060    table_schema: &TableSchema,
1061) -> Option<(usize, Vec<Value>)> {
1062    let (col_expr, rhs): (&Expr, Vec<&Expr>) = match leaf {
1063        Expr::BinaryOp {
1064            left,
1065            op: BinOp::Eq,
1066            right,
1067        } => (left, vec![right.as_ref()]),
1068        Expr::InList {
1069            expr,
1070            list,
1071            negated: false,
1072        } => (expr, list.iter().collect()),
1073        _ => return None,
1074    };
1075    let dim = filter_dim(col_expr, filter_cols, table_schema)?;
1076    let col_type = table_schema.columns[filter_cols[dim] as usize].data_type;
1077    let mut vals = Vec::with_capacity(rhs.len());
1078    for e in rhs {
1079        match coerce_pushdown_literal(eval_const_expr(e).ok()?, col_type) {
1080            Coerced::Exact(v) => vals.push(v),
1081            Coerced::NeverMatches => {}
1082            Coerced::Residual => return None,
1083        }
1084    }
1085    Some((dim, vals))
1086}
1087
1088/// Resolve a column expression to its attribute-dim index (position in
1089/// `filter_cols`), or `None` if it is not a declared filter column.
1090fn filter_dim(expr: &Expr, filter_cols: &[u16], table_schema: &TableSchema) -> Option<usize> {
1091    let name = match expr {
1092        Expr::Column(c) => c.to_ascii_lowercase(),
1093        Expr::QualifiedColumn { column, .. } => column.to_ascii_lowercase(),
1094        _ => return None,
1095    };
1096    let col_idx = table_schema
1097        .columns
1098        .iter()
1099        .position(|c| c.name.to_ascii_lowercase() == name)? as u16;
1100    filter_cols.iter().position(|&c| c == col_idx)
1101}
1102
1103fn fold_and(mut leaves: Vec<Expr>) -> Option<Expr> {
1104    if leaves.is_empty() {
1105        return None;
1106    }
1107    let first = leaves.remove(0);
1108    Some(leaves.into_iter().fold(first, |acc, e| Expr::BinaryOp {
1109        left: Box::new(acc),
1110        op: BinOp::And,
1111        right: Box::new(e),
1112    }))
1113}
1114
1115fn empty_result(table_schema: &TableSchema, stmt: &SelectStmt) -> Result<ExecutionResult> {
1116    let (col_names, projected) = project_rows(&table_schema.columns, &stmt.columns, Vec::new())?;
1117    Ok(ExecutionResult::Query(QueryResult {
1118        columns: col_names,
1119        rows: projected,
1120    }))
1121}
1122
1123/// The explicit freeze operation behind `Connection::persist_ann_index`: ONE
1124/// write txn scans the table (computing the content fingerprint), pays the
1125/// PRISM build, serializes the segment, replaces any prior one, and commits -
1126/// atomic by shadow paging. The single writer lock is held for the full build
1127/// (minutes on large tables): an offline/builder operation by design. The
1128/// fresh index also warms the shared RAM cache, so the NEXT attach is the
1129/// fast loaded one and THIS process serves queries immediately.
1130pub(crate) fn persist_ann_index(
1131    db: &citadel::Database,
1132    schema: &SchemaManager,
1133    table_schema: &TableSchema,
1134    column: &str,
1135) -> Result<ann_persist::AnnSegmentInfo> {
1136    let col_lower = column.to_ascii_lowercase();
1137    let col_idx = table_schema
1138        .columns
1139        .iter()
1140        .position(|c| c.name == col_lower)
1141        .ok_or_else(|| SqlError::ColumnNotFound(column.to_string()))?;
1142    let DataType::Vector { dim } = table_schema.columns[col_idx].data_type else {
1143        return Err(SqlError::InvalidValue(format!(
1144            "column `{column}` is not VECTOR(N)"
1145        )));
1146    };
1147    // Same admission as AnnTopKPlan::try_new: a table no plan can serve must
1148    // not get a persisted segment (it would be unservable dead weight with
1149    // mis-decoded row ids).
1150    if table_schema.primary_key_columns.len() != 1
1151        || !matches!(
1152            table_schema.columns[table_schema.primary_key_columns[0] as usize].data_type,
1153            DataType::Integer
1154        )
1155    {
1156        return Err(SqlError::InvalidValue(
1157            "ANN persistence requires a single INTEGER primary key (same rule as the \
1158             ANN query plan)"
1159                .into(),
1160        ));
1161    }
1162    let ann_index = table_schema
1163        .indices
1164        .iter()
1165        .find(|ix| {
1166            matches!(ix.kind, IndexKind::Inverted(InvertedKind::Ann { .. }))
1167                && ix.keys.len() == 1
1168                && matches!(ix.keys[0], IndexKey::Column { idx, .. } if idx as usize == col_idx)
1169        })
1170        .ok_or_else(|| SqlError::InvalidValue(format!("no ANN index declared on `{column}`")))?;
1171    let IndexKind::Inverted(InvertedKind::Ann { metric }) = ann_index.kind else {
1172        unreachable!("matched above");
1173    };
1174    let spec = AnnSpec {
1175        col_idx,
1176        dim,
1177        metric,
1178        filter_cols: ann_index.ann_filter_cols.clone(),
1179    };
1180
1181    let mut wtx = db.begin_write().map_err(SqlError::Storage)?;
1182    let outcome = scan_rows(&mut wtx, table_schema, &spec)?;
1183    if outcome.rows.is_empty() {
1184        return Err(SqlError::InvalidValue(
1185            "nothing to persist: the table has no indexable (non-NULL) vectors".into(),
1186        ));
1187    }
1188    let n = outcome.rows.len() as u64;
1189    let index = AnnIndex::build_with_attrs(
1190        outcome.rows,
1191        spec.filter_cols.len(),
1192        ann_metric_to_prism(spec.metric),
1193        spec.dim,
1194    )
1195    .map_err(|e| SqlError::InvalidValue(format!("ANN build failed: {e}")))?;
1196
1197    let body = citadel_vector::segment::encode(&index);
1198    let segment_b3 = *blake3::hash(&body).as_bytes();
1199    // Dict entries ordered by code (codes are assigned in first-seen scan
1200    // order, so by-code IS scan order).
1201    let dicts_ordered: Vec<Vec<(Vec<u8>, u32)>> = outcome
1202        .dicts
1203        .iter()
1204        .map(|d| {
1205            let mut entries: Vec<(Vec<u8>, u32)> = d.iter().map(|(k, &v)| (k.clone(), v)).collect();
1206            entries.sort_by_key(|&(_, code)| code);
1207            entries
1208        })
1209        .collect();
1210    let header = ann_persist::SegmentHeader {
1211        format_version: ann_persist::ANNSEG_FORMAT_VERSION,
1212        prism_config_hash: ann_persist::active_config_hash(ann_metric_to_prism(spec.metric)),
1213        dim: spec.dim,
1214        metric_tag: spec.metric_tag(),
1215        n,
1216        snapshot_max: index.snapshot_max,
1217        col_idx: spec.col_idx as u32,
1218        filter_cols: spec.filter_cols.iter().map(|&c| c as u32).collect(),
1219        dicts: dicts_ordered,
1220        content_fingerprint: outcome.fingerprint,
1221        segment_b3,
1222        chunk_count: body.len().div_ceil(ann_persist::CHUNK_BYTES) as u32,
1223        writer: format!("citadel-sql {}", env!("CARGO_PKG_VERSION")),
1224    };
1225
1226    let seg_table = ann_persist::segment_table_name(&table_schema.name);
1227    ann_persist::purge_segment(&mut wtx, &table_schema.name)?;
1228    wtx.create_table(&seg_table).map_err(SqlError::Storage)?;
1229    wtx.table_insert(&seg_table, &ann_persist::segment_key(0), &header.encode())
1230        .map_err(SqlError::Storage)?;
1231    for (chunk_no, chunk) in ann_persist::chunks(&body) {
1232        wtx.table_insert(&seg_table, &ann_persist::segment_key(chunk_no), chunk)
1233            .map_err(SqlError::Storage)?;
1234    }
1235    wtx.commit().map_err(SqlError::Storage)?;
1236
1237    // Warm the shared cache: this index reflects exactly the just-committed
1238    // state (single writer - our commit is the current generation).
1239    let cached = CachedAnnIndex {
1240        index,
1241        dicts: outcome.dicts,
1242        source: AnnIndexSource::Built { refusal: None },
1243        cached_gen: db.manager().commit_generation(),
1244    };
1245    let key = cache_key(&table_schema.name, spec.col_idx, spec.metric);
1246    let as_any: Arc<dyn Any + Send + Sync> = Arc::new(cached);
1247    schema.sql_caches.lock().insert(key, as_any);
1248
1249    Ok(ann_persist::AnnSegmentInfo {
1250        segment_b3,
1251        content_fingerprint: header.content_fingerprint,
1252        n,
1253        dim: spec.dim,
1254        metric_tag: header.metric_tag,
1255        chunk_count: header.chunk_count,
1256    })
1257}
1258
1259/// The queryable identity of the index currently cached for `table.column`:
1260/// `(source, snapshot generation)`, or `None` when nothing is cached.
1261pub(crate) fn ann_cache_status(
1262    schema: &SchemaManager,
1263    table_schema: &TableSchema,
1264    column: &str,
1265) -> Result<Option<(AnnIndexSource, u64)>> {
1266    let col_lower = column.to_ascii_lowercase();
1267    let col_idx = table_schema
1268        .columns
1269        .iter()
1270        .position(|c| c.name == col_lower)
1271        .ok_or_else(|| SqlError::ColumnNotFound(column.to_string()))?;
1272    let guard = schema.sql_caches.lock();
1273    for metric in [AnnMetric::L2, AnnMetric::Inner, AnnMetric::Cosine] {
1274        let key = cache_key(&table_schema.name, col_idx, metric);
1275        if let Some(entry) = guard.get(&key) {
1276            if let Ok(c) = Arc::clone(entry).downcast::<CachedAnnIndex>() {
1277                return Ok(Some((c.source.clone(), c.cached_gen)));
1278            }
1279        }
1280    }
1281    Ok(None)
1282}
1283
1284/// The per-table last-DML generation marker's cache key. Stamped by the
1285/// commit-time invalidation in `connection.rs`; read here to refuse any index
1286/// whose snapshot predates the most recent DML commit on its table.
1287pub(crate) fn ann_dml_gen_key(table_name: &str) -> String {
1288    format!("ann_dml_gen:{table_name}")
1289}
1290
1291/// Read the marker under an already-held cache lock.
1292fn marker_gen_locked(
1293    entries: &FxHashMap<String, Arc<dyn Any + Send + Sync>>,
1294    table_name: &str,
1295) -> Option<u64> {
1296    entries
1297        .get(&ann_dml_gen_key(table_name))
1298        .and_then(|e| e.downcast_ref::<u64>())
1299        .copied()
1300}
1301
1302fn lookup_cached(
1303    schema: &SchemaManager,
1304    cache_key: &str,
1305    table_name: &str,
1306) -> Result<Option<Arc<CachedAnnIndex>>> {
1307    let mut guard = schema.sql_caches.lock();
1308    let Some(entry) = guard.get(cache_key) else {
1309        return Ok(None);
1310    };
1311    let entry = Arc::clone(entry)
1312        .downcast::<CachedAnnIndex>()
1313        .map_err(|_| SqlError::InvalidValue(format!("ANN cache type mismatch for {cache_key}")))?;
1314    if marker_gen_locked(&guard, table_name).is_some_and(|g| entry.cached_gen < g) {
1315        // The entry predates a DML commit on this table (a build that raced a
1316        // commit slipped past eviction): drop it and rebuild/reload.
1317        guard.remove(cache_key);
1318        return Ok(None);
1319    }
1320    Ok(Some(entry))
1321}
1322
1323pub(super) fn cache_key(table_name: &str, col_idx: usize, metric: AnnMetric) -> String {
1324    let tag = match metric {
1325        AnnMetric::L2 => "l2",
1326        AnnMetric::Inner => "inner",
1327        AnnMetric::Cosine => "cosine",
1328    };
1329    format!(
1330        "ann:{}:{}:{}",
1331        table_name.to_ascii_lowercase(),
1332        col_idx,
1333        tag
1334    )
1335}
1336
1337fn ann_metric_to_prism(m: AnnMetric) -> Metric {
1338    match m {
1339        AnnMetric::L2 => Metric::L2,
1340        AnnMetric::Inner => Metric::InnerProduct,
1341        AnnMetric::Cosine => Metric::Cosine,
1342    }
1343}