Skip to main content

normalize_semantic/
vec_ext.rs

1//! sqlite-vec extension registration.
2//!
3//! sqlite-vec provides a `vec0` virtual table for approximate nearest-neighbor
4//! search. The crate (`sqlite-vec`) ships a static C library that must be
5//! registered with SQLite on each connection that wants `vec0` tables.
6//!
7//! # Per-connection registration
8//!
9//! We call `sqlite3_vec_init(db, NULL, NULL)` directly on the connection's
10//! underlying raw `*mut sqlite3` handle.  This avoids `sqlite3_auto_extension`,
11//! which internally calls `sqlite3_initialize()` and conflicts with libsql's
12//! own initialization (libsql asserts that `sqlite3_config(SERIALIZED)` returns
13//! `SQLITE_OK`, which fails if sqlite was already initialized).
14//!
15//! Because `libsql::Connection` does not expose the raw handle through its
16//! public API, we obtain it by opening a lightweight raw FFI connection to the
17//! same database file.  A `VecConnection` wraps this raw handle and provides
18//! helper methods for vec-specific operations.
19//!
20//! # Safety
21//!
22//! `sqlite3_vec_init` is declared by the `sqlite-vec` crate as `fn()` (zero
23//! arguments) but the actual C function has the standard SQLite extension entry
24//! point signature `(sqlite3*, char**, const sqlite3_api_routines*) -> int`.
25//! We re-declare it with the correct signature and call it directly.  Because
26//! the library is compiled with `SQLITE_CORE`, the `pApi` parameter is ignored
27//! (the extension calls SQLite functions directly rather than through the API
28//! struct).
29
30use libsql::ffi;
31use std::ffi::CString;
32use std::path::Path;
33
34/// The correct C signature for `sqlite3_vec_init`.  The `sqlite-vec` crate
35/// declares it as `fn()` (zero arguments), but the real symbol is a standard
36/// SQLite extension entry point: `(sqlite3*, char**, const sqlite3_api_routines*) -> int`.
37type VecInitFn = unsafe extern "C" fn(
38    *mut ffi::sqlite3,
39    *mut *mut ::std::os::raw::c_char,
40    *const ffi::sqlite3_api_routines,
41) -> ::std::os::raw::c_int;
42
43/// Register the sqlite-vec extension on a specific raw SQLite handle.
44///
45/// # Safety
46///
47/// `db` must be a valid, open `sqlite3*` handle.
48unsafe fn register_vec_on_handle(db: *mut ffi::sqlite3) -> bool {
49    // Transmute `sqlite_vec::sqlite3_vec_init` (declared as `fn()`) to the
50    // correct 3-argument entry-point signature.  This is safe because the
51    // underlying C function actually has that signature — the crate just
52    // declares it with zero args for use with `sqlite3_auto_extension`.
53    let init_fn: VecInitFn =
54        unsafe { std::mem::transmute(sqlite_vec::sqlite3_vec_init as *const ()) };
55    let rc = unsafe { init_fn(db, std::ptr::null_mut(), std::ptr::null()) };
56    rc == ffi::SQLITE_OK
57}
58
59/// A raw SQLite connection with sqlite-vec registered.
60///
61/// This wraps a `*mut sqlite3` handle opened via FFI and provides vec-specific
62/// operations.  The handle points to the same database file as the main
63/// `libsql::Connection`, allowing vec operations (ANN virtual table) alongside
64/// normal queries on the main connection.
65pub struct VecConnection {
66    raw: *mut ffi::sqlite3,
67}
68
69// SAFETY: The underlying sqlite3 handle is compiled with SQLITE_THREADSAFE=1
70// (serialized mode), so it's safe to send/share across threads.
71unsafe impl Send for VecConnection {}
72unsafe impl Sync for VecConnection {}
73
74impl VecConnection {
75    /// Open a raw connection to `db_path` and register sqlite-vec on it.
76    ///
77    /// Returns `None` if the path can't be converted to a C string or if
78    /// opening / extension registration fails.
79    pub fn open(db_path: &Path) -> Option<Self> {
80        let c_path = CString::new(db_path.to_str()?).ok()?;
81        let mut raw: *mut ffi::sqlite3 = std::ptr::null_mut();
82        let rc = unsafe {
83            ffi::sqlite3_open_v2(
84                c_path.as_ptr(),
85                &mut raw,
86                ffi::SQLITE_OPEN_READWRITE | ffi::SQLITE_OPEN_CREATE,
87                std::ptr::null(),
88            )
89        };
90        if rc != ffi::SQLITE_OK || raw.is_null() {
91            if !raw.is_null() {
92                unsafe { ffi::sqlite3_close(raw) };
93            }
94            return None;
95        }
96
97        if !unsafe { register_vec_on_handle(raw) } {
98            unsafe { ffi::sqlite3_close(raw) };
99            return None;
100        }
101
102        Some(VecConnection { raw })
103    }
104
105    /// Execute a SQL statement with no parameters.
106    pub fn execute(&self, sql: &str) -> Result<(), String> {
107        let c_sql = CString::new(sql).map_err(|e| e.to_string())?;
108        let rc = unsafe {
109            ffi::sqlite3_exec(
110                self.raw,
111                c_sql.as_ptr(),
112                None,
113                std::ptr::null_mut(),
114                std::ptr::null_mut(),
115            )
116        };
117        if rc == ffi::SQLITE_OK {
118            Ok(())
119        } else {
120            Err(format!("sqlite3_exec failed with code {rc}"))
121        }
122    }
123
124    /// Prepare a SQL statement and return a raw handle for manual binding.
125    ///
126    /// The caller is responsible for stepping, reading results, and
127    /// finalizing via [`VecStmt`].
128    pub fn prepare(&self, sql: &str) -> Result<VecStmt, String> {
129        let c_sql = CString::new(sql).map_err(|e| e.to_string())?;
130        let mut stmt: *mut ffi::sqlite3_stmt = std::ptr::null_mut();
131        let rc = unsafe {
132            ffi::sqlite3_prepare_v2(
133                self.raw,
134                c_sql.as_ptr(),
135                -1,
136                &mut stmt,
137                std::ptr::null_mut(),
138            )
139        };
140        if rc != ffi::SQLITE_OK {
141            return Err(format!("prepare failed: {rc}"));
142        }
143        Ok(VecStmt { raw: stmt })
144    }
145
146    /// Get the raw sqlite3 handle (for direct FFI use in store operations).
147    pub fn handle(&self) -> *mut ffi::sqlite3 {
148        self.raw
149    }
150
151    /// Return the rowid returned by `last_insert_rowid()`.
152    pub fn last_insert_rowid(&self) -> i64 {
153        unsafe { ffi::sqlite3_last_insert_rowid(self.raw) }
154    }
155}
156
157impl Drop for VecConnection {
158    fn drop(&mut self) {
159        if !self.raw.is_null() {
160            unsafe { ffi::sqlite3_close(self.raw) };
161        }
162    }
163}
164
165/// A prepared statement on a [`VecConnection`].
166///
167/// Provides methods for binding parameters, stepping, and reading results
168/// via raw SQLite FFI.
169pub struct VecStmt {
170    raw: *mut ffi::sqlite3_stmt,
171}
172
173impl VecStmt {
174    /// Bind an integer at 1-based position `idx`.
175    pub fn bind_int64(&self, idx: i32, val: i64) {
176        unsafe { ffi::sqlite3_bind_int64(self.raw, idx, val) };
177    }
178
179    /// Bind a BLOB at 1-based position `idx`.
180    pub fn bind_blob(&self, idx: i32, data: &[u8]) {
181        unsafe {
182            ffi::sqlite3_bind_blob(
183                self.raw,
184                idx,
185                data.as_ptr() as *const _,
186                data.len() as i32,
187                ffi::SQLITE_TRANSIENT(),
188            );
189        }
190    }
191
192    /// Bind a text string at 1-based position `idx`.
193    pub fn bind_text(&self, idx: i32, val: &str) {
194        let c_val = CString::new(val).unwrap_or_default();
195        unsafe {
196            ffi::sqlite3_bind_text(self.raw, idx, c_val.as_ptr(), -1, ffi::SQLITE_TRANSIENT());
197        }
198    }
199
200    /// Step the statement.  Returns `true` if a row is available
201    /// (`SQLITE_ROW`), `false` on `SQLITE_DONE`.  Other codes return an error.
202    pub fn step(&self) -> Result<bool, String> {
203        let rc = unsafe { ffi::sqlite3_step(self.raw) };
204        match rc {
205            _ if rc == ffi::SQLITE_ROW => Ok(true),
206            _ if rc == ffi::SQLITE_DONE => Ok(false),
207            _ => Err(format!("step failed: {rc}")),
208        }
209    }
210
211    /// Read a 64-bit integer from column `idx` (0-based).
212    pub fn column_int64(&self, idx: i32) -> i64 {
213        unsafe { ffi::sqlite3_column_int64(self.raw, idx) }
214    }
215
216    /// Read a double from column `idx` (0-based).
217    pub fn column_double(&self, idx: i32) -> f64 {
218        unsafe { ffi::sqlite3_column_double(self.raw, idx) }
219    }
220
221    /// Read a text string from column `idx` (0-based).
222    pub fn column_text(&self, idx: i32) -> Option<String> {
223        let ptr = unsafe { ffi::sqlite3_column_text(self.raw, idx) };
224        if ptr.is_null() {
225            None
226        } else {
227            let c_str = unsafe { std::ffi::CStr::from_ptr(ptr as *const _) };
228            c_str.to_str().ok().map(|s| s.to_string())
229        }
230    }
231
232    /// Read a BLOB from column `idx` (0-based).
233    pub fn column_blob(&self, idx: i32) -> Vec<u8> {
234        let ptr = unsafe { ffi::sqlite3_column_blob(self.raw, idx) };
235        let len = unsafe { ffi::sqlite3_column_bytes(self.raw, idx) };
236        if ptr.is_null() || len <= 0 {
237            Vec::new()
238        } else {
239            unsafe { std::slice::from_raw_parts(ptr as *const u8, len as usize) }.to_vec()
240        }
241    }
242}
243
244impl Drop for VecStmt {
245    fn drop(&mut self) {
246        if !self.raw.is_null() {
247            unsafe { ffi::sqlite3_finalize(self.raw) };
248        }
249    }
250}
251
252/// Register the sqlite-vec extension on a `libsql::Connection` by opening a
253/// raw FFI handle to the same database and registering there.
254///
255/// This is a convenience for the common case where you have a
256/// `libsql::Connection` and a `db_path` and want vec on a parallel handle.
257/// Returns the `VecConnection` if successful.
258pub fn open_vec_connection(db_path: &Path) -> Option<VecConnection> {
259    VecConnection::open(db_path)
260}
261
262/// Check whether the sqlite-vec extension is available on a connection by
263/// trying to call `vec_version()`.  Returns `true` if the extension is loaded.
264pub async fn vec_available(conn: &libsql::Connection) -> bool {
265    match conn.query("SELECT vec_version()", ()).await {
266        Ok(mut rows) => rows.next().await.is_ok(),
267        Err(_) => false,
268    }
269}