use std::{
collections::HashMap,
ffi::{c_char, c_int},
sync::{mpsc, Mutex, OnceLock},
};
use crate::{error::BamlError, ffi::bindings};
pub enum CallbackResult {
Partial(Vec<u8>),
Final(Vec<u8>),
Error(BamlError),
}
struct SyncCallbackData {
sender: mpsc::Sender<CallbackResult>,
}
struct AsyncCallbackData {
sender: async_channel::Sender<CallbackResult>,
}
enum CallbackData {
Sync(SyncCallbackData),
Async(AsyncCallbackData),
}
static CALLBACKS: OnceLock<Mutex<HashMap<u32, CallbackData>>> = OnceLock::new();
static NEXT_ID: OnceLock<Mutex<u32>> = OnceLock::new();
fn get_callbacks() -> &'static Mutex<HashMap<u32, CallbackData>> {
CALLBACKS.get_or_init(|| Mutex::new(HashMap::new()))
}
fn get_next_id() -> &'static Mutex<u32> {
NEXT_ID.get_or_init(|| Mutex::new(1))
}
pub fn initialize_callbacks() -> Result<(), baml_sys::BamlSysError> {
static INIT_ERROR: OnceLock<Option<String>> = OnceLock::new();
let error_msg = INIT_ERROR.get_or_init(|| {
#[allow(unsafe_code)]
match unsafe {
bindings::register_callbacks(result_callback, error_callback, on_tick_callback)
} {
Ok(()) => None,
Err(e) => Some(e.to_string()),
}
});
match error_msg {
None => Ok(()),
Some(msg) => Err(baml_sys::BamlSysError::LibraryNotFound {
searched_paths: vec![std::path::PathBuf::from(msg.clone())],
}),
}
}
fn allocate_callback_id(callbacks: &mut HashMap<u32, CallbackData>) -> u32 {
let mut next_id = get_next_id().lock().unwrap();
let mut id = *next_id;
loop {
if id != 0 && !callbacks.contains_key(&id) {
break;
}
id = id.wrapping_add(1);
assert!(id != *next_id, "callback ID space exhausted");
}
*next_id = id.wrapping_add(1);
id
}
pub fn create_callback() -> (u32, mpsc::Receiver<CallbackResult>) {
let (sender, receiver) = mpsc::channel();
let mut callbacks = get_callbacks().lock().unwrap();
let id = allocate_callback_id(&mut callbacks);
callbacks.insert(id, CallbackData::Sync(SyncCallbackData { sender }));
drop(callbacks);
(id, receiver)
}
pub fn create_async_callback() -> (u32, async_channel::Receiver<CallbackResult>) {
let (sender, receiver) = async_channel::unbounded();
let mut callbacks = get_callbacks().lock().unwrap();
let id = allocate_callback_id(&mut callbacks);
callbacks.insert(id, CallbackData::Async(AsyncCallbackData { sender }));
drop(callbacks);
(id, receiver)
}
pub fn remove_callback(id: u32) {
let mut callbacks = get_callbacks().lock().unwrap();
callbacks.remove(&id);
}
extern "C" fn result_callback(call_id: u32, is_done: c_int, content: *const c_char, length: usize) {
let data = if !content.is_null() && length > 0 {
#[allow(unsafe_code)]
let slice = unsafe { std::slice::from_raw_parts(content.cast::<u8>(), length) };
slice.to_vec()
} else {
Vec::new()
};
let result = if is_done != 0 {
CallbackResult::Final(data)
} else {
CallbackResult::Partial(data)
};
let callbacks = get_callbacks().lock().unwrap();
if let Some(cb_data) = callbacks.get(&call_id) {
match cb_data {
CallbackData::Sync(sync_data) => {
let _ = sync_data.sender.send(result);
}
CallbackData::Async(async_data) => {
let _ = async_data.sender.send_blocking(result);
}
}
}
drop(callbacks);
if is_done != 0 {
remove_callback(call_id);
}
}
extern "C" fn error_callback(call_id: u32, _is_done: c_int, content: *const c_char, length: usize) {
let error_msg = if !content.is_null() && length > 0 {
#[allow(unsafe_code)]
let slice = unsafe { std::slice::from_raw_parts(content.cast::<u8>(), length) };
String::from_utf8_lossy(slice).into_owned()
} else {
"Unknown error".to_string()
};
let callbacks = get_callbacks().lock().unwrap();
if let Some(cb_data) = callbacks.get(&call_id) {
let error = CallbackResult::Error(BamlError::internal(error_msg));
match cb_data {
CallbackData::Sync(sync_data) => {
let _ = sync_data.sender.send(error);
}
CallbackData::Async(async_data) => {
let _ = async_data.sender.send_blocking(error);
}
}
}
drop(callbacks);
remove_callback(call_id);
}
extern "C" fn on_tick_callback(_call_id: u32) {
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_callback_id_generation() {
let (id1, _rx1) = create_callback();
let (id2, _rx2) = create_callback();
let (id3, _rx3) = create_callback();
assert_ne!(id1, id2);
assert_ne!(id2, id3);
assert_ne!(id1, id3);
remove_callback(id1);
remove_callback(id2);
remove_callback(id3);
}
#[test]
fn test_callback_removal() {
let (id, _rx) = create_callback();
{
let callbacks = get_callbacks().lock().unwrap();
assert!(callbacks.contains_key(&id));
}
remove_callback(id);
{
let callbacks = get_callbacks().lock().unwrap();
assert!(!callbacks.contains_key(&id));
}
}
#[test]
fn test_async_callback_id_generation() {
let (id1, _rx1) = create_async_callback();
let (id2, _rx2) = create_async_callback();
let (id3, _rx3) = create_async_callback();
assert_ne!(id1, id2);
assert_ne!(id2, id3);
assert_ne!(id1, id3);
remove_callback(id1);
remove_callback(id2);
remove_callback(id3);
}
#[test]
fn test_mixed_sync_async_callbacks() {
let (sync_id, _rx_sync) = create_callback();
let (async_id, _rx_async) = create_async_callback();
let (sync_id2, _rx_sync2) = create_callback();
assert_ne!(sync_id, async_id);
assert_ne!(async_id, sync_id2);
assert_ne!(sync_id, sync_id2);
remove_callback(sync_id);
remove_callback(async_id);
remove_callback(sync_id2);
}
}