1use async_trait::async_trait;
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::pin::Pin;
5use std::sync::Arc;
6
7use crate::AofResult;
8
9#[async_trait]
13pub trait Model: Send + Sync {
14 async fn generate(&self, request: &ModelRequest) -> AofResult<ModelResponse>;
16
17 async fn generate_stream(
19 &self,
20 request: &ModelRequest,
21 ) -> AofResult<Pin<Box<dyn futures::Stream<Item = AofResult<StreamChunk>> + Send>>>;
22
23 fn config(&self) -> &ModelConfig;
25
26 fn provider(&self) -> ModelProvider;
28
29 fn count_tokens(&self, text: &str) -> usize {
31 text.len() / 4
33 }
34}
35
36#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
38#[serde(rename_all = "lowercase")]
39pub enum ModelProvider {
40 Anthropic,
41 OpenAI,
42 Google,
43 Groq,
44 Bedrock,
45 Azure,
46 Ollama,
47 Custom,
48}
49
50#[derive(Debug, Clone, Serialize, Deserialize)]
52pub struct ModelConfig {
53 pub model: String,
55
56 pub provider: ModelProvider,
58
59 #[serde(skip_serializing_if = "Option::is_none")]
61 pub api_key: Option<String>,
62
63 #[serde(skip_serializing_if = "Option::is_none")]
65 pub endpoint: Option<String>,
66
67 #[serde(default = "default_temperature")]
69 pub temperature: f32,
70
71 #[serde(skip_serializing_if = "Option::is_none")]
73 pub max_tokens: Option<usize>,
74
75 #[serde(default = "default_timeout")]
77 pub timeout_secs: u64,
78
79 #[serde(default)]
81 pub headers: HashMap<String, String>,
82
83 #[serde(flatten)]
85 pub extra: HashMap<String, serde_json::Value>,
86}
87
88fn default_temperature() -> f32 {
89 0.7
90}
91
92fn default_timeout() -> u64 {
93 60
94}
95
96#[derive(Debug, Clone, Serialize, Deserialize)]
98pub struct ModelRequest {
99 pub messages: Vec<RequestMessage>,
101
102 #[serde(skip_serializing_if = "Option::is_none")]
104 pub system: Option<String>,
105
106 #[serde(skip_serializing_if = "Vec::is_empty", default)]
108 pub tools: Vec<ToolDefinition>,
109
110 #[serde(skip_serializing_if = "Option::is_none")]
112 pub temperature: Option<f32>,
113
114 #[serde(skip_serializing_if = "Option::is_none")]
116 pub max_tokens: Option<usize>,
117
118 #[serde(default)]
120 pub stream: bool,
121
122 #[serde(flatten)]
124 pub extra: HashMap<String, serde_json::Value>,
125}
126
127#[derive(Debug, Clone, Serialize, Deserialize)]
129pub struct RequestMessage {
130 pub role: MessageRole,
131 pub content: String,
132 #[serde(skip_serializing_if = "Option::is_none")]
133 pub tool_calls: Option<Vec<crate::ToolCall>>,
134 #[serde(skip_serializing_if = "Option::is_none")]
136 pub tool_call_id: Option<String>,
137}
138
139#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
141#[serde(rename_all = "lowercase")]
142pub enum MessageRole {
143 User,
144 Assistant,
145 System,
146 Tool,
147}
148
149#[derive(Debug, Clone, Serialize, Deserialize)]
151pub struct ToolDefinition {
152 pub name: String,
153 pub description: String,
154 pub parameters: serde_json::Value,
155}
156
157#[derive(Debug, Clone, Serialize, Deserialize)]
159pub struct ModelResponse {
160 pub content: String,
162
163 #[serde(skip_serializing_if = "Vec::is_empty", default)]
165 pub tool_calls: Vec<crate::ToolCall>,
166
167 pub stop_reason: StopReason,
169
170 pub usage: Usage,
172
173 #[serde(flatten)]
175 pub metadata: HashMap<String, serde_json::Value>,
176}
177
178#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
180#[serde(rename_all = "snake_case")]
181pub enum StopReason {
182 EndTurn,
183 MaxTokens,
184 StopSequence,
185 ToolUse,
186 ContentFilter,
187}
188
189#[derive(Debug, Clone, Copy, Serialize, Deserialize, Default)]
191pub struct Usage {
192 pub input_tokens: usize,
193 pub output_tokens: usize,
194}
195
196#[derive(Debug, Clone, Serialize, Deserialize)]
198#[serde(tag = "type", rename_all = "snake_case")]
199pub enum StreamChunk {
200 ContentDelta { delta: String },
201 ToolCall { tool_call: crate::ToolCall },
202 Done { usage: Usage, stop_reason: StopReason },
203}
204
205pub type ModelRef = Arc<dyn Model>;
207
208#[cfg(test)]
209mod tests {
210 use super::*;
211
212 #[test]
213 fn test_model_provider_serialization() {
214 let provider = ModelProvider::Anthropic;
215 let json = serde_json::to_string(&provider).unwrap();
216 assert_eq!(json, "\"anthropic\"");
217
218 let provider: ModelProvider = serde_json::from_str("\"openai\"").unwrap();
219 assert_eq!(provider, ModelProvider::OpenAI);
220 }
221
222 #[test]
223 fn test_model_config_defaults() {
224 let json = r#"{
225 "model": "gpt-4",
226 "provider": "openai"
227 }"#;
228 let config: ModelConfig = serde_json::from_str(json).unwrap();
229
230 assert_eq!(config.model, "gpt-4");
231 assert_eq!(config.provider, ModelProvider::OpenAI);
232 assert_eq!(config.temperature, 0.7); assert_eq!(config.timeout_secs, 60); assert!(config.api_key.is_none());
235 }
236
237 #[test]
238 fn test_model_config_full() {
239 let config = ModelConfig {
240 model: "claude-3-5-sonnet".to_string(),
241 provider: ModelProvider::Anthropic,
242 api_key: Some("test_key".to_string()),
243 endpoint: Some("https://api.anthropic.com".to_string()),
244 temperature: 0.3,
245 max_tokens: Some(4096),
246 timeout_secs: 120,
247 headers: {
248 let mut h = HashMap::new();
249 h.insert("X-Custom".to_string(), "value".to_string());
250 h
251 },
252 extra: HashMap::new(),
253 };
254
255 let json = serde_json::to_string(&config).unwrap();
256 let deserialized: ModelConfig = serde_json::from_str(&json).unwrap();
257
258 assert_eq!(deserialized.model, "claude-3-5-sonnet");
259 assert_eq!(deserialized.temperature, 0.3);
260 assert_eq!(deserialized.max_tokens, Some(4096));
261 }
262
263 #[test]
264 fn test_message_role() {
265 assert_eq!(
266 serde_json::to_string(&MessageRole::User).unwrap(),
267 "\"user\""
268 );
269 assert_eq!(
270 serde_json::to_string(&MessageRole::Assistant).unwrap(),
271 "\"assistant\""
272 );
273 assert_eq!(
274 serde_json::to_string(&MessageRole::System).unwrap(),
275 "\"system\""
276 );
277 assert_eq!(
278 serde_json::to_string(&MessageRole::Tool).unwrap(),
279 "\"tool\""
280 );
281 }
282
283 #[test]
284 fn test_model_request() {
285 let request = ModelRequest {
286 messages: vec![
287 RequestMessage {
288 role: MessageRole::User,
289 content: "Hello".to_string(),
290 tool_calls: None,
291 tool_call_id: None,
292 },
293 ],
294 system: Some("You are a helpful assistant.".to_string()),
295 tools: vec![],
296 temperature: Some(0.5),
297 max_tokens: Some(1000),
298 stream: false,
299 extra: HashMap::new(),
300 };
301
302 assert_eq!(request.messages.len(), 1);
303 assert_eq!(request.system, Some("You are a helpful assistant.".to_string()));
304 }
305
306 #[test]
307 fn test_stop_reason_serialization() {
308 let end_turn = StopReason::EndTurn;
309 let json = serde_json::to_string(&end_turn).unwrap();
310 assert_eq!(json, "\"end_turn\"");
311
312 let max_tokens: StopReason = serde_json::from_str("\"max_tokens\"").unwrap();
313 assert_eq!(max_tokens, StopReason::MaxTokens);
314 }
315
316 #[test]
317 fn test_usage() {
318 let usage = Usage {
319 input_tokens: 100,
320 output_tokens: 50,
321 };
322
323 assert_eq!(usage.input_tokens, 100);
324 assert_eq!(usage.output_tokens, 50);
325
326 let default_usage = Usage::default();
327 assert_eq!(default_usage.input_tokens, 0);
328 assert_eq!(default_usage.output_tokens, 0);
329 }
330
331 #[test]
332 fn test_stream_chunk_content_delta() {
333 let chunk = StreamChunk::ContentDelta {
334 delta: "Hello".to_string(),
335 };
336
337 let json = serde_json::to_string(&chunk).unwrap();
338 assert!(json.contains("content_delta"));
339 assert!(json.contains("Hello"));
340 }
341
342 #[test]
343 fn test_stream_chunk_done() {
344 let chunk = StreamChunk::Done {
345 usage: Usage {
346 input_tokens: 10,
347 output_tokens: 20,
348 },
349 stop_reason: StopReason::EndTurn,
350 };
351
352 let json = serde_json::to_string(&chunk).unwrap();
353 assert!(json.contains("done"));
354 assert!(json.contains("end_turn"));
355 }
356
357 #[test]
358 fn test_model_response() {
359 let response = ModelResponse {
360 content: "Hello, world!".to_string(),
361 tool_calls: vec![],
362 stop_reason: StopReason::EndTurn,
363 usage: Usage {
364 input_tokens: 5,
365 output_tokens: 3,
366 },
367 metadata: HashMap::new(),
368 };
369
370 assert_eq!(response.content, "Hello, world!");
371 assert!(response.tool_calls.is_empty());
372 assert_eq!(response.stop_reason, StopReason::EndTurn);
373 }
374}