use std::sync::Arc;
use tokio_util::sync::CancellationToken;
use rs_genai::prelude::FunctionResponse;
use rs_genai::session::SessionWriter;
use crate::state::State;
use crate::tool::ToolDispatcher;
use crate::live::background_tool::BackgroundToolTracker;
use crate::live::callbacks::EventCallbacks;
use crate::live::events::LiveEvent;
use crate::live::extractor::{ExtractionTrigger, TurnExtractor};
use crate::live::phase::PhaseMachine;
use crate::live::transcript::TranscriptBuffer;
use super::extractors::run_extractors;
pub(in crate::live) async fn handle_tool_calls(
calls: Vec<rs_genai::prelude::FunctionCall>,
callbacks: &EventCallbacks,
dispatcher: &Option<Arc<ToolDispatcher>>,
writer: &Arc<dyn SessionWriter>,
state: &State,
phase_machine: &Option<tokio::sync::Mutex<PhaseMachine>>,
transcript_buffer: &mut TranscriptBuffer,
execution_modes: &std::collections::HashMap<
String,
crate::live::background_tool::ToolExecutionMode,
>,
background_tracker: &Option<Arc<BackgroundToolTracker>>,
extractors: &[Arc<dyn TurnExtractor>],
event_tx: &tokio::sync::broadcast::Sender<LiveEvent>,
) {
let (allowed_calls, rejected_responses) = if let Some(ref pm) = phase_machine {
let active_tools = {
let pm_guard = pm.lock().await;
pm_guard.active_tools().map(|t| t.to_vec())
};
if let Some(active_tools) = active_tools {
let mut allowed = Vec::new();
let mut rejected = Vec::new();
for call in calls {
if active_tools.iter().any(|t| t == &call.name) {
allowed.push(call);
} else {
rejected.push(FunctionResponse {
name: call.name.clone(),
response: serde_json::json!({
"error": format!(
"Tool '{}' is not available in the current conversation phase.",
call.name
)
}),
id: call.id.clone(),
scheduling: None,
});
}
}
(allowed, rejected)
} else {
(calls, Vec::new())
}
} else {
(calls, Vec::new())
};
let responses = if allowed_calls.is_empty() && !rejected_responses.is_empty() {
Some(rejected_responses.clone())
} else if let Some(cb) = &callbacks.on_tool_call {
let mut result = cb(allowed_calls.clone(), state.clone()).await;
if !rejected_responses.is_empty() {
let r = result.get_or_insert_with(Vec::new);
r.extend(rejected_responses.clone());
}
result
} else {
None
};
let (responses, background_spawns) = match responses {
Some(r) => (r, Vec::new()),
None => {
let mut results: Vec<FunctionResponse> = rejected_responses;
let mut bg_spawns: Vec<(
rs_genai::prelude::FunctionCall,
Option<Arc<dyn crate::live::background_tool::ResultFormatter>>,
)> = Vec::new();
if let Some(ref disp) = dispatcher {
for call in &allowed_calls {
let mode = execution_modes.get(&call.name);
match mode {
Some(crate::live::background_tool::ToolExecutionMode::Background {
formatter,
scheduling,
}) => {
let fmt: &dyn crate::live::background_tool::ResultFormatter = formatter
.as_ref()
.map(|f| f.as_ref())
.unwrap_or(&crate::live::background_tool::DefaultResultFormatter);
let ack = fmt.format_running(call);
results.push(FunctionResponse {
name: call.name.clone(),
response: ack,
id: call.id.clone(),
scheduling: *scheduling,
});
bg_spawns.push((call.clone(), formatter.clone()));
}
_ => {
match disp.call_function(&call.name, call.args.clone()).await {
Ok(result) => results.push(FunctionResponse {
name: call.name.clone(),
response: result,
id: call.id.clone(),
scheduling: None,
}),
Err(e) => results.push(FunctionResponse {
name: call.name.clone(),
response: serde_json::json!({"error": e.to_string()}),
id: call.id.clone(),
scheduling: None,
}),
}
}
}
}
} else if results.is_empty() {
#[cfg(feature = "tracing-support")]
tracing::warn!("Tool call received but no dispatcher or callback registered");
}
(results, bg_spawns)
}
};
let responses = if let Some(cb) = &callbacks.before_tool_response {
cb(responses, state.clone()).await
} else {
responses
};
for resp in &responses {
let args = allowed_calls
.iter()
.find(|c| c.name == resp.name)
.map(|c| &c.args)
.unwrap_or(&serde_json::Value::Null);
transcript_buffer.push_tool_call(resp.name.clone(), args, &resp.response);
let _ = event_tx.send(LiveEvent::ToolExecution {
name: resp.name.clone(),
args: args.clone(),
result: resp.response.clone(),
});
}
if !responses.is_empty() {
if let Err(_e) = writer.send_tool_response(responses).await {
#[cfg(feature = "tracing-support")]
tracing::error!("Failed to send tool response: {_e}");
}
}
for (call, formatter) in background_spawns {
let disp = dispatcher.clone();
let bg_writer = writer.clone();
let tracker = background_tracker.clone();
let call_id = call.id.clone().unwrap_or_default();
let cancel = CancellationToken::new();
let handle = tokio::spawn(async move {
let result = if let Some(ref d) = disp {
d.call_function(&call.name, call.args.clone())
.await
.map_err(|e| crate::error::ToolError::ExecutionFailed(e.to_string()))
} else {
Err(crate::error::ToolError::NotFound(call.name.clone()))
};
let fmt: &dyn crate::live::background_tool::ResultFormatter = formatter
.as_ref()
.map(|f| f.as_ref())
.unwrap_or(&crate::live::background_tool::DefaultResultFormatter);
let formatted = fmt.format_result(&call, result);
bg_writer
.send_tool_response(vec![FunctionResponse {
name: call.name.clone(),
response: formatted,
id: call.id.clone(),
scheduling: None,
}])
.await
.ok();
if let Some(ref t) = tracker {
t.remove(&call.id.clone().unwrap_or_default());
}
});
if let Some(ref t) = background_tracker {
t.spawn(call_id, handle, cancel);
}
}
let after_tool_extractors: Vec<Arc<dyn TurnExtractor>> = extractors
.iter()
.filter(|e| matches!(e.trigger(), ExtractionTrigger::AfterToolCall))
.cloned()
.collect();
run_extractors(
&after_tool_extractors,
transcript_buffer,
state,
callbacks,
event_tx,
)
.await;
}