#[cfg(feature = "native")]
pub mod apr_serve;
pub mod chat_template;
pub mod mock;
#[cfg(feature = "inference")]
pub mod realizar;
#[cfg(feature = "native")]
pub mod remote;
#[cfg(feature = "native")]
pub mod router;
pub mod validate;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use crate::agent::phase::LoopPhase;
use crate::agent::result::{AgentError, StopReason, TokenUsage};
use crate::serve::backends::PrivacyTier;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum Message {
System(String),
User(String),
Assistant(String),
AssistantToolUse(ToolCall),
ToolResult(ToolResultMsg),
}
impl Message {
pub fn to_chat_message(&self) -> crate::serve::templates::ChatMessage {
use crate::serve::templates::ChatMessage;
match self {
Self::System(s) => ChatMessage::system(s),
Self::User(s) => ChatMessage::user(s),
Self::Assistant(s) => ChatMessage::assistant(s),
Self::AssistantToolUse(call) => {
ChatMessage::assistant(format!("[tool_use: {} {}]", call.name, call.input))
}
Self::ToolResult(result) => {
ChatMessage::user(format!("[tool_result: {}]", result.content))
}
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCall {
pub id: String,
pub name: String,
pub input: serde_json::Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolResultMsg {
pub tool_use_id: String,
pub content: String,
pub is_error: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolDefinition {
pub name: String,
pub description: String,
pub input_schema: serde_json::Value,
}
#[derive(Debug, Clone)]
pub struct CompletionRequest {
pub model: String,
pub messages: Vec<Message>,
pub tools: Vec<ToolDefinition>,
pub max_tokens: u32,
pub temperature: f32,
pub system: Option<String>,
}
#[derive(Debug, Clone)]
pub struct CompletionResponse {
pub text: String,
pub stop_reason: StopReason,
pub tool_calls: Vec<ToolCall>,
pub usage: TokenUsage,
}
#[derive(Debug, Clone)]
pub enum StreamEvent {
PhaseChange {
phase: LoopPhase,
},
TextDelta {
text: String,
},
ToolUseStart {
id: String,
name: String,
},
ToolUseEnd {
id: String,
name: String,
result: String,
},
ContentComplete {
stop_reason: StopReason,
usage: TokenUsage,
},
}
#[async_trait]
pub trait LlmDriver: Send + Sync {
async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse, AgentError>;
async fn stream(
&self,
request: CompletionRequest,
tx: tokio::sync::mpsc::Sender<StreamEvent>,
) -> Result<CompletionResponse, AgentError> {
let response = self.complete(request).await?;
let _ = tx.send(StreamEvent::TextDelta { text: response.text.clone() }).await;
let _ = tx
.send(StreamEvent::ContentComplete {
stop_reason: response.stop_reason.clone(),
usage: response.usage.clone(),
})
.await;
Ok(response)
}
fn context_window(&self) -> usize;
fn privacy_tier(&self) -> PrivacyTier;
fn estimate_cost(&self, _usage: &TokenUsage) -> f64 {
0.0
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_message_serialization() {
let msgs = vec![
Message::System("sys".into()),
Message::User("hello".into()),
Message::Assistant("hi".into()),
];
for msg in &msgs {
let json = serde_json::to_string(msg).expect("serialize failed");
let back: Message = serde_json::from_str(&json).expect("deserialize failed");
match (msg, &back) {
(Message::System(a), Message::System(b)) => {
assert_eq!(a, b);
}
(Message::User(a), Message::User(b)) => assert_eq!(a, b),
(Message::Assistant(a), Message::Assistant(b)) => {
assert_eq!(a, b);
}
_ => panic!("mismatch"),
}
}
}
#[test]
fn test_tool_call_serialization() {
let call = ToolCall {
id: "1".into(),
name: "rag".into(),
input: serde_json::json!({"query": "test"}),
};
let json = serde_json::to_string(&call).expect("serialize failed");
let back: ToolCall = serde_json::from_str(&json).expect("deserialize failed");
assert_eq!(back.name, "rag");
}
#[test]
fn test_tool_definition_serialization() {
let def = ToolDefinition {
name: "memory".into(),
description: "Read/write memory".into(),
input_schema: serde_json::json!({
"type": "object",
"properties": {
"action": {"type": "string"}
}
}),
};
let json = serde_json::to_string(&def).expect("serialize failed");
assert!(json.contains("memory"));
}
#[tokio::test]
async fn test_stream_default_wraps_complete() {
use crate::agent::driver::mock::MockDriver;
use tokio::sync::mpsc;
let driver = MockDriver::single_response("streamed");
let (tx, mut rx) = mpsc::channel(16);
let request = CompletionRequest {
model: String::new(),
messages: vec![Message::User("hi".into())],
tools: vec![],
max_tokens: 100,
temperature: 0.5,
system: None,
};
let response = driver.stream(request, tx).await.expect("stream failed");
assert_eq!(response.text, "streamed");
let mut got_text = false;
let mut got_complete = false;
while let Ok(event) = rx.try_recv() {
match event {
StreamEvent::TextDelta { text } => {
assert_eq!(text, "streamed");
got_text = true;
}
StreamEvent::ContentComplete { .. } => {
got_complete = true;
}
_ => {}
}
}
assert!(got_text, "expected TextDelta event");
assert!(got_complete, "expected ContentComplete event");
}
}