use std::ffi::{c_char, c_void, CString};
use std::sync::{Mutex, OnceLock};
use ringo_fm_sys as sys;
use async_trait::async_trait;
use std::collections::HashMap;
use crate::error::{Error, Result};
use crate::generated::{GeneratedContent, GeneratedContentTag};
use crate::handle::{check_error, ManagedRef};
use crate::schema::GenerationSchema;
#[async_trait]
pub trait Tool: Send + Sync + 'static {
fn name(&self) -> &str;
fn description(&self) -> &str;
fn parameters(&self) -> Result<GenerationSchema>;
async fn call(&self, input: GeneratedContent) -> Result<String>;
}
pub struct ToolHandle {
handle: ManagedRef<ToolTag>,
_entry_key: usize,
}
pub(crate) struct ToolTag;
impl ToolHandle {
pub fn register<T: Tool>(tool: T) -> Result<Self> {
let name_c = CString::new(tool.name()).map_err(|e| Error::Native(e.to_string()))?;
let desc_c = CString::new(tool.description()).map_err(|e| Error::Native(e.to_string()))?;
let schema = tool.parameters()?;
let dispatcher: Box<dyn Dispatcher> = Box::new(ToolDispatcher { tool });
let key = registry().insert(dispatcher);
let mut code: i32 = 0;
let mut desc: *mut c_char = std::ptr::null_mut();
let ptr = unsafe {
sys::FMBridgedToolCreate(
name_c.as_ptr(),
desc_c.as_ptr(),
schema.as_ptr(),
Some(tool_trampoline),
&mut code,
&mut desc,
)
};
check_error(code, desc)?;
let handle = ManagedRef::<ToolTag>::from_owned(ptr)?;
registry().bind_tool(handle.as_ptr() as usize, key);
Ok(Self { handle, _entry_key: key })
}
pub(crate) fn as_ptr(&self) -> *const c_void {
self.handle.as_ptr()
}
}
#[async_trait]
trait Dispatcher: Send + Sync {
async fn dispatch(&self, input: GeneratedContent) -> Result<String>;
}
struct ToolDispatcher<T: Tool> {
tool: T,
}
#[async_trait]
impl<T: Tool> Dispatcher for ToolDispatcher<T> {
async fn dispatch(&self, input: GeneratedContent) -> Result<String> {
self.tool.call(input).await
}
}
struct Registry {
inner: Mutex<RegistryInner>,
}
#[derive(Default)]
struct RegistryInner {
next_key: usize,
dispatchers: HashMap<usize, std::sync::Arc<dyn Dispatcher>>,
tool_to_key: HashMap<usize, usize>,
}
impl Registry {
fn insert(&self, d: Box<dyn Dispatcher>) -> usize {
let mut g = self.inner.lock().unwrap();
let k = g.next_key;
g.next_key += 1;
g.dispatchers.insert(k, std::sync::Arc::from(d));
k
}
fn bind_tool(&self, tool_ptr: usize, key: usize) {
self.inner.lock().unwrap().tool_to_key.insert(tool_ptr, key);
}
fn dispatcher_for(&self, key: usize) -> Option<std::sync::Arc<dyn Dispatcher>> {
self.inner.lock().unwrap().dispatchers.get(&key).cloned()
}
}
fn registry() -> &'static Registry {
static REG: OnceLock<Registry> = OnceLock::new();
REG.get_or_init(|| Registry { inner: Mutex::new(RegistryInner::default()) })
}
unsafe extern "C" fn tool_trampoline(content: sys::FMGeneratedContentRef, call_id: u32) {
let key = {
let g = registry().inner.lock().unwrap();
g.tool_to_key.values().copied().last()
};
let Some(key) = key else { return };
let Some(dispatcher) = registry().dispatcher_for(key) else { return };
let handle = match ManagedRef::<GeneratedContentTag>::from_owned(content) {
Ok(h) => h,
Err(_) => return,
};
let input = GeneratedContent { handle };
let tool_ptr = {
let g = registry().inner.lock().unwrap();
*g.tool_to_key
.iter()
.find_map(|(ptr, k)| if *k == key { Some(ptr) } else { None })
.unwrap_or(&0)
};
std::thread::spawn(move || {
let rt = match tokio::runtime::Builder::new_current_thread().enable_all().build() {
Ok(r) => r,
Err(_) => return,
};
let result = rt.block_on(dispatcher.dispatch(input));
let output = match result {
Ok(s) => s,
Err(e) => format!("{{\"error\": \"{}\"}}", e.to_string().replace('"', "\\\"")),
};
let Ok(out_c) = CString::new(output) else { return };
if tool_ptr != 0 {
unsafe { sys::FMBridgedToolFinishCall(tool_ptr as *const c_void, call_id, out_c.as_ptr()) };
}
});
}