1#[cfg(feature = "native")]
8pub mod apr_serve;
9pub mod chat_template;
10pub mod mock;
11#[cfg(feature = "inference")]
12pub mod realizar;
13#[cfg(feature = "native")]
14pub mod remote;
15#[cfg(feature = "native")]
16pub mod router;
17pub mod validate;
18
19use async_trait::async_trait;
20use serde::{Deserialize, Serialize};
21
22use crate::agent::phase::LoopPhase;
23use crate::agent::result::{AgentError, StopReason, TokenUsage};
24use crate::serve::backends::PrivacyTier;
25
26#[derive(Debug, Clone, Serialize, Deserialize)]
28pub enum Message {
29 System(String),
31 User(String),
33 Assistant(String),
35 AssistantToolUse(ToolCall),
37 ToolResult(ToolResultMsg),
39}
40
41impl Message {
42 pub fn to_chat_message(&self) -> crate::serve::templates::ChatMessage {
47 use crate::serve::templates::ChatMessage;
48 match self {
49 Self::System(s) => ChatMessage::system(s),
50 Self::User(s) => ChatMessage::user(s),
51 Self::Assistant(s) => ChatMessage::assistant(s),
52 Self::AssistantToolUse(call) => {
53 ChatMessage::assistant(format!("[tool_use: {} {}]", call.name, call.input))
54 }
55 Self::ToolResult(result) => {
56 ChatMessage::user(format!("[tool_result: {}]", result.content))
57 }
58 }
59 }
60}
61
62#[derive(Debug, Clone, Serialize, Deserialize)]
64pub struct ToolCall {
65 pub id: String,
67 pub name: String,
69 pub input: serde_json::Value,
71}
72
73#[derive(Debug, Clone, Serialize, Deserialize)]
75pub struct ToolResultMsg {
76 pub tool_use_id: String,
78 pub content: String,
80 pub is_error: bool,
82}
83
84#[derive(Debug, Clone, Serialize, Deserialize)]
86pub struct ToolDefinition {
87 pub name: String,
89 pub description: String,
91 pub input_schema: serde_json::Value,
93}
94
95#[derive(Debug, Clone)]
97pub struct CompletionRequest {
98 pub model: String,
100 pub messages: Vec<Message>,
102 pub tools: Vec<ToolDefinition>,
104 pub max_tokens: u32,
106 pub temperature: f32,
108 pub system: Option<String>,
110}
111
112#[derive(Debug, Clone)]
114pub struct CompletionResponse {
115 pub text: String,
117 pub stop_reason: StopReason,
119 pub tool_calls: Vec<ToolCall>,
121 pub usage: TokenUsage,
123}
124
125#[derive(Debug, Clone)]
127pub enum StreamEvent {
128 PhaseChange {
130 phase: LoopPhase,
132 },
133 TextDelta {
135 text: String,
137 },
138 ToolUseStart {
140 id: String,
142 name: String,
144 },
145 ToolUseEnd {
147 id: String,
149 name: String,
151 result: String,
153 },
154 ContentComplete {
156 stop_reason: StopReason,
158 usage: TokenUsage,
160 },
161}
162
163#[async_trait]
167pub trait LlmDriver: Send + Sync {
168 async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse, AgentError>;
170
171 async fn stream(
176 &self,
177 request: CompletionRequest,
178 tx: tokio::sync::mpsc::Sender<StreamEvent>,
179 ) -> Result<CompletionResponse, AgentError> {
180 let response = self.complete(request).await?;
181 let _ = tx.send(StreamEvent::TextDelta { text: response.text.clone() }).await;
182 let _ = tx
183 .send(StreamEvent::ContentComplete {
184 stop_reason: response.stop_reason.clone(),
185 usage: response.usage.clone(),
186 })
187 .await;
188 Ok(response)
189 }
190
191 fn context_window(&self) -> usize;
193
194 fn privacy_tier(&self) -> PrivacyTier;
196
197 fn estimate_cost(&self, _usage: &TokenUsage) -> f64 {
202 0.0
203 }
204}
205
206#[cfg(test)]
207mod tests {
208 use super::*;
209
210 #[test]
211 fn test_message_serialization() {
212 let msgs = vec![
213 Message::System("sys".into()),
214 Message::User("hello".into()),
215 Message::Assistant("hi".into()),
216 ];
217 for msg in &msgs {
218 let json = serde_json::to_string(msg).expect("serialize failed");
219 let back: Message = serde_json::from_str(&json).expect("deserialize failed");
220 match (msg, &back) {
221 (Message::System(a), Message::System(b)) => {
222 assert_eq!(a, b);
223 }
224 (Message::User(a), Message::User(b)) => assert_eq!(a, b),
225 (Message::Assistant(a), Message::Assistant(b)) => {
226 assert_eq!(a, b);
227 }
228 _ => panic!("mismatch"),
229 }
230 }
231 }
232
233 #[test]
234 fn test_tool_call_serialization() {
235 let call = ToolCall {
236 id: "1".into(),
237 name: "rag".into(),
238 input: serde_json::json!({"query": "test"}),
239 };
240 let json = serde_json::to_string(&call).expect("serialize failed");
241 let back: ToolCall = serde_json::from_str(&json).expect("deserialize failed");
242 assert_eq!(back.name, "rag");
243 }
244
245 #[test]
246 fn test_tool_definition_serialization() {
247 let def = ToolDefinition {
248 name: "memory".into(),
249 description: "Read/write memory".into(),
250 input_schema: serde_json::json!({
251 "type": "object",
252 "properties": {
253 "action": {"type": "string"}
254 }
255 }),
256 };
257 let json = serde_json::to_string(&def).expect("serialize failed");
258 assert!(json.contains("memory"));
259 }
260
261 #[tokio::test]
262 async fn test_stream_default_wraps_complete() {
263 use crate::agent::driver::mock::MockDriver;
264 use tokio::sync::mpsc;
265
266 let driver = MockDriver::single_response("streamed");
267 let (tx, mut rx) = mpsc::channel(16);
268
269 let request = CompletionRequest {
270 model: String::new(),
271 messages: vec![Message::User("hi".into())],
272 tools: vec![],
273 max_tokens: 100,
274 temperature: 0.5,
275 system: None,
276 };
277
278 let response = driver.stream(request, tx).await.expect("stream failed");
279 assert_eq!(response.text, "streamed");
280
281 let mut got_text = false;
282 let mut got_complete = false;
283 while let Ok(event) = rx.try_recv() {
284 match event {
285 StreamEvent::TextDelta { text } => {
286 assert_eq!(text, "streamed");
287 got_text = true;
288 }
289 StreamEvent::ContentComplete { .. } => {
290 got_complete = true;
291 }
292 _ => {}
293 }
294 }
295 assert!(got_text, "expected TextDelta event");
296 assert!(got_complete, "expected ContentComplete event");
297 }
298}