1use async_trait::async_trait;
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use crate::errors::ModelError;
7
8#[async_trait]
10pub trait LanguageModel: Send + Sync {
11 async fn complete(&self, prompt: &str) -> Result<ModelResponse, ModelError>;
12 async fn complete_with_tools(&self, prompt: &str, tools: &[ToolDefinition]) -> Result<ModelResponse, ModelError>;
13 fn model_name(&self) -> &str;
14 fn supports_tools(&self) -> bool;
15}
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct ModelResponse {
20 pub content: String,
21 pub tool_calls: Vec<ToolCall>,
22 pub usage: Option<TokenUsage>,
23 pub metadata: HashMap<String, serde_json::Value>,
24}
25
26impl ModelResponse {
27 pub fn text(content: String) -> Self {
28 Self {
29 content,
30 tool_calls: vec![],
31 usage: None,
32 metadata: HashMap::new(),
33 }
34 }
35}
36
37#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct ToolCall {
40 pub name: String,
41 pub arguments: HashMap<String, serde_json::Value>,
42 pub id: Option<String>,
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct ToolDefinition {
48 pub name: String,
49 pub description: String,
50 pub parameters: serde_json::Value,
51}
52
53#[derive(Debug, Clone, Serialize, Deserialize)]
55pub struct TokenUsage {
56 pub prompt_tokens: u32,
57 pub completion_tokens: u32,
58 pub total_tokens: u32,
59}
60
61#[derive(Debug, Clone)]
63pub struct ModelCapabilities {
64 pub max_tokens: u32,
65 pub supports_tools: bool,
66 pub supports_streaming: bool,
67 pub supports_vision: bool,
68}
69
70impl Default for ModelCapabilities {
71 fn default() -> Self {
72 Self {
73 max_tokens: 4096,
74 supports_tools: false,
75 supports_streaming: false,
76 supports_vision: false,
77 }
78 }
79}
80
81pub struct MockModel {
83 name: String,
84}
85
86impl MockModel {
87 pub fn new(name: String) -> Self {
88 Self { name }
89 }
90}
91
92#[async_trait]
93impl LanguageModel for MockModel {
94 async fn complete(&self, prompt: &str) -> Result<ModelResponse, ModelError> {
95 Ok(ModelResponse::text(format!("Mock response from {}: {}", self.name, prompt)))
96 }
97
98 async fn complete_with_tools(&self, prompt: &str, _tools: &[ToolDefinition]) -> Result<ModelResponse, ModelError> {
99 let response = format!("Mock response with tools from {}: {}", self.name, prompt);
100 Ok(ModelResponse {
101 content: response,
102 tool_calls: vec![],
103 usage: None,
104 metadata: HashMap::new(),
105 })
106 }
107
108 fn model_name(&self) -> &str {
109 &self.name
110 }
111
112 fn supports_tools(&self) -> bool {
113 true
114 }
115}
116
117pub struct OpenAIModel {
119 _api_key: String,
120 model: String,
121}
122
123impl OpenAIModel {
124 pub fn new(api_key: String, model: String) -> Self {
125 Self { _api_key: api_key, model }
126 }
127}
128
129#[async_trait]
130impl LanguageModel for OpenAIModel {
131 async fn complete(&self, prompt: &str) -> Result<ModelResponse, ModelError> {
132 Ok(ModelResponse::text(format!("OpenAI response from {}: {}", self.model, prompt)))
134 }
135
136 async fn complete_with_tools(&self, prompt: &str, _tools: &[ToolDefinition]) -> Result<ModelResponse, ModelError> {
137 Ok(ModelResponse::text(format!("OpenAI response with tools from {}: {}", self.model, prompt)))
139 }
140
141 fn model_name(&self) -> &str {
142 &self.model
143 }
144
145 fn supports_tools(&self) -> bool {
146 true
147 }
148}
149
150pub struct ZhipuModel {
152 _api_key: String,
153 model: String,
154 endpoint: Option<String>,
155}
156
157impl ZhipuModel {
158 pub fn new(api_key: String, model: String, endpoint: Option<String>) -> Self {
159 Self { _api_key: api_key, model, endpoint }
160 }
161}
162
163#[async_trait]
164impl LanguageModel for ZhipuModel {
165 async fn complete(&self, prompt: &str) -> Result<ModelResponse, ModelError> {
166 let client = reqwest::Client::new();
167 let endpoint = self.endpoint.as_ref()
168 .map(|e| format!("{}/chat/completions", e))
169 .unwrap_or_else(|| "https://open.bigmodel.cn/api/paas/v4/chat/completions".to_string());
170
171 let request_body = serde_json::json!({
172 "model": self.model,
173 "messages": [
174 {
175 "role": "user",
176 "content": prompt
177 }
178 ],
179 "max_tokens": 4000,
180 "temperature": 0.7
181 });
182
183 let response = client
184 .post(&endpoint)
185 .header("Authorization", format!("Bearer {}", self._api_key))
186 .header("Content-Type", "application/json")
187 .json(&request_body)
188 .send()
189 .await
190 .map_err(|e| ModelError::APIError(e.to_string()))?;
191
192 if !response.status().is_success() {
193 let status = response.status();
194 let error_text = response.text().await.unwrap_or_default();
195 return Err(ModelError::APIError(format!("API request failed with status {}: {}", status, error_text)));
196 }
197
198 let response_json: serde_json::Value = response
199 .json()
200 .await
201 .map_err(|e| ModelError::APIError(e.to_string()))?;
202
203 let content = response_json["choices"][0]["message"]["content"]
204 .as_str()
205 .ok_or_else(|| ModelError::APIError("Invalid response format".to_string()))?;
206
207 Ok(ModelResponse::text(content.to_string()))
208 }
209
210 async fn complete_with_tools(&self, prompt: &str, _tools: &[ToolDefinition]) -> Result<ModelResponse, ModelError> {
211 self.complete(prompt).await
213 }
214
215 fn model_name(&self) -> &str {
216 &self.model
217 }
218
219 fn supports_tools(&self) -> bool {
220 true
221 }
222}
223
224pub struct AnthropicModel {
226 _api_key: String,
227 model: String,
228}
229
230impl AnthropicModel {
231 pub fn new(api_key: String, model: String) -> Self {
232 Self { _api_key: api_key, model }
233 }
234}
235
236#[async_trait]
237impl LanguageModel for AnthropicModel {
238 async fn complete(&self, prompt: &str) -> Result<ModelResponse, ModelError> {
239 Ok(ModelResponse::text(format!("Anthropic response from {}: {}", self.model, prompt)))
241 }
242
243 async fn complete_with_tools(&self, prompt: &str, _tools: &[ToolDefinition]) -> Result<ModelResponse, ModelError> {
244 Ok(ModelResponse::text(format!("Anthropic response with tools from {}: {}", self.model, prompt)))
246 }
247
248 fn model_name(&self) -> &str {
249 &self.model
250 }
251
252 fn supports_tools(&self) -> bool {
253 true
254 }
255}
256
257pub struct LocalModel {
259 _endpoint: String,
260 model: String,
261}
262
263impl LocalModel {
264 pub fn new(endpoint: String, model: String) -> Self {
265 Self { _endpoint: endpoint, model }
266 }
267}
268
269#[async_trait]
270impl LanguageModel for LocalModel {
271 async fn complete(&self, prompt: &str) -> Result<ModelResponse, ModelError> {
272 Ok(ModelResponse::text(format!("Local model response from {}: {}", self.model, prompt)))
274 }
275
276 async fn complete_with_tools(&self, prompt: &str, _tools: &[ToolDefinition]) -> Result<ModelResponse, ModelError> {
277 Ok(ModelResponse::text(format!("Local model response with tools from {}: {}", self.model, prompt)))
279 }
280
281 fn model_name(&self) -> &str {
282 &self.model
283 }
284
285 fn supports_tools(&self) -> bool {
286 false
287 }
288}