1use rusqlite::ffi;
24use rusqlite::functions::FunctionFlags;
25use rusqlite::{Connection, Error, Result};
26use std::collections::HashMap;
27use std::ffi::CStr;
28use std::os::raw::{c_char, c_int};
29use std::panic::{AssertUnwindSafe, catch_unwind};
30use std::path::PathBuf;
31use std::sync::Arc;
32use std::sync::Mutex as StdMutex;
33use std::sync::atomic::{AtomicU64, Ordering};
34use std::sync::mpsc::{Receiver, RecvTimeoutError};
35use std::time::Duration;
36use std::{ptr, sync::LazyLock};
37
38fn panic_error(payload: Box<dyn std::any::Any + Send>) -> Error {
39 let msg = if let Some(s) = payload.downcast_ref::<&str>() {
40 *s
41 } else if let Some(s) = payload.downcast_ref::<String>() {
42 s.as_str()
43 } else {
44 "non-string panic payload"
45 };
46 Error::UserFunctionError(Box::new(std::io::Error::other(format!(
47 "honker extension initialization panicked: {msg}"
48 ))))
49}
50
51fn extension_init(conn: Connection) -> Result<bool> {
52 match catch_unwind(AssertUnwindSafe(|| {
53 honker_core::attach_notify(&conn).map_err(|e| {
54 Error::UserFunctionError(Box::new(std::io::Error::other(e.to_string())))
55 })?;
56 honker_core::attach_honker_functions(&conn)?;
57 attach_watcher_sql_functions(&conn)?;
58 Ok(true)
59 })) {
60 Ok(result) => result,
61 Err(payload) => Err(panic_error(payload)),
62 }
63}
64
65static SQL_WATCHERS: LazyLock<StdMutex<HashMap<u64, HonkerWatcherHandle>>> =
66 LazyLock::new(|| StdMutex::new(HashMap::new()));
67static NEXT_SQL_WATCHER_ID: AtomicU64 = AtomicU64::new(1);
68
69fn open_watcher_handle(
70 db_path: &str,
71 backend: Option<&str>,
72) -> std::result::Result<HonkerWatcherHandle, String> {
73 let backend = honker_core::WatcherBackend::parse(backend.filter(|s| !s.is_empty()))?;
74 backend.probe(PathBuf::from(db_path).as_path())?;
75 let shared = Arc::new(honker_core::SharedUpdateWatcher::new_with_config(
76 PathBuf::from(db_path),
77 honker_core::WatcherConfig { backend },
78 ));
79 let (sub_id, rx) = shared.subscribe();
80 Ok(HonkerWatcherHandle { shared, sub_id, rx })
81}
82
83fn attach_watcher_sql_functions(conn: &Connection) -> Result<()> {
84 conn.create_scalar_function(
85 "honker_update_watcher_open",
86 2,
87 FunctionFlags::SQLITE_UTF8,
88 |ctx| {
89 let db_path: String = ctx.get(0)?;
90 let backend: Option<String> = ctx.get(1)?;
91 let handle = open_watcher_handle(&db_path, backend.as_deref()).map_err(|e| {
92 rusqlite::Error::UserFunctionError(Box::new(std::io::Error::other(e)))
93 })?;
94 let id = NEXT_SQL_WATCHER_ID.fetch_add(1, Ordering::Relaxed);
95 SQL_WATCHERS.lock().unwrap().insert(id, handle);
96 Ok(id as i64)
97 },
98 )?;
99 conn.create_scalar_function(
100 "honker_update_watcher_wait",
101 2,
102 FunctionFlags::SQLITE_UTF8,
103 |ctx| {
104 let id: i64 = ctx.get(0)?;
105 let timeout_ms: i64 = ctx.get(1)?;
106 let Some(handle) = SQL_WATCHERS.lock().unwrap().remove(&(id as u64)) else {
107 return Ok(-1);
108 };
109 let timeout_ms = timeout_ms.max(0) as u64;
110 let code = match handle.rx.recv_timeout(Duration::from_millis(timeout_ms)) {
111 Ok(()) => 1,
112 Err(RecvTimeoutError::Timeout) => 0,
113 Err(RecvTimeoutError::Disconnected) => -1,
114 };
115 if code != -1 {
116 SQL_WATCHERS.lock().unwrap().insert(id as u64, handle);
117 } else {
118 handle.shared.unsubscribe(handle.sub_id);
119 let _ = handle.shared.close();
120 }
121 Ok(code)
122 },
123 )?;
124 conn.create_scalar_function(
125 "honker_update_watcher_close",
126 1,
127 FunctionFlags::SQLITE_UTF8,
128 |ctx| {
129 let id: i64 = ctx.get(0)?;
130 if let Some(handle) = SQL_WATCHERS.lock().unwrap().remove(&(id as u64)) {
131 handle.shared.unsubscribe(handle.sub_id);
132 let _ = handle.shared.close();
133 }
134 Ok(1)
135 },
136 )?;
137 Ok(())
138}
139
140unsafe fn set_error_msg(
141 pz_err_msg: *mut *mut c_char,
142 p_api: *mut ffi::sqlite3_api_routines,
143 message: &str,
144) {
145 if pz_err_msg.is_null() || p_api.is_null() {
146 return;
147 }
148 let Some(malloc) = (unsafe { (*p_api).malloc }) else {
149 return;
150 };
151 let len = match message.len().checked_add(1) {
152 Some(len) if c_int::try_from(len).is_ok() => len,
153 _ => return,
154 };
155 let ptr = unsafe { malloc(len as c_int) }.cast::<c_char>();
156 if ptr.is_null() {
157 return;
158 }
159 unsafe {
160 ptr::copy_nonoverlapping(message.as_ptr().cast::<c_char>(), ptr, message.len());
161 *ptr.add(message.len()) = 0;
162 *pz_err_msg = ptr;
163 }
164}
165
166unsafe fn extension_init2(
167 db: *mut ffi::sqlite3,
168 pz_err_msg: *mut *mut c_char,
169 p_api: *mut ffi::sqlite3_api_routines,
170) -> c_int {
171 if p_api.is_null() {
172 return ffi::SQLITE_ERROR;
173 }
174 let result = unsafe { ffi::rusqlite_extension_init2(p_api) }
175 .map_err(Error::from)
176 .and_then(|()| unsafe { Connection::from_handle(db) })
177 .and_then(extension_init);
178 match result {
179 Ok(true) => ffi::SQLITE_OK_LOAD_PERMANENTLY,
180 Ok(false) => ffi::SQLITE_OK,
181 Err(err) => {
182 unsafe { set_error_msg(pz_err_msg, p_api, &err.to_string()) };
183 ffi::SQLITE_ERROR
184 }
185 }
186}
187
188#[unsafe(no_mangle)]
197pub unsafe extern "C" fn sqlite3_honkerext_init(
198 db: *mut ffi::sqlite3,
199 pz_err_msg: *mut *mut c_char,
200 p_api: *mut ffi::sqlite3_api_routines,
201) -> c_int {
202 match catch_unwind(AssertUnwindSafe(|| unsafe {
203 extension_init2(db, pz_err_msg, p_api)
204 })) {
205 Ok(code) => code,
206 Err(payload) => {
207 let err = panic_error(payload);
208 unsafe { set_error_msg(pz_err_msg, p_api, &err.to_string()) };
209 ffi::SQLITE_ERROR
210 }
211 }
212}
213
214pub struct HonkerWatcherHandle {
215 shared: Arc<honker_core::SharedUpdateWatcher>,
216 sub_id: u64,
217 rx: Receiver<()>,
218}
219
220unsafe fn cstr_to_string(ptr: *const c_char) -> std::result::Result<Option<String>, String> {
221 if ptr.is_null() {
222 return Ok(None);
223 }
224 let s = unsafe { CStr::from_ptr(ptr) }
225 .to_str()
226 .map_err(|e| format!("invalid UTF-8: {e}"))?;
227 if s.is_empty() {
228 Ok(None)
229 } else {
230 Ok(Some(s.to_string()))
231 }
232}
233
234unsafe fn write_error(buf: *mut c_char, len: usize, message: &str) {
235 if buf.is_null() || len == 0 {
236 return;
237 }
238 let bytes = message.as_bytes();
239 let copy_len = bytes.len().min(len.saturating_sub(1));
240 unsafe {
241 ptr::copy_nonoverlapping(bytes.as_ptr().cast::<c_char>(), buf, copy_len);
242 *buf.add(copy_len) = 0;
243 }
244}
245
246#[unsafe(no_mangle)]
255pub unsafe extern "C" fn honker_watcher_open(
256 db_path: *const c_char,
257 backend: *const c_char,
258 err_buf: *mut c_char,
259 err_buf_len: usize,
260) -> *mut HonkerWatcherHandle {
261 match catch_unwind(AssertUnwindSafe(|| {
262 if db_path.is_null() {
263 return Err("db_path is null".to_string());
264 }
265 let path = unsafe { CStr::from_ptr(db_path) }
266 .to_str()
267 .map_err(|e| format!("invalid db_path UTF-8: {e}"))?;
268 let backend = unsafe { cstr_to_string(backend) }?;
269 let handle = open_watcher_handle(path, backend.as_deref())?;
270 Ok(Box::into_raw(Box::new(handle)))
271 })) {
272 Ok(Ok(ptr)) => ptr,
273 Ok(Err(err)) => {
274 unsafe { write_error(err_buf, err_buf_len, &err) };
275 ptr::null_mut()
276 }
277 Err(payload) => {
278 let err = panic_error(payload).to_string();
279 unsafe { write_error(err_buf, err_buf_len, &err) };
280 ptr::null_mut()
281 }
282 }
283}
284
285#[unsafe(no_mangle)]
297pub unsafe extern "C" fn honker_watcher_wait(
298 handle: *mut HonkerWatcherHandle,
299 timeout_ms: u64,
300) -> c_int {
301 if handle.is_null() {
302 return -1;
303 }
304 match catch_unwind(AssertUnwindSafe(|| {
305 let handle = unsafe { &mut *handle };
306 match handle.rx.recv_timeout(Duration::from_millis(timeout_ms)) {
307 Ok(()) => 1,
308 Err(RecvTimeoutError::Timeout) => 0,
309 Err(RecvTimeoutError::Disconnected) => -1,
310 }
311 })) {
312 Ok(code) => code,
313 Err(_) => -2,
314 }
315}
316
317#[unsafe(no_mangle)]
323pub unsafe extern "C" fn honker_watcher_close(handle: *mut HonkerWatcherHandle) {
324 if handle.is_null() {
325 return;
326 }
327 let handle = unsafe { Box::from_raw(handle) };
328 handle.shared.unsubscribe(handle.sub_id);
329 let _ = handle.shared.close();
330}