1use std::collections::HashMap;
19use std::os::raw::c_char;
20use std::sync::Mutex;
21
22use once_cell::sync::Lazy;
23
24use crate::errors::{catch_panic, FfiError, FfiResult};
25use crate::memory::FfiBuffer;
26use crate::types::cstr_to_string;
27
28type CallbackFn = Box<dyn Fn(FfiBuffer) -> FfiResult + Send + Sync>;
31
32static CALLBACKS: Lazy<Mutex<HashMap<String, CallbackFn>>> =
33 Lazy::new(|| Mutex::new(HashMap::new()));
34
35pub fn register_callback<F>(name: &str, f: F) -> Result<(), FfiError>
46where
47 F: Fn(FfiBuffer) -> FfiResult + Send + Sync + 'static,
48{
49 let mut guard = CALLBACKS.lock().map_err(|_| FfiError::LockPoisoned)?;
50 guard.insert(name.to_string(), Box::new(f));
51 Ok(())
52}
53
54pub fn unregister_callback(name: &str) -> Result<bool, FfiError> {
58 let mut guard = CALLBACKS.lock().map_err(|_| FfiError::LockPoisoned)?;
59 Ok(guard.remove(name).is_some())
60}
61
62pub fn callback_count() -> usize {
64 CALLBACKS.lock().map(|g| g.len()).unwrap_or(0)
65}
66
67#[no_mangle]
81pub unsafe extern "C" fn ffi_register_callback(
82 name: *const c_char,
83 cb: extern "C" fn(FfiBuffer) -> FfiResult,
84) -> i32 {
85 let name_str = match cstr_to_string(name) {
86 Ok(s) => s,
87 Err(_) => return -1,
88 };
89 let wrapped = move |buf: FfiBuffer| cb(buf);
90 match CALLBACKS.lock() {
91 Ok(mut guard) => {
92 guard.insert(name_str, Box::new(wrapped));
93 0
94 }
95 Err(_) => -2,
96 }
97}
98
99#[no_mangle]
110pub unsafe extern "C" fn ffi_invoke_callback(name: *const c_char, input: FfiBuffer) -> FfiResult {
111 let name_str = match cstr_to_string(name) {
113 Ok(s) => s,
114 Err(e) => return FfiResult::err(e),
115 };
116
117 catch_panic(move || {
118 let guard = CALLBACKS.lock().map_err(|_| FfiError::LockPoisoned)?;
119
120 let cb = guard
121 .get(&name_str)
122 .ok_or_else(|| FfiError::NotFound(name_str.clone()))?;
123
124 let result = cb(input);
125
126 if result.is_ok() {
127 let FfiResult { payload, .. } = result;
131 Ok(payload)
132 } else {
133 Err(FfiError::Unknown(format!(
134 "callback '{}' returned error code {:?}",
135 name_str, result.error_code
136 )))
137 }
138 })
139}
140
141#[no_mangle]
151pub unsafe extern "C" fn ffi_unregister_callback(name: *const c_char) -> i32 {
152 let name_str = match cstr_to_string(name) {
153 Ok(s) => s,
154 Err(_) => return -1,
155 };
156 match CALLBACKS.lock() {
157 Ok(mut guard) => {
158 if guard.remove(&name_str).is_some() {
159 0
160 } else {
161 -1
162 }
163 }
164 Err(_) => -2,
165 }
166}
167
168#[no_mangle]
172pub extern "C" fn ffi_callback_count() -> usize {
173 callback_count()
174}
175
176#[cfg(test)]
179mod tests {
180 use super::*;
181 use crate::errors::{ffi_result_free, FfiErrorCode};
182
183 fn unique_name(prefix: &str) -> String {
184 use std::sync::atomic::{AtomicU64, Ordering};
185 static COUNTER: AtomicU64 = AtomicU64::new(0);
186 format!("{prefix}_{}", COUNTER.fetch_add(1, Ordering::Relaxed))
187 }
188
189 #[test]
190 fn register_and_invoke_rust_callback() {
191 let name = unique_name("test_echo");
192 register_callback(&name, |buf| {
193 let bytes = unsafe { buf.as_slice() }.to_vec();
194 FfiResult::ok(FfiBuffer::from_vec(bytes))
195 })
196 .unwrap();
197
198 let input = FfiBuffer::from_vec(b"test payload".to_vec());
199 let result = unsafe {
200 let c_name = std::ffi::CString::new(name.as_str()).unwrap();
201 ffi_invoke_callback(c_name.as_ptr(), input)
202 };
203 assert!(result.is_ok());
204 let slice = unsafe { result.payload.as_slice() };
205 assert_eq!(slice, b"test payload");
206 ffi_result_free(result);
207 }
208
209 #[test]
210 fn invoke_unknown_callback_returns_not_found() {
211 let result = unsafe {
212 let c_name = std::ffi::CString::new("__nonexistent_callback__").unwrap();
213 ffi_invoke_callback(c_name.as_ptr(), FfiBuffer::null())
214 };
215 assert_eq!(result.error_code, FfiErrorCode::NotFound);
216 ffi_result_free(result);
217 }
218
219 #[test]
220 fn unregister_removes_callback() {
221 let name = unique_name("test_unregister");
222 register_callback(&name, |buf| FfiResult::ok(buf)).unwrap();
223
224 let removed = unregister_callback(&name).unwrap();
225 assert!(removed);
226
227 let not_removed = unregister_callback(&name).unwrap();
228 assert!(!not_removed);
229 }
230
231 #[test]
232 fn callback_count_tracks_registrations() {
233 let name = unique_name("test_count");
234 let before = callback_count();
235 register_callback(&name, |buf| FfiResult::ok(buf)).unwrap();
236 assert_eq!(callback_count(), before + 1);
237 unregister_callback(&name).unwrap();
238 assert_eq!(callback_count(), before);
239 }
240}