aof_core/
model.rs

1use async_trait::async_trait;
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::pin::Pin;
5use std::sync::Arc;
6
7use crate::AofResult;
8
9/// Model provider trait - abstraction over LLM providers
10///
11/// Implementations should minimize allocations and use zero-copy where possible.
12#[async_trait]
13pub trait Model: Send + Sync {
14    /// Generate completion (non-streaming)
15    async fn generate(&self, request: &ModelRequest) -> AofResult<ModelResponse>;
16
17    /// Generate completion (streaming)
18    async fn generate_stream(
19        &self,
20        request: &ModelRequest,
21    ) -> AofResult<Pin<Box<dyn futures::Stream<Item = AofResult<StreamChunk>> + Send>>>;
22
23    /// Model configuration
24    fn config(&self) -> &ModelConfig;
25
26    /// Provider type
27    fn provider(&self) -> ModelProvider;
28
29    /// Count tokens in text (approximate)
30    fn count_tokens(&self, text: &str) -> usize {
31        // Default: rough approximation (4 chars per token)
32        text.len() / 4
33    }
34}
35
36/// Model provider enum
37#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
38#[serde(rename_all = "lowercase")]
39pub enum ModelProvider {
40    Anthropic,
41    OpenAI,
42    Google,
43    Groq,
44    Bedrock,
45    Azure,
46    Ollama,
47    Custom,
48}
49
50/// Model configuration
51#[derive(Debug, Clone, Serialize, Deserialize)]
52pub struct ModelConfig {
53    /// Model identifier (e.g., "claude-3-5-sonnet-20241022")
54    pub model: String,
55
56    /// Provider
57    pub provider: ModelProvider,
58
59    /// API key (optional, can use env var)
60    #[serde(skip_serializing_if = "Option::is_none")]
61    pub api_key: Option<String>,
62
63    /// API endpoint (for custom providers)
64    #[serde(skip_serializing_if = "Option::is_none")]
65    pub endpoint: Option<String>,
66
67    /// Temperature (0.0-1.0)
68    #[serde(default = "default_temperature")]
69    pub temperature: f32,
70
71    /// Max tokens
72    #[serde(skip_serializing_if = "Option::is_none")]
73    pub max_tokens: Option<usize>,
74
75    /// Timeout (seconds)
76    #[serde(default = "default_timeout")]
77    pub timeout_secs: u64,
78
79    /// Custom headers
80    #[serde(default)]
81    pub headers: HashMap<String, String>,
82
83    /// Extra provider-specific config
84    #[serde(flatten)]
85    pub extra: HashMap<String, serde_json::Value>,
86}
87
88fn default_temperature() -> f32 {
89    0.7
90}
91
92fn default_timeout() -> u64 {
93    60
94}
95
96/// Model request
97#[derive(Debug, Clone, Serialize, Deserialize)]
98pub struct ModelRequest {
99    /// Messages in conversation
100    pub messages: Vec<RequestMessage>,
101
102    /// System prompt (optional)
103    #[serde(skip_serializing_if = "Option::is_none")]
104    pub system: Option<String>,
105
106    /// Tools available
107    #[serde(skip_serializing_if = "Vec::is_empty", default)]
108    pub tools: Vec<ToolDefinition>,
109
110    /// Temperature override
111    #[serde(skip_serializing_if = "Option::is_none")]
112    pub temperature: Option<f32>,
113
114    /// Max tokens override
115    #[serde(skip_serializing_if = "Option::is_none")]
116    pub max_tokens: Option<usize>,
117
118    /// Stream response
119    #[serde(default)]
120    pub stream: bool,
121
122    /// Extra parameters
123    #[serde(flatten)]
124    pub extra: HashMap<String, serde_json::Value>,
125}
126
127/// Message in request
128#[derive(Debug, Clone, Serialize, Deserialize)]
129pub struct RequestMessage {
130    pub role: MessageRole,
131    pub content: String,
132    #[serde(skip_serializing_if = "Option::is_none")]
133    pub tool_calls: Option<Vec<crate::ToolCall>>,
134    /// Tool call ID (required for Tool role messages)
135    #[serde(skip_serializing_if = "Option::is_none")]
136    pub tool_call_id: Option<String>,
137}
138
139/// Message role
140#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
141#[serde(rename_all = "lowercase")]
142pub enum MessageRole {
143    User,
144    Assistant,
145    System,
146    Tool,
147}
148
149/// Tool definition for model
150#[derive(Debug, Clone, Serialize, Deserialize)]
151pub struct ToolDefinition {
152    pub name: String,
153    pub description: String,
154    pub parameters: serde_json::Value,
155}
156
157/// Model response
158#[derive(Debug, Clone, Serialize, Deserialize)]
159pub struct ModelResponse {
160    /// Generated text
161    pub content: String,
162
163    /// Tool calls requested by model
164    #[serde(skip_serializing_if = "Vec::is_empty", default)]
165    pub tool_calls: Vec<crate::ToolCall>,
166
167    /// Stop reason
168    pub stop_reason: StopReason,
169
170    /// Usage statistics
171    pub usage: Usage,
172
173    /// Provider-specific metadata
174    #[serde(flatten)]
175    pub metadata: HashMap<String, serde_json::Value>,
176}
177
178/// Stop reason
179#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
180#[serde(rename_all = "snake_case")]
181pub enum StopReason {
182    EndTurn,
183    MaxTokens,
184    StopSequence,
185    ToolUse,
186    ContentFilter,
187}
188
189/// Token usage statistics
190#[derive(Debug, Clone, Copy, Serialize, Deserialize, Default)]
191pub struct Usage {
192    pub input_tokens: usize,
193    pub output_tokens: usize,
194}
195
196/// Stream chunk
197#[derive(Debug, Clone, Serialize, Deserialize)]
198#[serde(tag = "type", rename_all = "snake_case")]
199pub enum StreamChunk {
200    ContentDelta { delta: String },
201    ToolCall { tool_call: crate::ToolCall },
202    Done { usage: Usage, stop_reason: StopReason },
203}
204
205/// Reference-counted model
206pub type ModelRef = Arc<dyn Model>;
207
208#[cfg(test)]
209mod tests {
210    use super::*;
211
212    #[test]
213    fn test_model_provider_serialization() {
214        let provider = ModelProvider::Anthropic;
215        let json = serde_json::to_string(&provider).unwrap();
216        assert_eq!(json, "\"anthropic\"");
217
218        let provider: ModelProvider = serde_json::from_str("\"openai\"").unwrap();
219        assert_eq!(provider, ModelProvider::OpenAI);
220    }
221
222    #[test]
223    fn test_model_config_defaults() {
224        let json = r#"{
225            "model": "gpt-4",
226            "provider": "openai"
227        }"#;
228        let config: ModelConfig = serde_json::from_str(json).unwrap();
229
230        assert_eq!(config.model, "gpt-4");
231        assert_eq!(config.provider, ModelProvider::OpenAI);
232        assert_eq!(config.temperature, 0.7); // default
233        assert_eq!(config.timeout_secs, 60); // default
234        assert!(config.api_key.is_none());
235    }
236
237    #[test]
238    fn test_model_config_full() {
239        let config = ModelConfig {
240            model: "claude-3-5-sonnet".to_string(),
241            provider: ModelProvider::Anthropic,
242            api_key: Some("test_key".to_string()),
243            endpoint: Some("https://api.anthropic.com".to_string()),
244            temperature: 0.3,
245            max_tokens: Some(4096),
246            timeout_secs: 120,
247            headers: {
248                let mut h = HashMap::new();
249                h.insert("X-Custom".to_string(), "value".to_string());
250                h
251            },
252            extra: HashMap::new(),
253        };
254
255        let json = serde_json::to_string(&config).unwrap();
256        let deserialized: ModelConfig = serde_json::from_str(&json).unwrap();
257
258        assert_eq!(deserialized.model, "claude-3-5-sonnet");
259        assert_eq!(deserialized.temperature, 0.3);
260        assert_eq!(deserialized.max_tokens, Some(4096));
261    }
262
263    #[test]
264    fn test_message_role() {
265        assert_eq!(
266            serde_json::to_string(&MessageRole::User).unwrap(),
267            "\"user\""
268        );
269        assert_eq!(
270            serde_json::to_string(&MessageRole::Assistant).unwrap(),
271            "\"assistant\""
272        );
273        assert_eq!(
274            serde_json::to_string(&MessageRole::System).unwrap(),
275            "\"system\""
276        );
277        assert_eq!(
278            serde_json::to_string(&MessageRole::Tool).unwrap(),
279            "\"tool\""
280        );
281    }
282
283    #[test]
284    fn test_model_request() {
285        let request = ModelRequest {
286            messages: vec![
287                RequestMessage {
288                    role: MessageRole::User,
289                    content: "Hello".to_string(),
290                    tool_calls: None,
291                    tool_call_id: None,
292                },
293            ],
294            system: Some("You are a helpful assistant.".to_string()),
295            tools: vec![],
296            temperature: Some(0.5),
297            max_tokens: Some(1000),
298            stream: false,
299            extra: HashMap::new(),
300        };
301
302        assert_eq!(request.messages.len(), 1);
303        assert_eq!(request.system, Some("You are a helpful assistant.".to_string()));
304    }
305
306    #[test]
307    fn test_stop_reason_serialization() {
308        let end_turn = StopReason::EndTurn;
309        let json = serde_json::to_string(&end_turn).unwrap();
310        assert_eq!(json, "\"end_turn\"");
311
312        let max_tokens: StopReason = serde_json::from_str("\"max_tokens\"").unwrap();
313        assert_eq!(max_tokens, StopReason::MaxTokens);
314    }
315
316    #[test]
317    fn test_usage() {
318        let usage = Usage {
319            input_tokens: 100,
320            output_tokens: 50,
321        };
322
323        assert_eq!(usage.input_tokens, 100);
324        assert_eq!(usage.output_tokens, 50);
325
326        let default_usage = Usage::default();
327        assert_eq!(default_usage.input_tokens, 0);
328        assert_eq!(default_usage.output_tokens, 0);
329    }
330
331    #[test]
332    fn test_stream_chunk_content_delta() {
333        let chunk = StreamChunk::ContentDelta {
334            delta: "Hello".to_string(),
335        };
336
337        let json = serde_json::to_string(&chunk).unwrap();
338        assert!(json.contains("content_delta"));
339        assert!(json.contains("Hello"));
340    }
341
342    #[test]
343    fn test_stream_chunk_done() {
344        let chunk = StreamChunk::Done {
345            usage: Usage {
346                input_tokens: 10,
347                output_tokens: 20,
348            },
349            stop_reason: StopReason::EndTurn,
350        };
351
352        let json = serde_json::to_string(&chunk).unwrap();
353        assert!(json.contains("done"));
354        assert!(json.contains("end_turn"));
355    }
356
357    #[test]
358    fn test_model_response() {
359        let response = ModelResponse {
360            content: "Hello, world!".to_string(),
361            tool_calls: vec![],
362            stop_reason: StopReason::EndTurn,
363            usage: Usage {
364                input_tokens: 5,
365                output_tokens: 3,
366            },
367            metadata: HashMap::new(),
368        };
369
370        assert_eq!(response.content, "Hello, world!");
371        assert!(response.tool_calls.is_empty());
372        assert_eq!(response.stop_reason, StopReason::EndTurn);
373    }
374}