use crate::providers::llm::{Message, ToolCallChunk};
use crate::providers::ProviderFactory;
use crate::services::config::AiConfig;
use anyhow::Result;
use async_trait::async_trait;
use futures::StreamExt;
use serde_json::{json, Value};
use std::sync::Arc;
use tokio::sync::mpsc::UnboundedSender;
#[derive(Debug)]
pub enum ChatOutput {
Final { text: String, chat_id: String },
ToolCalls {
_calls: Vec<Value>,
_chat_id: String,
},
}
#[async_trait]
pub trait ChatToolHandler: Send + Sync {
async fn handle_tool_call(&self, tool_name: &str, arguments_json: &str) -> Result<String>;
fn take_pending_tools(&self) -> Option<Vec<Value>> {
None
}
}
pub struct ChatRequest {
pub agent_config: Value,
pub messages: Vec<Value>,
pub tools: Option<Vec<Value>>,
pub token_sender: Option<UnboundedSender<String>>,
pub tool_handler: Option<Arc<dyn ChatToolHandler>>,
}
pub struct LocalRuntime {
factory: ProviderFactory,
default_model: String,
}
impl Default for LocalRuntime {
fn default() -> Self {
Self::new()
}
}
impl LocalRuntime {
pub fn new() -> Self {
Self {
factory: ProviderFactory::new(),
default_model: "gpt-4o-mini".to_string(),
}
}
pub fn new_with_config(ai: &AiConfig) -> Self {
use crate::providers::llm::factory::LlmProviderConfig;
let configs: std::collections::HashMap<String, LlmProviderConfig> = ai
.providers
.iter()
.map(|(name, cfg)| {
(
name.clone(),
LlmProviderConfig {
api_key: cfg.api_key.clone(),
base_url: cfg.base_url.clone(),
},
)
})
.collect();
Self {
factory: ProviderFactory::new_with_config(&configs),
default_model: ai
.default_model
.clone()
.unwrap_or_else(|| "gpt-4o-mini".to_string()),
}
}
pub async fn chat(&self, req: ChatRequest) -> Result<ChatOutput> {
let model = req
.agent_config
.get("model")
.and_then(|v| v.as_str())
.unwrap_or(&self.default_model)
.to_string();
let system_prompt = req
.agent_config
.get("system_prompt")
.and_then(|v| v.as_str())
.map(|s| s.to_string());
let mut history: Vec<Message> = req
.messages
.iter()
.filter_map(|v| {
serde_json::from_value(v.clone()).ok().or_else(|| {
let role = v.get("role")?.as_str()?;
let content = v.get("content")?.as_str()?;
Some(Message {
role: role.to_string(),
parts: json!([{"type": "text", "content": content}]),
tool_calls: None,
tool_call_id: None,
})
})
})
.collect();
let mut tools = req.tools.clone();
for _ in 0..50 {
let (provider, actual_model) = self.factory.get_provider(&model);
let mut stream = provider
.stream_chat(
&actual_model,
system_prompt.clone(),
history.clone(),
tools.clone(),
)
.await?;
let mut text_acc = String::new();
let mut tool_accs: Vec<ToolCallAccumulator> = Vec::new();
let mut has_tool_finish = false;
while let Some(chunk_result) = stream.next().await {
let chunk = chunk_result?;
if let Some(content) = &chunk.content {
text_acc.push_str(content);
if let Some(ref sender) = req.token_sender {
let _ = sender.send(content.clone());
}
}
if !chunk.tool_calls.is_empty() {
accumulate_tool_chunks(&mut tool_accs, &chunk.tool_calls);
}
if let Some(ref reason) = chunk.finish_reason {
let r = reason.to_lowercase();
if (r.contains("tool") || r.contains("end_turn")) && !tool_accs.is_empty() {
has_tool_finish = true;
}
}
}
if !has_tool_finish || tool_accs.is_empty() {
return Ok(ChatOutput::Final {
text: text_acc,
chat_id: String::new(),
});
}
let assembled_calls = accumulators_to_tool_calls(&tool_accs);
let handler = match &req.tool_handler {
Some(h) => h,
None => {
return Ok(ChatOutput::ToolCalls {
_calls: assembled_calls,
_chat_id: String::new(),
});
}
};
history.push(Message {
role: "assistant".to_string(),
parts: if text_acc.is_empty() {
json!([])
} else {
json!([{"type": "text", "content": text_acc}])
},
tool_calls: Some(json!(assembled_calls)),
tool_call_id: None,
});
for call in &assembled_calls {
let tool_name = call
.get("function")
.and_then(|f| f.get("name"))
.and_then(|n| n.as_str())
.unwrap_or("unknown");
let arguments = call
.get("function")
.and_then(|f| f.get("arguments"))
.and_then(|a| a.as_str())
.unwrap_or("{}");
let call_id = call.get("id").and_then(|v| v.as_str()).unwrap_or("");
let result = handler.handle_tool_call(tool_name, arguments).await?;
history.push(Message {
role: "tool".to_string(),
parts: json!([{"type": "text", "content": result}]),
tool_calls: None,
tool_call_id: Some(call_id.to_string()),
});
}
if let Some(new_tools) = handler.take_pending_tools() {
tools = Some(new_tools);
}
}
Err(anyhow::anyhow!(
"LocalRuntime: exceeded maximum tool call iterations (50)"
))
}
}
struct ToolCallAccumulator {
id: String,
name: String,
arguments: String,
}
fn accumulate_tool_chunks(accumulators: &mut Vec<ToolCallAccumulator>, chunks: &[ToolCallChunk]) {
for chunk in chunks {
let idx = chunk.index as usize;
while accumulators.len() <= idx {
accumulators.push(ToolCallAccumulator {
id: String::new(),
name: String::new(),
arguments: String::new(),
});
}
if let Some(id) = &chunk.id {
accumulators[idx].id = id.clone();
}
if let Some(name) = &chunk.name {
accumulators[idx].name = name.clone();
}
if let Some(args) = &chunk.arguments {
accumulators[idx].arguments.push_str(args);
}
}
}
fn accumulators_to_tool_calls(accs: &[ToolCallAccumulator]) -> Vec<Value> {
accs.iter()
.filter(|a| !a.name.is_empty())
.map(|a| {
json!({
"id": a.id,
"type": "function",
"function": {
"name": a.name,
"arguments": a.arguments,
}
})
})
.collect()
}