use std::sync::Arc;
use dashmap::DashMap;
use serde_json::Value;
use tokio::task::JoinHandle;
use tokio_util::sync::CancellationToken;
use rs_genai::prelude::FunctionCall;
use crate::error::ToolError;
pub trait ResultFormatter: Send + Sync + 'static {
fn format_running(&self, call: &FunctionCall) -> Value;
fn format_result(&self, call: &FunctionCall, result: Result<Value, ToolError>) -> Value;
fn format_cancelled(&self, call_id: &str) -> Value;
}
pub struct DefaultResultFormatter;
impl ResultFormatter for DefaultResultFormatter {
fn format_running(&self, call: &FunctionCall) -> Value {
serde_json::json!({
"status": "running",
"tool": call.name,
})
}
fn format_result(&self, call: &FunctionCall, result: Result<Value, ToolError>) -> Value {
match result {
Ok(value) => serde_json::json!({
"status": "completed",
"tool": call.name,
"result": value,
}),
Err(e) => serde_json::json!({
"status": "error",
"tool": call.name,
"error": e.to_string(),
}),
}
}
fn format_cancelled(&self, call_id: &str) -> Value {
serde_json::json!({
"status": "cancelled",
"call_id": call_id,
})
}
}
#[derive(Clone, Default)]
pub enum ToolExecutionMode {
#[default]
Standard,
Background {
formatter: Option<Arc<dyn ResultFormatter>>,
scheduling: Option<rs_genai::prelude::FunctionResponseScheduling>,
},
}
impl std::fmt::Debug for ToolExecutionMode {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Standard => write!(f, "Standard"),
Self::Background {
formatter,
scheduling,
} => {
write!(
f,
"Background(formatter={}, scheduling={:?})",
formatter.is_some(),
scheduling
)
}
}
}
}
pub struct BackgroundToolTracker {
tasks: DashMap<String, (JoinHandle<()>, CancellationToken)>,
}
impl BackgroundToolTracker {
pub fn new() -> Self {
Self {
tasks: DashMap::new(),
}
}
pub fn spawn(&self, call_id: String, task: JoinHandle<()>, cancel: CancellationToken) {
self.tasks.insert(call_id, (task, cancel));
}
pub fn cancel(&self, call_ids: &[String]) {
for id in call_ids {
if let Some((_, (handle, token))) = self.tasks.remove(id) {
token.cancel();
handle.abort();
}
}
}
pub fn cancel_all(&self) {
let keys: Vec<String> = self.tasks.iter().map(|r| r.key().clone()).collect();
for key in keys {
if let Some((_, (handle, token))) = self.tasks.remove(&key) {
token.cancel();
handle.abort();
}
}
}
pub fn active_ids(&self) -> Vec<String> {
self.tasks.iter().map(|r| r.key().clone()).collect()
}
pub fn remove(&self, call_id: &str) {
self.tasks.remove(call_id);
}
pub fn active_count(&self) -> usize {
self.tasks.len()
}
}
impl Default for BackgroundToolTracker {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn tracker_new_is_empty() {
let tracker = BackgroundToolTracker::new();
assert_eq!(tracker.active_count(), 0);
assert!(tracker.active_ids().is_empty());
}
#[tokio::test]
async fn spawn_shows_active_id() {
let tracker = BackgroundToolTracker::new();
let token = CancellationToken::new();
let t = token.clone();
let handle = tokio::spawn(async move {
t.cancelled().await;
});
tracker.spawn("call1".into(), handle, token.clone());
let ids = tracker.active_ids();
assert_eq!(ids, vec!["call1".to_string()]);
token.cancel();
}
#[tokio::test]
async fn spawn_increments_active_count() {
let tracker = BackgroundToolTracker::new();
let token1 = CancellationToken::new();
let t1 = token1.clone();
let h1 = tokio::spawn(async move {
t1.cancelled().await;
});
tracker.spawn("call1".into(), h1, token1.clone());
let token2 = CancellationToken::new();
let t2 = token2.clone();
let h2 = tokio::spawn(async move {
t2.cancelled().await;
});
tracker.spawn("call2".into(), h2, token2.clone());
assert_eq!(tracker.active_count(), 2);
token1.cancel();
token2.cancel();
}
#[tokio::test]
async fn cancel_removes_task_and_cancels_token() {
let tracker = BackgroundToolTracker::new();
let token = CancellationToken::new();
let t = token.clone();
let handle = tokio::spawn(async move {
t.cancelled().await;
});
tracker.spawn("call1".into(), handle, token.clone());
assert_eq!(tracker.active_count(), 1);
tracker.cancel(&["call1".into()]);
assert_eq!(tracker.active_count(), 0);
assert!(token.is_cancelled());
}
#[tokio::test]
async fn cancel_all_clears_all_tasks() {
let tracker = BackgroundToolTracker::new();
let token1 = CancellationToken::new();
let t1 = token1.clone();
let h1 = tokio::spawn(async move {
t1.cancelled().await;
});
tracker.spawn("call1".into(), h1, token1.clone());
let token2 = CancellationToken::new();
let t2 = token2.clone();
let h2 = tokio::spawn(async move {
t2.cancelled().await;
});
tracker.spawn("call2".into(), h2, token2.clone());
let token3 = CancellationToken::new();
let t3 = token3.clone();
let h3 = tokio::spawn(async move {
t3.cancelled().await;
});
tracker.spawn("call3".into(), h3, token3.clone());
assert_eq!(tracker.active_count(), 3);
tracker.cancel_all();
assert_eq!(tracker.active_count(), 0);
assert!(token1.is_cancelled());
assert!(token2.is_cancelled());
assert!(token3.is_cancelled());
}
#[tokio::test]
async fn remove_cleans_up_completed_task() {
let tracker = BackgroundToolTracker::new();
let token = CancellationToken::new();
let t = token.clone();
let handle = tokio::spawn(async move {
t.cancelled().await;
});
tracker.spawn("call1".into(), handle, token.clone());
assert_eq!(tracker.active_count(), 1);
tracker.remove("call1");
assert_eq!(tracker.active_count(), 0);
assert!(tracker.active_ids().is_empty());
token.cancel();
}
#[test]
fn cancel_nonexistent_id_is_noop() {
let tracker = BackgroundToolTracker::new();
tracker.cancel(&["nonexistent".into()]);
assert_eq!(tracker.active_count(), 0);
}
fn make_call(name: &str) -> FunctionCall {
FunctionCall {
name: name.to_string(),
args: serde_json::json!({"query": "test"}),
id: Some("fc_123".to_string()),
}
}
#[test]
fn format_running_output() {
let fmt = DefaultResultFormatter;
let call = make_call("search");
let result = fmt.format_running(&call);
assert_eq!(result["status"], "running");
assert_eq!(result["tool"], "search");
}
#[test]
fn format_result_ok() {
let fmt = DefaultResultFormatter;
let call = make_call("search");
let value = serde_json::json!({"items": [1, 2, 3]});
let result = fmt.format_result(&call, Ok(value.clone()));
assert_eq!(result["status"], "completed");
assert_eq!(result["tool"], "search");
assert_eq!(result["result"], value);
}
#[test]
fn format_result_err() {
let fmt = DefaultResultFormatter;
let call = make_call("search");
let err = ToolError::ExecutionFailed("connection timeout".into());
let result = fmt.format_result(&call, Err(err));
assert_eq!(result["status"], "error");
assert_eq!(result["tool"], "search");
assert!(result["error"]
.as_str()
.unwrap()
.contains("connection timeout"));
}
#[test]
fn format_cancelled_output() {
let fmt = DefaultResultFormatter;
let result = fmt.format_cancelled("fc_456");
assert_eq!(result["status"], "cancelled");
assert_eq!(result["call_id"], "fc_456");
}
#[test]
fn tool_execution_mode_default_is_standard() {
let mode = ToolExecutionMode::default();
assert!(matches!(mode, ToolExecutionMode::Standard));
}
#[test]
fn tool_execution_mode_debug_standard() {
let mode = ToolExecutionMode::Standard;
assert_eq!(format!("{:?}", mode), "Standard");
}
#[test]
fn tool_execution_mode_debug_background_none() {
let mode = ToolExecutionMode::Background {
formatter: None,
scheduling: None,
};
assert_eq!(
format!("{:?}", mode),
"Background(formatter=false, scheduling=None)"
);
}
#[test]
fn tool_execution_mode_debug_background_some() {
let mode = ToolExecutionMode::Background {
formatter: Some(Arc::new(DefaultResultFormatter)),
scheduling: None,
};
assert_eq!(
format!("{:?}", mode),
"Background(formatter=true, scheduling=None)"
);
}
}