Skip to main content

batuta/agent/driver/
mod.rs

1//! LLM driver abstraction.
2//!
3//! Defines the `LlmDriver` trait — the interface between the agent
4//! loop and LLM inference backends. The default implementation is
5//! `RealizarDriver` (sovereign, local GGUF/APR inference).
6
7#[cfg(feature = "native")]
8pub mod apr_serve;
9pub mod chat_template;
10pub mod mock;
11#[cfg(feature = "inference")]
12pub mod realizar;
13#[cfg(feature = "native")]
14pub mod remote;
15#[cfg(feature = "native")]
16pub mod router;
17pub mod validate;
18
19use async_trait::async_trait;
20use serde::{Deserialize, Serialize};
21
22use crate::agent::phase::LoopPhase;
23use crate::agent::result::{AgentError, StopReason, TokenUsage};
24use crate::serve::backends::PrivacyTier;
25
26/// Message in the agent conversation.
27#[derive(Debug, Clone, Serialize, Deserialize)]
28pub enum Message {
29    /// System prompt (injected once at start).
30    System(String),
31    /// User query or follow-up.
32    User(String),
33    /// Assistant text response.
34    Assistant(String),
35    /// Assistant tool use request.
36    AssistantToolUse(ToolCall),
37    /// Tool execution result.
38    ToolResult(ToolResultMsg),
39}
40
41impl Message {
42    /// Convert to `ChatMessage` for context window truncation.
43    ///
44    /// Tool-use and tool-result messages are serialized as
45    /// assistant/user text so the token estimator can size them.
46    pub fn to_chat_message(&self) -> crate::serve::templates::ChatMessage {
47        use crate::serve::templates::ChatMessage;
48        match self {
49            Self::System(s) => ChatMessage::system(s),
50            Self::User(s) => ChatMessage::user(s),
51            Self::Assistant(s) => ChatMessage::assistant(s),
52            Self::AssistantToolUse(call) => {
53                ChatMessage::assistant(format!("[tool_use: {} {}]", call.name, call.input))
54            }
55            Self::ToolResult(result) => {
56                ChatMessage::user(format!("[tool_result: {}]", result.content))
57            }
58        }
59    }
60}
61
62/// A tool call request from the model.
63#[derive(Debug, Clone, Serialize, Deserialize)]
64pub struct ToolCall {
65    /// Unique ID for this tool call.
66    pub id: String,
67    /// Tool name.
68    pub name: String,
69    /// Tool input as JSON.
70    pub input: serde_json::Value,
71}
72
73/// A tool result message.
74#[derive(Debug, Clone, Serialize, Deserialize)]
75pub struct ToolResultMsg {
76    /// ID of the tool call this is responding to.
77    pub tool_use_id: String,
78    /// Result content.
79    pub content: String,
80    /// Whether the tool call errored.
81    pub is_error: bool,
82}
83
84/// Tool definition for the LLM (JSON Schema).
85#[derive(Debug, Clone, Serialize, Deserialize)]
86pub struct ToolDefinition {
87    /// Tool name (must match Tool trait name).
88    pub name: String,
89    /// Human-readable description.
90    pub description: String,
91    /// JSON Schema for the tool's input.
92    pub input_schema: serde_json::Value,
93}
94
95/// Request to the LLM driver.
96#[derive(Debug, Clone)]
97pub struct CompletionRequest {
98    /// Model identifier.
99    pub model: String,
100    /// Conversation messages.
101    pub messages: Vec<Message>,
102    /// Available tools.
103    pub tools: Vec<ToolDefinition>,
104    /// Maximum tokens to generate.
105    pub max_tokens: u32,
106    /// Sampling temperature.
107    pub temperature: f32,
108    /// System prompt (separate from messages).
109    pub system: Option<String>,
110}
111
112/// Response from the LLM driver.
113#[derive(Debug, Clone)]
114pub struct CompletionResponse {
115    /// Generated text (may be empty if only tool calls).
116    pub text: String,
117    /// Why the model stopped generating.
118    pub stop_reason: StopReason,
119    /// Tool calls requested by the model.
120    pub tool_calls: Vec<ToolCall>,
121    /// Token usage for this completion.
122    pub usage: TokenUsage,
123}
124
125/// Streaming event from the LLM driver.
126#[derive(Debug, Clone)]
127pub enum StreamEvent {
128    /// Agent loop phase changed.
129    PhaseChange {
130        /// New phase.
131        phase: LoopPhase,
132    },
133    /// Incremental text from the model.
134    TextDelta {
135        /// Text fragment.
136        text: String,
137    },
138    /// Tool call started.
139    ToolUseStart {
140        /// Tool call ID.
141        id: String,
142        /// Tool name.
143        name: String,
144    },
145    /// Tool call completed.
146    ToolUseEnd {
147        /// Tool call ID.
148        id: String,
149        /// Tool name.
150        name: String,
151        /// Tool result.
152        result: String,
153    },
154    /// Completion finished.
155    ContentComplete {
156        /// Stop reason.
157        stop_reason: StopReason,
158        /// Usage for this completion.
159        usage: TokenUsage,
160    },
161}
162
163/// Abstraction over LLM inference backends.
164///
165/// Default implementation: `RealizarDriver` (sovereign, local).
166#[async_trait]
167pub trait LlmDriver: Send + Sync {
168    /// Non-streaming completion.
169    async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse, AgentError>;
170
171    /// Streaming completion with channel-based events.
172    ///
173    /// Default wraps `complete()` for drivers that don't support
174    /// native streaming. Override for token-by-token output.
175    async fn stream(
176        &self,
177        request: CompletionRequest,
178        tx: tokio::sync::mpsc::Sender<StreamEvent>,
179    ) -> Result<CompletionResponse, AgentError> {
180        let response = self.complete(request).await?;
181        let _ = tx.send(StreamEvent::TextDelta { text: response.text.clone() }).await;
182        let _ = tx
183            .send(StreamEvent::ContentComplete {
184                stop_reason: response.stop_reason.clone(),
185                usage: response.usage.clone(),
186            })
187            .await;
188        Ok(response)
189    }
190
191    /// Maximum context window in tokens.
192    fn context_window(&self) -> usize;
193
194    /// Privacy tier this driver operates at.
195    fn privacy_tier(&self) -> PrivacyTier;
196
197    /// Estimate cost in USD for a single completion's token usage.
198    ///
199    /// Default: 0.0 (sovereign/local inference is free).
200    /// Remote drivers override with their pricing model.
201    fn estimate_cost(&self, _usage: &TokenUsage) -> f64 {
202        0.0
203    }
204}
205
206#[cfg(test)]
207mod tests {
208    use super::*;
209
210    #[test]
211    fn test_message_serialization() {
212        let msgs = vec![
213            Message::System("sys".into()),
214            Message::User("hello".into()),
215            Message::Assistant("hi".into()),
216        ];
217        for msg in &msgs {
218            let json = serde_json::to_string(msg).expect("serialize failed");
219            let back: Message = serde_json::from_str(&json).expect("deserialize failed");
220            match (msg, &back) {
221                (Message::System(a), Message::System(b)) => {
222                    assert_eq!(a, b);
223                }
224                (Message::User(a), Message::User(b)) => assert_eq!(a, b),
225                (Message::Assistant(a), Message::Assistant(b)) => {
226                    assert_eq!(a, b);
227                }
228                _ => panic!("mismatch"),
229            }
230        }
231    }
232
233    #[test]
234    fn test_tool_call_serialization() {
235        let call = ToolCall {
236            id: "1".into(),
237            name: "rag".into(),
238            input: serde_json::json!({"query": "test"}),
239        };
240        let json = serde_json::to_string(&call).expect("serialize failed");
241        let back: ToolCall = serde_json::from_str(&json).expect("deserialize failed");
242        assert_eq!(back.name, "rag");
243    }
244
245    #[test]
246    fn test_tool_definition_serialization() {
247        let def = ToolDefinition {
248            name: "memory".into(),
249            description: "Read/write memory".into(),
250            input_schema: serde_json::json!({
251                "type": "object",
252                "properties": {
253                    "action": {"type": "string"}
254                }
255            }),
256        };
257        let json = serde_json::to_string(&def).expect("serialize failed");
258        assert!(json.contains("memory"));
259    }
260
261    #[tokio::test]
262    async fn test_stream_default_wraps_complete() {
263        use crate::agent::driver::mock::MockDriver;
264        use tokio::sync::mpsc;
265
266        let driver = MockDriver::single_response("streamed");
267        let (tx, mut rx) = mpsc::channel(16);
268
269        let request = CompletionRequest {
270            model: String::new(),
271            messages: vec![Message::User("hi".into())],
272            tools: vec![],
273            max_tokens: 100,
274            temperature: 0.5,
275            system: None,
276        };
277
278        let response = driver.stream(request, tx).await.expect("stream failed");
279        assert_eq!(response.text, "streamed");
280
281        let mut got_text = false;
282        let mut got_complete = false;
283        while let Ok(event) = rx.try_recv() {
284            match event {
285                StreamEvent::TextDelta { text } => {
286                    assert_eq!(text, "streamed");
287                    got_text = true;
288                }
289                StreamEvent::ContentComplete { .. } => {
290                    got_complete = true;
291                }
292                _ => {}
293            }
294        }
295        assert!(got_text, "expected TextDelta event");
296        assert!(got_complete, "expected ContentComplete event");
297    }
298}