use super::Callable;
use crate::kernel::cost::TokenUsage;
use crate::providers::{
ChatMessage, ChatRequest, ChatTool, ChatToolFunction, ContentPart, MessageToolCall,
ModelProvider, ToolChoice,
};
use crate::routing::{ModelRouter, RoutingDecision, RoutingPolicy};
use crate::streaming::{EventEmitter, StreamEvent};
use crate::tool::{DynTool, Tool};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::sync::{Arc, Mutex};
use tokio::sync::mpsc;
use tokio::time::{interval, Duration};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCall {
pub id: String,
pub name: String,
pub arguments: Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MultimodalInput {
#[serde(rename = "__multimodal__")]
pub multimodal_marker: bool,
pub text: String,
#[serde(default)]
pub images: Vec<MultimodalImage>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MultimodalImage {
pub data: String,
pub mime_type: String,
}
impl MultimodalInput {
pub fn new(text: impl Into<String>, images: Vec<(Vec<u8>, String)>) -> Self {
use base64::Engine;
Self {
multimodal_marker: true,
text: text.into(),
images: images
.into_iter()
.map(|(data, mime_type)| MultimodalImage {
data: base64::engine::general_purpose::STANDARD.encode(&data),
mime_type,
})
.collect(),
}
}
pub fn to_json(&self) -> String {
serde_json::to_string(self).unwrap_or_else(|_| self.text.clone())
}
pub fn parse(input: &str) -> Option<Self> {
if !input.trim_start().starts_with(r#"{"__multimodal__":"#) {
return None;
}
serde_json::from_str(input).ok()
}
}
#[derive(Debug, Clone, Serialize)]
pub struct ToolSchema {
#[serde(rename = "type")]
pub tool_type: String,
pub function: FunctionSchema,
}
#[derive(Debug, Clone, Serialize)]
pub struct FunctionSchema {
pub name: String,
pub description: String,
pub parameters: Value,
}
impl ToolSchema {
pub fn from_tool(tool: &dyn Tool) -> Self {
Self {
tool_type: "function".to_string(),
function: FunctionSchema {
name: tool.name().to_string(),
description: tool.description().to_string(),
parameters: tool.parameters_schema(),
},
}
}
}
pub struct LlmCallable {
name: String,
description: Option<String>,
system_prompt: String,
provider: Arc<dyn ModelProvider>,
requested_model: Option<String>,
routing_policy: RoutingPolicy,
tools: Vec<DynTool>,
max_iterations: usize,
emitter: Option<Arc<EventEmitter>>,
last_usage: Mutex<Option<TokenUsage>>,
}
impl LlmCallable {
pub fn with_provider(
name: impl Into<String>,
system_prompt: impl Into<String>,
provider: Arc<dyn ModelProvider>,
) -> Self {
Self {
name: name.into(),
description: None,
system_prompt: system_prompt.into(),
provider,
requested_model: None,
routing_policy: RoutingPolicy::default(),
tools: Vec::new(),
max_iterations: 10,
emitter: None,
last_usage: Mutex::new(None),
}
}
pub fn with_emitter(mut self, emitter: Arc<EventEmitter>) -> Self {
self.emitter = Some(emitter);
self
}
pub fn with_model(mut self, model: impl Into<String>) -> Self {
self.requested_model = Some(model.into());
self
}
pub fn with_routing_policy(mut self, policy: RoutingPolicy) -> Self {
self.routing_policy = policy;
self
}
pub fn with_description(mut self, description: impl Into<String>) -> Self {
self.description = Some(description.into());
self
}
pub fn add_tool(mut self, tool: impl Tool + 'static) -> Self {
self.tools.push(Arc::new(tool));
self
}
pub fn add_tools(mut self, tools: Vec<DynTool>) -> Self {
self.tools.extend(tools);
self
}
pub fn max_iterations(mut self, max: usize) -> Self {
self.max_iterations = max;
self
}
async fn execute_tool(&self, name: &str, args: Value) -> anyhow::Result<Value> {
let tool = self
.tools
.iter()
.find(|t| t.name() == name)
.ok_or_else(|| anyhow::anyhow!("Tool '{}' not found", name))?;
tool.execute(args).await
}
fn build_chat_tools(&self) -> Vec<ChatTool> {
self.tools
.iter()
.map(|t| ChatTool {
tool_type: "function".to_string(),
function: ChatToolFunction {
name: t.name().to_string(),
description: t.description().to_string(),
parameters: t.parameters_schema(),
},
})
.collect()
}
fn message_tool_calls_to_internal(&self, tool_calls: &[MessageToolCall]) -> Vec<ToolCall> {
tool_calls
.iter()
.map(|tc| {
let arguments = serde_json::from_str(&tc.function.arguments).unwrap_or(Value::Null);
ToolCall {
id: tc.id.clone(),
name: tc.function.name.clone(),
arguments,
}
})
.collect()
}
fn resolve_routing(&self) -> RoutingDecision {
ModelRouter::resolve(
self.requested_model.as_deref(),
self.provider.as_ref(),
&self.routing_policy,
)
}
}
#[async_trait]
impl Callable for LlmCallable {
fn name(&self) -> &str {
&self.name
}
fn description(&self) -> Option<&str> {
self.description.as_deref()
}
async fn run_streaming(
&self,
input: &str,
event_tx: mpsc::Sender<StreamEvent>,
) -> anyhow::Result<String> {
let emitter = self.emitter.clone();
let tx = event_tx.clone();
let poll_handle = if emitter.is_some() {
Some(tokio::spawn(async move {
let emitter = match &emitter {
Some(e) => e,
None => return,
};
let mut interval = interval(Duration::from_millis(50));
loop {
interval.tick().await;
let events = emitter.drain();
for ev in events {
if tx.send(ev).await.is_err() {
return;
}
}
}
}))
} else {
None
};
let result = self.run(input).await;
if let Some(ref e) = self.emitter {
for ev in e.drain() {
let _ = event_tx.send(ev).await;
}
}
drop(event_tx);
if let Some(h) = poll_handle {
let _ = h.await;
}
result
}
async fn run(&self, input: &str) -> anyhow::Result<String> {
*self.last_usage.lock().expect("last_usage mutex") = None;
if !self.tools.is_empty() && !self.provider.capabilities().supports_tools {
anyhow::bail!(
"Callable has {} tool(s) but provider does not support native tools (supports_tools is false)",
self.tools.len()
);
}
let routing = self.resolve_routing();
tracing::info!(
callable = %self.name,
logical_model = %routing.logical_model,
concrete_model = %routing.concrete_model,
profile = ?routing.profile,
confidence = routing.confidence,
used_default_router = routing.used_default_router,
rationale = %routing.rationale,
"Model routing decision resolved"
);
let user_message = if let Some(multimodal) = MultimodalInput::parse(input) {
tracing::debug!(
image_count = multimodal.images.len(),
text_len = multimodal.text.len(),
"Processing multimodal input with images"
);
if !self.provider.capabilities().supports_vision {
tracing::warn!(
"Provider does not support vision, falling back to text-only. \
Images will be ignored. Consider using a vision-capable model."
);
ChatMessage::user(&multimodal.text)
} else {
use base64::Engine;
let mut parts = vec![ContentPart::text(&multimodal.text)];
for img in &multimodal.images {
if let Ok(data) = base64::engine::general_purpose::STANDARD.decode(&img.data) {
parts.push(ContentPart::image_base64(
base64::engine::general_purpose::STANDARD.encode(&data),
&img.mime_type,
));
} else {
tracing::warn!(mime_type = %img.mime_type, "Failed to decode image base64 data");
}
}
ChatMessage {
role: "user".to_string(),
content: None,
multimodal_content: Some(parts),
tool_calls: None,
tool_call_id: None,
}
}
} else {
ChatMessage::user(input)
};
let mut messages = vec![ChatMessage::system(&self.system_prompt), user_message];
let (tools, tool_choice) = if self.tools.is_empty() {
(None, None)
} else {
(
Some(self.build_chat_tools()),
Some(ToolChoice::String("auto".to_string())),
)
};
let mut accumulated_usage: Option<TokenUsage> = None;
for iteration in 0..self.max_iterations {
tracing::debug!(iteration, "Callable iteration");
let request = ChatRequest {
messages: messages.clone(),
max_tokens: Some(4096),
temperature: Some(0.7),
tools: tools.clone(),
tool_choice: tool_choice.clone(),
};
let response = self.provider.chat(request).await?;
if let Some(ref u) = response.usage {
accumulated_usage = Some(match accumulated_usage {
None => TokenUsage::new(u.prompt_tokens, u.completion_tokens),
Some(a) => TokenUsage::new(
a.prompt_tokens + u.prompt_tokens,
a.completion_tokens + u.completion_tokens,
),
});
}
let choice = response
.choices
.first()
.ok_or_else(|| anyhow::anyhow!("Empty choices in chat response"))?;
let msg = &choice.message;
let native_tool_calls = msg.tool_calls.as_deref().unwrap_or(&[]);
if native_tool_calls.is_empty() {
let content = msg.content.clone().unwrap_or_default();
*self.last_usage.lock().expect("last_usage mutex") = accumulated_usage;
return Ok(content);
}
let calls = self.message_tool_calls_to_internal(native_tool_calls);
messages.push(ChatMessage::assistant_with_tool_calls(
msg.content.clone(),
native_tool_calls.to_vec(),
));
for call in &calls {
tracing::debug!(tool = %call.name, "Executing tool");
if let Some(ref emitter) = self.emitter {
emitter.emit(StreamEvent::ToolInputAvailable {
tool_call_id: call.id.clone(),
tool_name: call.name.clone(),
input: call.arguments.clone(),
});
}
let tool_start = std::time::Instant::now();
let result = self
.execute_tool(&call.name, call.arguments.clone())
.await?;
let tool_duration_ms = tool_start.elapsed().as_millis() as u64;
if let Some(ref emitter) = self.emitter {
emitter.emit(StreamEvent::ToolOutputAvailable {
tool_call_id: call.id.clone(),
output: serde_json::json!({
"result": result.clone(),
"duration_ms": tool_duration_ms,
}),
});
}
let result_str = serde_json::to_string(&result)?;
messages.push(ChatMessage::tool_result(&call.id, &result_str));
}
}
*self.last_usage.lock().expect("last_usage mutex") = accumulated_usage;
anyhow::bail!("Max iterations ({}) reached", self.max_iterations)
}
fn last_usage(&self) -> Option<crate::kernel::LlmTokenUsage> {
self.last_usage.lock().expect("last_usage mutex").clone()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::providers::{ChatChoice, ChatResponse, MessageToolCall, MessageToolCallFunction};
use crate::tool::Tool;
use async_trait::async_trait;
struct MockProviderNoTools;
#[async_trait]
impl ModelProvider for MockProviderNoTools {
fn name(&self) -> &str {
"mock-no-tools"
}
fn capabilities(&self) -> crate::providers::ModelCapabilities {
crate::providers::ModelCapabilities {
supports_tools: false,
..Default::default()
}
}
async fn chat(&self, _request: ChatRequest) -> anyhow::Result<ChatResponse> {
Ok(ChatResponse {
id: "id".to_string(),
choices: vec![ChatChoice {
index: 0,
message: ChatMessage::assistant("ok"),
finish_reason: Some("stop".to_string()),
}],
usage: None,
})
}
}
struct EchoTool;
#[async_trait]
impl Tool for EchoTool {
fn name(&self) -> &str {
"echo"
}
fn description(&self) -> &str {
"Echoes input"
}
async fn execute(&self, args: Value) -> anyhow::Result<Value> {
Ok(args.get("x").cloned().unwrap_or(Value::Null))
}
}
#[tokio::test]
async fn run_errors_when_tools_registered_but_provider_does_not_support_tools() {
let provider = Arc::new(MockProviderNoTools);
let callable =
LlmCallable::with_provider("test", "You are helpful", provider).add_tool(EchoTool);
let err = callable.run("hello").await.unwrap_err();
assert!(
err.to_string().contains("does not support native tools"),
"expected error about supports_tools, got: {}",
err
);
}
struct MockProviderWithToolCalls {
call_count: std::sync::atomic::AtomicUsize,
}
#[async_trait]
impl ModelProvider for MockProviderWithToolCalls {
fn name(&self) -> &str {
"mock-with-tools"
}
fn capabilities(&self) -> crate::providers::ModelCapabilities {
crate::providers::ModelCapabilities {
supports_tools: true,
..Default::default()
}
}
async fn chat(&self, request: ChatRequest) -> anyhow::Result<ChatResponse> {
let n = self
.call_count
.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
let has_tool_result = request.messages.iter().any(|m| m.role == "tool");
if !has_tool_result && n == 0 {
return Ok(ChatResponse {
id: "id".to_string(),
choices: vec![ChatChoice {
index: 0,
message: ChatMessage::assistant_with_tool_calls(
None,
vec![MessageToolCall {
id: "call-1".to_string(),
call_type: "function".to_string(),
function: MessageToolCallFunction {
name: "echo".to_string(),
arguments: r#"{"x": "world"}"#.to_string(),
},
}],
),
finish_reason: Some("tool_calls".to_string()),
}],
usage: None,
});
}
Ok(ChatResponse {
id: "id".to_string(),
choices: vec![ChatChoice {
index: 0,
message: ChatMessage::assistant("Final: world"),
finish_reason: Some("stop".to_string()),
}],
usage: None,
})
}
}
#[tokio::test]
async fn run_uses_native_tool_calls_and_returns_final_content() {
let provider = Arc::new(MockProviderWithToolCalls {
call_count: std::sync::atomic::AtomicUsize::new(0),
});
let callable = LlmCallable::with_provider("test", "You are helpful", provider)
.add_tool(EchoTool)
.max_iterations(5);
let out = callable.run("hello").await.unwrap();
assert_eq!(out, "Final: world");
}
}