use std::sync::Arc;
use either::Either;
use indexmap::IndexMap;
use serde_json::Value;
use crate::{
get_mut_arcmutex, search, MessageContent, NormalRequest, RequestMessage, Response,
ToolCallResponse, ToolChoice, WebSearchOptions,
};
use super::Engine;
fn get_messages_mut(request: &mut NormalRequest) -> &mut Vec<IndexMap<String, MessageContent>> {
match &mut request.messages {
RequestMessage::Chat { messages, .. } | RequestMessage::MultimodalChat { messages, .. } => {
messages
}
_ => unreachable!(),
}
}
fn build_tool_calls_field(tc: &ToolCallResponse) -> MessageContent {
let mut tc_map = IndexMap::new();
tc_map.insert("id".to_string(), Value::String(tc.id.clone()));
tc_map.insert("type".to_string(), Value::String("function".to_string()));
let mut function_map = serde_json::Map::new();
function_map.insert("name".to_string(), Value::String(tc.function.name.clone()));
let args_value = serde_json::from_str(&tc.function.arguments)
.unwrap_or(Value::String(tc.function.arguments.clone()));
function_map.insert("arguments".to_string(), args_value);
tc_map.insert("function".to_string(), Value::Object(function_map));
Either::Right(vec![tc_map])
}
fn append_assistant_tool_call(
messages: &mut Vec<IndexMap<String, MessageContent>>,
tc: &ToolCallResponse,
) {
let mut message: IndexMap<String, MessageContent> = IndexMap::new();
message.insert("role".to_string(), Either::Left("assistant".to_string()));
message.insert("content".to_string(), Either::Left(String::new()));
message.insert("tool_calls".to_string(), build_tool_calls_field(tc));
messages.push(message);
}
fn append_tool_response(
messages: &mut Vec<IndexMap<String, MessageContent>>,
tool_name: &str,
content: String,
) {
let mut message: IndexMap<String, MessageContent> = IndexMap::new();
message.insert("role".to_string(), Either::Left("tool".to_string()));
message.insert("name".to_string(), Either::Left(tool_name.to_string()));
message.insert("content".to_string(), Either::Left(content));
messages.push(message);
}
fn ensure_system_message(messages: &mut Vec<IndexMap<String, MessageContent>>) {
let has_system = messages
.first()
.and_then(|m| m.get("role"))
.and_then(|r| match r {
Either::Left(s) => Some(s.as_str()),
_ => None,
})
.is_some_and(|r| r == "system" || r == "developer");
if !has_system {
let mut sys_msg: IndexMap<String, MessageContent> = IndexMap::new();
sys_msg.insert("role".to_string(), Either::Left("system".to_string()));
sys_msg.insert("content".to_string(), Either::Left(String::new()));
messages.insert(0, sys_msg);
}
}
async fn forward_passthrough(
resp: Response,
user_sender: &tokio::sync::mpsc::Sender<Response>,
) -> Option<Response> {
match resp {
Response::Done(_) | Response::Chunk(_) => Some(resp),
other => {
let _ = user_sender.send(other).await;
None
}
}
}
use super::tool_dispatch;
async fn do_search(
engine: Arc<Engine>,
mut request: NormalRequest,
tc: &ToolCallResponse,
opts: &WebSearchOptions,
) -> NormalRequest {
let messages = get_messages_mut(&mut request);
append_assistant_tool_call(messages, tc);
let result = tool_dispatch::execute_search(&engine, tc, opts).await;
append_tool_response(messages, &tc.function.name, result.content);
request.tool_choice = Some(ToolChoice::Auto);
request
}
async fn do_extraction(
engine: Arc<Engine>,
mut request: NormalRequest,
tc: &ToolCallResponse,
opts: &WebSearchOptions,
) -> NormalRequest {
let messages = get_messages_mut(&mut request);
append_assistant_tool_call(messages, tc);
let result = tool_dispatch::execute_extraction(&engine, tc, opts).await;
append_tool_response(messages, &tc.function.name, result.content);
request.tool_choice = Some(ToolChoice::Auto);
request
}
async fn do_custom_tool(
engine: Arc<Engine>,
mut request: NormalRequest,
tc: &ToolCallResponse,
) -> NormalRequest {
let messages = get_messages_mut(&mut request);
append_assistant_tool_call(messages, tc);
let result = tool_dispatch::execute_custom_tool(&engine, tc);
append_tool_response(messages, &tc.function.name, result.content);
request.tool_choice = Some(ToolChoice::Auto);
request
}
fn do_http_tool(mut request: NormalRequest, tc: &ToolCallResponse, url: &str) -> NormalRequest {
let messages = get_messages_mut(&mut request);
append_assistant_tool_call(messages, tc);
let result = tool_dispatch::execute_http_tool(tc, url);
append_tool_response(messages, &tc.function.name, result.content);
request.tool_choice = Some(ToolChoice::Auto);
request
}
pub(super) async fn search_request(this: Arc<Engine>, request: NormalRequest) {
let web_search_options = request.web_search_options.clone();
let dispatch_url = request.tool_dispatch_url.clone();
let user_sender = request.response.clone();
let is_streaming = request.is_streaming;
let mut probe = request.clone();
if let Some(ref opts) = web_search_options {
probe
.tools
.get_or_insert_with(Vec::new)
.extend(search::get_search_tools(opts).unwrap());
}
if !this.tool_callbacks.is_empty() {
let tools = probe.tools.get_or_insert_with(Vec::new);
let existing_tool_names: Vec<String> =
tools.iter().map(|t| t.function.name.clone()).collect();
for (name, callback_with_tool) in &this.tool_callbacks {
if !existing_tool_names.contains(name) {
tools.push(callback_with_tool.tool.clone());
}
}
}
ensure_system_message(get_messages_mut(&mut probe));
probe.tool_choice = Some(ToolChoice::Auto);
probe.web_search_options = None;
let mut visible_req = probe.clone();
visible_req.response = user_sender.clone();
let this_clone = this.clone();
let handle = tokio::spawn(async move {
let mut current = probe;
let max_rounds = current.max_tool_rounds.unwrap_or(16);
let mut round = 0;
loop {
let (sender, mut receiver) = tokio::sync::mpsc::channel(1);
current.response = sender;
current.web_search_options = None;
current.max_tool_rounds = None;
current.tool_dispatch_url = None;
let _ = this_clone
.tx
.send(crate::request::Request::Normal(Box::new(current)))
.await;
if !is_streaming {
let resp = receiver.recv().await.unwrap();
let Some(resp) = forward_passthrough(resp, &user_sender).await else {
return;
};
let done = match resp {
Response::Done(done) => done,
_ => {
let _ = user_sender.send(resp).await;
return;
}
};
let tc_opt = match &done.choices[0].message.tool_calls {
Some(calls) if !calls.is_empty() => {
if calls.len() > 1 {
tracing::warn!(
"Model returned {} tool calls; executing only the first.",
calls.len()
);
}
Some(&calls[0])
}
_ => None,
};
if tc_opt.is_none() || round >= max_rounds {
user_sender
.send(Response::Done(done.clone()))
.await
.unwrap();
return;
}
let tc = tc_opt.unwrap();
let next_visible = if search::search_tool_called(&tc.function.name) {
let web_search_options = web_search_options.as_ref().unwrap();
if tc.function.name == search::SEARCH_TOOL_NAME {
do_search(this_clone.clone(), visible_req, tc, web_search_options).await
} else {
do_extraction(this_clone.clone(), visible_req, tc, web_search_options).await
}
} else if this_clone.tool_callbacks.contains_key(&tc.function.name) {
do_custom_tool(this_clone.clone(), visible_req, tc).await
} else if let Some(ref url) = dispatch_url {
do_http_tool(visible_req, tc, url)
} else {
user_sender
.send(Response::Done(done.clone()))
.await
.unwrap();
return;
};
round += 1;
visible_req = next_visible.clone();
visible_req.response = user_sender.clone();
current = visible_req.clone();
}
else {
let mut last_choice = None;
while let Some(resp) = receiver.recv().await {
let Some(resp) = forward_passthrough(resp, &user_sender).await else {
return;
};
match resp {
Response::Chunk(chunk) => {
let first_choice = &chunk.choices[0];
if first_choice.delta.tool_calls.is_none() {
let _ = user_sender.send(Response::Chunk(chunk.clone())).await;
}
last_choice = Some(first_choice.clone());
if last_choice
.as_ref()
.and_then(|c| c.finish_reason.as_ref())
.is_some()
{
break;
}
}
other => {
let _ = user_sender.send(other).await;
return;
}
}
}
let Some(choice) = last_choice else { break };
let tc_opt = match &choice.delta.tool_calls {
Some(calls) if !calls.is_empty() => {
if calls.len() > 1 {
tracing::warn!(
"Model returned {} tool calls; executing only the first.",
calls.len()
);
}
Some(&calls[0])
}
_ => None,
};
if tc_opt.is_none() || round >= max_rounds {
break;
}
let tc = tc_opt.unwrap();
let next_visible = if search::search_tool_called(&tc.function.name) {
let web_search_options = web_search_options.as_ref().unwrap();
if tc.function.name == search::SEARCH_TOOL_NAME {
do_search(this_clone.clone(), visible_req, tc, web_search_options).await
} else {
do_extraction(this_clone.clone(), visible_req, tc, web_search_options).await
}
} else if this_clone.tool_callbacks.contains_key(&tc.function.name) {
do_custom_tool(this_clone.clone(), visible_req, tc).await
} else if let Some(ref url) = dispatch_url {
do_http_tool(visible_req, tc, url)
} else {
break; };
round += 1;
visible_req = next_visible.clone();
visible_req.response = user_sender.clone();
current = visible_req.clone();
}
}
});
get_mut_arcmutex!(this.handles).push(handle);
}