use libsql::ffi;
use std::ffi::CString;
use std::path::Path;
type VecInitFn = unsafe extern "C" fn(
*mut ffi::sqlite3,
*mut *mut ::std::os::raw::c_char,
*const ffi::sqlite3_api_routines,
) -> ::std::os::raw::c_int;
unsafe fn register_vec_on_handle(db: *mut ffi::sqlite3) -> bool {
let init_fn: VecInitFn =
unsafe { std::mem::transmute(sqlite_vec::sqlite3_vec_init as *const ()) };
let rc = unsafe { init_fn(db, std::ptr::null_mut(), std::ptr::null()) };
rc == ffi::SQLITE_OK
}
pub struct VecConnection {
raw: *mut ffi::sqlite3,
}
unsafe impl Send for VecConnection {}
unsafe impl Sync for VecConnection {}
impl VecConnection {
pub fn open(db_path: &Path) -> Option<Self> {
let c_path = CString::new(db_path.to_str()?).ok()?;
let mut raw: *mut ffi::sqlite3 = std::ptr::null_mut();
let rc = unsafe {
ffi::sqlite3_open_v2(
c_path.as_ptr(),
&mut raw,
ffi::SQLITE_OPEN_READWRITE | ffi::SQLITE_OPEN_CREATE,
std::ptr::null(),
)
};
if rc != ffi::SQLITE_OK || raw.is_null() {
if !raw.is_null() {
unsafe { ffi::sqlite3_close(raw) };
}
return None;
}
if !unsafe { register_vec_on_handle(raw) } {
unsafe { ffi::sqlite3_close(raw) };
return None;
}
Some(VecConnection { raw })
}
pub fn execute(&self, sql: &str) -> Result<(), String> {
let c_sql = CString::new(sql).map_err(|e| e.to_string())?;
let rc = unsafe {
ffi::sqlite3_exec(
self.raw,
c_sql.as_ptr(),
None,
std::ptr::null_mut(),
std::ptr::null_mut(),
)
};
if rc == ffi::SQLITE_OK {
Ok(())
} else {
Err(format!("sqlite3_exec failed with code {rc}"))
}
}
pub fn prepare(&self, sql: &str) -> Result<VecStmt, String> {
let c_sql = CString::new(sql).map_err(|e| e.to_string())?;
let mut stmt: *mut ffi::sqlite3_stmt = std::ptr::null_mut();
let rc = unsafe {
ffi::sqlite3_prepare_v2(
self.raw,
c_sql.as_ptr(),
-1,
&mut stmt,
std::ptr::null_mut(),
)
};
if rc != ffi::SQLITE_OK {
return Err(format!("prepare failed: {rc}"));
}
Ok(VecStmt { raw: stmt })
}
pub fn handle(&self) -> *mut ffi::sqlite3 {
self.raw
}
pub fn last_insert_rowid(&self) -> i64 {
unsafe { ffi::sqlite3_last_insert_rowid(self.raw) }
}
}
impl Drop for VecConnection {
fn drop(&mut self) {
if !self.raw.is_null() {
unsafe { ffi::sqlite3_close(self.raw) };
}
}
}
pub struct VecStmt {
raw: *mut ffi::sqlite3_stmt,
}
impl VecStmt {
pub fn bind_int64(&self, idx: i32, val: i64) {
unsafe { ffi::sqlite3_bind_int64(self.raw, idx, val) };
}
pub fn bind_blob(&self, idx: i32, data: &[u8]) {
unsafe {
ffi::sqlite3_bind_blob(
self.raw,
idx,
data.as_ptr() as *const _,
data.len() as i32,
ffi::SQLITE_TRANSIENT(),
);
}
}
pub fn bind_text(&self, idx: i32, val: &str) {
let c_val = CString::new(val).unwrap_or_default();
unsafe {
ffi::sqlite3_bind_text(self.raw, idx, c_val.as_ptr(), -1, ffi::SQLITE_TRANSIENT());
}
}
pub fn step(&self) -> Result<bool, String> {
let rc = unsafe { ffi::sqlite3_step(self.raw) };
match rc {
_ if rc == ffi::SQLITE_ROW => Ok(true),
_ if rc == ffi::SQLITE_DONE => Ok(false),
_ => Err(format!("step failed: {rc}")),
}
}
pub fn column_int64(&self, idx: i32) -> i64 {
unsafe { ffi::sqlite3_column_int64(self.raw, idx) }
}
pub fn column_double(&self, idx: i32) -> f64 {
unsafe { ffi::sqlite3_column_double(self.raw, idx) }
}
pub fn column_text(&self, idx: i32) -> Option<String> {
let ptr = unsafe { ffi::sqlite3_column_text(self.raw, idx) };
if ptr.is_null() {
None
} else {
let c_str = unsafe { std::ffi::CStr::from_ptr(ptr as *const _) };
c_str.to_str().ok().map(|s| s.to_string())
}
}
pub fn column_blob(&self, idx: i32) -> Vec<u8> {
let ptr = unsafe { ffi::sqlite3_column_blob(self.raw, idx) };
let len = unsafe { ffi::sqlite3_column_bytes(self.raw, idx) };
if ptr.is_null() || len <= 0 {
Vec::new()
} else {
unsafe { std::slice::from_raw_parts(ptr as *const u8, len as usize) }.to_vec()
}
}
}
impl Drop for VecStmt {
fn drop(&mut self) {
if !self.raw.is_null() {
unsafe { ffi::sqlite3_finalize(self.raw) };
}
}
}
pub fn open_vec_connection(db_path: &Path) -> Option<VecConnection> {
VecConnection::open(db_path)
}
pub async fn vec_available(conn: &libsql::Connection) -> bool {
match conn.query("SELECT vec_version()", ()).await {
Ok(mut rows) => rows.next().await.is_ok(),
Err(_) => false,
}
}