1use super::Callable;
7use crate::kernel::cost::TokenUsage;
8use crate::providers::{
9 ChatMessage, ChatRequest, ChatTool, ChatToolFunction, ContentPart, MessageToolCall,
10 ModelProvider, ToolChoice,
11};
12use crate::routing::{ModelRouter, RoutingDecision, RoutingPolicy};
13use crate::streaming::{EventEmitter, StreamEvent};
14use crate::tool::{DynTool, Tool};
15use async_trait::async_trait;
16use serde::{Deserialize, Serialize};
17use serde_json::Value;
18use std::sync::{Arc, Mutex};
19use tokio::sync::mpsc;
20use tokio::time::{interval, Duration};
21
22#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct ToolCall {
25 pub id: String,
26 pub name: String,
27 pub arguments: Value,
28}
29
30#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct MultimodalInput {
44 #[serde(rename = "__multimodal__")]
45 pub multimodal_marker: bool,
46 pub text: String,
48 #[serde(default)]
50 pub images: Vec<MultimodalImage>,
51}
52
53#[derive(Debug, Clone, Serialize, Deserialize)]
55pub struct MultimodalImage {
56 pub data: String,
58 pub mime_type: String,
60}
61
62impl MultimodalInput {
63 pub fn new(text: impl Into<String>, images: Vec<(Vec<u8>, String)>) -> Self {
65 use base64::Engine;
66 Self {
67 multimodal_marker: true,
68 text: text.into(),
69 images: images
70 .into_iter()
71 .map(|(data, mime_type)| MultimodalImage {
72 data: base64::engine::general_purpose::STANDARD.encode(&data),
73 mime_type,
74 })
75 .collect(),
76 }
77 }
78
79 pub fn to_json(&self) -> String {
81 serde_json::to_string(self).unwrap_or_else(|_| self.text.clone())
82 }
83
84 pub fn parse(input: &str) -> Option<Self> {
87 if !input.trim_start().starts_with(r#"{"__multimodal__":"#) {
88 return None;
89 }
90 serde_json::from_str(input).ok()
91 }
92}
93
94#[derive(Debug, Clone, Serialize)]
96pub struct ToolSchema {
97 #[serde(rename = "type")]
98 pub tool_type: String,
99 pub function: FunctionSchema,
100}
101
102#[derive(Debug, Clone, Serialize)]
103pub struct FunctionSchema {
104 pub name: String,
105 pub description: String,
106 pub parameters: Value,
107}
108
109impl ToolSchema {
110 pub fn from_tool(tool: &dyn Tool) -> Self {
111 Self {
112 tool_type: "function".to_string(),
113 function: FunctionSchema {
114 name: tool.name().to_string(),
115 description: tool.description().to_string(),
116 parameters: tool.parameters_schema(),
117 },
118 }
119 }
120}
121
122pub struct LlmCallable {
129 name: String,
130 description: Option<String>,
131 system_prompt: String,
132 provider: Arc<dyn ModelProvider>,
133 requested_model: Option<String>,
134 routing_policy: RoutingPolicy,
135 tools: Vec<DynTool>,
136 max_iterations: usize,
137 emitter: Option<Arc<EventEmitter>>,
139 last_usage: Mutex<Option<TokenUsage>>,
141}
142
143impl LlmCallable {
144 pub fn with_provider(
146 name: impl Into<String>,
147 system_prompt: impl Into<String>,
148 provider: Arc<dyn ModelProvider>,
149 ) -> Self {
150 Self {
151 name: name.into(),
152 description: None,
153 system_prompt: system_prompt.into(),
154 provider,
155 requested_model: None,
156 routing_policy: RoutingPolicy::default(),
157 tools: Vec::new(),
158 max_iterations: 10,
159 emitter: None,
160 last_usage: Mutex::new(None),
161 }
162 }
163
164 pub fn with_emitter(mut self, emitter: Arc<EventEmitter>) -> Self {
166 self.emitter = Some(emitter);
167 self
168 }
169
170 pub fn with_model(mut self, model: impl Into<String>) -> Self {
172 self.requested_model = Some(model.into());
173 self
174 }
175
176 pub fn with_routing_policy(mut self, policy: RoutingPolicy) -> Self {
178 self.routing_policy = policy;
179 self
180 }
181
182 pub fn with_description(mut self, description: impl Into<String>) -> Self {
183 self.description = Some(description.into());
184 self
185 }
186
187 pub fn add_tool(mut self, tool: impl Tool + 'static) -> Self {
189 self.tools.push(Arc::new(tool));
190 self
191 }
192
193 pub fn add_tools(mut self, tools: Vec<DynTool>) -> Self {
195 self.tools.extend(tools);
196 self
197 }
198
199 pub fn max_iterations(mut self, max: usize) -> Self {
201 self.max_iterations = max;
202 self
203 }
204
205 async fn execute_tool(&self, name: &str, args: Value) -> anyhow::Result<Value> {
207 let tool = self
208 .tools
209 .iter()
210 .find(|t| t.name() == name)
211 .ok_or_else(|| anyhow::anyhow!("Tool '{}' not found", name))?;
212
213 tool.execute(args).await
214 }
215
216 fn build_chat_tools(&self) -> Vec<ChatTool> {
218 self.tools
219 .iter()
220 .map(|t| ChatTool {
221 tool_type: "function".to_string(),
222 function: ChatToolFunction {
223 name: t.name().to_string(),
224 description: t.description().to_string(),
225 parameters: t.parameters_schema(),
226 },
227 })
228 .collect()
229 }
230
231 fn message_tool_calls_to_internal(&self, tool_calls: &[MessageToolCall]) -> Vec<ToolCall> {
233 tool_calls
234 .iter()
235 .map(|tc| {
236 let arguments = serde_json::from_str(&tc.function.arguments).unwrap_or(Value::Null);
237 ToolCall {
238 id: tc.id.clone(),
239 name: tc.function.name.clone(),
240 arguments,
241 }
242 })
243 .collect()
244 }
245
246 fn resolve_routing(&self) -> RoutingDecision {
247 ModelRouter::resolve(
248 self.requested_model.as_deref(),
249 self.provider.as_ref(),
250 &self.routing_policy,
251 )
252 }
253}
254
255#[async_trait]
256impl Callable for LlmCallable {
257 fn name(&self) -> &str {
258 &self.name
259 }
260
261 fn description(&self) -> Option<&str> {
262 self.description.as_deref()
263 }
264
265 async fn run_streaming(
266 &self,
267 input: &str,
268 event_tx: mpsc::Sender<StreamEvent>,
269 ) -> anyhow::Result<String> {
270 let emitter = self.emitter.clone();
271 let tx = event_tx.clone();
272 let poll_handle = if emitter.is_some() {
273 Some(tokio::spawn(async move {
274 let emitter = match &emitter {
275 Some(e) => e,
276 None => return,
277 };
278 let mut interval = interval(Duration::from_millis(50));
279 loop {
280 interval.tick().await;
281 let events = emitter.drain();
282 for ev in events {
283 if tx.send(ev).await.is_err() {
284 return;
285 }
286 }
287 }
288 }))
289 } else {
290 None
291 };
292
293 let result = self.run(input).await;
294
295 if let Some(ref e) = self.emitter {
296 for ev in e.drain() {
297 let _ = event_tx.send(ev).await;
298 }
299 }
300 drop(event_tx);
301 if let Some(h) = poll_handle {
302 let _ = h.await;
303 }
304
305 result
306 }
307
308 async fn run(&self, input: &str) -> anyhow::Result<String> {
309 *self.last_usage.lock().expect("last_usage mutex") = None;
310
311 if !self.tools.is_empty() && !self.provider.capabilities().supports_tools {
312 anyhow::bail!(
313 "Callable has {} tool(s) but provider does not support native tools (supports_tools is false)",
314 self.tools.len()
315 );
316 }
317
318 let routing = self.resolve_routing();
319 tracing::info!(
320 callable = %self.name,
321 logical_model = %routing.logical_model,
322 concrete_model = %routing.concrete_model,
323 profile = ?routing.profile,
324 confidence = routing.confidence,
325 used_default_router = routing.used_default_router,
326 rationale = %routing.rationale,
327 "Model routing decision resolved"
328 );
329
330 let user_message = if let Some(multimodal) = MultimodalInput::parse(input) {
332 tracing::debug!(
333 image_count = multimodal.images.len(),
334 text_len = multimodal.text.len(),
335 "Processing multimodal input with images"
336 );
337
338 if !self.provider.capabilities().supports_vision {
339 tracing::warn!(
340 "Provider does not support vision, falling back to text-only. \
341 Images will be ignored. Consider using a vision-capable model."
342 );
343 ChatMessage::user(&multimodal.text)
344 } else {
345 use base64::Engine;
347 let mut parts = vec![ContentPart::text(&multimodal.text)];
348 for img in &multimodal.images {
349 if let Ok(data) = base64::engine::general_purpose::STANDARD.decode(&img.data) {
351 parts.push(ContentPart::image_base64(
352 base64::engine::general_purpose::STANDARD.encode(&data),
353 &img.mime_type,
354 ));
355 } else {
356 tracing::warn!(mime_type = %img.mime_type, "Failed to decode image base64 data");
357 }
358 }
359
360 ChatMessage {
361 role: "user".to_string(),
362 content: None,
363 multimodal_content: Some(parts),
364 tool_calls: None,
365 tool_call_id: None,
366 }
367 }
368 } else {
369 ChatMessage::user(input)
370 };
371
372 let mut messages = vec![ChatMessage::system(&self.system_prompt), user_message];
373
374 let (tools, tool_choice) = if self.tools.is_empty() {
375 (None, None)
376 } else {
377 (
378 Some(self.build_chat_tools()),
379 Some(ToolChoice::String("auto".to_string())),
380 )
381 };
382
383 let mut accumulated_usage: Option<TokenUsage> = None;
384
385 for iteration in 0..self.max_iterations {
386 tracing::debug!(iteration, "Callable iteration");
387
388 let request = ChatRequest {
389 messages: messages.clone(),
390 max_tokens: Some(4096),
391 temperature: Some(0.7),
392 tools: tools.clone(),
393 tool_choice: tool_choice.clone(),
394 };
395
396 let response = self.provider.chat(request).await?;
397
398 if let Some(ref u) = response.usage {
399 accumulated_usage = Some(match accumulated_usage {
400 None => TokenUsage::new(u.prompt_tokens, u.completion_tokens),
401 Some(a) => TokenUsage::new(
402 a.prompt_tokens + u.prompt_tokens,
403 a.completion_tokens + u.completion_tokens,
404 ),
405 });
406 }
407
408 let choice = response
409 .choices
410 .first()
411 .ok_or_else(|| anyhow::anyhow!("Empty choices in chat response"))?;
412 let msg = &choice.message;
413
414 let native_tool_calls = msg.tool_calls.as_deref().unwrap_or(&[]);
415 if native_tool_calls.is_empty() {
416 let content = msg.content.clone().unwrap_or_default();
417 *self.last_usage.lock().expect("last_usage mutex") = accumulated_usage;
418 return Ok(content);
419 }
420
421 let calls = self.message_tool_calls_to_internal(native_tool_calls);
422 messages.push(ChatMessage::assistant_with_tool_calls(
423 msg.content.clone(),
424 native_tool_calls.to_vec(),
425 ));
426
427 for call in &calls {
428 tracing::debug!(tool = %call.name, "Executing tool");
429
430 if let Some(ref emitter) = self.emitter {
431 emitter.emit(StreamEvent::ToolInputAvailable {
432 tool_call_id: call.id.clone(),
433 tool_name: call.name.clone(),
434 input: call.arguments.clone(),
435 });
436 }
437
438 let tool_start = std::time::Instant::now();
439 let result = self
440 .execute_tool(&call.name, call.arguments.clone())
441 .await?;
442 let tool_duration_ms = tool_start.elapsed().as_millis() as u64;
443
444 if let Some(ref emitter) = self.emitter {
445 emitter.emit(StreamEvent::ToolOutputAvailable {
446 tool_call_id: call.id.clone(),
447 output: serde_json::json!({
448 "result": result.clone(),
449 "duration_ms": tool_duration_ms,
450 }),
451 });
452 }
453
454 let result_str = serde_json::to_string(&result)?;
455 messages.push(ChatMessage::tool_result(&call.id, &result_str));
456 }
457 }
458
459 *self.last_usage.lock().expect("last_usage mutex") = accumulated_usage;
460 anyhow::bail!("Max iterations ({}) reached", self.max_iterations)
461 }
462
463 fn last_usage(&self) -> Option<crate::kernel::LlmTokenUsage> {
464 self.last_usage.lock().expect("last_usage mutex").clone()
465 }
466}
467
468#[cfg(test)]
469mod tests {
470 use super::*;
471 use crate::providers::{ChatChoice, ChatResponse, MessageToolCall, MessageToolCallFunction};
472 use crate::tool::Tool;
473 use async_trait::async_trait;
474
475 struct MockProviderNoTools;
476 #[async_trait]
477 impl ModelProvider for MockProviderNoTools {
478 fn name(&self) -> &str {
479 "mock-no-tools"
480 }
481 fn capabilities(&self) -> crate::providers::ModelCapabilities {
482 crate::providers::ModelCapabilities {
483 supports_tools: false,
484 ..Default::default()
485 }
486 }
487 async fn chat(&self, _request: ChatRequest) -> anyhow::Result<ChatResponse> {
488 Ok(ChatResponse {
489 id: "id".to_string(),
490 choices: vec![ChatChoice {
491 index: 0,
492 message: ChatMessage::assistant("ok"),
493 finish_reason: Some("stop".to_string()),
494 }],
495 usage: None,
496 })
497 }
498 }
499
500 struct EchoTool;
501 #[async_trait]
502 impl Tool for EchoTool {
503 fn name(&self) -> &str {
504 "echo"
505 }
506 fn description(&self) -> &str {
507 "Echoes input"
508 }
509 async fn execute(&self, args: Value) -> anyhow::Result<Value> {
510 Ok(args.get("x").cloned().unwrap_or(Value::Null))
511 }
512 }
513
514 #[tokio::test]
515 async fn run_errors_when_tools_registered_but_provider_does_not_support_tools() {
516 let provider = Arc::new(MockProviderNoTools);
517 let callable =
518 LlmCallable::with_provider("test", "You are helpful", provider).add_tool(EchoTool);
519
520 let err = callable.run("hello").await.unwrap_err();
521 assert!(
522 err.to_string().contains("does not support native tools"),
523 "expected error about supports_tools, got: {}",
524 err
525 );
526 }
527
528 struct MockProviderWithToolCalls {
530 call_count: std::sync::atomic::AtomicUsize,
531 }
532 #[async_trait]
533 impl ModelProvider for MockProviderWithToolCalls {
534 fn name(&self) -> &str {
535 "mock-with-tools"
536 }
537 fn capabilities(&self) -> crate::providers::ModelCapabilities {
538 crate::providers::ModelCapabilities {
539 supports_tools: true,
540 ..Default::default()
541 }
542 }
543 async fn chat(&self, request: ChatRequest) -> anyhow::Result<ChatResponse> {
544 let n = self
545 .call_count
546 .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
547 let has_tool_result = request.messages.iter().any(|m| m.role == "tool");
548 if !has_tool_result && n == 0 {
549 return Ok(ChatResponse {
550 id: "id".to_string(),
551 choices: vec![ChatChoice {
552 index: 0,
553 message: ChatMessage::assistant_with_tool_calls(
554 None,
555 vec![MessageToolCall {
556 id: "call-1".to_string(),
557 call_type: "function".to_string(),
558 function: MessageToolCallFunction {
559 name: "echo".to_string(),
560 arguments: r#"{"x": "world"}"#.to_string(),
561 },
562 }],
563 ),
564 finish_reason: Some("tool_calls".to_string()),
565 }],
566 usage: None,
567 });
568 }
569 Ok(ChatResponse {
570 id: "id".to_string(),
571 choices: vec![ChatChoice {
572 index: 0,
573 message: ChatMessage::assistant("Final: world"),
574 finish_reason: Some("stop".to_string()),
575 }],
576 usage: None,
577 })
578 }
579 }
580
581 #[tokio::test]
582 async fn run_uses_native_tool_calls_and_returns_final_content() {
583 let provider = Arc::new(MockProviderWithToolCalls {
584 call_count: std::sync::atomic::AtomicUsize::new(0),
585 });
586 let callable = LlmCallable::with_provider("test", "You are helpful", provider)
587 .add_tool(EchoTool)
588 .max_iterations(5);
589
590 let out = callable.run("hello").await.unwrap();
591 assert_eq!(out, "Final: world");
592 }
593}