use super::options::LLMEventHandlers;
use super::tool_registry::ToolRegistry;
use super::types::{Choice, LLMChunkResponse, LLMResponse, Message};
use crate::Error;
use regex::Regex;
use serde_json::Value;
pub fn extract_answer(response: LLMResponse) -> String {
response
.choices
.get(0)
.and_then(|choice| choice.message.as_ref().and_then(|msg| msg.content.clone()))
.unwrap_or_default()
}
pub fn parse_chunks_to_llm_response(
mut chunks: Vec<LLMChunkResponse>,
final_content: String,
) -> Result<LLMResponse, Error> {
if chunks.is_empty() {
return Ok(LLMResponse {
id: "unknown".to_string(),
object: "chat.completion".to_string(),
created: 0,
model: "unknown".to_string(),
choices: vec![Choice {
index: 0,
message: Some(Message {
role: "assistant".to_string(),
content: Some(final_content),
tool_calls: None,
tool_call_id: None,
}),
finish_reason: Some("stop".to_string()),
}],
system_fingerprint: None,
});
}
let last_chunk = chunks.pop().unwrap();
let finish_reason = last_chunk
.choices
.get(0)
.and_then(|c| c.finish_reason.clone());
let role = last_chunk.choices.get(0).and_then(|c| c.delta.role.clone());
let response = LLMResponse {
id: last_chunk.id,
object: last_chunk.object,
created: last_chunk.created,
model: last_chunk.model,
choices: vec![Choice {
index: 0,
message: Some(Message {
role: role.unwrap_or_else(|| "assistant".to_string()),
content: Some(final_content),
tool_calls: None,
tool_call_id: None,
}),
finish_reason,
}],
system_fingerprint: last_chunk.system_fingerprint,
};
Ok(response)
}
pub fn strip_thinking(answer: &str) -> String {
if answer.contains("<think>") && answer.contains("</think>") {
let re_reasoning = Regex::new(r"(?s)<think>(.*?)</think>").unwrap();
re_reasoning.replace_all(answer, "").to_string()
} else {
answer.to_string()
}
}
pub(super) fn finish_reason_is_tool_calls(resp: &LLMResponse) -> bool {
if resp
.choices
.get(0)
.and_then(|c| c.finish_reason.as_deref())
!= Some("tool_calls")
{
return false;
}
resp.choices
.get(0)
.and_then(|c| c.message.as_ref())
.and_then(|m| m.tool_calls.as_ref())
.map(|tc| !tc.is_empty())
.unwrap_or(false)
}
pub async fn handle_tool_calls(
final_response: &mut LLMResponse,
mut messages: Vec<Message>,
registry: Option<&ToolRegistry>,
handlers: &LLMEventHandlers,
) -> Result<Vec<Message>, Error> {
tracing::info!("handling tool calls");
let msg = final_response
.choices
.get(0)
.and_then(|c| c.message.as_ref())
.ok_or_else(|| Error::Message("No message in final_response".into()))?;
let tool_calls = msg
.tool_calls
.as_ref()
.ok_or_else(|| Error::Message("No tool calls in message".into()))?;
tracing::info!("{} call(s)", tool_calls.len());
messages.push(msg.clone());
for call in tool_calls {
tracing::debug!("tool call {}", call.function.name);
let args_str = call.function.arguments.as_str().unwrap_or_default();
tracing::trace!("tool args: {}", args_str);
if let Some(ref cb) = handlers.on_tool_call {
cb(call);
}
let args: Value = serde_json::from_str(args_str).unwrap_or_default();
let tool_result = match registry {
Some(r) => r.call(&call.function.name, args).await,
None => Err("No ToolRegistry provided but LLM requested a tool call".to_string()),
};
tracing::debug!("tool result: {:?}", tool_result);
if let Some(ref cb) = handlers.on_tool_result {
cb(call, &tool_result);
}
let tool_msg = Message {
role: "tool".to_string(),
content: Some(format!("{:?}", tool_result)),
tool_calls: None,
tool_call_id: Some(call.id.clone()),
};
messages.push(tool_msg);
}
Ok(messages)
}