use sqlite3_ext::function::FunctionOptions;
use sqlite3_ext::query::ToParam;
use sqlite3_ext::*;
use crate::arrow_io;
use crate::distance::{DistanceMetric, compute_distance};
use crate::index::HnswIndex;
use crate::json::{blob_to_json, json_to_blob};
use crate::types::VectorType;
use crate::vtab::shadow::ShadowOps;
pub fn register_scalar_functions(db: &Connection) -> Result<()> {
db.create_scalar_function(
"vector_distance",
&FunctionOptions::default()
.set_n_args(4)
.set_deterministic(true),
|ctx, args| {
let metric_name = args[2].get_str()?.to_owned();
let type_name = args[3].get_str()?.to_owned();
let blob_a = args[0].get_blob()?.to_vec();
let blob_b = args[1].get_blob()?.to_vec();
let vtype =
VectorType::from_name(&type_name).map_err(|e| Error::Module(e.to_string()))?;
let metric = DistanceMetric::from_name(&metric_name)
.map_err(|e| Error::Module(e.to_string()))?;
let dim = blob_a.len() / vtype.element_size();
let dist = compute_distance(&blob_a, &blob_b, vtype, metric, dim)
.map_err(|e| Error::Module(e.to_string()))?;
ctx.set_result(dist)?;
Ok(())
},
)?;
db.create_scalar_function(
"vector_from_json",
&FunctionOptions::default()
.set_n_args(2)
.set_deterministic(true),
|ctx, args| {
let json_text = args[0].get_str()?.to_owned();
let type_name = args[1].get_str()?.to_owned();
let vtype =
VectorType::from_name(&type_name).map_err(|e| Error::Module(e.to_string()))?;
let blob = json_to_blob(&json_text, vtype).map_err(|e| Error::Module(e.to_string()))?;
ctx.set_result(&blob[..])?;
Ok(())
},
)?;
db.create_scalar_function(
"vector_to_json",
&FunctionOptions::default()
.set_n_args(2)
.set_deterministic(true),
|ctx, args| {
let type_name = args[1].get_str()?.to_owned();
let blob = args[0].get_blob()?.to_vec();
let vtype =
VectorType::from_name(&type_name).map_err(|e| Error::Module(e.to_string()))?;
let json = blob_to_json(&blob, vtype).map_err(|e| Error::Module(e.to_string()))?;
ctx.set_result(json)?;
Ok(())
},
)?;
db.create_scalar_function(
"vector_dims",
&FunctionOptions::default()
.set_n_args(2)
.set_deterministic(true),
|ctx, args| {
let type_name = args[1].get_str()?.to_owned();
let blob = args[0].get_blob()?;
let vtype =
VectorType::from_name(&type_name).map_err(|e| Error::Module(e.to_string()))?;
let dims = blob.len() / vtype.element_size();
ctx.set_result(dims as i64)?;
Ok(())
},
)?;
db.create_scalar_function(
"knn_match",
&FunctionOptions::default().set_n_args(2),
|ctx, _args| {
ctx.set_result(1i32)?;
Ok(())
},
)?;
db.create_scalar_function(
"vector_rebuild_index",
&FunctionOptions::default().set_n_args(3),
|ctx, args| {
let table_name = args[0].get_str()?.to_owned();
let type_name = args[1].get_str()?.to_owned();
let metric_name = args[2].get_str()?.to_owned();
let vtype =
VectorType::from_name(&type_name).map_err(|e| Error::Module(e.to_string()))?;
let metric = DistanceMetric::from_name(&metric_name)
.map_err(|e| Error::Module(e.to_string()))?;
let db = ctx.db();
let sql = ShadowOps::select_all_data_sql(&table_name);
let mut stmt = db.prepare(&sql)?;
stmt.query(())?;
let mut rows: Vec<(i64, Vec<u8>)> = Vec::new();
while let Some(row) = stmt.next()? {
let id = row[0].get_i64();
let blob = row[1].get_blob()?.to_vec();
rows.push((id, blob));
}
if rows.is_empty() {
ctx.set_result(0i64)?;
return Ok(());
}
let dim = rows[0].1.len() / vtype.element_size();
let index = HnswIndex::new(dim, vtype, metric, None)
.map_err(|e| Error::Module(e.to_string()))?;
for (id, blob) in &rows {
index
.add(*id as u64, blob)
.map_err(|e| Error::Module(e.to_string()))?;
}
let buf = index
.save_to_buffer()
.map_err(|e| Error::Module(e.to_string()))?;
let upsert_sql = ShadowOps::upsert_index_sql(&table_name);
db.insert(&upsert_sql, |stmt: &mut query::Statement| {
"hnsw_graph".bind_param(&mut *stmt, 1)?;
buf.as_slice().bind_param(&mut *stmt, 2)?;
Ok(())
})?;
ctx.set_result(rows.len() as i64)?;
Ok(())
},
)?;
db.create_scalar_function(
"vector_export_arrow",
&FunctionOptions::default().set_n_args(2),
|ctx, args| {
let table_name = args[0].get_str()?.to_owned();
let type_name = args[1].get_str()?.to_owned();
let vtype =
VectorType::from_name(&type_name).map_err(|e| Error::Module(e.to_string()))?;
let db = ctx.db();
let sql = ShadowOps::select_all_data_sql(&table_name);
let mut stmt = db.prepare(&sql)?;
stmt.query(())?;
let mut blobs: Vec<Vec<u8>> = Vec::new();
while let Some(row) = stmt.next()? {
blobs.push(row[1].get_blob()?.to_vec());
}
if blobs.is_empty() {
let empty: &[u8] = &[];
ctx.set_result(empty)?;
return Ok(());
}
let dim = blobs[0].len() / vtype.element_size();
let ipc = arrow_io::vectors_to_arrow_ipc(&blobs, vtype, dim)
.map_err(|e| Error::Module(e.to_string()))?;
ctx.set_result(&ipc[..])?;
Ok(())
},
)?;
db.create_scalar_function(
"vector_insert_arrow",
&FunctionOptions::default().set_n_args(3),
|ctx, args| {
let table_name = args[0].get_str()?.to_owned();
let type_name = args[1].get_str()?.to_owned();
let ipc_blob = args[2].get_blob()?.to_vec();
let vtype =
VectorType::from_name(&type_name).map_err(|e| Error::Module(e.to_string()))?;
if ipc_blob.is_empty() {
ctx.set_result(0i64)?;
return Ok(());
}
let blobs = arrow_io::arrow_ipc_to_vectors(&ipc_blob, vtype, 0)
.map_err(|e| Error::Module(e.to_string()))?;
if blobs.is_empty() {
ctx.set_result(0i64)?;
return Ok(());
}
let db = ctx.db();
let insert_sql = ShadowOps::insert_vector_only_sql(&table_name);
for blob in &blobs {
db.insert(&insert_sql, [blob.as_slice()])?;
}
ctx.set_result(blobs.len() as i64)?;
Ok(())
},
)?;
Ok(())
}