use rusqlite::ffi;
use rusqlite::functions::FunctionFlags;
use rusqlite::{Connection, Error, Result};
use std::collections::HashMap;
use std::ffi::CStr;
use std::os::raw::{c_char, c_int};
use std::panic::{AssertUnwindSafe, catch_unwind};
use std::path::PathBuf;
use std::sync::Arc;
use std::sync::Mutex as StdMutex;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::mpsc::{Receiver, RecvTimeoutError};
use std::time::Duration;
use std::{ptr, sync::LazyLock};
fn panic_error(payload: Box<dyn std::any::Any + Send>) -> Error {
let msg = if let Some(s) = payload.downcast_ref::<&str>() {
*s
} else if let Some(s) = payload.downcast_ref::<String>() {
s.as_str()
} else {
"non-string panic payload"
};
Error::UserFunctionError(Box::new(std::io::Error::other(format!(
"honker extension initialization panicked: {msg}"
))))
}
fn extension_init(conn: Connection) -> Result<bool> {
match catch_unwind(AssertUnwindSafe(|| {
honker_core::attach_notify(&conn).map_err(|e| {
Error::UserFunctionError(Box::new(std::io::Error::other(e.to_string())))
})?;
honker_core::attach_honker_functions(&conn)?;
attach_watcher_sql_functions(&conn)?;
Ok(true)
})) {
Ok(result) => result,
Err(payload) => Err(panic_error(payload)),
}
}
static SQL_WATCHERS: LazyLock<StdMutex<HashMap<u64, HonkerWatcherHandle>>> =
LazyLock::new(|| StdMutex::new(HashMap::new()));
static NEXT_SQL_WATCHER_ID: AtomicU64 = AtomicU64::new(1);
fn open_watcher_handle(
db_path: &str,
backend: Option<&str>,
) -> std::result::Result<HonkerWatcherHandle, String> {
let backend = honker_core::WatcherBackend::parse(backend.filter(|s| !s.is_empty()))?;
backend.probe(PathBuf::from(db_path).as_path())?;
let shared = Arc::new(honker_core::SharedUpdateWatcher::new_with_config(
PathBuf::from(db_path),
honker_core::WatcherConfig { backend },
));
let (sub_id, rx) = shared.subscribe();
Ok(HonkerWatcherHandle { shared, sub_id, rx })
}
fn attach_watcher_sql_functions(conn: &Connection) -> Result<()> {
conn.create_scalar_function(
"honker_update_watcher_open",
2,
FunctionFlags::SQLITE_UTF8,
|ctx| {
let db_path: String = ctx.get(0)?;
let backend: Option<String> = ctx.get(1)?;
let handle = open_watcher_handle(&db_path, backend.as_deref()).map_err(|e| {
rusqlite::Error::UserFunctionError(Box::new(std::io::Error::other(e)))
})?;
let id = NEXT_SQL_WATCHER_ID.fetch_add(1, Ordering::Relaxed);
SQL_WATCHERS.lock().unwrap().insert(id, handle);
Ok(id as i64)
},
)?;
conn.create_scalar_function(
"honker_update_watcher_wait",
2,
FunctionFlags::SQLITE_UTF8,
|ctx| {
let id: i64 = ctx.get(0)?;
let timeout_ms: i64 = ctx.get(1)?;
let Some(handle) = SQL_WATCHERS.lock().unwrap().remove(&(id as u64)) else {
return Ok(-1);
};
let timeout_ms = timeout_ms.max(0) as u64;
let code = match handle.rx.recv_timeout(Duration::from_millis(timeout_ms)) {
Ok(()) => 1,
Err(RecvTimeoutError::Timeout) => 0,
Err(RecvTimeoutError::Disconnected) => -1,
};
if code != -1 {
SQL_WATCHERS.lock().unwrap().insert(id as u64, handle);
} else {
handle.shared.unsubscribe(handle.sub_id);
let _ = handle.shared.close();
}
Ok(code)
},
)?;
conn.create_scalar_function(
"honker_update_watcher_close",
1,
FunctionFlags::SQLITE_UTF8,
|ctx| {
let id: i64 = ctx.get(0)?;
if let Some(handle) = SQL_WATCHERS.lock().unwrap().remove(&(id as u64)) {
handle.shared.unsubscribe(handle.sub_id);
let _ = handle.shared.close();
}
Ok(1)
},
)?;
Ok(())
}
unsafe fn set_error_msg(
pz_err_msg: *mut *mut c_char,
p_api: *mut ffi::sqlite3_api_routines,
message: &str,
) {
if pz_err_msg.is_null() || p_api.is_null() {
return;
}
let Some(malloc) = (unsafe { (*p_api).malloc }) else {
return;
};
let len = match message.len().checked_add(1) {
Some(len) if c_int::try_from(len).is_ok() => len,
_ => return,
};
let ptr = unsafe { malloc(len as c_int) }.cast::<c_char>();
if ptr.is_null() {
return;
}
unsafe {
ptr::copy_nonoverlapping(message.as_ptr().cast::<c_char>(), ptr, message.len());
*ptr.add(message.len()) = 0;
*pz_err_msg = ptr;
}
}
unsafe fn extension_init2(
db: *mut ffi::sqlite3,
pz_err_msg: *mut *mut c_char,
p_api: *mut ffi::sqlite3_api_routines,
) -> c_int {
if p_api.is_null() {
return ffi::SQLITE_ERROR;
}
let result = unsafe { ffi::rusqlite_extension_init2(p_api) }
.map_err(Error::from)
.and_then(|()| unsafe { Connection::from_handle(db) })
.and_then(extension_init);
match result {
Ok(true) => ffi::SQLITE_OK_LOAD_PERMANENTLY,
Ok(false) => ffi::SQLITE_OK,
Err(err) => {
unsafe { set_error_msg(pz_err_msg, p_api, &err.to_string()) };
ffi::SQLITE_ERROR
}
}
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn sqlite3_honkerext_init(
db: *mut ffi::sqlite3,
pz_err_msg: *mut *mut c_char,
p_api: *mut ffi::sqlite3_api_routines,
) -> c_int {
match catch_unwind(AssertUnwindSafe(|| unsafe {
extension_init2(db, pz_err_msg, p_api)
})) {
Ok(code) => code,
Err(payload) => {
let err = panic_error(payload);
unsafe { set_error_msg(pz_err_msg, p_api, &err.to_string()) };
ffi::SQLITE_ERROR
}
}
}
pub struct HonkerWatcherHandle {
shared: Arc<honker_core::SharedUpdateWatcher>,
sub_id: u64,
rx: Receiver<()>,
}
unsafe fn cstr_to_string(ptr: *const c_char) -> std::result::Result<Option<String>, String> {
if ptr.is_null() {
return Ok(None);
}
let s = unsafe { CStr::from_ptr(ptr) }
.to_str()
.map_err(|e| format!("invalid UTF-8: {e}"))?;
if s.is_empty() {
Ok(None)
} else {
Ok(Some(s.to_string()))
}
}
unsafe fn write_error(buf: *mut c_char, len: usize, message: &str) {
if buf.is_null() || len == 0 {
return;
}
let bytes = message.as_bytes();
let copy_len = bytes.len().min(len.saturating_sub(1));
unsafe {
ptr::copy_nonoverlapping(bytes.as_ptr().cast::<c_char>(), buf, copy_len);
*buf.add(copy_len) = 0;
}
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn honker_watcher_open(
db_path: *const c_char,
backend: *const c_char,
err_buf: *mut c_char,
err_buf_len: usize,
) -> *mut HonkerWatcherHandle {
match catch_unwind(AssertUnwindSafe(|| {
if db_path.is_null() {
return Err("db_path is null".to_string());
}
let path = unsafe { CStr::from_ptr(db_path) }
.to_str()
.map_err(|e| format!("invalid db_path UTF-8: {e}"))?;
let backend = unsafe { cstr_to_string(backend) }?;
let handle = open_watcher_handle(path, backend.as_deref())?;
Ok(Box::into_raw(Box::new(handle)))
})) {
Ok(Ok(ptr)) => ptr,
Ok(Err(err)) => {
unsafe { write_error(err_buf, err_buf_len, &err) };
ptr::null_mut()
}
Err(payload) => {
let err = panic_error(payload).to_string();
unsafe { write_error(err_buf, err_buf_len, &err) };
ptr::null_mut()
}
}
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn honker_watcher_wait(
handle: *mut HonkerWatcherHandle,
timeout_ms: u64,
) -> c_int {
if handle.is_null() {
return -1;
}
match catch_unwind(AssertUnwindSafe(|| {
let handle = unsafe { &mut *handle };
match handle.rx.recv_timeout(Duration::from_millis(timeout_ms)) {
Ok(()) => 1,
Err(RecvTimeoutError::Timeout) => 0,
Err(RecvTimeoutError::Disconnected) => -1,
}
})) {
Ok(code) => code,
Err(_) => -2,
}
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn honker_watcher_close(handle: *mut HonkerWatcherHandle) {
if handle.is_null() {
return;
}
let handle = unsafe { Box::from_raw(handle) };
handle.shared.unsubscribe(handle.sub_id);
let _ = handle.shared.close();
}