Skip to main content

ffi_bridge/
callback.rs

1//! # callback — Named callback registry
2//!
3//! Provides a thread-safe, global registry of named callbacks.
4//! Callbacks can be registered either from Rust (typed closure) or from
5//! the C/Go side via `ffi_register_callback` (C function pointer).
6//!
7//! ## Design
8//!
9//! The registry is a `Mutex<HashMap<String, CallbackFn>>` initialized lazily
10//! via `once_cell`. All exported functions are panic-safe via [`catch_panic`].
11//!
12//! ## Thread safety
13//!
14//! The registry is protected by a `std::sync::Mutex`. If a thread panics while
15//! holding the lock, the lock becomes poisoned and all subsequent operations
16//! return [`FfiError::LockPoisoned`].
17
18use std::collections::HashMap;
19use std::os::raw::c_char;
20use std::sync::Mutex;
21
22use once_cell::sync::Lazy;
23
24use crate::errors::{catch_panic, FfiError, FfiResult};
25use crate::memory::FfiBuffer;
26use crate::types::cstr_to_string;
27
28// ─── Registry ─────────────────────────────────────────────────────────────────
29
30type CallbackFn = Box<dyn Fn(FfiBuffer) -> FfiResult + Send + Sync>;
31
32static CALLBACKS: Lazy<Mutex<HashMap<String, CallbackFn>>> =
33    Lazy::new(|| Mutex::new(HashMap::new()));
34
35// ─── Rust-native API ──────────────────────────────────────────────────────────
36
37/// Register a named callback from Rust code.
38///
39/// The closure receives an [`FfiBuffer`] input and must return an [`FfiResult`].
40///
41/// # Thread safety
42///
43/// This function acquires the global callback registry lock.
44/// Returns `Err` if the lock is poisoned.
45pub fn register_callback<F>(name: &str, f: F) -> Result<(), FfiError>
46where
47    F: Fn(FfiBuffer) -> FfiResult + Send + Sync + 'static,
48{
49    let mut guard = CALLBACKS.lock().map_err(|_| FfiError::LockPoisoned)?;
50    guard.insert(name.to_string(), Box::new(f));
51    Ok(())
52}
53
54/// Remove a callback by name from Rust code.
55///
56/// Returns `Ok(true)` if the callback was removed, `Ok(false)` if not found.
57pub fn unregister_callback(name: &str) -> Result<bool, FfiError> {
58    let mut guard = CALLBACKS.lock().map_err(|_| FfiError::LockPoisoned)?;
59    Ok(guard.remove(name).is_some())
60}
61
62/// Return the number of currently registered callbacks.
63pub fn callback_count() -> usize {
64    CALLBACKS.lock().map(|g| g.len()).unwrap_or(0)
65}
66
67// ─── FFI-exported API ─────────────────────────────────────────────────────────
68
69/// Register a C function pointer as a named callback.
70///
71/// `name` must be a valid null-terminated UTF-8 string.
72/// Returns `0` on success, `-1` if `name` is null or not valid UTF-8,
73/// `-2` if the registry lock is poisoned.
74///
75/// **Exported as:** `ffi_register_callback`
76///
77/// # Safety
78///
79/// `name` must be a valid null-terminated C string.
80#[no_mangle]
81pub unsafe extern "C" fn ffi_register_callback(
82    name: *const c_char,
83    cb: extern "C" fn(FfiBuffer) -> FfiResult,
84) -> i32 {
85    let name_str = match cstr_to_string(name) {
86        Ok(s) => s,
87        Err(_) => return -1,
88    };
89    let wrapped = move |buf: FfiBuffer| cb(buf);
90    match CALLBACKS.lock() {
91        Ok(mut guard) => {
92            guard.insert(name_str, Box::new(wrapped));
93            0
94        }
95        Err(_) => -2,
96    }
97}
98
99/// Invoke a registered callback by name.
100///
101/// Returns `FFI_ERR_NOT_FOUND` if no callback with the given name is registered.
102/// Returns `FFI_ERR_PANIC` if the callback panics.
103///
104/// **Exported as:** `ffi_invoke_callback`
105///
106/// # Safety
107///
108/// `name` must be a valid null-terminated C string.
109#[no_mangle]
110pub unsafe extern "C" fn ffi_invoke_callback(name: *const c_char, input: FfiBuffer) -> FfiResult {
111    // We must resolve the name before entering catch_panic (CStr isn't UnwindSafe).
112    let name_str = match cstr_to_string(name) {
113        Ok(s) => s,
114        Err(e) => return FfiResult::err(e),
115    };
116
117    catch_panic(move || {
118        let guard = CALLBACKS.lock().map_err(|_| FfiError::LockPoisoned)?;
119
120        let cb = guard
121            .get(&name_str)
122            .ok_or_else(|| FfiError::NotFound(name_str.clone()))?;
123
124        let result = cb(input);
125
126        if result.is_ok() {
127            // Extract the payload, leaving the (empty) error_message as-is.
128            // FfiResult has no Drop, so the struct fields are just stack values;
129            // we take ownership directly without mem::forget.
130            let FfiResult { payload, .. } = result;
131            Ok(payload)
132        } else {
133            Err(FfiError::Unknown(format!(
134                "callback '{}' returned error code {:?}",
135                name_str, result.error_code
136            )))
137        }
138    })
139}
140
141/// Remove a registered callback by name.
142///
143/// Returns `0` if removed, `-1` if not found, `-2` if lock is poisoned.
144///
145/// **Exported as:** `ffi_unregister_callback`
146///
147/// # Safety
148///
149/// `name` must be a valid null-terminated C string.
150#[no_mangle]
151pub unsafe extern "C" fn ffi_unregister_callback(name: *const c_char) -> i32 {
152    let name_str = match cstr_to_string(name) {
153        Ok(s) => s,
154        Err(_) => return -1,
155    };
156    match CALLBACKS.lock() {
157        Ok(mut guard) => {
158            if guard.remove(&name_str).is_some() {
159                0
160            } else {
161                -1
162            }
163        }
164        Err(_) => -2,
165    }
166}
167
168/// Return the number of registered callbacks.
169///
170/// **Exported as:** `ffi_callback_count`
171#[no_mangle]
172pub extern "C" fn ffi_callback_count() -> usize {
173    callback_count()
174}
175
176// ─── Tests ────────────────────────────────────────────────────────────────────
177
178#[cfg(test)]
179mod tests {
180    use super::*;
181    use crate::errors::{ffi_result_free, FfiErrorCode};
182
183    fn unique_name(prefix: &str) -> String {
184        use std::sync::atomic::{AtomicU64, Ordering};
185        static COUNTER: AtomicU64 = AtomicU64::new(0);
186        format!("{prefix}_{}", COUNTER.fetch_add(1, Ordering::Relaxed))
187    }
188
189    #[test]
190    fn register_and_invoke_rust_callback() {
191        let name = unique_name("test_echo");
192        register_callback(&name, |buf| {
193            let bytes = unsafe { buf.as_slice() }.to_vec();
194            FfiResult::ok(FfiBuffer::from_vec(bytes))
195        })
196        .unwrap();
197
198        let input = FfiBuffer::from_vec(b"test payload".to_vec());
199        let result = unsafe {
200            let c_name = std::ffi::CString::new(name.as_str()).unwrap();
201            ffi_invoke_callback(c_name.as_ptr(), input)
202        };
203        assert!(result.is_ok());
204        let slice = unsafe { result.payload.as_slice() };
205        assert_eq!(slice, b"test payload");
206        ffi_result_free(result);
207    }
208
209    #[test]
210    fn invoke_unknown_callback_returns_not_found() {
211        let result = unsafe {
212            let c_name = std::ffi::CString::new("__nonexistent_callback__").unwrap();
213            ffi_invoke_callback(c_name.as_ptr(), FfiBuffer::null())
214        };
215        assert_eq!(result.error_code, FfiErrorCode::NotFound);
216        ffi_result_free(result);
217    }
218
219    #[test]
220    fn unregister_removes_callback() {
221        let name = unique_name("test_unregister");
222        register_callback(&name, |buf| FfiResult::ok(buf)).unwrap();
223
224        let removed = unregister_callback(&name).unwrap();
225        assert!(removed);
226
227        let not_removed = unregister_callback(&name).unwrap();
228        assert!(!not_removed);
229    }
230
231    #[test]
232    fn callback_count_tracks_registrations() {
233        let name = unique_name("test_count");
234        let before = callback_count();
235        register_callback(&name, |buf| FfiResult::ok(buf)).unwrap();
236        assert_eq!(callback_count(), before + 1);
237        unregister_callback(&name).unwrap();
238        assert_eq!(callback_count(), before);
239    }
240}