Skip to main content

enact_core/callable/
llm.rs

1//! LLM Callable - LLM-powered execution with tool loop
2//!
3//! This is an "agentic" callable - it runs an LLM with tools in a loop
4//! until the LLM produces a final response.
5
6use super::Callable;
7use crate::kernel::cost::TokenUsage;
8use crate::providers::{
9    ChatMessage, ChatRequest, ChatTool, ChatToolFunction, ContentPart, MessageToolCall,
10    ModelProvider, ToolChoice,
11};
12use crate::routing::{ModelRouter, RoutingDecision, RoutingPolicy};
13use crate::streaming::{EventEmitter, StreamEvent};
14use crate::tool::{DynTool, Tool};
15use async_trait::async_trait;
16use serde::{Deserialize, Serialize};
17use serde_json::Value;
18use std::sync::{Arc, Mutex};
19use tokio::sync::mpsc;
20use tokio::time::{interval, Duration};
21
22/// Tool call from the LLM
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct ToolCall {
25    pub id: String,
26    pub name: String,
27    pub arguments: Value,
28}
29
30/// Multimodal input format for encoding images in string input
31///
32/// When sending images to an LLM callable, encode the input as:
33/// ```json
34/// {
35///   "__multimodal__": true,
36///   "text": "Describe this image",
37///   "images": [
38///     {"data": "<base64>", "mime_type": "image/jpeg"}
39///   ]
40/// }
41/// ```
42#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct MultimodalInput {
44    #[serde(rename = "__multimodal__")]
45    pub multimodal_marker: bool,
46    /// The text portion of the message
47    pub text: String,
48    /// Base64-encoded images with mime types
49    #[serde(default)]
50    pub images: Vec<MultimodalImage>,
51}
52
53/// An image in multimodal input
54#[derive(Debug, Clone, Serialize, Deserialize)]
55pub struct MultimodalImage {
56    /// Base64-encoded image data
57    pub data: String,
58    /// MIME type (e.g., "image/jpeg", "image/png")
59    pub mime_type: String,
60}
61
62impl MultimodalInput {
63    /// Create a new multimodal input with text and images
64    pub fn new(text: impl Into<String>, images: Vec<(Vec<u8>, String)>) -> Self {
65        use base64::Engine;
66        Self {
67            multimodal_marker: true,
68            text: text.into(),
69            images: images
70                .into_iter()
71                .map(|(data, mime_type)| MultimodalImage {
72                    data: base64::engine::general_purpose::STANDARD.encode(&data),
73                    mime_type,
74                })
75                .collect(),
76        }
77    }
78
79    /// Encode to JSON string for passing through the runner
80    pub fn to_json(&self) -> String {
81        serde_json::to_string(self).unwrap_or_else(|_| self.text.clone())
82    }
83
84    /// Try to parse multimodal input from a string
85    /// Returns None if not multimodal format
86    pub fn parse(input: &str) -> Option<Self> {
87        if !input.trim_start().starts_with(r#"{"__multimodal__":"#) {
88            return None;
89        }
90        serde_json::from_str(input).ok()
91    }
92}
93
94/// Tool schema for LLM
95#[derive(Debug, Clone, Serialize)]
96pub struct ToolSchema {
97    #[serde(rename = "type")]
98    pub tool_type: String,
99    pub function: FunctionSchema,
100}
101
102#[derive(Debug, Clone, Serialize)]
103pub struct FunctionSchema {
104    pub name: String,
105    pub description: String,
106    pub parameters: Value,
107}
108
109impl ToolSchema {
110    pub fn from_tool(tool: &dyn Tool) -> Self {
111        Self {
112            tool_type: "function".to_string(),
113            function: FunctionSchema {
114                name: tool.name().to_string(),
115                description: tool.description().to_string(),
116                parameters: tool.parameters_schema(),
117            },
118        }
119    }
120}
121
122/// LLM-powered callable with tool execution
123///
124/// This is the "agentic" callable - it runs the LLM in a loop,
125/// executing tools as requested until a final response is produced.
126///
127/// Note: The loop is controlled by `max_iterations` to prevent runaway.
128pub struct LlmCallable {
129    name: String,
130    description: Option<String>,
131    system_prompt: String,
132    provider: Arc<dyn ModelProvider>,
133    requested_model: Option<String>,
134    routing_policy: RoutingPolicy,
135    tools: Vec<DynTool>,
136    max_iterations: usize,
137    /// Optional event emitter for streaming tool events
138    emitter: Option<Arc<EventEmitter>>,
139    /// Token usage from the last run (accumulated across tool-loop iterations)
140    last_usage: Mutex<Option<TokenUsage>>,
141}
142
143impl LlmCallable {
144    /// Create with a custom provider
145    pub fn with_provider(
146        name: impl Into<String>,
147        system_prompt: impl Into<String>,
148        provider: Arc<dyn ModelProvider>,
149    ) -> Self {
150        Self {
151            name: name.into(),
152            description: None,
153            system_prompt: system_prompt.into(),
154            provider,
155            requested_model: None,
156            routing_policy: RoutingPolicy::default(),
157            tools: Vec::new(),
158            max_iterations: 10,
159            emitter: None,
160            last_usage: Mutex::new(None),
161        }
162    }
163
164    /// Set the event emitter for streaming tool events
165    pub fn with_emitter(mut self, emitter: Arc<EventEmitter>) -> Self {
166        self.emitter = Some(emitter);
167        self
168    }
169
170    /// Pin a logical model id for this callable.
171    pub fn with_model(mut self, model: impl Into<String>) -> Self {
172        self.requested_model = Some(model.into());
173        self
174    }
175
176    /// Override routing policy for this callable.
177    pub fn with_routing_policy(mut self, policy: RoutingPolicy) -> Self {
178        self.routing_policy = policy;
179        self
180    }
181
182    pub fn with_description(mut self, description: impl Into<String>) -> Self {
183        self.description = Some(description.into());
184        self
185    }
186
187    /// Add a tool to the callable
188    pub fn add_tool(mut self, tool: impl Tool + 'static) -> Self {
189        self.tools.push(Arc::new(tool));
190        self
191    }
192
193    /// Add multiple tools
194    pub fn add_tools(mut self, tools: Vec<DynTool>) -> Self {
195        self.tools.extend(tools);
196        self
197    }
198
199    /// Set max iterations for tool loop
200    pub fn max_iterations(mut self, max: usize) -> Self {
201        self.max_iterations = max;
202        self
203    }
204
205    /// Execute a tool by name
206    async fn execute_tool(&self, name: &str, args: Value) -> anyhow::Result<Value> {
207        let tool = self
208            .tools
209            .iter()
210            .find(|t| t.name() == name)
211            .ok_or_else(|| anyhow::anyhow!("Tool '{}' not found", name))?;
212
213        tool.execute(args).await
214    }
215
216    /// Build ChatTool list for request (OpenAI shape)
217    fn build_chat_tools(&self) -> Vec<ChatTool> {
218        self.tools
219            .iter()
220            .map(|t| ChatTool {
221                tool_type: "function".to_string(),
222                function: ChatToolFunction {
223                    name: t.name().to_string(),
224                    description: t.description().to_string(),
225                    parameters: t.parameters_schema(),
226                },
227            })
228            .collect()
229    }
230
231    /// Map native message tool_calls to internal ToolCall (parse arguments JSON)
232    fn message_tool_calls_to_internal(&self, tool_calls: &[MessageToolCall]) -> Vec<ToolCall> {
233        tool_calls
234            .iter()
235            .map(|tc| {
236                let arguments = serde_json::from_str(&tc.function.arguments).unwrap_or(Value::Null);
237                ToolCall {
238                    id: tc.id.clone(),
239                    name: tc.function.name.clone(),
240                    arguments,
241                }
242            })
243            .collect()
244    }
245
246    fn resolve_routing(&self) -> RoutingDecision {
247        ModelRouter::resolve(
248            self.requested_model.as_deref(),
249            self.provider.as_ref(),
250            &self.routing_policy,
251        )
252    }
253}
254
255#[async_trait]
256impl Callable for LlmCallable {
257    fn name(&self) -> &str {
258        &self.name
259    }
260
261    fn description(&self) -> Option<&str> {
262        self.description.as_deref()
263    }
264
265    async fn run_streaming(
266        &self,
267        input: &str,
268        event_tx: mpsc::Sender<StreamEvent>,
269    ) -> anyhow::Result<String> {
270        let emitter = self.emitter.clone();
271        let tx = event_tx.clone();
272        let poll_handle = if emitter.is_some() {
273            Some(tokio::spawn(async move {
274                let emitter = match &emitter {
275                    Some(e) => e,
276                    None => return,
277                };
278                let mut interval = interval(Duration::from_millis(50));
279                loop {
280                    interval.tick().await;
281                    let events = emitter.drain();
282                    for ev in events {
283                        if tx.send(ev).await.is_err() {
284                            return;
285                        }
286                    }
287                }
288            }))
289        } else {
290            None
291        };
292
293        let result = self.run(input).await;
294
295        if let Some(ref e) = self.emitter {
296            for ev in e.drain() {
297                let _ = event_tx.send(ev).await;
298            }
299        }
300        drop(event_tx);
301        if let Some(h) = poll_handle {
302            let _ = h.await;
303        }
304
305        result
306    }
307
308    async fn run(&self, input: &str) -> anyhow::Result<String> {
309        *self.last_usage.lock().expect("last_usage mutex") = None;
310
311        if !self.tools.is_empty() && !self.provider.capabilities().supports_tools {
312            anyhow::bail!(
313                "Callable has {} tool(s) but provider does not support native tools (supports_tools is false)",
314                self.tools.len()
315            );
316        }
317
318        let routing = self.resolve_routing();
319        tracing::info!(
320            callable = %self.name,
321            logical_model = %routing.logical_model,
322            concrete_model = %routing.concrete_model,
323            profile = ?routing.profile,
324            confidence = routing.confidence,
325            used_default_router = routing.used_default_router,
326            rationale = %routing.rationale,
327            "Model routing decision resolved"
328        );
329
330        // Check if input is multimodal (contains images)
331        let user_message = if let Some(multimodal) = MultimodalInput::parse(input) {
332            tracing::debug!(
333                image_count = multimodal.images.len(),
334                text_len = multimodal.text.len(),
335                "Processing multimodal input with images"
336            );
337
338            if !self.provider.capabilities().supports_vision {
339                tracing::warn!(
340                    "Provider does not support vision, falling back to text-only. \
341                     Images will be ignored. Consider using a vision-capable model."
342                );
343                ChatMessage::user(&multimodal.text)
344            } else {
345                // Build multimodal message with images
346                use base64::Engine;
347                let mut parts = vec![ContentPart::text(&multimodal.text)];
348                for img in &multimodal.images {
349                    // Decode and re-encode to ensure valid base64
350                    if let Ok(data) = base64::engine::general_purpose::STANDARD.decode(&img.data) {
351                        parts.push(ContentPart::image_base64(
352                            base64::engine::general_purpose::STANDARD.encode(&data),
353                            &img.mime_type,
354                        ));
355                    } else {
356                        tracing::warn!(mime_type = %img.mime_type, "Failed to decode image base64 data");
357                    }
358                }
359
360                ChatMessage {
361                    role: "user".to_string(),
362                    content: None,
363                    multimodal_content: Some(parts),
364                    tool_calls: None,
365                    tool_call_id: None,
366                }
367            }
368        } else {
369            ChatMessage::user(input)
370        };
371
372        let mut messages = vec![ChatMessage::system(&self.system_prompt), user_message];
373
374        let (tools, tool_choice) = if self.tools.is_empty() {
375            (None, None)
376        } else {
377            (
378                Some(self.build_chat_tools()),
379                Some(ToolChoice::String("auto".to_string())),
380            )
381        };
382
383        let mut accumulated_usage: Option<TokenUsage> = None;
384
385        for iteration in 0..self.max_iterations {
386            tracing::debug!(iteration, "Callable iteration");
387
388            let request = ChatRequest {
389                messages: messages.clone(),
390                max_tokens: Some(4096),
391                temperature: Some(0.7),
392                tools: tools.clone(),
393                tool_choice: tool_choice.clone(),
394            };
395
396            let response = self.provider.chat(request).await?;
397
398            if let Some(ref u) = response.usage {
399                accumulated_usage = Some(match accumulated_usage {
400                    None => TokenUsage::new(u.prompt_tokens, u.completion_tokens),
401                    Some(a) => TokenUsage::new(
402                        a.prompt_tokens + u.prompt_tokens,
403                        a.completion_tokens + u.completion_tokens,
404                    ),
405                });
406            }
407
408            let choice = response
409                .choices
410                .first()
411                .ok_or_else(|| anyhow::anyhow!("Empty choices in chat response"))?;
412            let msg = &choice.message;
413
414            let native_tool_calls = msg.tool_calls.as_deref().unwrap_or(&[]);
415            if native_tool_calls.is_empty() {
416                let content = msg.content.clone().unwrap_or_default();
417                *self.last_usage.lock().expect("last_usage mutex") = accumulated_usage;
418                return Ok(content);
419            }
420
421            let calls = self.message_tool_calls_to_internal(native_tool_calls);
422            messages.push(ChatMessage::assistant_with_tool_calls(
423                msg.content.clone(),
424                native_tool_calls.to_vec(),
425            ));
426
427            for call in &calls {
428                tracing::debug!(tool = %call.name, "Executing tool");
429
430                if let Some(ref emitter) = self.emitter {
431                    emitter.emit(StreamEvent::ToolInputAvailable {
432                        tool_call_id: call.id.clone(),
433                        tool_name: call.name.clone(),
434                        input: call.arguments.clone(),
435                    });
436                }
437
438                let tool_start = std::time::Instant::now();
439                let result = self
440                    .execute_tool(&call.name, call.arguments.clone())
441                    .await?;
442                let tool_duration_ms = tool_start.elapsed().as_millis() as u64;
443
444                if let Some(ref emitter) = self.emitter {
445                    emitter.emit(StreamEvent::ToolOutputAvailable {
446                        tool_call_id: call.id.clone(),
447                        output: serde_json::json!({
448                            "result": result.clone(),
449                            "duration_ms": tool_duration_ms,
450                        }),
451                    });
452                }
453
454                let result_str = serde_json::to_string(&result)?;
455                messages.push(ChatMessage::tool_result(&call.id, &result_str));
456            }
457        }
458
459        *self.last_usage.lock().expect("last_usage mutex") = accumulated_usage;
460        anyhow::bail!("Max iterations ({}) reached", self.max_iterations)
461    }
462
463    fn last_usage(&self) -> Option<crate::kernel::LlmTokenUsage> {
464        self.last_usage.lock().expect("last_usage mutex").clone()
465    }
466}
467
468#[cfg(test)]
469mod tests {
470    use super::*;
471    use crate::providers::{ChatChoice, ChatResponse, MessageToolCall, MessageToolCallFunction};
472    use crate::tool::Tool;
473    use async_trait::async_trait;
474
475    struct MockProviderNoTools;
476    #[async_trait]
477    impl ModelProvider for MockProviderNoTools {
478        fn name(&self) -> &str {
479            "mock-no-tools"
480        }
481        fn capabilities(&self) -> crate::providers::ModelCapabilities {
482            crate::providers::ModelCapabilities {
483                supports_tools: false,
484                ..Default::default()
485            }
486        }
487        async fn chat(&self, _request: ChatRequest) -> anyhow::Result<ChatResponse> {
488            Ok(ChatResponse {
489                id: "id".to_string(),
490                choices: vec![ChatChoice {
491                    index: 0,
492                    message: ChatMessage::assistant("ok"),
493                    finish_reason: Some("stop".to_string()),
494                }],
495                usage: None,
496            })
497        }
498    }
499
500    struct EchoTool;
501    #[async_trait]
502    impl Tool for EchoTool {
503        fn name(&self) -> &str {
504            "echo"
505        }
506        fn description(&self) -> &str {
507            "Echoes input"
508        }
509        async fn execute(&self, args: Value) -> anyhow::Result<Value> {
510            Ok(args.get("x").cloned().unwrap_or(Value::Null))
511        }
512    }
513
514    #[tokio::test]
515    async fn run_errors_when_tools_registered_but_provider_does_not_support_tools() {
516        let provider = Arc::new(MockProviderNoTools);
517        let callable =
518            LlmCallable::with_provider("test", "You are helpful", provider).add_tool(EchoTool);
519
520        let err = callable.run("hello").await.unwrap_err();
521        assert!(
522            err.to_string().contains("does not support native tools"),
523            "expected error about supports_tools, got: {}",
524            err
525        );
526    }
527
528    /// Mock provider that returns tool_calls on first call, then final content when request includes tool results.
529    struct MockProviderWithToolCalls {
530        call_count: std::sync::atomic::AtomicUsize,
531    }
532    #[async_trait]
533    impl ModelProvider for MockProviderWithToolCalls {
534        fn name(&self) -> &str {
535            "mock-with-tools"
536        }
537        fn capabilities(&self) -> crate::providers::ModelCapabilities {
538            crate::providers::ModelCapabilities {
539                supports_tools: true,
540                ..Default::default()
541            }
542        }
543        async fn chat(&self, request: ChatRequest) -> anyhow::Result<ChatResponse> {
544            let n = self
545                .call_count
546                .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
547            let has_tool_result = request.messages.iter().any(|m| m.role == "tool");
548            if !has_tool_result && n == 0 {
549                return Ok(ChatResponse {
550                    id: "id".to_string(),
551                    choices: vec![ChatChoice {
552                        index: 0,
553                        message: ChatMessage::assistant_with_tool_calls(
554                            None,
555                            vec![MessageToolCall {
556                                id: "call-1".to_string(),
557                                call_type: "function".to_string(),
558                                function: MessageToolCallFunction {
559                                    name: "echo".to_string(),
560                                    arguments: r#"{"x": "world"}"#.to_string(),
561                                },
562                            }],
563                        ),
564                        finish_reason: Some("tool_calls".to_string()),
565                    }],
566                    usage: None,
567                });
568            }
569            Ok(ChatResponse {
570                id: "id".to_string(),
571                choices: vec![ChatChoice {
572                    index: 0,
573                    message: ChatMessage::assistant("Final: world"),
574                    finish_reason: Some("stop".to_string()),
575                }],
576                usage: None,
577            })
578        }
579    }
580
581    #[tokio::test]
582    async fn run_uses_native_tool_calls_and_returns_final_content() {
583        let provider = Arc::new(MockProviderWithToolCalls {
584            call_count: std::sync::atomic::AtomicUsize::new(0),
585        });
586        let callable = LlmCallable::with_provider("test", "You are helpful", provider)
587            .add_tool(EchoTool)
588            .max_iterations(5);
589
590        let out = callable.run("hello").await.unwrap();
591        assert_eq!(out, "Final: world");
592    }
593}