ai_agent/
models.rs

1//! Language model abstraction and implementations
2
3use async_trait::async_trait;
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use crate::errors::ModelError;
7
8/// Language model trait
9#[async_trait]
10pub trait LanguageModel: Send + Sync {
11    async fn complete(&self, prompt: &str) -> Result<ModelResponse, ModelError>;
12    async fn complete_with_tools(&self, prompt: &str, tools: &[ToolDefinition]) -> Result<ModelResponse, ModelError>;
13    fn model_name(&self) -> &str;
14    fn supports_tools(&self) -> bool;
15}
16
17/// Model response
18#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct ModelResponse {
20    pub content: String,
21    pub tool_calls: Vec<ToolCall>,
22    pub usage: Option<TokenUsage>,
23    pub metadata: HashMap<String, serde_json::Value>,
24}
25
26impl ModelResponse {
27    pub fn text(content: String) -> Self {
28        Self {
29            content,
30            tool_calls: vec![],
31            usage: None,
32            metadata: HashMap::new(),
33        }
34    }
35}
36
37/// Tool call
38#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct ToolCall {
40    pub name: String,
41    pub arguments: HashMap<String, serde_json::Value>,
42    pub id: Option<String>,
43}
44
45/// Tool definition
46#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct ToolDefinition {
48    pub name: String,
49    pub description: String,
50    pub parameters: serde_json::Value,
51}
52
53/// Token usage information
54#[derive(Debug, Clone, Serialize, Deserialize)]
55pub struct TokenUsage {
56    pub prompt_tokens: u32,
57    pub completion_tokens: u32,
58    pub total_tokens: u32,
59}
60
61/// Model capabilities
62#[derive(Debug, Clone)]
63pub struct ModelCapabilities {
64    pub max_tokens: u32,
65    pub supports_tools: bool,
66    pub supports_streaming: bool,
67    pub supports_vision: bool,
68}
69
70impl Default for ModelCapabilities {
71    fn default() -> Self {
72        Self {
73            max_tokens: 4096,
74            supports_tools: false,
75            supports_streaming: false,
76            supports_vision: false,
77        }
78    }
79}
80
81// Mock model for testing
82pub struct MockModel {
83    name: String,
84}
85
86impl MockModel {
87    pub fn new(name: String) -> Self {
88        Self { name }
89    }
90}
91
92#[async_trait]
93impl LanguageModel for MockModel {
94    async fn complete(&self, prompt: &str) -> Result<ModelResponse, ModelError> {
95        Ok(ModelResponse::text(format!("Mock response from {}: {}", self.name, prompt)))
96    }
97
98    async fn complete_with_tools(&self, prompt: &str, _tools: &[ToolDefinition]) -> Result<ModelResponse, ModelError> {
99        let response = format!("Mock response with tools from {}: {}", self.name, prompt);
100        Ok(ModelResponse {
101            content: response,
102            tool_calls: vec![],
103            usage: None,
104            metadata: HashMap::new(),
105        })
106    }
107
108    fn model_name(&self) -> &str {
109        &self.name
110    }
111
112    fn supports_tools(&self) -> bool {
113        true
114    }
115}
116
117// OpenAI model implementation (placeholder)
118pub struct OpenAIModel {
119    _api_key: String,
120    model: String,
121}
122
123impl OpenAIModel {
124    pub fn new(api_key: String, model: String) -> Self {
125        Self { _api_key: api_key, model }
126    }
127}
128
129#[async_trait]
130impl LanguageModel for OpenAIModel {
131    async fn complete(&self, prompt: &str) -> Result<ModelResponse, ModelError> {
132        // TODO: Implement OpenAI API call
133        Ok(ModelResponse::text(format!("OpenAI response from {}: {}", self.model, prompt)))
134    }
135
136    async fn complete_with_tools(&self, prompt: &str, _tools: &[ToolDefinition]) -> Result<ModelResponse, ModelError> {
137        // TODO: Implement OpenAI API call with tools
138        Ok(ModelResponse::text(format!("OpenAI response with tools from {}: {}", self.model, prompt)))
139    }
140
141    fn model_name(&self) -> &str {
142        &self.model
143    }
144
145    fn supports_tools(&self) -> bool {
146        true
147    }
148}
149
150// Zhipu model implementation (placeholder)
151pub struct ZhipuModel {
152    _api_key: String,
153    model: String,
154    endpoint: Option<String>,
155}
156
157impl ZhipuModel {
158    pub fn new(api_key: String, model: String, endpoint: Option<String>) -> Self {
159        Self { _api_key: api_key, model, endpoint }
160    }
161}
162
163#[async_trait]
164impl LanguageModel for ZhipuModel {
165    async fn complete(&self, prompt: &str) -> Result<ModelResponse, ModelError> {
166        let client = reqwest::Client::new();
167        let endpoint = self.endpoint.as_ref()
168            .map(|e| format!("{}/chat/completions", e))
169            .unwrap_or_else(|| "https://open.bigmodel.cn/api/paas/v4/chat/completions".to_string());
170
171        let request_body = serde_json::json!({
172            "model": self.model,
173            "messages": [
174                {
175                    "role": "user",
176                    "content": prompt
177                }
178            ],
179            "max_tokens": 4000,
180            "temperature": 0.7
181        });
182
183        let response = client
184            .post(&endpoint)
185            .header("Authorization", format!("Bearer {}", self._api_key))
186            .header("Content-Type", "application/json")
187            .json(&request_body)
188            .send()
189            .await
190            .map_err(|e| ModelError::APIError(e.to_string()))?;
191
192        if !response.status().is_success() {
193            let status = response.status();
194            let error_text = response.text().await.unwrap_or_default();
195            return Err(ModelError::APIError(format!("API request failed with status {}: {}", status, error_text)));
196        }
197
198        let response_json: serde_json::Value = response
199            .json()
200            .await
201            .map_err(|e| ModelError::APIError(e.to_string()))?;
202
203        let content = response_json["choices"][0]["message"]["content"]
204            .as_str()
205            .ok_or_else(|| ModelError::APIError("Invalid response format".to_string()))?;
206
207        Ok(ModelResponse::text(content.to_string()))
208    }
209
210    async fn complete_with_tools(&self, prompt: &str, _tools: &[ToolDefinition]) -> Result<ModelResponse, ModelError> {
211        // For now, just use the regular complete method
212        self.complete(prompt).await
213    }
214
215    fn model_name(&self) -> &str {
216        &self.model
217    }
218
219    fn supports_tools(&self) -> bool {
220        true
221    }
222}
223
224// Anthropic model implementation (placeholder)
225pub struct AnthropicModel {
226    _api_key: String,
227    model: String,
228}
229
230impl AnthropicModel {
231    pub fn new(api_key: String, model: String) -> Self {
232        Self { _api_key: api_key, model }
233    }
234}
235
236#[async_trait]
237impl LanguageModel for AnthropicModel {
238    async fn complete(&self, prompt: &str) -> Result<ModelResponse, ModelError> {
239        // TODO: Implement Anthropic API call
240        Ok(ModelResponse::text(format!("Anthropic response from {}: {}", self.model, prompt)))
241    }
242
243    async fn complete_with_tools(&self, prompt: &str, _tools: &[ToolDefinition]) -> Result<ModelResponse, ModelError> {
244        // TODO: Implement Anthropic API call with tools
245        Ok(ModelResponse::text(format!("Anthropic response with tools from {}: {}", self.model, prompt)))
246    }
247
248    fn model_name(&self) -> &str {
249        &self.model
250    }
251
252    fn supports_tools(&self) -> bool {
253        true
254    }
255}
256
257// Local model implementation (placeholder)
258pub struct LocalModel {
259    _endpoint: String,
260    model: String,
261}
262
263impl LocalModel {
264    pub fn new(endpoint: String, model: String) -> Self {
265        Self { _endpoint: endpoint, model }
266    }
267}
268
269#[async_trait]
270impl LanguageModel for LocalModel {
271    async fn complete(&self, prompt: &str) -> Result<ModelResponse, ModelError> {
272        // TODO: Implement local model API call
273        Ok(ModelResponse::text(format!("Local model response from {}: {}", self.model, prompt)))
274    }
275
276    async fn complete_with_tools(&self, prompt: &str, _tools: &[ToolDefinition]) -> Result<ModelResponse, ModelError> {
277        // TODO: Implement local model API call with tools
278        Ok(ModelResponse::text(format!("Local model response with tools from {}: {}", self.model, prompt)))
279    }
280
281    fn model_name(&self) -> &str {
282        &self.model
283    }
284
285    fn supports_tools(&self) -> bool {
286        false
287    }
288}