Skip to main content

honker_ext/
lib.rs

1//! honker SQLite loadable extension.
2//!
3//! Thin wrapper around `honker-core`. Registers:
4//!
5//!   * `notify()` SQL scalar function + `_honker_notifications`
6//!     table — via `honker_core::attach_notify`.
7//!   * Every `honker_*` queue / lock / rate-limit / scheduler / result
8//!     function — via `honker_core::attach_honker_functions`.
9//!
10//!     .load ./libhonker_ext
11//!     SELECT honker_bootstrap();
12//!     INSERT INTO _honker_live (queue, payload)
13//!     VALUES ('emails', '{"to": "alice"}');
14//!     SELECT honker_claim_batch('emails', 'worker-1', 32, 300);
15//!     SELECT honker_ack_batch('[1,2,3]', 'worker-1');
16//!     SELECT notify('orders', '{"id": 42}');
17//!
18//! Actual SQL implementations live in `honker_core::honker_ops`
19//! so the Python (PyO3) and Node (napi-rs) bindings can register the
20//! same functions on their own connections without loading this
21//! `.dylib`. One source of truth for the SQL.
22
23use rusqlite::ffi;
24use rusqlite::functions::FunctionFlags;
25use rusqlite::{Connection, Error, Result};
26use std::collections::HashMap;
27use std::ffi::CStr;
28use std::os::raw::{c_char, c_int};
29use std::panic::{AssertUnwindSafe, catch_unwind};
30use std::path::PathBuf;
31use std::sync::Arc;
32use std::sync::Mutex as StdMutex;
33use std::sync::atomic::{AtomicU64, Ordering};
34use std::sync::mpsc::{Receiver, RecvTimeoutError};
35use std::time::Duration;
36use std::{ptr, sync::LazyLock};
37
38fn panic_error(payload: Box<dyn std::any::Any + Send>) -> Error {
39    let msg = if let Some(s) = payload.downcast_ref::<&str>() {
40        *s
41    } else if let Some(s) = payload.downcast_ref::<String>() {
42        s.as_str()
43    } else {
44        "non-string panic payload"
45    };
46    Error::UserFunctionError(Box::new(std::io::Error::other(format!(
47        "honker extension initialization panicked: {msg}"
48    ))))
49}
50
51fn extension_init(conn: Connection) -> Result<bool> {
52    match catch_unwind(AssertUnwindSafe(|| {
53        honker_core::attach_notify(&conn).map_err(|e| {
54            Error::UserFunctionError(Box::new(std::io::Error::other(e.to_string())))
55        })?;
56        honker_core::attach_honker_functions(&conn)?;
57        attach_watcher_sql_functions(&conn)?;
58        Ok(true)
59    })) {
60        Ok(result) => result,
61        Err(payload) => Err(panic_error(payload)),
62    }
63}
64
65static SQL_WATCHERS: LazyLock<StdMutex<HashMap<u64, HonkerWatcherHandle>>> =
66    LazyLock::new(|| StdMutex::new(HashMap::new()));
67static NEXT_SQL_WATCHER_ID: AtomicU64 = AtomicU64::new(1);
68
69fn open_watcher_handle(
70    db_path: &str,
71    backend: Option<&str>,
72) -> std::result::Result<HonkerWatcherHandle, String> {
73    let backend = honker_core::WatcherBackend::parse(backend.filter(|s| !s.is_empty()))?;
74    backend.probe(PathBuf::from(db_path).as_path())?;
75    let shared = Arc::new(honker_core::SharedUpdateWatcher::new_with_config(
76        PathBuf::from(db_path),
77        honker_core::WatcherConfig { backend },
78    ));
79    let (sub_id, rx) = shared.subscribe();
80    Ok(HonkerWatcherHandle { shared, sub_id, rx })
81}
82
83fn attach_watcher_sql_functions(conn: &Connection) -> Result<()> {
84    conn.create_scalar_function(
85        "honker_update_watcher_open",
86        2,
87        FunctionFlags::SQLITE_UTF8,
88        |ctx| {
89            let db_path: String = ctx.get(0)?;
90            let backend: Option<String> = ctx.get(1)?;
91            let handle = open_watcher_handle(&db_path, backend.as_deref()).map_err(|e| {
92                rusqlite::Error::UserFunctionError(Box::new(std::io::Error::other(e)))
93            })?;
94            let id = NEXT_SQL_WATCHER_ID.fetch_add(1, Ordering::Relaxed);
95            SQL_WATCHERS.lock().unwrap().insert(id, handle);
96            Ok(id as i64)
97        },
98    )?;
99    conn.create_scalar_function(
100        "honker_update_watcher_wait",
101        2,
102        FunctionFlags::SQLITE_UTF8,
103        |ctx| {
104            let id: i64 = ctx.get(0)?;
105            let timeout_ms: i64 = ctx.get(1)?;
106            let Some(handle) = SQL_WATCHERS.lock().unwrap().remove(&(id as u64)) else {
107                return Ok(-1);
108            };
109            let timeout_ms = timeout_ms.max(0) as u64;
110            let code = match handle.rx.recv_timeout(Duration::from_millis(timeout_ms)) {
111                Ok(()) => 1,
112                Err(RecvTimeoutError::Timeout) => 0,
113                Err(RecvTimeoutError::Disconnected) => -1,
114            };
115            if code != -1 {
116                SQL_WATCHERS.lock().unwrap().insert(id as u64, handle);
117            } else {
118                handle.shared.unsubscribe(handle.sub_id);
119                let _ = handle.shared.close();
120            }
121            Ok(code)
122        },
123    )?;
124    conn.create_scalar_function(
125        "honker_update_watcher_close",
126        1,
127        FunctionFlags::SQLITE_UTF8,
128        |ctx| {
129            let id: i64 = ctx.get(0)?;
130            if let Some(handle) = SQL_WATCHERS.lock().unwrap().remove(&(id as u64)) {
131                handle.shared.unsubscribe(handle.sub_id);
132                let _ = handle.shared.close();
133            }
134            Ok(1)
135        },
136    )?;
137    Ok(())
138}
139
140unsafe fn set_error_msg(
141    pz_err_msg: *mut *mut c_char,
142    p_api: *mut ffi::sqlite3_api_routines,
143    message: &str,
144) {
145    if pz_err_msg.is_null() || p_api.is_null() {
146        return;
147    }
148    let Some(malloc) = (unsafe { (*p_api).malloc }) else {
149        return;
150    };
151    let len = match message.len().checked_add(1) {
152        Some(len) if c_int::try_from(len).is_ok() => len,
153        _ => return,
154    };
155    let ptr = unsafe { malloc(len as c_int) }.cast::<c_char>();
156    if ptr.is_null() {
157        return;
158    }
159    unsafe {
160        ptr::copy_nonoverlapping(message.as_ptr().cast::<c_char>(), ptr, message.len());
161        *ptr.add(message.len()) = 0;
162        *pz_err_msg = ptr;
163    }
164}
165
166unsafe fn extension_init2(
167    db: *mut ffi::sqlite3,
168    pz_err_msg: *mut *mut c_char,
169    p_api: *mut ffi::sqlite3_api_routines,
170) -> c_int {
171    if p_api.is_null() {
172        return ffi::SQLITE_ERROR;
173    }
174    let result = unsafe { ffi::rusqlite_extension_init2(p_api) }
175        .map_err(Error::from)
176        .and_then(|()| unsafe { Connection::from_handle(db) })
177        .and_then(extension_init);
178    match result {
179        Ok(true) => ffi::SQLITE_OK_LOAD_PERMANENTLY,
180        Ok(false) => ffi::SQLITE_OK,
181        Err(err) => {
182            unsafe { set_error_msg(pz_err_msg, p_api, &err.to_string()) };
183            ffi::SQLITE_ERROR
184        }
185    }
186}
187
188/// SQLite entry point. Name must match `sqlite3_<extname>_init`; SQLite
189/// derives `<extname>` from the filename — stripping the `lib` prefix
190/// and any non-alphabetic characters:
191/// `libhonker_ext.dylib` -> `honker_ext` -> `honkerext`
192/// -> `sqlite3_honkerext_init`.
193///
194/// # Safety
195/// Called by SQLite. All pointers are SQLite-owned.
196#[unsafe(no_mangle)]
197pub unsafe extern "C" fn sqlite3_honkerext_init(
198    db: *mut ffi::sqlite3,
199    pz_err_msg: *mut *mut c_char,
200    p_api: *mut ffi::sqlite3_api_routines,
201) -> c_int {
202    match catch_unwind(AssertUnwindSafe(|| unsafe {
203        extension_init2(db, pz_err_msg, p_api)
204    })) {
205        Ok(code) => code,
206        Err(payload) => {
207            let err = panic_error(payload);
208            unsafe { set_error_msg(pz_err_msg, p_api, &err.to_string()) };
209            ffi::SQLITE_ERROR
210        }
211    }
212}
213
214pub struct HonkerWatcherHandle {
215    shared: Arc<honker_core::SharedUpdateWatcher>,
216    sub_id: u64,
217    rx: Receiver<()>,
218}
219
220unsafe fn cstr_to_string(ptr: *const c_char) -> std::result::Result<Option<String>, String> {
221    if ptr.is_null() {
222        return Ok(None);
223    }
224    let s = unsafe { CStr::from_ptr(ptr) }
225        .to_str()
226        .map_err(|e| format!("invalid UTF-8: {e}"))?;
227    if s.is_empty() {
228        Ok(None)
229    } else {
230        Ok(Some(s.to_string()))
231    }
232}
233
234unsafe fn write_error(buf: *mut c_char, len: usize, message: &str) {
235    if buf.is_null() || len == 0 {
236        return;
237    }
238    let bytes = message.as_bytes();
239    let copy_len = bytes.len().min(len.saturating_sub(1));
240    unsafe {
241        ptr::copy_nonoverlapping(bytes.as_ptr().cast::<c_char>(), buf, copy_len);
242        *buf.add(copy_len) = 0;
243    }
244}
245
246/// Open a core-backed update watcher over `db_path`.
247///
248/// Returns null on error and writes a NUL-terminated diagnostic into
249/// `err_buf` when provided. `backend` accepts the same exact aliases as
250/// `honker_core::WatcherBackend::parse`; null / empty means polling.
251///
252/// # Safety
253/// All pointers must be valid NUL-terminated strings when non-null.
254#[unsafe(no_mangle)]
255pub unsafe extern "C" fn honker_watcher_open(
256    db_path: *const c_char,
257    backend: *const c_char,
258    err_buf: *mut c_char,
259    err_buf_len: usize,
260) -> *mut HonkerWatcherHandle {
261    match catch_unwind(AssertUnwindSafe(|| {
262        if db_path.is_null() {
263            return Err("db_path is null".to_string());
264        }
265        let path = unsafe { CStr::from_ptr(db_path) }
266            .to_str()
267            .map_err(|e| format!("invalid db_path UTF-8: {e}"))?;
268        let backend = unsafe { cstr_to_string(backend) }?;
269        let handle = open_watcher_handle(path, backend.as_deref())?;
270        Ok(Box::into_raw(Box::new(handle)))
271    })) {
272        Ok(Ok(ptr)) => ptr,
273        Ok(Err(err)) => {
274            unsafe { write_error(err_buf, err_buf_len, &err) };
275            ptr::null_mut()
276        }
277        Err(payload) => {
278            let err = panic_error(payload).to_string();
279            unsafe { write_error(err_buf, err_buf_len, &err) };
280            ptr::null_mut()
281        }
282    }
283}
284
285/// Wait for the next database update.
286///
287/// Returns:
288/// * `1` when an update was observed
289/// * `0` on timeout
290/// * `-1` when the watcher/subscription has closed or died
291/// * `-2` if this function catches an internal panic
292///
293/// # Safety
294/// `handle` must be a pointer returned by `honker_watcher_open` and not
295/// yet passed to `honker_watcher_close`.
296#[unsafe(no_mangle)]
297pub unsafe extern "C" fn honker_watcher_wait(
298    handle: *mut HonkerWatcherHandle,
299    timeout_ms: u64,
300) -> c_int {
301    if handle.is_null() {
302        return -1;
303    }
304    match catch_unwind(AssertUnwindSafe(|| {
305        let handle = unsafe { &mut *handle };
306        match handle.rx.recv_timeout(Duration::from_millis(timeout_ms)) {
307            Ok(()) => 1,
308            Err(RecvTimeoutError::Timeout) => 0,
309            Err(RecvTimeoutError::Disconnected) => -1,
310        }
311    })) {
312        Ok(code) => code,
313        Err(_) => -2,
314    }
315}
316
317/// Close a watcher opened by `honker_watcher_open`.
318///
319/// # Safety
320/// `handle` must be null or a pointer returned by `honker_watcher_open`.
321/// Passing the same non-null pointer twice is undefined behavior.
322#[unsafe(no_mangle)]
323pub unsafe extern "C" fn honker_watcher_close(handle: *mut HonkerWatcherHandle) {
324    if handle.is_null() {
325        return;
326    }
327    let handle = unsafe { Box::from_raw(handle) };
328    handle.shared.unsubscribe(handle.sub_id);
329    let _ = handle.shared.close();
330}