pub mod arrow_io;
pub mod distance;
pub mod index;
pub mod json;
pub mod scalar;
pub mod types;
pub mod vtab;
#[cfg(feature = "loadable_extension")]
use sqlite3_ext::*;
#[cfg(feature = "loadable_extension")]
#[sqlite3_ext_main(persistent)]
fn sqlite3_extension_init(db: &Connection) -> Result<()> {
use sqlite3_ext::vtab::Module;
let module = sqlite3_ext::vtab::StandardModule::<vtab::VectorTable<'_>>::new()
.with_update()
.with_transactions()
.with_find_function();
db.create_module("vector", module, ())?;
scalar::register_scalar_functions(db)?;
Ok(())
}
#[cfg(feature = "library")]
pub fn register(conn: &rusqlite::Connection) -> std::result::Result<(), rusqlite::Error> {
let path = find_extension_path().ok_or_else(|| {
rusqlite::Error::ModuleError(
"Could not find the sqlite-vector-rs extension library. \
Build with `cargo build` first, or set SQLITE_VECTOR_RS_LIB to the library path."
.into(),
)
})?;
unsafe {
conn.load_extension_enable()?;
}
let result = unsafe { conn.load_extension(&path, None::<&str>) };
let _ = conn.load_extension_disable();
result
}
#[cfg(feature = "library")]
fn find_extension_path() -> Option<String> {
use std::path::Path;
let stem = if cfg!(target_os = "windows") {
"sqlite_vector_rs"
} else {
"libsqlite_vector_rs"
};
let extensions: &[&str] = if cfg!(target_os = "macos") {
&[".dylib"]
} else if cfg!(target_os = "windows") {
&[".dll"]
} else {
&[".so"]
};
if let Ok(val) = std::env::var("SQLITE_VECTOR_RS_LIB") {
if Path::new(&val).exists() {
return Some(val);
}
if extensions.iter().any(|ext| Path::new(&format!("{val}{ext}")).exists()) {
return Some(val);
}
}
if let Ok(exe) = std::env::current_exe()
&& let Some(dir) = exe.parent()
{
let base = dir.join(stem);
let base_str = base.to_string_lossy();
if extensions
.iter()
.any(|ext| Path::new(&format!("{base_str}{ext}")).exists())
{
return Some(base_str.into_owned());
}
}
for profile in &["debug", "release"] {
let base = format!("target/{profile}/{stem}");
if extensions
.iter()
.any(|ext| Path::new(&format!("{base}{ext}")).exists())
{
return Some(base);
}
}
None
}