use std::any::Any;
use std::cmp::Ordering;
use std::collections::BinaryHeap;
use std::sync::Arc;
use citadel_txn::read_txn::ReadTxn;
use citadel_txn::write_txn::WriteTxn;
use citadel_vector::{AnnIndex, Filter, Metric};
use rustc_hash::FxHashMap;
use crate::encoding::{
decode_column_raw, decode_pk_integer, encode_int_key_into, encode_key_value,
encode_key_value_collated_into,
};
use crate::error::{Result, SqlError};
use crate::eval::{eval_expr, is_truthy, ColumnMap, EvalCtx};
use crate::parser::*;
use crate::schema::SchemaManager;
use crate::types::*;
use super::aggregate::is_aggregate_expr;
use super::ann_persist;
use super::helpers::{decode_full_row, eval_const_expr, eval_const_int, project_rows};
use super::window::has_any_window_function;
type StorageResult<T> = std::result::Result<T, citadel_core::Error>;
type ScanRow<'a> = dyn FnMut(&[u8], &[u8]) -> Result<bool> + 'a;
type RawScanRow<'a> = dyn FnMut(&[u8], &[u8]) -> StorageResult<bool> + 'a;
type RankedRow = (f64, i64, Vec<Value>);
pub(super) trait AnnScan {
fn ann_scan(&mut self, table: &[u8], f: &mut ScanRow<'_>) -> Result<()>;
fn ann_scan_from(&mut self, table: &[u8], start_key: &[u8], f: &mut ScanRow<'_>) -> Result<()>;
fn ann_get(&mut self, table: &[u8], key: &[u8]) -> Result<Option<Vec<u8>>>;
fn cache_generation(&self) -> Option<u64>;
fn ann_table_root(&self, table: &[u8]) -> Option<u64>;
}
fn bridge_scan(
scan: impl FnOnce(&mut RawScanRow<'_>) -> StorageResult<()>,
f: &mut ScanRow<'_>,
) -> Result<()> {
let mut cb_err: Option<SqlError> = None;
scan(&mut |key, value| match f(key, value) {
Ok(go) => Ok(go),
Err(e) => {
cb_err = Some(e);
Ok(false)
}
})
.map_err(SqlError::Storage)?;
match cb_err {
Some(e) => Err(e),
None => Ok(()),
}
}
impl AnnScan for ReadTxn<'_> {
fn ann_scan(&mut self, table: &[u8], f: &mut ScanRow<'_>) -> Result<()> {
bridge_scan(|cb| self.table_scan_from(table, b"", cb), f)
}
fn ann_scan_from(&mut self, table: &[u8], start_key: &[u8], f: &mut ScanRow<'_>) -> Result<()> {
bridge_scan(|cb| self.table_scan_from(table, start_key, cb), f)
}
fn ann_get(&mut self, table: &[u8], key: &[u8]) -> Result<Option<Vec<u8>>> {
self.table_get(table, key).map_err(SqlError::Storage)
}
fn cache_generation(&self) -> Option<u64> {
Some(self.commit_generation())
}
fn ann_table_root(&self, table: &[u8]) -> Option<u64> {
self.table_root_page(table)
.ok()
.flatten()
.map(|p| u64::from(p.0))
}
}
impl AnnScan for WriteTxn<'_> {
fn ann_scan(&mut self, table: &[u8], f: &mut ScanRow<'_>) -> Result<()> {
bridge_scan(|cb| self.table_scan_from(table, b"", cb), f)
}
fn ann_scan_from(&mut self, table: &[u8], start_key: &[u8], f: &mut ScanRow<'_>) -> Result<()> {
bridge_scan(|cb| self.table_scan_from(table, start_key, cb), f)
}
fn ann_get(&mut self, table: &[u8], key: &[u8]) -> Result<Option<Vec<u8>>> {
self.table_get(table, key).map_err(SqlError::Storage)
}
fn cache_generation(&self) -> Option<u64> {
None
}
fn ann_table_root(&self, table: &[u8]) -> Option<u64> {
self.table_root_page(table)
.ok()
.flatten()
.map(|p| u64::from(p.0))
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum AnnIndexSource {
Built { refusal: Option<String> },
Loaded { segment_b3: [u8; 32] },
}
struct CachedAnnIndex {
index: AnnIndex,
dicts: Vec<FxHashMap<Vec<u8>, u32>>,
source: AnnIndexSource,
cached_gen: u64,
}
pub(super) struct AnnTopKPlan {
col_idx: usize,
dim: u16,
metric: AnnMetric,
query_vec: Vec<f32>,
k: usize,
offset: usize,
filter_cols: Vec<u16>,
pushable: Vec<(usize, Vec<Value>)>,
residual: Option<Expr>,
}
fn topk_shape_ok(stmt: &SelectStmt) -> bool {
stmt.order_by.len() == 1
&& !stmt.order_by[0].descending
&& stmt.limit.is_some()
&& stmt.group_by.is_empty()
&& stmt.having.is_none()
&& stmt.joins.is_empty()
&& !stmt.distinct
&& !has_any_window_function(stmt)
&& !stmt
.columns
.iter()
.any(|c| matches!(c, SelectColumn::Expr { expr, .. } if is_aggregate_expr(expr)))
}
enum RunOutcome {
Done(ExecutionResult),
Rebuild,
}
fn tail_distance(metric: AnnMetric, q: &[f32], v: &[f32]) -> Option<f64> {
let d = match metric {
AnnMetric::L2 => {
let mut sum = 0.0f64;
for (x, y) in q.iter().zip(v.iter()) {
let diff = (*x as f64) - (*y as f64);
sum += diff * diff;
}
sum.sqrt()
}
AnnMetric::Inner => {
let mut sum = 0.0f64;
for (x, y) in q.iter().zip(v.iter()) {
sum += (*x as f64) * (*y as f64);
}
-sum
}
AnnMetric::Cosine => {
let mut dot = 0.0f64;
let mut nq = 0.0f64;
let mut nv = 0.0f64;
for (x, y) in q.iter().zip(v.iter()) {
let xf = *x as f64;
let yf = *y as f64;
dot += xf * yf;
nq += xf * xf;
nv += yf * yf;
}
let denom = nq.sqrt() * nv.sqrt();
if denom == 0.0 {
return None;
}
1.0 - dot / denom
}
};
Some(d)
}
impl AnnTopKPlan {
pub(super) fn try_new(stmt: &SelectStmt, table_schema: &TableSchema) -> Result<Option<Self>> {
if !topk_shape_ok(stmt) {
return Ok(None);
}
let ob = &stmt.order_by[0];
let (col_idx, dim, op_metric, query_vec) = match &ob.expr {
Expr::BinaryOp { left, op, right } => {
let op_metric = match op {
BinOp::VectorL2 => AnnMetric::L2,
BinOp::VectorInner => AnnMetric::Inner,
BinOp::VectorCosine => AnnMetric::Cosine,
_ => return Ok(None),
};
let col_name = match left.as_ref() {
Expr::Column(name) => name.to_ascii_lowercase(),
_ => return Ok(None),
};
let (col_idx, dim) = match table_schema
.columns
.iter()
.enumerate()
.find(|(_, c)| c.name.to_ascii_lowercase() == col_name)
{
Some((i, c)) => match c.data_type {
DataType::Vector { dim } => (i, dim),
_ => return Ok(None),
},
None => return Ok(None),
};
let col_map = ColumnMap::new(&table_schema.columns);
let ctx = EvalCtx::new(&col_map, &[]);
let v = match eval_expr(right, &ctx) {
Ok(Value::Vector(v)) => v,
_ => return Ok(None),
};
if v.len() != dim as usize {
return Err(SqlError::InvalidValue(format!(
"ANN query vector dim {} does not match column dim {}",
v.len(),
dim
)));
}
(col_idx, dim, op_metric, v.to_vec())
}
_ => return Ok(None),
};
let ann_index = table_schema.indices.iter().find(|ix| {
matches!(ix.kind,
IndexKind::Inverted(InvertedKind::Ann { metric }) if metric == op_metric
) && ix.keys.len() == 1
&& matches!(ix.keys[0],
IndexKey::Column { idx, .. } if idx as usize == col_idx
)
});
let Some(ann_index) = ann_index else {
return Ok(None);
};
let filter_cols = ann_index.ann_filter_cols.clone();
if table_schema.primary_key_columns.len() != 1 {
return Ok(None);
}
let pk_col = &table_schema.columns[table_schema.primary_key_columns[0] as usize];
if !matches!(pk_col.data_type, DataType::Integer) {
return Ok(None);
}
let mut pushable: Vec<(usize, Vec<Value>)> = Vec::new();
let mut residual_leaves: Vec<Expr> = Vec::new();
if let Some(w) = &stmt.where_clause {
split_where(
w,
&filter_cols,
table_schema,
&mut pushable,
&mut residual_leaves,
);
if pushable.is_empty() {
return Ok(None);
}
}
let residual = fold_and(residual_leaves);
let k_limit = eval_const_int(stmt.limit.as_ref().unwrap())?.max(0) as usize;
let offset = stmt
.offset
.as_ref()
.map(eval_const_int)
.transpose()?
.unwrap_or(0)
.max(0) as usize;
if k_limit == 0 {
return Ok(None);
}
Ok(Some(Self {
col_idx,
dim,
metric: op_metric,
query_vec,
k: k_limit,
offset,
filter_cols,
pushable,
residual,
}))
}
pub(super) fn execute_with_read(
&self,
rtx: &mut ReadTxn<'_>,
schema: &SchemaManager,
stmt: &SelectStmt,
table_schema: &TableSchema,
) -> Result<ExecutionResult> {
let cache_key = cache_key(&table_schema.name, self.col_idx, self.metric);
let mut force_rebuild = false;
loop {
if force_rebuild {
schema.sql_caches.lock().remove(&cache_key);
}
let Some(cached) = self.load_or_build_index(rtx, schema, &cache_key, table_schema)?
else {
return empty_result(table_schema, stmt);
};
match self.run_query(rtx, &cached, stmt, table_schema, !force_rebuild)? {
RunOutcome::Done(result) => return Ok(result),
RunOutcome::Rebuild => force_rebuild = true,
}
}
}
fn run_query(
&self,
txn: &mut dyn AnnScan,
cached: &CachedAnnIndex,
stmt: &SelectStmt,
table_schema: &TableSchema,
allow_rebuild: bool,
) -> Result<RunOutcome> {
let mut constraints: Vec<(usize, Vec<u32>)> = Vec::with_capacity(self.pushable.len());
let mut index_unsat = false;
for (dim, values) in &self.pushable {
let dict = &cached.dicts[*dim];
let coll = table_schema.columns[self.filter_cols[*dim] as usize].collation;
let mut codes = Vec::with_capacity(values.len());
let mut canon = Vec::with_capacity(16);
for v in values {
canon.clear();
encode_key_value_collated_into(v, coll, &mut canon);
if let Some(&code) = dict.get(canon.as_slice()) {
codes.push(code);
}
}
if codes.is_empty() {
index_unsat = true;
}
constraints.push((*dim, codes));
}
let want = self.k.saturating_add(self.offset).max(1);
let mut merged: Vec<RankedRow> = if index_unsat {
Vec::new()
} else {
let filter = if constraints.is_empty() {
Filter::none()
} else {
Filter::new(constraints)
};
self.collect_survivors(txn, &cached.index, &filter, table_schema, want)?
};
match self.collect_tail(txn, &cached.index, table_schema, allow_rebuild)? {
Some(tail) => merged.extend(tail),
None => return Ok(RunOutcome::Rebuild),
}
merged.sort_by(|a, b| a.0.total_cmp(&b.0).then_with(|| a.1.cmp(&b.1)));
let mut rows: Vec<Vec<Value>> = merged.into_iter().map(|(_, _, row)| row).collect();
if self.offset >= rows.len() {
rows.clear();
} else if self.offset > 0 {
rows = rows.split_off(self.offset);
}
rows.truncate(self.k);
let (col_names, projected) = project_rows(&table_schema.columns, &stmt.columns, rows)?;
Ok(RunOutcome::Done(ExecutionResult::Query(QueryResult {
columns: col_names,
rows: projected,
})))
}
fn collect_survivors(
&self,
txn: &mut dyn AnnScan,
index: &AnnIndex,
filter: &Filter,
table_schema: &TableSchema,
want: usize,
) -> Result<Vec<RankedRow>> {
let col_map = ColumnMap::new(&table_schema.columns);
let max_target = index.indexed_len().max(1);
let mut key_buf: Vec<u8> = Vec::with_capacity(10);
let mut target = want;
loop {
target = target.min(max_target);
let hits = index.search_filtered_default_ef(&self.query_vec, target, filter);
let mut survivors: Vec<RankedRow> = Vec::with_capacity(want);
for (id, dist) in &hits {
encode_int_key_into(*id as i64, &mut key_buf);
let Some(row_bytes) = txn.ann_get(table_schema.name.as_bytes(), &key_buf)? else {
continue;
};
let row = decode_full_row(table_schema, &key_buf, &row_bytes)?;
let keep = match &self.residual {
None => true,
Some(expr) => {
let ctx = EvalCtx::new(&col_map, &row);
is_truthy(&eval_expr(expr, &ctx)?)
}
};
if keep {
survivors.push((*dist as f64, *id as i64, row));
if survivors.len() >= want {
break;
}
}
}
if survivors.len() >= want || target >= max_target || hits.len() < target {
return Ok(survivors);
}
target = target.saturating_mul(2);
}
}
fn collect_tail(
&self,
txn: &mut dyn AnnScan,
index: &AnnIndex,
table_schema: &TableSchema,
allow_rebuild: bool,
) -> Result<Option<Vec<RankedRow>>> {
let snapshot_max = index.snapshot_max;
let first_tail_pk = match (snapshot_max as i64).checked_add(1) {
Some(pk) if (snapshot_max as i64) >= 0 => pk,
_ => return Ok(Some(Vec::new())),
};
let mut start_key: Vec<u8> = Vec::with_capacity(10);
encode_int_key_into(first_tail_pk, &mut start_key);
let col_map = ColumnMap::new(&table_schema.columns);
let mut out: Vec<RankedRow> = Vec::new();
let mut seen: u64 = 0;
let mut over_threshold = false;
txn.ann_scan_from(
table_schema.name.as_bytes(),
&start_key,
&mut |key, value| {
seen += 1;
if allow_rebuild && index.tail_is_stale(snapshot_max.saturating_add(seen)) {
over_threshold = true;
return Ok(false);
}
let row = decode_full_row(table_schema, key, value)?;
if !self.tail_passes_pushable(&row, table_schema) {
return Ok(true);
}
if let Some(expr) = &self.residual {
let ctx = EvalCtx::new(&col_map, &row);
if !is_truthy(&eval_expr(expr, &ctx)?) {
return Ok(true);
}
}
let dist = match &row[self.col_idx] {
Value::Vector(v) => match tail_distance(self.metric, &self.query_vec, v) {
Some(d) => d,
None => return Ok(true), },
Value::Null => return Ok(true), _ => {
return Err(SqlError::InvalidValue(
"ANN column produced non-vector value".into(),
))
}
};
out.push((dist, decode_pk_integer(key)?, row));
Ok(true)
},
)?;
if over_threshold {
return Ok(None);
}
Ok(Some(out))
}
fn tail_passes_pushable(&self, row: &[Value], table_schema: &TableSchema) -> bool {
for (dim, values) in &self.pushable {
let col = self.filter_cols[*dim] as usize;
let coll = table_schema.columns[col].collation;
let mut canon_row = Vec::with_capacity(16);
encode_key_value_collated_into(&row[col], coll, &mut canon_row);
let matched = values.iter().any(|v| {
let mut canon_v = Vec::with_capacity(16);
encode_key_value_collated_into(v, coll, &mut canon_v);
canon_v == canon_row
});
if !matched {
return false;
}
}
true
}
fn load_or_build_index(
&self,
txn: &mut dyn AnnScan,
schema: &SchemaManager,
cache_key: &str,
table_schema: &TableSchema,
) -> Result<Option<Arc<CachedAnnIndex>>> {
if let Some(existing) = lookup_cached(schema, cache_key, &table_schema.name)? {
return Ok(Some(existing));
}
let spec = AnnSpec {
col_idx: self.col_idx,
dim: self.dim,
metric: self.metric,
filter_cols: self.filter_cols.clone(),
};
load_or_build(txn, schema, cache_key, table_schema, &spec)
}
}
pub(super) struct AnnSpec {
pub col_idx: usize,
pub dim: u16,
pub metric: AnnMetric,
pub filter_cols: Vec<u16>,
}
impl AnnSpec {
fn metric_tag(&self) -> u8 {
citadel_vector::segment::metric_tag(ann_metric_to_prism(self.metric))
}
}
struct ScanOutcome {
rows: Vec<(u64, Vec<f32>, Vec<u32>)>,
dicts: Vec<FxHashMap<Vec<u8>, u32>>,
fingerprint: [u8; 32],
}
fn scan_rows(
txn: &mut dyn AnnScan,
table_schema: &TableSchema,
spec: &AnnSpec,
) -> Result<ScanOutcome> {
let non_pk = table_schema.non_pk_indices();
let enc_pos = table_schema.encoding_positions();
let nonpk_order = non_pk
.iter()
.position(|&i| i == spec.col_idx)
.ok_or_else(|| {
SqlError::InvalidValue("vector column must be non-PK for ANN build".into())
})?;
let enc_idx = enc_pos[nonpk_order] as usize;
let num_attrs = spec.filter_cols.len();
let extracts: Vec<Extract> = spec
.filter_cols
.iter()
.map(|&c| extract_plan(c, table_schema, non_pk, enc_pos))
.collect::<Result<_>>()?;
let collations: Vec<Collation> = spec
.filter_cols
.iter()
.map(|&c| table_schema.columns[c as usize].collation)
.collect();
let mut dicts: Vec<FxHashMap<Vec<u8>, u32>> = vec![FxHashMap::default(); num_attrs];
let mut fp = ann_persist::FingerprintHasher::new(
&table_schema.name,
spec.col_idx as u32,
&spec
.filter_cols
.iter()
.map(|&c| c as u32)
.collect::<Vec<_>>(),
spec.dim,
spec.metric_tag(),
);
let mut rows: Vec<(u64, Vec<f32>, Vec<u32>)> = Vec::new();
txn.ann_scan(table_schema.name.as_bytes(), &mut |key, value| {
let vector = match decode_column_raw(value, enc_idx)?.to_value() {
Value::Vector(arr) => Some(arr.to_vec()),
Value::Null => None, _ => {
return Err(SqlError::InvalidValue(
"ANN column produced non-vector value".into(),
))
}
};
let mut filter_vals: Vec<Value> = Vec::with_capacity(num_attrs);
for ex in &extracts {
filter_vals.push(ex.extract(key, value)?);
}
let encoded_filters: Vec<Vec<u8>> = filter_vals.iter().map(encode_key_value).collect();
let vec_bytes: Vec<u8> = vector
.as_deref()
.unwrap_or(&[])
.iter()
.flat_map(|f| f.to_le_bytes())
.collect();
fp.row(
key,
&vec_bytes,
&encoded_filters
.iter()
.map(Vec::as_slice)
.collect::<Vec<_>>(),
);
let Some(vector) = vector else {
return Ok(true);
};
let id = decode_pk_integer(key)? as u64;
let mut codes: Vec<u32> = Vec::with_capacity(num_attrs);
for (j, v) in filter_vals.iter().enumerate() {
let mut canon = Vec::with_capacity(16);
encode_key_value_collated_into(v, collations[j], &mut canon);
let next = dicts[j].len() as u32;
codes.push(*dicts[j].entry(canon).or_insert(next));
}
rows.push((id, vector, codes));
Ok(true)
})?;
Ok(ScanOutcome {
rows,
dicts,
fingerprint: fp.finish(),
})
}
#[cfg(test)]
fn note_ann_rebuild() {
ANN_REBUILD_COUNT.with(|c| c.set(c.get() + 1));
}
#[cfg(test)]
thread_local! {
static ANN_REBUILD_COUNT: std::cell::Cell<u64> = const { std::cell::Cell::new(0) };
}
#[cfg(test)]
pub(super) fn take_ann_rebuilds() -> u64 {
ANN_REBUILD_COUNT.with(|c| c.replace(0))
}
fn build_index(
txn: &mut dyn AnnScan,
table_schema: &TableSchema,
spec: &AnnSpec,
refusal: Option<String>,
cached_gen: u64,
) -> Result<Option<CachedAnnIndex>> {
let outcome = scan_rows(txn, table_schema, spec)?;
if outcome.rows.is_empty() {
return Ok(None);
}
let index = AnnIndex::build_with_attrs(
outcome.rows,
spec.filter_cols.len(),
ann_metric_to_prism(spec.metric),
spec.dim,
)
.map_err(|e| SqlError::InvalidValue(format!("ANN build failed: {e}")))?;
#[cfg(test)]
note_ann_rebuild();
Ok(Some(CachedAnnIndex {
index,
dicts: outcome.dicts,
source: AnnIndexSource::Built { refusal },
cached_gen,
}))
}
enum LoadOutcome {
Loaded(Box<CachedAnnIndex>),
NoSegment,
Refused { reason: String, corrupt: bool },
}
fn try_load_segment(
txn: &mut dyn AnnScan,
table_schema: &TableSchema,
spec: &AnnSpec,
cached_gen: u64,
) -> Result<LoadOutcome> {
let seg_table = ann_persist::segment_table_name(&table_schema.name);
let header_bytes = match txn.ann_get(&seg_table, &ann_persist::segment_key(0)) {
Ok(Some(b)) => b,
Ok(None) | Err(_) => return Ok(LoadOutcome::NoSegment),
};
let refuse = |reason: String, corrupt: bool| Ok(LoadOutcome::Refused { reason, corrupt });
let header = match ann_persist::SegmentHeader::decode(&header_bytes) {
Ok(h) => h,
Err(e) => return refuse(format!("header: {e}"), true),
};
if header.format_version != ann_persist::ANNSEG_FORMAT_VERSION {
return refuse(
format!("format v{} (this binary reads v2)", header.format_version),
false,
);
}
let active_cfg = citadel_vector::segment::prism_config_hash(&AnnIndex::active_config(
ann_metric_to_prism(spec.metric),
));
if header.prism_config_hash != active_cfg {
return refuse(
"PRISM config drift (segment built by another geometry)".into(),
false,
);
}
if header.dim != spec.dim
|| header.metric_tag != spec.metric_tag()
|| header.col_idx != spec.col_idx as u32
|| header.filter_cols
!= spec
.filter_cols
.iter()
.map(|&c| c as u32)
.collect::<Vec<_>>()
{
return refuse(
"index identity mismatch (column/metric/filter set)".into(),
false,
);
}
let mut body = Vec::new();
for chunk_no in 1..=header.chunk_count {
match txn.ann_get(&seg_table, &ann_persist::segment_key(chunk_no)) {
Ok(Some(c)) => body.extend_from_slice(&c),
_ => return refuse(format!("missing chunk {chunk_no}"), true),
}
}
if *blake3::hash(&body).as_bytes() != header.segment_b3 {
return refuse("segment body BLAKE3 mismatch (corrupt)".into(), true);
}
let parts = match citadel_vector::segment::decode(&body) {
Ok(p) => p,
Err(e) => return refuse(format!("segment decode: {e}"), true),
};
if parts.n() as u64 != header.n || parts.dim() != header.dim {
return refuse("segment body disagrees with header counts".into(), true);
}
match txn.ann_table_root(table_schema.name.as_bytes()) {
Some(live) if live == header.table_root => {}
_ => {
return refuse(
"stale: table root moved since the segment was persisted".into(),
false,
)
}
}
let index = parts.into_index_embedded();
Ok(LoadOutcome::Loaded(Box::new(CachedAnnIndex {
index,
dicts: header.dict_maps(),
source: AnnIndexSource::Loaded {
segment_b3: header.segment_b3,
},
cached_gen,
})))
}
fn load_or_build(
txn: &mut dyn AnnScan,
schema: &SchemaManager,
cache_key: &str,
table_schema: &TableSchema,
spec: &AnnSpec,
) -> Result<Option<Arc<CachedAnnIndex>>> {
let gen = txn.cache_generation();
let cached_gen = gen.unwrap_or(u64::MAX);
let loaded = match try_load_segment(txn, table_schema, spec, cached_gen)? {
LoadOutcome::Loaded(c) => Some(*c),
LoadOutcome::NoSegment => None,
LoadOutcome::Refused { reason, corrupt } => {
if corrupt {
eprintln!(
"citadel-sql: ANN segment for `{}` REFUSED as corrupt ({reason}); \
rebuilding from scan - investigate before re-persisting",
table_schema.name
);
}
match build_index(txn, table_schema, spec, Some(reason), cached_gen)? {
Some(c) => Some(c),
None => return Ok(None),
}
}
};
let built = match loaded {
Some(c) => c,
None => match build_index(txn, table_schema, spec, None, cached_gen)? {
Some(c) => c,
None => return Ok(None),
},
};
let arc: Arc<CachedAnnIndex> = Arc::new(built);
if gen.is_none() {
return Ok(Some(arc));
}
let mut guard = schema.sql_caches.lock();
if let Some(existing) = guard.get(cache_key) {
return Arc::clone(existing)
.downcast::<CachedAnnIndex>()
.map(Some)
.map_err(|_| {
SqlError::InvalidValue(format!("ANN cache type mismatch for {cache_key}"))
});
}
let marker = marker_gen_locked(&guard, &table_schema.name);
if marker.is_some_and(|g| arc.cached_gen < g) {
return Ok(Some(arc));
}
let as_any: Arc<dyn Any + Send + Sync> = arc.clone();
guard.insert(cache_key.to_string(), as_any);
Ok(Some(arc))
}
pub(super) struct VectorTopKPlan {
order_expr: Expr,
where_clause: Option<Expr>,
k: usize,
offset: usize,
nulls_first: bool,
}
struct Ranked {
dist: f64,
seq: u64,
row: Vec<Value>,
}
impl PartialEq for Ranked {
fn eq(&self, other: &Self) -> bool {
self.cmp(other) == Ordering::Equal
}
}
impl Eq for Ranked {}
impl PartialOrd for Ranked {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for Ranked {
fn cmp(&self, other: &Self) -> Ordering {
self.dist
.total_cmp(&other.dist)
.then_with(|| self.seq.cmp(&other.seq))
}
}
impl VectorTopKPlan {
pub(super) fn try_new(stmt: &SelectStmt, table_schema: &TableSchema) -> Result<Option<Self>> {
if !topk_shape_ok(stmt) {
return Ok(None);
}
let ob = &stmt.order_by[0];
let Expr::BinaryOp { left, op, .. } = &ob.expr else {
return Ok(None);
};
if !matches!(
op,
BinOp::VectorL2 | BinOp::VectorInner | BinOp::VectorCosine
) {
return Ok(None);
}
let Expr::Column(name) = left.as_ref() else {
return Ok(None);
};
let name = name.to_ascii_lowercase();
let is_vector_col = table_schema.columns.iter().any(|c| {
c.name.to_ascii_lowercase() == name && matches!(c.data_type, DataType::Vector { .. })
});
if !is_vector_col {
return Ok(None);
}
let k = eval_const_int(stmt.limit.as_ref().unwrap())?.max(0) as usize;
if k == 0 {
return Ok(None);
}
let offset = stmt
.offset
.as_ref()
.map(eval_const_int)
.transpose()?
.unwrap_or(0)
.max(0) as usize;
Ok(Some(Self {
order_expr: ob.expr.clone(),
where_clause: stmt.where_clause.clone(),
k,
offset,
nulls_first: ob.nulls_first.unwrap_or(true),
}))
}
pub(super) fn execute(
&self,
txn: &mut dyn AnnScan,
table_schema: &TableSchema,
stmt: &SelectStmt,
) -> Result<ExecutionResult> {
let want = self.k.saturating_add(self.offset);
let col_map = ColumnMap::new(&table_schema.columns);
let null_dist = if self.nulls_first {
f64::NEG_INFINITY
} else {
f64::INFINITY
};
let mut heap: BinaryHeap<Ranked> = BinaryHeap::new();
let mut seq: u64 = 0;
txn.ann_scan(table_schema.name.as_bytes(), &mut |key, value| {
let row = decode_full_row(table_schema, key, value)?;
let ctx = EvalCtx::new(&col_map, &row);
if let Some(w) = &self.where_clause {
if !is_truthy(&eval_expr(w, &ctx)?) {
return Ok(true);
}
}
let dist = match eval_expr(&self.order_expr, &ctx)? {
Value::Real(d) => d,
Value::Integer(i) => i as f64,
Value::Null => null_dist,
other => {
return Err(SqlError::InvalidValue(format!(
"ORDER BY vector distance produced a non-numeric {}",
other.data_type()
)))
}
};
let cand = Ranked { dist, seq, row };
seq += 1;
if heap.len() < want {
heap.push(cand);
} else if heap.peek().is_some_and(|top| cand < *top) {
heap.pop();
heap.push(cand);
}
Ok(true)
})?;
let mut rows: Vec<Vec<Value>> = heap.into_sorted_vec().into_iter().map(|r| r.row).collect();
if self.offset >= rows.len() {
rows.clear();
} else if self.offset > 0 {
rows = rows.split_off(self.offset);
}
rows.truncate(self.k);
let (col_names, projected) = project_rows(&table_schema.columns, &stmt.columns, rows)?;
Ok(ExecutionResult::Query(QueryResult {
columns: col_names,
rows: projected,
}))
}
}
enum Extract {
Pk,
NonPk(usize),
}
impl Extract {
fn extract(&self, key: &[u8], value: &[u8]) -> Result<Value> {
match self {
Extract::Pk => Ok(Value::Integer(decode_pk_integer(key)?)),
Extract::NonPk(ei) => Ok(decode_column_raw(value, *ei)?.to_value()),
}
}
}
fn extract_plan(
col: u16,
table_schema: &TableSchema,
non_pk: &[usize],
enc_pos: &[u16],
) -> Result<Extract> {
if table_schema.primary_key_columns.contains(&col) {
return Ok(Extract::Pk);
}
let order = non_pk
.iter()
.position(|&i| i == col as usize)
.ok_or_else(|| SqlError::InvalidValue("ANN filter column not found in row".into()))?;
Ok(Extract::NonPk(enc_pos[order] as usize))
}
fn split_where(
expr: &Expr,
filter_cols: &[u16],
table_schema: &TableSchema,
pushable: &mut Vec<(usize, Vec<Value>)>,
residual: &mut Vec<Expr>,
) {
if let Expr::BinaryOp {
left,
op: BinOp::And,
right,
} = expr
{
split_where(left, filter_cols, table_schema, pushable, residual);
split_where(right, filter_cols, table_schema, pushable, residual);
return;
}
match classify_leaf(expr, filter_cols, table_schema) {
Some(constraint) => pushable.push(constraint),
None => residual.push(expr.clone()),
}
}
enum Coerced {
Exact(Value),
NeverMatches,
Residual,
}
fn coerce_pushdown_literal(val: Value, col_type: DataType) -> Coerced {
const EXACT_F64_INT: f64 = 9_007_199_254_740_992.0;
if val.is_null() {
return Coerced::Residual;
}
if val.data_type() == col_type {
return Coerced::Exact(val);
}
match (val, col_type) {
(Value::Real(r), DataType::Integer) => {
if r.is_nan() || r.is_infinite() {
Coerced::NeverMatches
} else if r.abs() > EXACT_F64_INT {
Coerced::Residual
} else if r.fract() == 0.0 {
Coerced::Exact(Value::Integer(r as i64))
} else {
Coerced::NeverMatches
}
}
(Value::Integer(i), DataType::Real) => {
if i.unsigned_abs() <= EXACT_F64_INT as u64 {
Coerced::Exact(Value::Real(i as f64))
} else {
Coerced::Residual
}
}
_ => Coerced::Residual,
}
}
fn classify_leaf(
leaf: &Expr,
filter_cols: &[u16],
table_schema: &TableSchema,
) -> Option<(usize, Vec<Value>)> {
let (col_expr, rhs): (&Expr, Vec<&Expr>) = match leaf {
Expr::BinaryOp {
left,
op: BinOp::Eq,
right,
} => (left, vec![right.as_ref()]),
Expr::InList {
expr,
list,
negated: false,
} => (expr, list.iter().collect()),
_ => return None,
};
let dim = filter_dim(col_expr, filter_cols, table_schema)?;
let col_type = table_schema.columns[filter_cols[dim] as usize].data_type;
let mut vals = Vec::with_capacity(rhs.len());
for e in rhs {
match coerce_pushdown_literal(eval_const_expr(e).ok()?, col_type) {
Coerced::Exact(v) => vals.push(v),
Coerced::NeverMatches => {}
Coerced::Residual => return None,
}
}
Some((dim, vals))
}
fn filter_dim(expr: &Expr, filter_cols: &[u16], table_schema: &TableSchema) -> Option<usize> {
let name = match expr {
Expr::Column(c) => c.to_ascii_lowercase(),
Expr::QualifiedColumn { column, .. } => column.to_ascii_lowercase(),
_ => return None,
};
let col_idx = table_schema
.columns
.iter()
.position(|c| c.name.to_ascii_lowercase() == name)? as u16;
filter_cols.iter().position(|&c| c == col_idx)
}
fn fold_and(mut leaves: Vec<Expr>) -> Option<Expr> {
if leaves.is_empty() {
return None;
}
let first = leaves.remove(0);
Some(leaves.into_iter().fold(first, |acc, e| Expr::BinaryOp {
left: Box::new(acc),
op: BinOp::And,
right: Box::new(e),
}))
}
fn empty_result(table_schema: &TableSchema, stmt: &SelectStmt) -> Result<ExecutionResult> {
let (col_names, projected) = project_rows(&table_schema.columns, &stmt.columns, Vec::new())?;
Ok(ExecutionResult::Query(QueryResult {
columns: col_names,
rows: projected,
}))
}
pub(crate) fn persist_ann_index(
db: &citadel::Database,
schema: &SchemaManager,
table_schema: &TableSchema,
column: &str,
) -> Result<ann_persist::AnnSegmentInfo> {
let col_lower = column.to_ascii_lowercase();
let col_idx = table_schema
.columns
.iter()
.position(|c| c.name == col_lower)
.ok_or_else(|| SqlError::ColumnNotFound(column.to_string()))?;
let DataType::Vector { dim } = table_schema.columns[col_idx].data_type else {
return Err(SqlError::InvalidValue(format!(
"column `{column}` is not VECTOR(N)"
)));
};
if table_schema.primary_key_columns.len() != 1
|| !matches!(
table_schema.columns[table_schema.primary_key_columns[0] as usize].data_type,
DataType::Integer
)
{
return Err(SqlError::InvalidValue(
"ANN persistence requires a single INTEGER primary key (same rule as the \
ANN query plan)"
.into(),
));
}
let ann_index = table_schema
.indices
.iter()
.find(|ix| {
matches!(ix.kind, IndexKind::Inverted(InvertedKind::Ann { .. }))
&& ix.keys.len() == 1
&& matches!(ix.keys[0], IndexKey::Column { idx, .. } if idx as usize == col_idx)
})
.ok_or_else(|| SqlError::InvalidValue(format!("no ANN index declared on `{column}`")))?;
let IndexKind::Inverted(InvertedKind::Ann { metric }) = ann_index.kind else {
unreachable!("matched above");
};
let spec = AnnSpec {
col_idx,
dim,
metric,
filter_cols: ann_index.ann_filter_cols.clone(),
};
let mut wtx = db.begin_write().map_err(SqlError::Storage)?;
let outcome = scan_rows(&mut wtx, table_schema, &spec)?;
if outcome.rows.is_empty() {
return Err(SqlError::InvalidValue(
"nothing to persist: the table has no indexable (non-NULL) vectors".into(),
));
}
let n = outcome.rows.len() as u64;
let index = AnnIndex::build_with_attrs(
outcome.rows,
spec.filter_cols.len(),
ann_metric_to_prism(spec.metric),
spec.dim,
)
.map_err(|e| SqlError::InvalidValue(format!("ANN build failed: {e}")))?;
let body = citadel_vector::segment::encode(&index);
let segment_b3 = *blake3::hash(&body).as_bytes();
let dicts_ordered: Vec<Vec<(Vec<u8>, u32)>> = outcome
.dicts
.iter()
.map(|d| {
let mut entries: Vec<(Vec<u8>, u32)> = d.iter().map(|(k, &v)| (k.clone(), v)).collect();
entries.sort_by_key(|&(_, code)| code);
entries
})
.collect();
let table_root = wtx
.table_root_page(table_schema.name.as_bytes())
.map_err(SqlError::Storage)?
.map(|p| u64::from(p.0))
.ok_or_else(|| SqlError::InvalidValue("table vanished during ANN persist".into()))?;
let header = ann_persist::SegmentHeader {
format_version: ann_persist::ANNSEG_FORMAT_VERSION,
prism_config_hash: ann_persist::active_config_hash(ann_metric_to_prism(spec.metric)),
dim: spec.dim,
metric_tag: spec.metric_tag(),
n,
snapshot_max: index.snapshot_max,
table_root,
col_idx: spec.col_idx as u32,
filter_cols: spec.filter_cols.iter().map(|&c| c as u32).collect(),
dicts: dicts_ordered,
content_fingerprint: outcome.fingerprint,
segment_b3,
chunk_count: body.len().div_ceil(ann_persist::CHUNK_BYTES) as u32,
writer: format!("citadel-sql {}", env!("CARGO_PKG_VERSION")),
};
let seg_table = ann_persist::segment_table_name(&table_schema.name);
ann_persist::purge_segment(&mut wtx, &table_schema.name)?;
wtx.create_table(&seg_table).map_err(SqlError::Storage)?;
wtx.table_insert(&seg_table, &ann_persist::segment_key(0), &header.encode())
.map_err(SqlError::Storage)?;
for (chunk_no, chunk) in ann_persist::chunks(&body) {
wtx.table_insert(&seg_table, &ann_persist::segment_key(chunk_no), chunk)
.map_err(SqlError::Storage)?;
}
wtx.commit().map_err(SqlError::Storage)?;
let cached = CachedAnnIndex {
index,
dicts: outcome.dicts,
source: AnnIndexSource::Built { refusal: None },
cached_gen: db.manager().commit_generation(),
};
let key = cache_key(&table_schema.name, spec.col_idx, spec.metric);
let as_any: Arc<dyn Any + Send + Sync> = Arc::new(cached);
schema.sql_caches.lock().insert(key, as_any);
Ok(ann_persist::AnnSegmentInfo {
segment_b3,
content_fingerprint: header.content_fingerprint,
n,
dim: spec.dim,
metric_tag: header.metric_tag,
chunk_count: header.chunk_count,
})
}
pub(crate) fn ann_cache_status(
schema: &SchemaManager,
table_schema: &TableSchema,
column: &str,
) -> Result<Option<(AnnIndexSource, u64)>> {
let col_lower = column.to_ascii_lowercase();
let col_idx = table_schema
.columns
.iter()
.position(|c| c.name == col_lower)
.ok_or_else(|| SqlError::ColumnNotFound(column.to_string()))?;
let guard = schema.sql_caches.lock();
for metric in [AnnMetric::L2, AnnMetric::Inner, AnnMetric::Cosine] {
let key = cache_key(&table_schema.name, col_idx, metric);
if let Some(entry) = guard.get(&key) {
if let Ok(c) = Arc::clone(entry).downcast::<CachedAnnIndex>() {
return Ok(Some((c.source.clone(), c.cached_gen)));
}
}
}
Ok(None)
}
pub(crate) fn ann_dml_gen_key(table_name: &str) -> String {
format!("ann_dml_gen:{table_name}")
}
pub(crate) fn ann_appends_safe(schema: &SchemaManager, table: &str, min_pk: i64) -> bool {
let prefix = format!("ann:{}:", table.to_ascii_lowercase());
let guard = schema.sql_caches.lock();
for (key, val) in guard.iter() {
if !key.starts_with(&prefix) {
continue;
}
if let Some(cached) = val.downcast_ref::<CachedAnnIndex>() {
let snap = cached.index.snapshot_max as i64;
if snap < 0 || min_pk <= snap {
return false;
}
}
}
true
}
fn marker_gen_locked(
entries: &FxHashMap<String, Arc<dyn Any + Send + Sync>>,
table_name: &str,
) -> Option<u64> {
entries
.get(&ann_dml_gen_key(table_name))
.and_then(|e| e.downcast_ref::<u64>())
.copied()
}
fn lookup_cached(
schema: &SchemaManager,
cache_key: &str,
table_name: &str,
) -> Result<Option<Arc<CachedAnnIndex>>> {
let mut guard = schema.sql_caches.lock();
let Some(entry) = guard.get(cache_key) else {
return Ok(None);
};
let entry = Arc::clone(entry)
.downcast::<CachedAnnIndex>()
.map_err(|_| SqlError::InvalidValue(format!("ANN cache type mismatch for {cache_key}")))?;
if marker_gen_locked(&guard, table_name).is_some_and(|g| entry.cached_gen < g) {
guard.remove(cache_key);
return Ok(None);
}
Ok(Some(entry))
}
pub(super) fn cache_key(table_name: &str, col_idx: usize, metric: AnnMetric) -> String {
let tag = match metric {
AnnMetric::L2 => "l2",
AnnMetric::Inner => "inner",
AnnMetric::Cosine => "cosine",
};
format!(
"ann:{}:{}:{}",
table_name.to_ascii_lowercase(),
col_idx,
tag
)
}
fn ann_metric_to_prism(m: AnnMetric) -> Metric {
match m {
AnnMetric::L2 => Metric::L2,
AnnMetric::Inner => Metric::InnerProduct,
AnnMetric::Cosine => Metric::Cosine,
}
}
#[cfg(test)]
mod thrash_tests {
use super::take_ann_rebuilds;
use crate::{Connection, ExecutionResult, Value};
use citadel::{Argon2Profile, DatabaseBuilder};
const DIM: usize = 8;
fn vec_for(i: u64) -> Vec<f32> {
(0..DIM)
.map(|d| {
let x = (i.wrapping_mul(2654435761).wrapping_add(d as u64 * 40503) % 1000) as f32;
x / 1000.0
})
.collect()
}
fn vec_literal(v: &[f32]) -> String {
let parts: Vec<String> = v.iter().map(|x| format!("{x}")).collect();
format!("'[{}]'::VECTOR({})", parts.join(", "), DIM)
}
fn recall_ids(conn: &Connection<'_>, qvec: &[f32], k: usize) -> Vec<i64> {
let sql = format!(
"SELECT id FROM t WHERE category = 0 ORDER BY v <-> {} LIMIT {k}",
vec_literal(qvec)
);
match conn.execute(&sql).unwrap() {
ExecutionResult::Query(qr) => qr
.rows
.iter()
.map(|r| match &r[0] {
Value::Integer(i) => *i,
other => panic!("expected Integer id, got {other:?}"),
})
.collect(),
_ => panic!("expected query result"),
}
}
#[test]
fn interleaved_append_recall_does_not_thrash() {
let dir = tempfile::tempdir().unwrap();
let db = DatabaseBuilder::new(dir.path().join("test.db"))
.passphrase(b"test-passphrase")
.argon2_profile(Argon2Profile::Iot)
.create()
.unwrap();
let conn = Connection::open(&db).unwrap();
conn.execute(
"CREATE TABLE t (id INTEGER PRIMARY KEY, category INTEGER, score REAL, v VECTOR(8))",
)
.unwrap();
let base = 200u64;
for i in 1..=base {
conn.execute(&format!(
"INSERT INTO t VALUES ({i}, 0, 1.0, {})",
vec_literal(&vec_for(i))
))
.unwrap();
}
conn.execute(
"CREATE INDEX ix_v ON t USING ann (v) WITH (metric = 'l2', filters = 'category')",
)
.unwrap();
let _ = recall_ids(&conn, &vec_for(7), 5);
let _ = take_ann_rebuilds();
let appends = 10u64;
let mut total_rebuilds = 0u64;
for j in 0..appends {
let new_id = base + 1 + j;
let qvec = vec![0.50005f32 + (j as f32) * 0.0001; DIM];
conn.execute(&format!(
"INSERT INTO t VALUES ({new_id}, 0, 1.0, {})",
vec_literal(&qvec)
))
.unwrap();
let ids = recall_ids(&conn, &qvec, 5);
total_rebuilds += take_ann_rebuilds();
assert_eq!(
ids.first().copied(),
Some(new_id as i64),
"freshly appended exact-match row must rank #0 (I1 fresh-visibility)"
);
}
assert_eq!(
total_rebuilds, 0,
"appends must not trigger PRISM rebuilds (got {total_rebuilds} over {appends} recalls = thrash)"
);
}
fn fresh_db(dir: &std::path::Path) -> citadel::Database {
DatabaseBuilder::new(dir.join("t.db"))
.passphrase(b"test-passphrase")
.argon2_profile(Argon2Profile::Iot)
.create()
.unwrap()
}
fn setup(conn: &Connection<'_>) {
conn.execute(
"CREATE TABLE t (id INTEGER PRIMARY KEY, category INTEGER, score REAL, v VECTOR(8))",
)
.unwrap();
}
fn insert(conn: &Connection<'_>, id: u64, v: &[f32]) {
conn.execute(&format!(
"INSERT INTO t VALUES ({id}, 0, 1.0, {})",
vec_literal(v)
))
.unwrap();
}
fn build_index(conn: &Connection<'_>) {
conn.execute(
"CREATE INDEX ix_v ON t USING ann (v) WITH (metric = 'l2', filters = 'category')",
)
.unwrap();
}
#[test]
fn inplace_vector_update_is_reflected() {
let dir = tempfile::tempdir().unwrap();
let db = fresh_db(dir.path());
let conn = Connection::open(&db).unwrap();
setup(&conn);
for i in 1..=200 {
insert(&conn, i, &vec_for(i));
}
build_index(&conn);
let qvec = vec![0.50007f32; DIM];
let _ = recall_ids(&conn, &vec_for(7), 5); let _ = take_ann_rebuilds();
conn.execute(&format!(
"UPDATE t SET v = {} WHERE id = 50",
vec_literal(&qvec)
))
.unwrap();
let ids = recall_ids(&conn, &qvec, 5);
assert!(
take_ann_rebuilds() >= 1,
"an in-place vector UPDATE must invalidate the cached index"
);
assert_eq!(ids.first().copied(), Some(50), "updated row must rank #0");
}
#[test]
fn delete_indexed_row_disappears() {
let dir = tempfile::tempdir().unwrap();
let db = fresh_db(dir.path());
let conn = Connection::open(&db).unwrap();
setup(&conn);
for i in 1..=200 {
insert(&conn, i, &vec_for(i));
}
build_index(&conn);
let q = vec_for(7);
let before = recall_ids(&conn, &q, 5);
assert_eq!(before.first().copied(), Some(7), "id 7 is the exact match");
let _ = take_ann_rebuilds();
conn.execute("DELETE FROM t WHERE id = 7").unwrap();
let after = recall_ids(&conn, &q, 5);
assert!(
take_ann_rebuilds() >= 1,
"a DELETE must invalidate the cached index"
);
assert!(
!after.contains(&7),
"deleted row must not appear: {after:?}"
);
}
#[test]
fn gap_fill_below_snapshot_is_visible() {
let dir = tempfile::tempdir().unwrap();
let db = fresh_db(dir.path());
let conn = Connection::open(&db).unwrap();
setup(&conn);
for i in 1..=50 {
insert(&conn, i, &vec_for(i));
}
for i in 60..=100 {
insert(&conn, i, &vec_for(i));
}
build_index(&conn);
let _ = recall_ids(&conn, &vec_for(7), 5); let _ = take_ann_rebuilds();
let qvec = vec![0.50009f32; DIM];
insert(&conn, 55, &qvec); let ids = recall_ids(&conn, &qvec, 5);
assert!(
take_ann_rebuilds() >= 1,
"a gap-fill insert below snapshot must invalidate, not tail-merge"
);
assert_eq!(
ids.first().copied(),
Some(55),
"gap-fill row must be visible at rank #0: {ids:?}"
);
}
#[test]
fn long_tail_triggers_single_rebuild() {
let dir = tempfile::tempdir().unwrap();
let db = fresh_db(dir.path());
let conn = Connection::open(&db).unwrap();
setup(&conn);
for i in 1..=40 {
insert(&conn, i, &vec_for(i));
}
build_index(&conn);
let _ = recall_ids(&conn, &vec_for(7), 5); let _ = take_ann_rebuilds();
let qvec = vec![0.50011f32; DIM];
for i in 41..=55u64 {
let v = if i == 55 {
qvec.clone()
} else {
vec_for(i + 1000)
};
insert(&conn, i, &v);
}
assert_eq!(
take_ann_rebuilds(),
0,
"appends alone must not rebuild (retained for tail merge)"
);
let ids = recall_ids(&conn, &qvec, 5);
assert_eq!(
take_ann_rebuilds(),
1,
"a tail past the threshold must trigger exactly one rebuild on recall"
);
assert_eq!(
ids.first().copied(),
Some(55),
"post-rebuild result correct"
);
}
}