Skip to main content

citadel_sql/executor/
ann_topk.rs

1//! Plans for `SELECT ... ORDER BY col <dist> :q LIMIT k`: [`AnnTopKPlan`] uses a
2//! cached PRISM index; [`VectorTopKPlan`] streams a bounded-heap top-k when no
3//! index applies or inside a write txn (uncommitted rows).
4
5use std::any::Any;
6use std::cmp::Ordering;
7use std::collections::BinaryHeap;
8use std::sync::Arc;
9
10use citadel_txn::read_txn::ReadTxn;
11use citadel_txn::write_txn::WriteTxn;
12use citadel_vector::{AnnIndex, Filter, Metric};
13use rustc_hash::FxHashMap;
14
15use crate::encoding::{
16    decode_column_raw, decode_pk_integer, encode_int_key_into, encode_key_value,
17    encode_key_value_collated_into,
18};
19use crate::error::{Result, SqlError};
20use crate::eval::{eval_expr, is_truthy, ColumnMap, EvalCtx};
21use crate::parser::*;
22use crate::schema::SchemaManager;
23use crate::types::*;
24
25use super::aggregate::is_aggregate_expr;
26use super::ann_persist;
27use super::helpers::{decode_full_row, eval_const_expr, eval_const_int, project_rows};
28use super::window::has_any_window_function;
29
30type StorageResult<T> = std::result::Result<T, citadel_core::Error>;
31type ScanRow<'a> = dyn FnMut(&[u8], &[u8]) -> Result<bool> + 'a;
32type RawScanRow<'a> = dyn FnMut(&[u8], &[u8]) -> StorageResult<bool> + 'a;
33/// Recall candidate: (distance in SQL operator units, row id, decoded row).
34type RankedRow = (f64, i64, Vec<Value>);
35
36/// Scan + point-get over a read or write txn, materializing overflow values.
37pub(super) trait AnnScan {
38    fn ann_scan(&mut self, table: &[u8], f: &mut ScanRow<'_>) -> Result<()>;
39    /// Forward scan from `start_key` (inclusive); O(tail) for the tail merge.
40    fn ann_scan_from(&mut self, table: &[u8], start_key: &[u8], f: &mut ScanRow<'_>) -> Result<()>;
41    fn ann_get(&mut self, table: &[u8], key: &[u8]) -> Result<Option<Vec<u8>>>;
42    /// Commit generation this snapshot reflects; `None` when the view has uncommitted
43    /// writes - such an index cannot enter the shared cache.
44    fn cache_generation(&self) -> Option<u64>;
45    /// The table's live catalog root (the CoW freshness anchor) - a lookup, not a scan.
46    fn ann_table_root(&self, table: &[u8]) -> Option<u64>;
47}
48
49/// Adapt a storage-level scan to report `SqlError`, surfacing the first callback error.
50fn bridge_scan(
51    scan: impl FnOnce(&mut RawScanRow<'_>) -> StorageResult<()>,
52    f: &mut ScanRow<'_>,
53) -> Result<()> {
54    let mut cb_err: Option<SqlError> = None;
55    scan(&mut |key, value| match f(key, value) {
56        Ok(go) => Ok(go),
57        Err(e) => {
58            cb_err = Some(e);
59            Ok(false)
60        }
61    })
62    .map_err(SqlError::Storage)?;
63    match cb_err {
64        Some(e) => Err(e),
65        None => Ok(()),
66    }
67}
68
69impl AnnScan for ReadTxn<'_> {
70    fn ann_scan(&mut self, table: &[u8], f: &mut ScanRow<'_>) -> Result<()> {
71        bridge_scan(|cb| self.table_scan_from(table, b"", cb), f)
72    }
73
74    fn ann_scan_from(&mut self, table: &[u8], start_key: &[u8], f: &mut ScanRow<'_>) -> Result<()> {
75        bridge_scan(|cb| self.table_scan_from(table, start_key, cb), f)
76    }
77
78    fn ann_get(&mut self, table: &[u8], key: &[u8]) -> Result<Option<Vec<u8>>> {
79        self.table_get(table, key).map_err(SqlError::Storage)
80    }
81
82    fn cache_generation(&self) -> Option<u64> {
83        Some(self.commit_generation())
84    }
85
86    fn ann_table_root(&self, table: &[u8]) -> Option<u64> {
87        self.table_root_page(table)
88            .ok()
89            .flatten()
90            .map(|p| u64::from(p.0))
91    }
92}
93
94impl AnnScan for WriteTxn<'_> {
95    fn ann_scan(&mut self, table: &[u8], f: &mut ScanRow<'_>) -> Result<()> {
96        bridge_scan(|cb| self.table_scan_from(table, b"", cb), f)
97    }
98
99    fn ann_scan_from(&mut self, table: &[u8], start_key: &[u8], f: &mut ScanRow<'_>) -> Result<()> {
100        bridge_scan(|cb| self.table_scan_from(table, start_key, cb), f)
101    }
102
103    fn ann_get(&mut self, table: &[u8], key: &[u8]) -> Result<Option<Vec<u8>>> {
104        self.table_get(table, key).map_err(SqlError::Storage)
105    }
106
107    fn cache_generation(&self) -> Option<u64> {
108        None
109    }
110
111    fn ann_table_root(&self, table: &[u8]) -> Option<u64> {
112        self.table_root_page(table)
113            .ok()
114            .flatten()
115            .map(|p| u64::from(p.0))
116    }
117}
118
119/// Provenance of a cached index; queryable via `ann_cache_status` and carries a
120/// load-refusal reason so a refused segment's cause stays visible, not log-only.
121#[derive(Debug, Clone, PartialEq, Eq)]
122pub enum AnnIndexSource {
123    /// Built from a table scan this process; `refusal` records why a persisted
124    /// segment was rejected, if one existed.
125    Built { refusal: Option<String> },
126    /// Loaded from a persisted segment (body BLAKE3 `segment_b3`) the freshness gate accepted.
127    Loaded { segment_b3: [u8; 32] },
128}
129
130/// A cached ANN index plus the metadata needed to push SQL filters into it.
131struct CachedAnnIndex {
132    index: AnnIndex,
133    /// Per attribute dim: maps an encoded filter-column value to its PRISM code.
134    dicts: Vec<FxHashMap<Vec<u8>, u32>>,
135    source: AnnIndexSource,
136    /// Commit generation the index reflects; a cache insert is declined if the DB
137    /// moved past it, so a cached index never describes a superseded snapshot.
138    cached_gen: u64,
139}
140
141pub(super) struct AnnTopKPlan {
142    col_idx: usize,
143    dim: u16,
144    metric: AnnMetric,
145    query_vec: Vec<f32>,
146    k: usize,
147    offset: usize,
148    /// Schema column indices declared filterable on the index, in attr-dim order.
149    filter_cols: Vec<u16>,
150    /// Pushable conjuncts: `(attr_dim, allowed_values)` from `col = v` / `col IN (...)`.
151    pushable: Vec<(usize, Vec<Value>)>,
152    /// Remaining WHERE predicate evaluated as a recheck on decoded candidates.
153    residual: Option<Expr>,
154}
155
156/// Gate for single-key ascending ORDER BY ... LIMIT k (no group/having/join/distinct/window/agg).
157fn topk_shape_ok(stmt: &SelectStmt) -> bool {
158    stmt.order_by.len() == 1
159        && !stmt.order_by[0].descending
160        && stmt.limit.is_some()
161        && stmt.group_by.is_empty()
162        && stmt.having.is_none()
163        && stmt.joins.is_empty()
164        && !stmt.distinct
165        && !has_any_window_function(stmt)
166        && !stmt
167            .columns
168            .iter()
169            .any(|c| matches!(c, SelectColumn::Expr { expr, .. } if is_aggregate_expr(expr)))
170}
171
172/// A finished result, or a request to rebuild the cache (tail too long to merge).
173enum RunOutcome {
174    Done(ExecutionResult),
175    Rebuild,
176}
177
178/// Tail-row distance in SQL operator units; None for a zero vector under cosine.
179fn tail_distance(metric: AnnMetric, q: &[f32], v: &[f32]) -> Option<f64> {
180    let d = match metric {
181        AnnMetric::L2 => {
182            let mut sum = 0.0f64;
183            for (x, y) in q.iter().zip(v.iter()) {
184                let diff = (*x as f64) - (*y as f64);
185                sum += diff * diff;
186            }
187            sum.sqrt()
188        }
189        AnnMetric::Inner => {
190            let mut sum = 0.0f64;
191            for (x, y) in q.iter().zip(v.iter()) {
192                sum += (*x as f64) * (*y as f64);
193            }
194            -sum
195        }
196        AnnMetric::Cosine => {
197            let mut dot = 0.0f64;
198            let mut nq = 0.0f64;
199            let mut nv = 0.0f64;
200            for (x, y) in q.iter().zip(v.iter()) {
201                let xf = *x as f64;
202                let yf = *y as f64;
203                dot += xf * yf;
204                nq += xf * xf;
205                nv += yf * yf;
206            }
207            let denom = nq.sqrt() * nv.sqrt();
208            if denom == 0.0 {
209                return None;
210            }
211            1.0 - dot / denom
212        }
213    };
214    Some(d)
215}
216
217impl AnnTopKPlan {
218    pub(super) fn try_new(stmt: &SelectStmt, table_schema: &TableSchema) -> Result<Option<Self>> {
219        if !topk_shape_ok(stmt) {
220            return Ok(None);
221        }
222        let ob = &stmt.order_by[0];
223
224        let (col_idx, dim, op_metric, query_vec) = match &ob.expr {
225            Expr::BinaryOp { left, op, right } => {
226                let op_metric = match op {
227                    BinOp::VectorL2 => AnnMetric::L2,
228                    BinOp::VectorInner => AnnMetric::Inner,
229                    BinOp::VectorCosine => AnnMetric::Cosine,
230                    _ => return Ok(None),
231                };
232                let col_name = match left.as_ref() {
233                    Expr::Column(name) => name.to_ascii_lowercase(),
234                    _ => return Ok(None),
235                };
236                let (col_idx, dim) = match table_schema
237                    .columns
238                    .iter()
239                    .enumerate()
240                    .find(|(_, c)| c.name.to_ascii_lowercase() == col_name)
241                {
242                    Some((i, c)) => match c.data_type {
243                        DataType::Vector { dim } => (i, dim),
244                        _ => return Ok(None),
245                    },
246                    None => return Ok(None),
247                };
248                let col_map = ColumnMap::new(&table_schema.columns);
249                let ctx = EvalCtx::new(&col_map, &[]);
250                let v = match eval_expr(right, &ctx) {
251                    Ok(Value::Vector(v)) => v,
252                    _ => return Ok(None),
253                };
254                if v.len() != dim as usize {
255                    return Err(SqlError::InvalidValue(format!(
256                        "ANN query vector dim {} does not match column dim {}",
257                        v.len(),
258                        dim
259                    )));
260                }
261                (col_idx, dim, op_metric, v.to_vec())
262            }
263            _ => return Ok(None),
264        };
265
266        let ann_index = table_schema.indices.iter().find(|ix| {
267            matches!(ix.kind,
268                IndexKind::Inverted(InvertedKind::Ann { metric }) if metric == op_metric
269            ) && ix.keys.len() == 1
270                && matches!(ix.keys[0],
271                    IndexKey::Column { idx, .. } if idx as usize == col_idx
272                )
273        });
274        let Some(ann_index) = ann_index else {
275            return Ok(None);
276        };
277        let filter_cols = ann_index.ann_filter_cols.clone();
278
279        if table_schema.primary_key_columns.len() != 1 {
280            return Ok(None);
281        }
282        let pk_col = &table_schema.columns[table_schema.primary_key_columns[0] as usize];
283        if !matches!(pk_col.data_type, DataType::Integer) {
284            return Ok(None);
285        }
286
287        // No pushable predicate = no index leverage; decline for the exact filtered scan.
288        let mut pushable: Vec<(usize, Vec<Value>)> = Vec::new();
289        let mut residual_leaves: Vec<Expr> = Vec::new();
290        if let Some(w) = &stmt.where_clause {
291            split_where(
292                w,
293                &filter_cols,
294                table_schema,
295                &mut pushable,
296                &mut residual_leaves,
297            );
298            if pushable.is_empty() {
299                return Ok(None);
300            }
301        }
302        let residual = fold_and(residual_leaves);
303
304        let k_limit = eval_const_int(stmt.limit.as_ref().unwrap())?.max(0) as usize;
305        let offset = stmt
306            .offset
307            .as_ref()
308            .map(eval_const_int)
309            .transpose()?
310            .unwrap_or(0)
311            .max(0) as usize;
312        if k_limit == 0 {
313            return Ok(None);
314        }
315
316        Ok(Some(Self {
317            col_idx,
318            dim,
319            metric: op_metric,
320            query_vec,
321            k: k_limit,
322            offset,
323            filter_cols,
324            pushable,
325            residual,
326        }))
327    }
328
329    pub(super) fn execute_with_read(
330        &self,
331        rtx: &mut ReadTxn<'_>,
332        schema: &SchemaManager,
333        stmt: &SelectStmt,
334        table_schema: &TableSchema,
335    ) -> Result<ExecutionResult> {
336        let cache_key = cache_key(&table_schema.name, self.col_idx, self.metric);
337        // One rebuild at most; the rebuilt snapshot has an empty tail.
338        let mut force_rebuild = false;
339        loop {
340            if force_rebuild {
341                schema.sql_caches.lock().remove(&cache_key);
342            }
343            let Some(cached) = self.load_or_build_index(rtx, schema, &cache_key, table_schema)?
344            else {
345                return empty_result(table_schema, stmt);
346            };
347            match self.run_query(rtx, &cached, stmt, table_schema, !force_rebuild)? {
348                RunOutcome::Done(result) => return Ok(result),
349                RunOutcome::Rebuild => force_rebuild = true,
350            }
351        }
352    }
353
354    /// Merge index hits with the brute-forced tail; `Rebuild` when the tail is too long.
355    fn run_query(
356        &self,
357        txn: &mut dyn AnnScan,
358        cached: &CachedAnnIndex,
359        stmt: &SelectStmt,
360        table_schema: &TableSchema,
361        allow_rebuild: bool,
362    ) -> Result<RunOutcome> {
363        // A filter value absent from the dict matches no indexed row, but a fresh
364        // tail row still might, so skip only the index search (not the tail).
365        let mut constraints: Vec<(usize, Vec<u32>)> = Vec::with_capacity(self.pushable.len());
366        let mut index_unsat = false;
367        for (dim, values) in &self.pushable {
368            let dict = &cached.dicts[*dim];
369            let coll = table_schema.columns[self.filter_cols[*dim] as usize].collation;
370            let mut codes = Vec::with_capacity(values.len());
371            let mut canon = Vec::with_capacity(16);
372            for v in values {
373                canon.clear();
374                encode_key_value_collated_into(v, coll, &mut canon);
375                if let Some(&code) = dict.get(canon.as_slice()) {
376                    codes.push(code);
377                }
378            }
379            if codes.is_empty() {
380                index_unsat = true;
381            }
382            constraints.push((*dim, codes));
383        }
384
385        let want = self.k.saturating_add(self.offset).max(1);
386        let mut merged: Vec<RankedRow> = if index_unsat {
387            Vec::new()
388        } else {
389            let filter = if constraints.is_empty() {
390                Filter::none()
391            } else {
392                Filter::new(constraints)
393            };
394            self.collect_survivors(txn, &cached.index, &filter, table_schema, want)?
395        };
396
397        match self.collect_tail(txn, &cached.index, table_schema, allow_rebuild)? {
398            Some(tail) => merged.extend(tail),
399            None => return Ok(RunOutcome::Rebuild),
400        }
401
402        // Global distance order; ties broken by id for determinism.
403        merged.sort_by(|a, b| a.0.total_cmp(&b.0).then_with(|| a.1.cmp(&b.1)));
404        let mut rows: Vec<Vec<Value>> = merged.into_iter().map(|(_, _, row)| row).collect();
405
406        if self.offset >= rows.len() {
407            rows.clear();
408        } else if self.offset > 0 {
409            rows = rows.split_off(self.offset);
410        }
411        rows.truncate(self.k);
412
413        let (col_names, projected) = project_rows(&table_schema.columns, &stmt.columns, rows)?;
414        Ok(RunOutcome::Done(ExecutionResult::Query(QueryResult {
415            columns: col_names,
416            rows: projected,
417        })))
418    }
419
420    /// Index hits passing the residual recheck, over-fetched until `want` survive.
421    fn collect_survivors(
422        &self,
423        txn: &mut dyn AnnScan,
424        index: &AnnIndex,
425        filter: &Filter,
426        table_schema: &TableSchema,
427        want: usize,
428    ) -> Result<Vec<RankedRow>> {
429        let col_map = ColumnMap::new(&table_schema.columns);
430        let max_target = index.indexed_len().max(1);
431        let mut key_buf: Vec<u8> = Vec::with_capacity(10);
432        let mut target = want;
433        loop {
434            target = target.min(max_target);
435            let hits = index.search_filtered_default_ef(&self.query_vec, target, filter);
436            let mut survivors: Vec<RankedRow> = Vec::with_capacity(want);
437            for (id, dist) in &hits {
438                encode_int_key_into(*id as i64, &mut key_buf);
439                let Some(row_bytes) = txn.ann_get(table_schema.name.as_bytes(), &key_buf)? else {
440                    continue;
441                };
442                let row = decode_full_row(table_schema, &key_buf, &row_bytes)?;
443                let keep = match &self.residual {
444                    None => true,
445                    Some(expr) => {
446                        let ctx = EvalCtx::new(&col_map, &row);
447                        is_truthy(&eval_expr(expr, &ctx)?)
448                    }
449                };
450                if keep {
451                    survivors.push((*dist as f64, *id as i64, row));
452                    if survivors.len() >= want {
453                        break;
454                    }
455                }
456            }
457            // Stop when satisfied, the index is exhausted, or PRISM returns fewer than asked.
458            if survivors.len() >= want || target >= max_target || hits.len() < target {
459                return Ok(survivors);
460            }
461            target = target.saturating_mul(2);
462        }
463    }
464
465    /// Exact-rank rows appended past the snapshot; `None` when the tail is too long.
466    fn collect_tail(
467        &self,
468        txn: &mut dyn AnnScan,
469        index: &AnnIndex,
470        table_schema: &TableSchema,
471        allow_rebuild: bool,
472    ) -> Result<Option<Vec<RankedRow>>> {
473        let snapshot_max = index.snapshot_max;
474        // Negative pks (snapshot_max reads negative as i64) make the pk>snapshot_max
475        // boundary unsound; those tables hard-invalidate on append, so the tail is empty.
476        let first_tail_pk = match (snapshot_max as i64).checked_add(1) {
477            Some(pk) if (snapshot_max as i64) >= 0 => pk,
478            _ => return Ok(Some(Vec::new())),
479        };
480        let mut start_key: Vec<u8> = Vec::with_capacity(10);
481        encode_int_key_into(first_tail_pk, &mut start_key);
482
483        let col_map = ColumnMap::new(&table_schema.columns);
484        let mut out: Vec<RankedRow> = Vec::new();
485        let mut seen: u64 = 0;
486        let mut over_threshold = false;
487
488        txn.ann_scan_from(
489            table_schema.name.as_bytes(),
490            &start_key,
491            &mut |key, value| {
492                seen += 1;
493                if allow_rebuild && index.tail_is_stale(snapshot_max.saturating_add(seen)) {
494                    over_threshold = true;
495                    return Ok(false);
496                }
497                let row = decode_full_row(table_schema, key, value)?;
498                if !self.tail_passes_pushable(&row, table_schema) {
499                    return Ok(true);
500                }
501                if let Some(expr) = &self.residual {
502                    let ctx = EvalCtx::new(&col_map, &row);
503                    if !is_truthy(&eval_expr(expr, &ctx)?) {
504                        return Ok(true);
505                    }
506                }
507                let dist = match &row[self.col_idx] {
508                    Value::Vector(v) => match tail_distance(self.metric, &self.query_vec, v) {
509                        Some(d) => d,
510                        None => return Ok(true), // undefined distance (zero vector under cosine)
511                    },
512                    Value::Null => return Ok(true), // null vectors are unindexable
513                    _ => {
514                        return Err(SqlError::InvalidValue(
515                            "ANN column produced non-vector value".into(),
516                        ))
517                    }
518                };
519                out.push((dist, decode_pk_integer(key)?, row));
520                Ok(true)
521            },
522        )?;
523
524        if over_threshold {
525            return Ok(None);
526        }
527        Ok(Some(out))
528    }
529
530    /// Pushable conjuncts checked on decoded tail values (the tail has no PRISM codes).
531    fn tail_passes_pushable(&self, row: &[Value], table_schema: &TableSchema) -> bool {
532        for (dim, values) in &self.pushable {
533            let col = self.filter_cols[*dim] as usize;
534            let coll = table_schema.columns[col].collation;
535            let mut canon_row = Vec::with_capacity(16);
536            encode_key_value_collated_into(&row[col], coll, &mut canon_row);
537            let matched = values.iter().any(|v| {
538                let mut canon_v = Vec::with_capacity(16);
539                encode_key_value_collated_into(v, coll, &mut canon_v);
540                canon_v == canon_row
541            });
542            if !matched {
543                return false;
544            }
545        }
546        true
547    }
548
549    fn load_or_build_index(
550        &self,
551        txn: &mut dyn AnnScan,
552        schema: &SchemaManager,
553        cache_key: &str,
554        table_schema: &TableSchema,
555    ) -> Result<Option<Arc<CachedAnnIndex>>> {
556        if let Some(existing) = lookup_cached(schema, cache_key, &table_schema.name)? {
557            return Ok(Some(existing));
558        }
559        let spec = AnnSpec {
560            col_idx: self.col_idx,
561            dim: self.dim,
562            metric: self.metric,
563            filter_cols: self.filter_cols.clone(),
564        };
565        load_or_build(txn, schema, cache_key, table_schema, &spec)
566    }
567}
568
569/// The index identity build/load/persist operates on, resolved from the statement
570/// (`AnnTopKPlan`) or the declared index (`persist_ann_index`).
571pub(super) struct AnnSpec {
572    pub col_idx: usize,
573    pub dim: u16,
574    pub metric: AnnMetric,
575    pub filter_cols: Vec<u16>,
576}
577
578impl AnnSpec {
579    fn metric_tag(&self) -> u8 {
580        citadel_vector::segment::metric_tag(ann_metric_to_prism(self.metric))
581    }
582}
583
584/// One scan pass: build rows, filter dicts (codes in first-seen order), and the
585/// injective content fingerprint; the single decode path for build/persist/load.
586struct ScanOutcome {
587    rows: Vec<(u64, Vec<f32>, Vec<u32>)>,
588    dicts: Vec<FxHashMap<Vec<u8>, u32>>,
589    fingerprint: [u8; 32],
590}
591
592fn scan_rows(
593    txn: &mut dyn AnnScan,
594    table_schema: &TableSchema,
595    spec: &AnnSpec,
596) -> Result<ScanOutcome> {
597    let non_pk = table_schema.non_pk_indices();
598    let enc_pos = table_schema.encoding_positions();
599    let nonpk_order = non_pk
600        .iter()
601        .position(|&i| i == spec.col_idx)
602        .ok_or_else(|| {
603            SqlError::InvalidValue("vector column must be non-PK for ANN build".into())
604        })?;
605    let enc_idx = enc_pos[nonpk_order] as usize;
606
607    let num_attrs = spec.filter_cols.len();
608    let extracts: Vec<Extract> = spec
609        .filter_cols
610        .iter()
611        .map(|&c| extract_plan(c, table_schema, non_pk, enc_pos))
612        .collect::<Result<_>>()?;
613    // Dict keys are collation-canonical so collation-equal values share a code (matching
614    // eval equality); the fingerprint keeps raw encodings to still detect content edits.
615    let collations: Vec<Collation> = spec
616        .filter_cols
617        .iter()
618        .map(|&c| table_schema.columns[c as usize].collation)
619        .collect();
620    let mut dicts: Vec<FxHashMap<Vec<u8>, u32>> = vec![FxHashMap::default(); num_attrs];
621    let mut fp = ann_persist::FingerprintHasher::new(
622        &table_schema.name,
623        spec.col_idx as u32,
624        &spec
625            .filter_cols
626            .iter()
627            .map(|&c| c as u32)
628            .collect::<Vec<_>>(),
629        spec.dim,
630        spec.metric_tag(),
631    );
632    let mut rows: Vec<(u64, Vec<f32>, Vec<u32>)> = Vec::new();
633
634    txn.ann_scan(table_schema.name.as_bytes(), &mut |key, value| {
635        let vector = match decode_column_raw(value, enc_idx)?.to_value() {
636            Value::Vector(arr) => Some(arr.to_vec()),
637            Value::Null => None, // null vectors are content, but not indexed
638            _ => {
639                return Err(SqlError::InvalidValue(
640                    "ANN column produced non-vector value".into(),
641                ))
642            }
643        };
644        let mut filter_vals: Vec<Value> = Vec::with_capacity(num_attrs);
645        for ex in &extracts {
646            filter_vals.push(ex.extract(key, value)?);
647        }
648        let encoded_filters: Vec<Vec<u8>> = filter_vals.iter().map(encode_key_value).collect();
649        let vec_bytes: Vec<u8> = vector
650            .as_deref()
651            .unwrap_or(&[])
652            .iter()
653            .flat_map(|f| f.to_le_bytes())
654            .collect();
655        fp.row(
656            key,
657            &vec_bytes,
658            &encoded_filters
659                .iter()
660                .map(Vec::as_slice)
661                .collect::<Vec<_>>(),
662        );
663        let Some(vector) = vector else {
664            return Ok(true);
665        };
666        let id = decode_pk_integer(key)? as u64;
667        let mut codes: Vec<u32> = Vec::with_capacity(num_attrs);
668        for (j, v) in filter_vals.iter().enumerate() {
669            let mut canon = Vec::with_capacity(16);
670            encode_key_value_collated_into(v, collations[j], &mut canon);
671            let next = dicts[j].len() as u32;
672            codes.push(*dicts[j].entry(canon).or_insert(next));
673        }
674        rows.push((id, vector, codes));
675        Ok(true)
676    })?;
677
678    Ok(ScanOutcome {
679        rows,
680        dicts,
681        fingerprint: fp.finish(),
682    })
683}
684
685/// Count a full O(N) rebuild; thrash tests assert this stays 0 on pure appends.
686#[cfg(test)]
687fn note_ann_rebuild() {
688    ANN_REBUILD_COUNT.with(|c| c.set(c.get() + 1));
689}
690
691#[cfg(test)]
692thread_local! {
693    static ANN_REBUILD_COUNT: std::cell::Cell<u64> = const { std::cell::Cell::new(0) };
694}
695
696#[cfg(test)]
697pub(super) fn take_ann_rebuilds() -> u64 {
698    ANN_REBUILD_COUNT.with(|c| c.replace(0))
699}
700
701/// Build the index from a scan; `None` if there are no indexable rows.
702fn build_index(
703    txn: &mut dyn AnnScan,
704    table_schema: &TableSchema,
705    spec: &AnnSpec,
706    refusal: Option<String>,
707    cached_gen: u64,
708) -> Result<Option<CachedAnnIndex>> {
709    let outcome = scan_rows(txn, table_schema, spec)?;
710    if outcome.rows.is_empty() {
711        return Ok(None);
712    }
713    let index = AnnIndex::build_with_attrs(
714        outcome.rows,
715        spec.filter_cols.len(),
716        ann_metric_to_prism(spec.metric),
717        spec.dim,
718    )
719    .map_err(|e| SqlError::InvalidValue(format!("ANN build failed: {e}")))?;
720    #[cfg(test)]
721    note_ann_rebuild();
722    Ok(Some(CachedAnnIndex {
723        index,
724        dicts: outcome.dicts,
725        source: AnnIndexSource::Built { refusal },
726        cached_gen,
727    }))
728}
729
730/// Outcome of a persisted-segment load. `Refused` triggers a rebuild; corrupt
731/// segments also warn (HMAC-authenticated page + failing BLAKE3 = writer bug).
732enum LoadOutcome {
733    Loaded(Box<CachedAnnIndex>),
734    NoSegment,
735    Refused { reason: String, corrupt: bool },
736}
737
738/// Try to serve the table's persisted segment: header pins, body decode, and the
739/// table-root freshness gate confirming it matches this snapshot.
740fn try_load_segment(
741    txn: &mut dyn AnnScan,
742    table_schema: &TableSchema,
743    spec: &AnnSpec,
744    cached_gen: u64,
745) -> Result<LoadOutcome> {
746    let seg_table = ann_persist::segment_table_name(&table_schema.name);
747    let header_bytes = match txn.ann_get(&seg_table, &ann_persist::segment_key(0)) {
748        Ok(Some(b)) => b,
749        // Missing tree and missing header are both "never persisted".
750        Ok(None) | Err(_) => return Ok(LoadOutcome::NoSegment),
751    };
752    let refuse = |reason: String, corrupt: bool| Ok(LoadOutcome::Refused { reason, corrupt });
753    let header = match ann_persist::SegmentHeader::decode(&header_bytes) {
754        Ok(h) => h,
755        Err(e) => return refuse(format!("header: {e}"), true),
756    };
757    if header.format_version != ann_persist::ANNSEG_FORMAT_VERSION {
758        return refuse(
759            format!("format v{} (this binary reads v2)", header.format_version),
760            false,
761        );
762    }
763    let active_cfg = citadel_vector::segment::prism_config_hash(&AnnIndex::active_config(
764        ann_metric_to_prism(spec.metric),
765    ));
766    if header.prism_config_hash != active_cfg {
767        return refuse(
768            "PRISM config drift (segment built by another geometry)".into(),
769            false,
770        );
771    }
772    if header.dim != spec.dim
773        || header.metric_tag != spec.metric_tag()
774        || header.col_idx != spec.col_idx as u32
775        || header.filter_cols
776            != spec
777                .filter_cols
778                .iter()
779                .map(|&c| c as u32)
780                .collect::<Vec<_>>()
781    {
782        return refuse(
783            "index identity mismatch (column/metric/filter set)".into(),
784            false,
785        );
786    }
787
788    let mut body = Vec::new();
789    for chunk_no in 1..=header.chunk_count {
790        match txn.ann_get(&seg_table, &ann_persist::segment_key(chunk_no)) {
791            Ok(Some(c)) => body.extend_from_slice(&c),
792            _ => return refuse(format!("missing chunk {chunk_no}"), true),
793        }
794    }
795    if *blake3::hash(&body).as_bytes() != header.segment_b3 {
796        return refuse("segment body BLAKE3 mismatch (corrupt)".into(), true);
797    }
798    let parts = match citadel_vector::segment::decode(&body) {
799        Ok(p) => p,
800        Err(e) => return refuse(format!("segment decode: {e}"), true),
801    };
802    if parts.n() as u64 != header.n || parts.dim() != header.dim {
803        return refuse("segment body disagrees with header counts".into(), true);
804    }
805
806    // CoW freshness gate: a committed DML rewrites the root, so live root != stamp means stale.
807    match txn.ann_table_root(table_schema.name.as_bytes()) {
808        Some(live) if live == header.table_root => {}
809        _ => {
810            return refuse(
811                "stale: table root moved since the segment was persisted".into(),
812                false,
813            )
814        }
815    }
816
817    // Vectors ride in the segment (TAG_VECTORS), so the load is a bulk read, no rescan.
818    let index = parts.into_index_embedded();
819    Ok(LoadOutcome::Loaded(Box::new(CachedAnnIndex {
820        index,
821        dicts: header.dict_maps(),
822        source: AnnIndexSource::Loaded {
823            segment_b3: header.segment_b3,
824        },
825        cached_gen,
826    })))
827}
828
829/// Shared load-then-build flow: try the segment, else scan-build carrying the refusal
830/// as a diagnostic; cache only if no DML committed past the snapshot, never from a write txn.
831fn load_or_build(
832    txn: &mut dyn AnnScan,
833    schema: &SchemaManager,
834    cache_key: &str,
835    table_schema: &TableSchema,
836    spec: &AnnSpec,
837) -> Result<Option<Arc<CachedAnnIndex>>> {
838    let gen = txn.cache_generation();
839    let cached_gen = gen.unwrap_or(u64::MAX);
840    let loaded = match try_load_segment(txn, table_schema, spec, cached_gen)? {
841        LoadOutcome::Loaded(c) => Some(*c),
842        LoadOutcome::NoSegment => None,
843        LoadOutcome::Refused { reason, corrupt } => {
844            if corrupt {
845                eprintln!(
846                    "citadel-sql: ANN segment for `{}` REFUSED as corrupt ({reason}); \
847                     rebuilding from scan - investigate before re-persisting",
848                    table_schema.name
849                );
850            }
851            // Stale/drift refusals are the expected degradation; the reason stays queryable on the rebuild.
852            match build_index(txn, table_schema, spec, Some(reason), cached_gen)? {
853                Some(c) => Some(c),
854                None => return Ok(None),
855            }
856        }
857    };
858    let built = match loaded {
859        Some(c) => c,
860        None => match build_index(txn, table_schema, spec, None, cached_gen)? {
861            Some(c) => c,
862            None => return Ok(None),
863        },
864    };
865    let arc: Arc<CachedAnnIndex> = Arc::new(built);
866    if gen.is_none() {
867        // A write-txn view may include uncommitted rows: serve, never cache.
868        return Ok(Some(arc));
869    }
870    let mut guard = schema.sql_caches.lock();
871    if let Some(existing) = guard.get(cache_key) {
872        // Another thread won the race; prefer that one and drop ours.
873        return Arc::clone(existing)
874            .downcast::<CachedAnnIndex>()
875            .map(Some)
876            .map_err(|_| {
877                SqlError::InvalidValue(format!("ANN cache type mismatch for {cache_key}"))
878            });
879    }
880    let marker = marker_gen_locked(&guard, &table_schema.name);
881    if marker.is_some_and(|g| arc.cached_gen < g) {
882        // DML committed during the build: a superseded snapshot. Serve this query, decline the cache.
883        return Ok(Some(arc));
884    }
885    let as_any: Arc<dyn Any + Send + Sync> = arc.clone();
886    guard.insert(cache_key.to_string(), as_any);
887    Ok(Some(arc))
888}
889
890/// Streaming brute-force top-k for `ORDER BY <distance> LIMIT k` when no ANN
891/// index applies (or inside a write txn); bounded heap, O(k) memory.
892pub(super) struct VectorTopKPlan {
893    order_expr: Expr,
894    where_clause: Option<Expr>,
895    k: usize,
896    offset: usize,
897    nulls_first: bool,
898}
899
900/// A candidate keyed by (distance, scan position); `seq` breaks ties by scan
901/// order so the bounded heap matches the stable sort.
902struct Ranked {
903    dist: f64,
904    seq: u64,
905    row: Vec<Value>,
906}
907
908impl PartialEq for Ranked {
909    fn eq(&self, other: &Self) -> bool {
910        self.cmp(other) == Ordering::Equal
911    }
912}
913impl Eq for Ranked {}
914impl PartialOrd for Ranked {
915    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
916        Some(self.cmp(other))
917    }
918}
919impl Ord for Ranked {
920    fn cmp(&self, other: &Self) -> Ordering {
921        self.dist
922            .total_cmp(&other.dist)
923            .then_with(|| self.seq.cmp(&other.seq))
924    }
925}
926
927impl VectorTopKPlan {
928    pub(super) fn try_new(stmt: &SelectStmt, table_schema: &TableSchema) -> Result<Option<Self>> {
929        if !topk_shape_ok(stmt) {
930            return Ok(None);
931        }
932        let ob = &stmt.order_by[0];
933        let Expr::BinaryOp { left, op, .. } = &ob.expr else {
934            return Ok(None);
935        };
936        if !matches!(
937            op,
938            BinOp::VectorL2 | BinOp::VectorInner | BinOp::VectorCosine
939        ) {
940            return Ok(None);
941        }
942        // Only claim a vector-distance sort key; anything else uses the general path.
943        let Expr::Column(name) = left.as_ref() else {
944            return Ok(None);
945        };
946        let name = name.to_ascii_lowercase();
947        let is_vector_col = table_schema.columns.iter().any(|c| {
948            c.name.to_ascii_lowercase() == name && matches!(c.data_type, DataType::Vector { .. })
949        });
950        if !is_vector_col {
951            return Ok(None);
952        }
953
954        let k = eval_const_int(stmt.limit.as_ref().unwrap())?.max(0) as usize;
955        if k == 0 {
956            return Ok(None);
957        }
958        let offset = stmt
959            .offset
960            .as_ref()
961            .map(eval_const_int)
962            .transpose()?
963            .unwrap_or(0)
964            .max(0) as usize;
965
966        Ok(Some(Self {
967            order_expr: ob.expr.clone(),
968            where_clause: stmt.where_clause.clone(),
969            k,
970            offset,
971            // citadel defaults to NULLS FIRST for ascending order.
972            nulls_first: ob.nulls_first.unwrap_or(true),
973        }))
974    }
975
976    pub(super) fn execute(
977        &self,
978        txn: &mut dyn AnnScan,
979        table_schema: &TableSchema,
980        stmt: &SelectStmt,
981    ) -> Result<ExecutionResult> {
982        let want = self.k.saturating_add(self.offset);
983        let col_map = ColumnMap::new(&table_schema.columns);
984        // NULL distances sort like NULLs under the requested ordering.
985        let null_dist = if self.nulls_first {
986            f64::NEG_INFINITY
987        } else {
988            f64::INFINITY
989        };
990        let mut heap: BinaryHeap<Ranked> = BinaryHeap::new();
991        let mut seq: u64 = 0;
992
993        txn.ann_scan(table_schema.name.as_bytes(), &mut |key, value| {
994            let row = decode_full_row(table_schema, key, value)?;
995            let ctx = EvalCtx::new(&col_map, &row);
996            if let Some(w) = &self.where_clause {
997                if !is_truthy(&eval_expr(w, &ctx)?) {
998                    return Ok(true);
999                }
1000            }
1001            let dist = match eval_expr(&self.order_expr, &ctx)? {
1002                Value::Real(d) => d,
1003                Value::Integer(i) => i as f64,
1004                Value::Null => null_dist,
1005                other => {
1006                    return Err(SqlError::InvalidValue(format!(
1007                        "ORDER BY vector distance produced a non-numeric {}",
1008                        other.data_type()
1009                    )))
1010                }
1011            };
1012            let cand = Ranked { dist, seq, row };
1013            seq += 1;
1014            // `seq` only grows, so ties never evict an earlier row (stable-sort order).
1015            if heap.len() < want {
1016                heap.push(cand);
1017            } else if heap.peek().is_some_and(|top| cand < *top) {
1018                heap.pop();
1019                heap.push(cand);
1020            }
1021            Ok(true)
1022        })?;
1023
1024        let mut rows: Vec<Vec<Value>> = heap.into_sorted_vec().into_iter().map(|r| r.row).collect();
1025        if self.offset >= rows.len() {
1026            rows.clear();
1027        } else if self.offset > 0 {
1028            rows = rows.split_off(self.offset);
1029        }
1030        rows.truncate(self.k);
1031
1032        let (col_names, projected) = project_rows(&table_schema.columns, &stmt.columns, rows)?;
1033        Ok(ExecutionResult::Query(QueryResult {
1034            columns: col_names,
1035            rows: projected,
1036        }))
1037    }
1038}
1039
1040/// How to read a filter column's value out of a raw row during the build scan.
1041enum Extract {
1042    /// The single integer primary key, read from the row key.
1043    Pk,
1044    /// A non-PK column at the given encoding position in the row value.
1045    NonPk(usize),
1046}
1047
1048impl Extract {
1049    fn extract(&self, key: &[u8], value: &[u8]) -> Result<Value> {
1050        match self {
1051            Extract::Pk => Ok(Value::Integer(decode_pk_integer(key)?)),
1052            Extract::NonPk(ei) => Ok(decode_column_raw(value, *ei)?.to_value()),
1053        }
1054    }
1055}
1056
1057fn extract_plan(
1058    col: u16,
1059    table_schema: &TableSchema,
1060    non_pk: &[usize],
1061    enc_pos: &[u16],
1062) -> Result<Extract> {
1063    if table_schema.primary_key_columns.contains(&col) {
1064        return Ok(Extract::Pk);
1065    }
1066    let order = non_pk
1067        .iter()
1068        .position(|&i| i == col as usize)
1069        .ok_or_else(|| SqlError::InvalidValue("ANN filter column not found in row".into()))?;
1070    Ok(Extract::NonPk(enc_pos[order] as usize))
1071}
1072
1073/// Walk the AND-tree, sorting each leaf into a pushable attribute predicate or
1074/// the recheck residual.
1075fn split_where(
1076    expr: &Expr,
1077    filter_cols: &[u16],
1078    table_schema: &TableSchema,
1079    pushable: &mut Vec<(usize, Vec<Value>)>,
1080    residual: &mut Vec<Expr>,
1081) {
1082    if let Expr::BinaryOp {
1083        left,
1084        op: BinOp::And,
1085        right,
1086    } = expr
1087    {
1088        split_where(left, filter_cols, table_schema, pushable, residual);
1089        split_where(right, filter_cols, table_schema, pushable, residual);
1090        return;
1091    }
1092    match classify_leaf(expr, filter_cols, table_schema) {
1093        Some(constraint) => pushable.push(constraint),
1094        None => residual.push(expr.clone()),
1095    }
1096}
1097
1098/// Outcome of coercing a pushdown literal to the filter column's stored type.
1099enum Coerced {
1100    /// Encodes exactly like a stored value; safe for the dictionary lookup.
1101    Exact(Value),
1102    /// Can never equal any stored value of this column (e.g. a fractional
1103    /// literal vs INTEGER); contributes no codes.
1104    NeverMatches,
1105    /// Eval equality may diverge from encoded-byte equality (NULL three-valued
1106    /// logic, cross-type comparisons, floats past 2^53); the whole leaf must
1107    /// stay in the residual so the eval path decides.
1108    Residual,
1109}
1110
1111fn coerce_pushdown_literal(val: Value, col_type: DataType) -> Coerced {
1112    // Past 2^53 int<->f64 is not 1:1, so encoded and numeric equality diverge.
1113    const EXACT_F64_INT: f64 = 9_007_199_254_740_992.0;
1114    if val.is_null() {
1115        return Coerced::Residual;
1116    }
1117    if val.data_type() == col_type {
1118        return Coerced::Exact(val);
1119    }
1120    match (val, col_type) {
1121        (Value::Real(r), DataType::Integer) => {
1122            if r.is_nan() || r.is_infinite() {
1123                Coerced::NeverMatches
1124            } else if r.abs() > EXACT_F64_INT {
1125                Coerced::Residual
1126            } else if r.fract() == 0.0 {
1127                Coerced::Exact(Value::Integer(r as i64))
1128            } else {
1129                Coerced::NeverMatches
1130            }
1131        }
1132        (Value::Integer(i), DataType::Real) => {
1133            if i.unsigned_abs() <= EXACT_F64_INT as u64 {
1134                Coerced::Exact(Value::Real(i as f64))
1135            } else {
1136                Coerced::Residual
1137            }
1138        }
1139        _ => Coerced::Residual,
1140    }
1141}
1142
1143/// A leaf is pushable if it is `col = literal` or `col IN (literal, ...)` on a
1144/// declared filter column whose constant right-hand side coerces exactly to
1145/// the column's stored type. An empty value list means the leaf is provably
1146/// unsatisfiable (the caller short-circuits to an empty result).
1147fn classify_leaf(
1148    leaf: &Expr,
1149    filter_cols: &[u16],
1150    table_schema: &TableSchema,
1151) -> Option<(usize, Vec<Value>)> {
1152    let (col_expr, rhs): (&Expr, Vec<&Expr>) = match leaf {
1153        Expr::BinaryOp {
1154            left,
1155            op: BinOp::Eq,
1156            right,
1157        } => (left, vec![right.as_ref()]),
1158        Expr::InList {
1159            expr,
1160            list,
1161            negated: false,
1162        } => (expr, list.iter().collect()),
1163        _ => return None,
1164    };
1165    let dim = filter_dim(col_expr, filter_cols, table_schema)?;
1166    let col_type = table_schema.columns[filter_cols[dim] as usize].data_type;
1167    let mut vals = Vec::with_capacity(rhs.len());
1168    for e in rhs {
1169        match coerce_pushdown_literal(eval_const_expr(e).ok()?, col_type) {
1170            Coerced::Exact(v) => vals.push(v),
1171            Coerced::NeverMatches => {}
1172            Coerced::Residual => return None,
1173        }
1174    }
1175    Some((dim, vals))
1176}
1177
1178/// Resolve a column expression to its attribute-dim index (position in
1179/// `filter_cols`), or `None` if it is not a declared filter column.
1180fn filter_dim(expr: &Expr, filter_cols: &[u16], table_schema: &TableSchema) -> Option<usize> {
1181    let name = match expr {
1182        Expr::Column(c) => c.to_ascii_lowercase(),
1183        Expr::QualifiedColumn { column, .. } => column.to_ascii_lowercase(),
1184        _ => return None,
1185    };
1186    let col_idx = table_schema
1187        .columns
1188        .iter()
1189        .position(|c| c.name.to_ascii_lowercase() == name)? as u16;
1190    filter_cols.iter().position(|&c| c == col_idx)
1191}
1192
1193fn fold_and(mut leaves: Vec<Expr>) -> Option<Expr> {
1194    if leaves.is_empty() {
1195        return None;
1196    }
1197    let first = leaves.remove(0);
1198    Some(leaves.into_iter().fold(first, |acc, e| Expr::BinaryOp {
1199        left: Box::new(acc),
1200        op: BinOp::And,
1201        right: Box::new(e),
1202    }))
1203}
1204
1205fn empty_result(table_schema: &TableSchema, stmt: &SelectStmt) -> Result<ExecutionResult> {
1206    let (col_names, projected) = project_rows(&table_schema.columns, &stmt.columns, Vec::new())?;
1207    Ok(ExecutionResult::Query(QueryResult {
1208        columns: col_names,
1209        rows: projected,
1210    }))
1211}
1212
1213/// Freeze behind `Connection::persist_ann_index`: one write txn scans the table
1214/// (computing the fingerprint), builds PRISM, serializes + replaces the segment, and
1215/// commits (atomic by shadow paging). Holds the writer lock for the full build (minutes
1216/// on large tables) - an offline operation. Warms the shared cache so the next attach
1217/// loads fast and this process serves queries immediately.
1218pub(crate) fn persist_ann_index(
1219    db: &citadel::Database,
1220    schema: &SchemaManager,
1221    table_schema: &TableSchema,
1222    column: &str,
1223) -> Result<ann_persist::AnnSegmentInfo> {
1224    let col_lower = column.to_ascii_lowercase();
1225    let col_idx = table_schema
1226        .columns
1227        .iter()
1228        .position(|c| c.name == col_lower)
1229        .ok_or_else(|| SqlError::ColumnNotFound(column.to_string()))?;
1230    let DataType::Vector { dim } = table_schema.columns[col_idx].data_type else {
1231        return Err(SqlError::InvalidValue(format!(
1232            "column `{column}` is not VECTOR(N)"
1233        )));
1234    };
1235    // Same admission as AnnTopKPlan::try_new: an unservable table gets no segment
1236    // (dead weight with mis-decoded row ids).
1237    if table_schema.primary_key_columns.len() != 1
1238        || !matches!(
1239            table_schema.columns[table_schema.primary_key_columns[0] as usize].data_type,
1240            DataType::Integer
1241        )
1242    {
1243        return Err(SqlError::InvalidValue(
1244            "ANN persistence requires a single INTEGER primary key (same rule as the \
1245             ANN query plan)"
1246                .into(),
1247        ));
1248    }
1249    let ann_index = table_schema
1250        .indices
1251        .iter()
1252        .find(|ix| {
1253            matches!(ix.kind, IndexKind::Inverted(InvertedKind::Ann { .. }))
1254                && ix.keys.len() == 1
1255                && matches!(ix.keys[0], IndexKey::Column { idx, .. } if idx as usize == col_idx)
1256        })
1257        .ok_or_else(|| SqlError::InvalidValue(format!("no ANN index declared on `{column}`")))?;
1258    let IndexKind::Inverted(InvertedKind::Ann { metric }) = ann_index.kind else {
1259        unreachable!("matched above");
1260    };
1261    let spec = AnnSpec {
1262        col_idx,
1263        dim,
1264        metric,
1265        filter_cols: ann_index.ann_filter_cols.clone(),
1266    };
1267
1268    let mut wtx = db.begin_write().map_err(SqlError::Storage)?;
1269    let outcome = scan_rows(&mut wtx, table_schema, &spec)?;
1270    if outcome.rows.is_empty() {
1271        return Err(SqlError::InvalidValue(
1272            "nothing to persist: the table has no indexable (non-NULL) vectors".into(),
1273        ));
1274    }
1275    let n = outcome.rows.len() as u64;
1276    let index = AnnIndex::build_with_attrs(
1277        outcome.rows,
1278        spec.filter_cols.len(),
1279        ann_metric_to_prism(spec.metric),
1280        spec.dim,
1281    )
1282    .map_err(|e| SqlError::InvalidValue(format!("ANN build failed: {e}")))?;
1283
1284    let body = citadel_vector::segment::encode(&index);
1285    let segment_b3 = *blake3::hash(&body).as_bytes();
1286    // Order dict entries by code; codes are first-seen order, so by-code is scan order.
1287    let dicts_ordered: Vec<Vec<(Vec<u8>, u32)>> = outcome
1288        .dicts
1289        .iter()
1290        .map(|d| {
1291            let mut entries: Vec<(Vec<u8>, u32)> = d.iter().map(|(k, &v)| (k.clone(), v)).collect();
1292            entries.sort_by_key(|&(_, code)| code);
1293            entries
1294        })
1295        .collect();
1296    // Stamp the table's CoW root; the loader refuses a segment whose root != live.
1297    let table_root = wtx
1298        .table_root_page(table_schema.name.as_bytes())
1299        .map_err(SqlError::Storage)?
1300        .map(|p| u64::from(p.0))
1301        .ok_or_else(|| SqlError::InvalidValue("table vanished during ANN persist".into()))?;
1302    let header = ann_persist::SegmentHeader {
1303        format_version: ann_persist::ANNSEG_FORMAT_VERSION,
1304        prism_config_hash: ann_persist::active_config_hash(ann_metric_to_prism(spec.metric)),
1305        dim: spec.dim,
1306        metric_tag: spec.metric_tag(),
1307        n,
1308        snapshot_max: index.snapshot_max,
1309        table_root,
1310        col_idx: spec.col_idx as u32,
1311        filter_cols: spec.filter_cols.iter().map(|&c| c as u32).collect(),
1312        dicts: dicts_ordered,
1313        content_fingerprint: outcome.fingerprint,
1314        segment_b3,
1315        chunk_count: body.len().div_ceil(ann_persist::CHUNK_BYTES) as u32,
1316        writer: format!("citadel-sql {}", env!("CARGO_PKG_VERSION")),
1317    };
1318
1319    let seg_table = ann_persist::segment_table_name(&table_schema.name);
1320    ann_persist::purge_segment(&mut wtx, &table_schema.name)?;
1321    wtx.create_table(&seg_table).map_err(SqlError::Storage)?;
1322    wtx.table_insert(&seg_table, &ann_persist::segment_key(0), &header.encode())
1323        .map_err(SqlError::Storage)?;
1324    for (chunk_no, chunk) in ann_persist::chunks(&body) {
1325        wtx.table_insert(&seg_table, &ann_persist::segment_key(chunk_no), chunk)
1326            .map_err(SqlError::Storage)?;
1327    }
1328    wtx.commit().map_err(SqlError::Storage)?;
1329
1330    // Warm the shared cache: this index reflects the just-committed state
1331    // (single writer, so the commit is the current generation).
1332    let cached = CachedAnnIndex {
1333        index,
1334        dicts: outcome.dicts,
1335        source: AnnIndexSource::Built { refusal: None },
1336        cached_gen: db.manager().commit_generation(),
1337    };
1338    let key = cache_key(&table_schema.name, spec.col_idx, spec.metric);
1339    let as_any: Arc<dyn Any + Send + Sync> = Arc::new(cached);
1340    schema.sql_caches.lock().insert(key, as_any);
1341
1342    Ok(ann_persist::AnnSegmentInfo {
1343        segment_b3,
1344        content_fingerprint: header.content_fingerprint,
1345        n,
1346        dim: spec.dim,
1347        metric_tag: header.metric_tag,
1348        chunk_count: header.chunk_count,
1349    })
1350}
1351
1352/// The queryable identity of the index currently cached for `table.column`:
1353/// `(source, snapshot generation)`, or `None` when nothing is cached.
1354pub(crate) fn ann_cache_status(
1355    schema: &SchemaManager,
1356    table_schema: &TableSchema,
1357    column: &str,
1358) -> Result<Option<(AnnIndexSource, u64)>> {
1359    let col_lower = column.to_ascii_lowercase();
1360    let col_idx = table_schema
1361        .columns
1362        .iter()
1363        .position(|c| c.name == col_lower)
1364        .ok_or_else(|| SqlError::ColumnNotFound(column.to_string()))?;
1365    let guard = schema.sql_caches.lock();
1366    for metric in [AnnMetric::L2, AnnMetric::Inner, AnnMetric::Cosine] {
1367        let key = cache_key(&table_schema.name, col_idx, metric);
1368        if let Some(entry) = guard.get(&key) {
1369            if let Ok(c) = Arc::clone(entry).downcast::<CachedAnnIndex>() {
1370                return Ok(Some((c.source.clone(), c.cached_gen)));
1371            }
1372        }
1373    }
1374    Ok(None)
1375}
1376
1377/// The per-table last-DML generation marker's cache key. Stamped by the
1378/// commit-time invalidation in `connection.rs`; read here to refuse any index
1379/// whose snapshot predates the most recent DML commit on its table.
1380pub(crate) fn ann_dml_gen_key(table_name: &str) -> String {
1381    format!("ann_dml_gen:{table_name}")
1382}
1383
1384/// Whether a pure append (smallest pk `min_pk`) can keep `table`'s cached ANN
1385/// indexes: false if any has negative pks or `min_pk <= snapshot_max` (a gap-fill).
1386pub(crate) fn ann_appends_safe(schema: &SchemaManager, table: &str, min_pk: i64) -> bool {
1387    let prefix = format!("ann:{}:", table.to_ascii_lowercase());
1388    let guard = schema.sql_caches.lock();
1389    for (key, val) in guard.iter() {
1390        if !key.starts_with(&prefix) {
1391            continue;
1392        }
1393        if let Some(cached) = val.downcast_ref::<CachedAnnIndex>() {
1394            let snap = cached.index.snapshot_max as i64;
1395            if snap < 0 || min_pk <= snap {
1396                return false;
1397            }
1398        }
1399    }
1400    true
1401}
1402
1403/// Read the marker under an already-held cache lock.
1404fn marker_gen_locked(
1405    entries: &FxHashMap<String, Arc<dyn Any + Send + Sync>>,
1406    table_name: &str,
1407) -> Option<u64> {
1408    entries
1409        .get(&ann_dml_gen_key(table_name))
1410        .and_then(|e| e.downcast_ref::<u64>())
1411        .copied()
1412}
1413
1414fn lookup_cached(
1415    schema: &SchemaManager,
1416    cache_key: &str,
1417    table_name: &str,
1418) -> Result<Option<Arc<CachedAnnIndex>>> {
1419    let mut guard = schema.sql_caches.lock();
1420    let Some(entry) = guard.get(cache_key) else {
1421        return Ok(None);
1422    };
1423    let entry = Arc::clone(entry)
1424        .downcast::<CachedAnnIndex>()
1425        .map_err(|_| SqlError::InvalidValue(format!("ANN cache type mismatch for {cache_key}")))?;
1426    if marker_gen_locked(&guard, table_name).is_some_and(|g| entry.cached_gen < g) {
1427        // Entry predates a DML commit (a build that raced eviction): drop and rebuild.
1428        guard.remove(cache_key);
1429        return Ok(None);
1430    }
1431    Ok(Some(entry))
1432}
1433
1434pub(super) fn cache_key(table_name: &str, col_idx: usize, metric: AnnMetric) -> String {
1435    let tag = match metric {
1436        AnnMetric::L2 => "l2",
1437        AnnMetric::Inner => "inner",
1438        AnnMetric::Cosine => "cosine",
1439    };
1440    format!(
1441        "ann:{}:{}:{}",
1442        table_name.to_ascii_lowercase(),
1443        col_idx,
1444        tag
1445    )
1446}
1447
1448fn ann_metric_to_prism(m: AnnMetric) -> Metric {
1449    match m {
1450        AnnMetric::L2 => Metric::L2,
1451        AnnMetric::Inner => Metric::InnerProduct,
1452        AnnMetric::Cosine => Metric::Cosine,
1453    }
1454}
1455
1456#[cfg(test)]
1457mod thrash_tests {
1458    use super::take_ann_rebuilds;
1459    use crate::{Connection, ExecutionResult, Value};
1460    use citadel::{Argon2Profile, DatabaseBuilder};
1461
1462    const DIM: usize = 8;
1463
1464    fn vec_for(i: u64) -> Vec<f32> {
1465        (0..DIM)
1466            .map(|d| {
1467                let x = (i.wrapping_mul(2654435761).wrapping_add(d as u64 * 40503) % 1000) as f32;
1468                x / 1000.0
1469            })
1470            .collect()
1471    }
1472
1473    fn vec_literal(v: &[f32]) -> String {
1474        let parts: Vec<String> = v.iter().map(|x| format!("{x}")).collect();
1475        format!("'[{}]'::VECTOR({})", parts.join(", "), DIM)
1476    }
1477
1478    fn recall_ids(conn: &Connection<'_>, qvec: &[f32], k: usize) -> Vec<i64> {
1479        let sql = format!(
1480            "SELECT id FROM t WHERE category = 0 ORDER BY v <-> {} LIMIT {k}",
1481            vec_literal(qvec)
1482        );
1483        match conn.execute(&sql).unwrap() {
1484            ExecutionResult::Query(qr) => qr
1485                .rows
1486                .iter()
1487                .map(|r| match &r[0] {
1488                    Value::Integer(i) => *i,
1489                    other => panic!("expected Integer id, got {other:?}"),
1490                })
1491                .collect(),
1492            _ => panic!("expected query result"),
1493        }
1494    }
1495
1496    /// Interleaved append+recall must tail-merge, not rebuild per recall (thrash).
1497    #[test]
1498    fn interleaved_append_recall_does_not_thrash() {
1499        let dir = tempfile::tempdir().unwrap();
1500        let db = DatabaseBuilder::new(dir.path().join("test.db"))
1501            .passphrase(b"test-passphrase")
1502            .argon2_profile(Argon2Profile::Iot)
1503            .create()
1504            .unwrap();
1505        let conn = Connection::open(&db).unwrap();
1506        conn.execute(
1507            "CREATE TABLE t (id INTEGER PRIMARY KEY, category INTEGER, score REAL, v VECTOR(8))",
1508        )
1509        .unwrap();
1510        // category 0 for everything so the pushable filter keeps all rows.
1511        let base = 200u64;
1512        for i in 1..=base {
1513            conn.execute(&format!(
1514                "INSERT INTO t VALUES ({i}, 0, 1.0, {})",
1515                vec_literal(&vec_for(i))
1516            ))
1517            .unwrap();
1518        }
1519        conn.execute(
1520            "CREATE INDEX ix_v ON t USING ann (v) WITH (metric = 'l2', filters = 'category')",
1521        )
1522        .unwrap();
1523
1524        // Warm: first recall builds/loads + caches.
1525        let _ = recall_ids(&conn, &vec_for(7), 5);
1526        let _ = take_ann_rebuilds(); // reset after warm-up
1527
1528        // Each append is a unique off-grid vector, queried exactly -> it's the nearest.
1529        let appends = 10u64;
1530        let mut total_rebuilds = 0u64;
1531        for j in 0..appends {
1532            let new_id = base + 1 + j;
1533            let qvec = vec![0.50005f32 + (j as f32) * 0.0001; DIM];
1534            conn.execute(&format!(
1535                "INSERT INTO t VALUES ({new_id}, 0, 1.0, {})",
1536                vec_literal(&qvec)
1537            ))
1538            .unwrap();
1539            let ids = recall_ids(&conn, &qvec, 5);
1540            total_rebuilds += take_ann_rebuilds();
1541            assert_eq!(
1542                ids.first().copied(),
1543                Some(new_id as i64),
1544                "freshly appended exact-match row must rank #0 (I1 fresh-visibility)"
1545            );
1546        }
1547        assert_eq!(
1548            total_rebuilds, 0,
1549            "appends must not trigger PRISM rebuilds (got {total_rebuilds} over {appends} recalls = thrash)"
1550        );
1551    }
1552
1553    fn fresh_db(dir: &std::path::Path) -> citadel::Database {
1554        DatabaseBuilder::new(dir.join("t.db"))
1555            .passphrase(b"test-passphrase")
1556            .argon2_profile(Argon2Profile::Iot)
1557            .create()
1558            .unwrap()
1559    }
1560
1561    fn setup(conn: &Connection<'_>) {
1562        conn.execute(
1563            "CREATE TABLE t (id INTEGER PRIMARY KEY, category INTEGER, score REAL, v VECTOR(8))",
1564        )
1565        .unwrap();
1566    }
1567
1568    fn insert(conn: &Connection<'_>, id: u64, v: &[f32]) {
1569        conn.execute(&format!(
1570            "INSERT INTO t VALUES ({id}, 0, 1.0, {})",
1571            vec_literal(v)
1572        ))
1573        .unwrap();
1574    }
1575
1576    fn build_index(conn: &Connection<'_>) {
1577        conn.execute(
1578            "CREATE INDEX ix_v ON t USING ann (v) WITH (metric = 'l2', filters = 'category')",
1579        )
1580        .unwrap();
1581    }
1582
1583    /// I2: an in-place vector UPDATE must hard-invalidate (new vector reflected).
1584    #[test]
1585    fn inplace_vector_update_is_reflected() {
1586        let dir = tempfile::tempdir().unwrap();
1587        let db = fresh_db(dir.path());
1588        let conn = Connection::open(&db).unwrap();
1589        setup(&conn);
1590        for i in 1..=200 {
1591            insert(&conn, i, &vec_for(i));
1592        }
1593        build_index(&conn);
1594        let qvec = vec![0.50007f32; DIM];
1595        let _ = recall_ids(&conn, &vec_for(7), 5); // warm
1596        let _ = take_ann_rebuilds();
1597
1598        conn.execute(&format!(
1599            "UPDATE t SET v = {} WHERE id = 50",
1600            vec_literal(&qvec)
1601        ))
1602        .unwrap();
1603        let ids = recall_ids(&conn, &qvec, 5);
1604        assert!(
1605            take_ann_rebuilds() >= 1,
1606            "an in-place vector UPDATE must invalidate the cached index"
1607        );
1608        assert_eq!(ids.first().copied(), Some(50), "updated row must rank #0");
1609    }
1610
1611    /// A DELETE of an indexed row must hard-invalidate so the row stops appearing.
1612    #[test]
1613    fn delete_indexed_row_disappears() {
1614        let dir = tempfile::tempdir().unwrap();
1615        let db = fresh_db(dir.path());
1616        let conn = Connection::open(&db).unwrap();
1617        setup(&conn);
1618        for i in 1..=200 {
1619            insert(&conn, i, &vec_for(i));
1620        }
1621        build_index(&conn);
1622        let q = vec_for(7);
1623        let before = recall_ids(&conn, &q, 5);
1624        assert_eq!(before.first().copied(), Some(7), "id 7 is the exact match");
1625        let _ = take_ann_rebuilds();
1626
1627        conn.execute("DELETE FROM t WHERE id = 7").unwrap();
1628        let after = recall_ids(&conn, &q, 5);
1629        assert!(
1630            take_ann_rebuilds() >= 1,
1631            "a DELETE must invalidate the cached index"
1632        );
1633        assert!(
1634            !after.contains(&7),
1635            "deleted row must not appear: {after:?}"
1636        );
1637    }
1638
1639    /// I3: a gap-fill INSERT below the snapshot must hard-invalidate (tail misses it).
1640    #[test]
1641    fn gap_fill_below_snapshot_is_visible() {
1642        let dir = tempfile::tempdir().unwrap();
1643        let db = fresh_db(dir.path());
1644        let conn = Connection::open(&db).unwrap();
1645        setup(&conn);
1646        // Leave a gap at ids 51..=59; snapshot_max becomes 100.
1647        for i in 1..=50 {
1648            insert(&conn, i, &vec_for(i));
1649        }
1650        for i in 60..=100 {
1651            insert(&conn, i, &vec_for(i));
1652        }
1653        build_index(&conn);
1654        let _ = recall_ids(&conn, &vec_for(7), 5); // warm, snapshot_max = 100
1655        let _ = take_ann_rebuilds();
1656
1657        let qvec = vec![0.50009f32; DIM];
1658        insert(&conn, 55, &qvec); // gap-fill: 55 < snapshot_max
1659        let ids = recall_ids(&conn, &qvec, 5);
1660        assert!(
1661            take_ann_rebuilds() >= 1,
1662            "a gap-fill insert below snapshot must invalidate, not tail-merge"
1663        );
1664        assert_eq!(
1665            ids.first().copied(),
1666            Some(55),
1667            "gap-fill row must be visible at rank #0: {ids:?}"
1668        );
1669    }
1670
1671    /// A tail past the threshold triggers exactly one rebuild on recall.
1672    #[test]
1673    fn long_tail_triggers_single_rebuild() {
1674        let dir = tempfile::tempdir().unwrap();
1675        let db = fresh_db(dir.path());
1676        let conn = Connection::open(&db).unwrap();
1677        setup(&conn);
1678        for i in 1..=40 {
1679            insert(&conn, i, &vec_for(i));
1680        }
1681        build_index(&conn);
1682        let _ = recall_ids(&conn, &vec_for(7), 5); // warm, snapshot_max = 40, indexed_len/4 = 10
1683        let _ = take_ann_rebuilds();
1684
1685        // Append 15 rows (> indexed_len/4) with no recall between: all retained.
1686        let qvec = vec![0.50011f32; DIM];
1687        for i in 41..=55u64 {
1688            let v = if i == 55 {
1689                qvec.clone()
1690            } else {
1691                vec_for(i + 1000)
1692            };
1693            insert(&conn, i, &v);
1694        }
1695        assert_eq!(
1696            take_ann_rebuilds(),
1697            0,
1698            "appends alone must not rebuild (retained for tail merge)"
1699        );
1700
1701        let ids = recall_ids(&conn, &qvec, 5);
1702        assert_eq!(
1703            take_ann_rebuilds(),
1704            1,
1705            "a tail past the threshold must trigger exactly one rebuild on recall"
1706        );
1707        assert_eq!(
1708            ids.first().copied(),
1709            Some(55),
1710            "post-rebuild result correct"
1711        );
1712    }
1713}