use std::collections::HashMap;
use std::os::raw::c_char;
use std::sync::Mutex;
use once_cell::sync::Lazy;
use crate::errors::{catch_panic, FfiError, FfiResult};
use crate::memory::FfiBuffer;
use crate::types::cstr_to_string;
type CallbackFn = Box<dyn Fn(FfiBuffer) -> FfiResult + Send + Sync>;
static CALLBACKS: Lazy<Mutex<HashMap<String, CallbackFn>>> =
Lazy::new(|| Mutex::new(HashMap::new()));
pub fn register_callback<F>(name: &str, f: F) -> Result<(), FfiError>
where
F: Fn(FfiBuffer) -> FfiResult + Send + Sync + 'static,
{
let mut guard = CALLBACKS.lock().map_err(|_| FfiError::LockPoisoned)?;
guard.insert(name.to_string(), Box::new(f));
Ok(())
}
pub fn unregister_callback(name: &str) -> Result<bool, FfiError> {
let mut guard = CALLBACKS.lock().map_err(|_| FfiError::LockPoisoned)?;
Ok(guard.remove(name).is_some())
}
pub fn callback_count() -> usize {
CALLBACKS.lock().map(|g| g.len()).unwrap_or(0)
}
#[no_mangle]
pub unsafe extern "C" fn ffi_register_callback(
name: *const c_char,
cb: extern "C" fn(FfiBuffer) -> FfiResult,
) -> i32 {
let name_str = match cstr_to_string(name) {
Ok(s) => s,
Err(_) => return -1,
};
let wrapped = move |buf: FfiBuffer| cb(buf);
match CALLBACKS.lock() {
Ok(mut guard) => {
guard.insert(name_str, Box::new(wrapped));
0
}
Err(_) => -2,
}
}
#[no_mangle]
pub unsafe extern "C" fn ffi_invoke_callback(name: *const c_char, input: FfiBuffer) -> FfiResult {
let name_str = match cstr_to_string(name) {
Ok(s) => s,
Err(e) => return FfiResult::err(e),
};
catch_panic(move || {
let guard = CALLBACKS.lock().map_err(|_| FfiError::LockPoisoned)?;
let cb = guard
.get(&name_str)
.ok_or_else(|| FfiError::NotFound(name_str.clone()))?;
let result = cb(input);
if result.is_ok() {
let FfiResult { payload, .. } = result;
Ok(payload)
} else {
Err(FfiError::Unknown(format!(
"callback '{}' returned error code {:?}",
name_str, result.error_code
)))
}
})
}
#[no_mangle]
pub unsafe extern "C" fn ffi_unregister_callback(name: *const c_char) -> i32 {
let name_str = match cstr_to_string(name) {
Ok(s) => s,
Err(_) => return -1,
};
match CALLBACKS.lock() {
Ok(mut guard) => {
if guard.remove(&name_str).is_some() {
0
} else {
-1
}
}
Err(_) => -2,
}
}
#[no_mangle]
pub extern "C" fn ffi_callback_count() -> usize {
callback_count()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::errors::{ffi_result_free, FfiErrorCode};
fn unique_name(prefix: &str) -> String {
use std::sync::atomic::{AtomicU64, Ordering};
static COUNTER: AtomicU64 = AtomicU64::new(0);
format!("{prefix}_{}", COUNTER.fetch_add(1, Ordering::Relaxed))
}
#[test]
fn register_and_invoke_rust_callback() {
let name = unique_name("test_echo");
register_callback(&name, |buf| {
let bytes = unsafe { buf.as_slice() }.to_vec();
FfiResult::ok(FfiBuffer::from_vec(bytes))
})
.unwrap();
let input = FfiBuffer::from_vec(b"test payload".to_vec());
let result = unsafe {
let c_name = std::ffi::CString::new(name.as_str()).unwrap();
ffi_invoke_callback(c_name.as_ptr(), input)
};
assert!(result.is_ok());
let slice = unsafe { result.payload.as_slice() };
assert_eq!(slice, b"test payload");
ffi_result_free(result);
}
#[test]
fn invoke_unknown_callback_returns_not_found() {
let result = unsafe {
let c_name = std::ffi::CString::new("__nonexistent_callback__").unwrap();
ffi_invoke_callback(c_name.as_ptr(), FfiBuffer::null())
};
assert_eq!(result.error_code, FfiErrorCode::NotFound);
ffi_result_free(result);
}
#[test]
fn unregister_removes_callback() {
let name = unique_name("test_unregister");
register_callback(&name, |buf| FfiResult::ok(buf)).unwrap();
let removed = unregister_callback(&name).unwrap();
assert!(removed);
let not_removed = unregister_callback(&name).unwrap();
assert!(!not_removed);
}
#[test]
fn callback_count_tracks_registrations() {
let name = unique_name("test_count");
let before = callback_count();
register_callback(&name, |buf| FfiResult::ok(buf)).unwrap();
assert_eq!(callback_count(), before + 1);
unregister_callback(&name).unwrap();
assert_eq!(callback_count(), before);
}
}