pub mod config;
pub mod cursor;
pub mod shadow;
pub mod transaction;
use std::cell::RefCell;
use std::sync::Arc;
use sqlite3_ext::query::ToParam;
use sqlite3_ext::vtab::{
ChangeInfo, ChangeType, ConstraintOp, CreateVTab, DisconnectResult, FindFunctionVTab,
IndexInfo, TransactionVTab, UpdateVTab, VTab, VTabConnection, VTabFunctionList,
};
use sqlite3_ext::{Error, FromValue, Result, SQLITE_EMPTY, ValueRef, function::Context};
use crate::index::HnswIndex;
use crate::vtab::config::VectorTableConfig;
use crate::vtab::cursor::{CursorMode, VectorCursor};
use crate::vtab::shadow::ShadowOps;
use crate::vtab::transaction::{IndexState, VectorTransaction};
const INDEX_SCAN: i32 = 0;
const INDEX_KNN: i32 = 1;
pub struct VectorTable<'vtab> {
config: VectorTableConfig,
state: Arc<RefCell<IndexState>>,
db: *const VTabConnection,
functions: VTabFunctionList<'vtab, Self>,
}
unsafe impl Send for VectorTable<'_> {}
unsafe impl Sync for VectorTable<'_> {}
fn load_index_from_shadow(db: &VTabConnection, table_name: &str) -> Result<Option<Vec<u8>>> {
let sql = ShadowOps::select_index_sql(table_name);
match db.query_row(&sql, ["hnsw_graph"], |row| {
let blob = row[0].get_blob()?;
Ok(blob.to_vec())
}) {
Ok(buf) => Ok(Some(buf)),
Err(ref e) if *e == SQLITE_EMPTY => Ok(None),
Err(e) => Err(e),
}
}
#[allow(dead_code)]
fn save_meta_to_shadow(db: &VTabConnection, table_name: &str, meta_json: &str) -> Result<()> {
let sql = ShadowOps::upsert_index_sql(table_name);
db.execute(&sql, ["meta", meta_json])?;
Ok(())
}
fn insert_into_data_shadow(
db: &VTabConnection,
config: &VectorTableConfig,
vector_blob: &[u8],
metadata_args: &mut [&mut ValueRef],
) -> Result<i64> {
use sqlite3_ext::query::Statement;
let sql = ShadowOps::insert_data_sql(config);
db.insert(&sql, |stmt: &mut Statement| {
vector_blob.bind_param(&mut *stmt, 1)?;
for (i, val) in metadata_args.iter_mut().enumerate() {
val.bind_param(&mut *stmt, (i + 2) as i32)?;
}
Ok(())
})
}
fn delete_from_data_shadow(db: &VTabConnection, table_name: &str, rowid: i64) -> Result<()> {
let sql = ShadowOps::delete_data_sql(table_name);
db.execute(&sql, [rowid])?;
Ok(())
}
fn update_data_shadow(
db: &VTabConnection,
config: &VectorTableConfig,
rowid: i64,
vector_blob: &[u8],
metadata_args: &mut [&mut ValueRef],
) -> Result<()> {
delete_from_data_shadow(db, &config.table_name, rowid)?;
insert_into_data_shadow(db, config, vector_blob, metadata_args)?;
Ok(())
}
#[allow(clippy::arc_with_non_send_sync)]
fn init<'vtab>(db: &VTabConnection, args: &[&str]) -> Result<(String, VectorTable<'vtab>)> {
let config = VectorTableConfig::parse(args).map_err(|e| Error::Module(e.to_string()))?;
let schema = config.vtab_schema();
let index = match load_index_from_shadow(db, &config.table_name) {
Ok(Some(buf)) => {
let idx = HnswIndex::new(
config.dim,
config.vtype,
config.metric,
Some(config.hnsw_params),
)
.map_err(|e| Error::Module(e.to_string()))?;
idx.load_from_buffer(&buf)
.map_err(|e| Error::Module(e.to_string()))?;
idx
}
_ => HnswIndex::new(
config.dim,
config.vtype,
config.metric,
Some(config.hnsw_params),
)
.map_err(|e| Error::Module(e.to_string()))?,
};
let state = Arc::new(RefCell::new(IndexState {
index,
dirty: false,
last_committed: None,
}));
let functions = VTabFunctionList::default();
functions.add(
2,
"knn_match",
Some(ConstraintOp::Function(150)),
|ctx: &Context, _args: &mut [&mut ValueRef]| ctx.set_result(1i32),
);
let vtab = VectorTable {
config,
state,
db: db as *const VTabConnection,
functions,
};
Ok((schema, vtab))
}
impl<'vtab> VTab<'vtab> for VectorTable<'vtab> {
type Aux = ();
type Cursor = VectorCursor;
fn connect(
db: &'vtab VTabConnection,
_aux: &'vtab Self::Aux,
args: &[&str],
) -> Result<(String, Self)> {
init(db, args)
}
fn best_index(&'vtab self, info: &mut IndexInfo) -> Result<()> {
let distance_col = (2 + self.config.metadata_columns.len()) as i32;
let mut found_knn = false;
let mut argv_next: u32 = 1;
for mut c in info.constraints() {
if !c.usable() {
continue;
}
if c.column() == distance_col
&& let ConstraintOp::Function(_) = c.op()
{
c.set_argv_index(Some(argv_next - 1));
c.set_omit(true);
argv_next += 1;
found_knn = true;
}
if let ConstraintOp::Limit = c.op()
&& found_knn
{
c.set_argv_index(Some(argv_next - 1));
c.set_omit(true);
argv_next += 1;
}
}
if found_knn {
info.set_index_num(INDEX_KNN);
info.set_estimated_cost(10.0);
info.set_estimated_rows(10);
} else {
info.set_index_num(INDEX_SCAN);
info.set_estimated_cost(1_000_000.0);
info.set_estimated_rows(1_000_000);
}
Ok(())
}
fn open(&'vtab self) -> Result<Self::Cursor> {
Ok(VectorCursor {
mode: CursorMode::Scan {
rows: Vec::new(),
pos: 0,
},
num_metadata_cols: self.config.metadata_columns.len(),
db: self.db,
config: &self.config as *const VectorTableConfig,
state: Arc::clone(&self.state),
})
}
}
impl<'vtab> CreateVTab<'vtab> for VectorTable<'vtab> {
const SHADOW_NAMES: &'static [&'static str] = &["data", "index"];
fn create(
db: &'vtab VTabConnection,
aux: &'vtab Self::Aux,
args: &[&str],
) -> Result<(String, Self)> {
let (schema, vtab) = init(db, args)?;
db.execute(&ShadowOps::create_data_table_sql(&vtab.config), ())?;
db.execute(&ShadowOps::create_index_table_sql(&vtab.config), ())?;
let _ = aux;
Ok((schema, vtab))
}
fn destroy(self) -> DisconnectResult<Self> {
let db = unsafe { &*self.db };
for sql in ShadowOps::drop_shadow_tables_sql(&self.config.table_name) {
if let Err(e) = db.execute(&sql, ()) {
return Err((self, e));
}
}
Ok(())
}
}
impl<'vtab> UpdateVTab<'vtab> for VectorTable<'vtab> {
fn update(&'vtab self, info: &mut ChangeInfo) -> Result<i64> {
let db = unsafe { &*self.db };
match info.change_type() {
ChangeType::Delete => {
let rowid = info.rowid().get_i64();
delete_from_data_shadow(db, &self.config.table_name, rowid)?;
self.state
.borrow()
.index
.remove(rowid as u64)
.map_err(|e| Error::Module(e.to_string()))?;
self.state.borrow_mut().dirty = true;
Ok(0)
}
ChangeType::Insert => {
let args = info.args_mut();
let vector_blob = args[2].get_blob()?.to_vec();
let num_meta = self.config.metadata_columns.len();
let meta_args = &mut args[3..3 + num_meta];
self.config
.vtype
.validate_blob(&vector_blob, self.config.dim)
.map_err(|e| Error::Module(e.to_string()))?;
self.config
.vtype
.validate_finite(&vector_blob, self.config.dim)
.map_err(|e| Error::Module(e.to_string()))?;
let rowid = insert_into_data_shadow(db, &self.config, &vector_blob, meta_args)?;
let state = self.state.borrow();
state
.index
.add(rowid as u64, &vector_blob)
.map_err(|e| Error::Module(e.to_string()))?;
drop(state);
self.state.borrow_mut().dirty = true;
Ok(rowid)
}
ChangeType::Update => {
let rowid = info.rowid().get_i64();
let args = info.args_mut();
let vector_blob = args[2].get_blob()?.to_vec();
let num_meta = self.config.metadata_columns.len();
let meta_args = &mut args[3..3 + num_meta];
update_data_shadow(db, &self.config, rowid, &vector_blob, meta_args)?;
let state = self.state.borrow();
state
.index
.remove(rowid as u64)
.map_err(|e| Error::Module(e.to_string()))?;
state
.index
.add(rowid as u64, &vector_blob)
.map_err(|e| Error::Module(e.to_string()))?;
drop(state);
self.state.borrow_mut().dirty = true;
Ok(rowid)
}
}
}
}
impl<'vtab> TransactionVTab<'vtab> for VectorTable<'vtab> {
type Transaction = VectorTransaction;
fn begin(&'vtab self) -> Result<Self::Transaction> {
Ok(VectorTransaction {
state: Arc::clone(&self.state),
table_name: self.config.table_name.clone(),
db: self.db,
})
}
}
impl<'vtab> FindFunctionVTab<'vtab> for VectorTable<'vtab> {
fn functions(&'vtab self) -> &'vtab VTabFunctionList<'vtab, Self> {
&self.functions
}
}