1use crate::models::{CompletionRequest, CompletionResponse, Message, Model, Tool};
2use anyhow::Result;
3use serde::Serialize;
4use serde_json::json;
5
6#[derive(Clone)]
7pub struct Client {
8 api_key: String,
9 model: Model,
10 client: reqwest::Client,
11 base_url: Option<String>,
12}
13
14#[derive(Debug, Serialize)]
15pub struct StructuredResponse {
16 pub tool_call: bool,
17 pub content: String,
18}
19
20impl Client {
21 pub fn new(api_key: String, model: Model) -> Self {
22 Self {
23 api_key,
24 model,
25 client: reqwest::Client::new(),
26 base_url: None,
27 }
28 }
29
30 pub fn with_model(mut self, model: Model) -> Self {
31 self.model = model;
32 self
33 }
34
35 #[cfg(test)]
36 pub fn with_base_url(mut self, base_url: String) -> Self {
37 self.base_url = Some(base_url);
38 self
39 }
40
41 fn get_base_url(&self) -> String {
42 if let Some(url) = &self.base_url {
43 url.clone()
44 } else {
45 match self.model {
46 Model::OpenAI(_) => "https://api.openai.com".to_string(),
47 Model::Anthropic(_) => "https://api.anthropic.com".to_string(),
48 }
49 }
50 }
51
52 pub async fn send_prompt_with_tools(
53 &self,
54 prompt: Option<String>,
55 mut history: Vec<Message>,
56 mut tools: Vec<Tool>,
57 ) -> Result<StructuredResponse> {
58 println!("Sending prompt with tools");
59 if let Some(prompt) = prompt {
61 history.push(Message {
62 role: "user".to_string(),
63 content: Some(prompt),
64 tool_calls: None,
65 tool_call_id: None,
66 });
67 }
68
69 for tool in &mut tools {
71 if let Some(properties) = tool.function.parameters.get_mut("properties") {
72 if let Some(obj) = properties.as_object_mut() {
73 for (_, value) in obj.iter_mut() {
74 if let Some(param_obj) = value.as_object_mut() {
75 if param_obj.get("type").and_then(|t| t.as_str()) == Some("array") {
76 if !param_obj.contains_key("items") {
78 param_obj.insert(
79 "items".to_string(),
80 json!({
81 "type": "string" }),
83 );
84 }
85 }
86 }
87 }
88 }
89 }
90 }
91
92 let request = CompletionRequest {
93 model: self.model.as_str().to_string(),
94 messages: history,
95 temperature: Some(0.7),
96 tool_choice: match tools.is_empty() {
97 true => None,
98 false => Some("auto".to_string()),
99 },
100 parallel_tool_calls: match tools.is_empty() {
101 true => None,
102 false => Some(true),
103 },
104 tools: match tools.is_empty() {
105 true => None,
106 false => Some(tools),
107 },
108 ..Default::default()
109 };
110
111 let endpoint = match self.model {
112 Model::OpenAI(_) => "/v1/chat/completions",
113 Model::Anthropic(_) => "/v1/messages",
114 };
115
116 let response = self
117 .client
118 .post(format!("{}{}", self.get_base_url(), endpoint))
119 .header("Authorization", format!("Bearer {}", self.api_key))
120 .header("Content-Type", "application/json")
121 .header(
122 "anthropic-version",
123 if matches!(self.model, Model::Anthropic(_)) {
124 "2023-06-01"
125 } else {
126 ""
127 },
128 )
129 .json(&request)
130 .send()
131 .await?;
132
133 let text = response.text().await?;
134 let completion: CompletionResponse = serde_json::from_str(&text).unwrap();
135 let first_choice = completion
137 .choices
138 .first()
139 .ok_or_else(|| anyhow::anyhow!("No completion choices returned from the API"))?;
140
141 match &first_choice.message.tool_calls {
142 Some(tool_calls) if !tool_calls.is_empty() => Ok(StructuredResponse {
143 tool_call: true,
144 content: serde_json::to_string(&tool_calls)?,
145 }),
146 _ => Ok(StructuredResponse {
147 tool_call: false,
148 content: first_choice
149 .message
150 .content
151 .as_ref()
152 .unwrap_or(&"".to_string())
153 .clone(),
154 }),
155 }
156 }
157}
158
159#[cfg(test)]
160mod tests {
161 use super::*;
162 use crate::models::{
163 AnthropicModel, Choice, FunctionDefinition, OpenAIModel, ToolCall, ToolDefinition,
164 };
165 use mockito;
166 use serde_json::json;
167
168 #[tokio::test]
169 async fn test_new_client() {
170 let client = Client::new(
171 "test-key".to_string(),
172 Model::OpenAI(OpenAIModel::GPT35Turbo),
173 );
174 assert_eq!(client.api_key, "test-key");
175 assert!(matches!(
176 client.model,
177 Model::OpenAI(OpenAIModel::GPT35Turbo)
178 ));
179 }
180
181 #[tokio::test]
182 async fn test_with_model() {
183 let client = Client::new(
184 "test-key".to_string(),
185 Model::OpenAI(OpenAIModel::GPT35Turbo),
186 )
187 .with_model(Model::Anthropic(AnthropicModel::Claude3Sonnet));
188
189 assert!(matches!(
190 client.model,
191 Model::Anthropic(AnthropicModel::Claude3Sonnet)
192 ));
193 }
194
195 #[tokio::test]
196 async fn test_send_prompt_with_tools() {
197 let mut server = mockito::Server::new_async().await;
198 let url = server.url();
199
200 let mock = server
201 .mock("POST", "/v1/chat/completions")
202 .with_status(200)
203 .with_header("content-type", "application/json")
204 .with_body(
205 json!({
206 "id": "chatcmpl-123",
207 "object": "chat.completion",
208 "created": 1677652288,
209 "choices": [{
210 "index": 0,
211 "message": {
212 "role": "assistant",
213 "content": "Hello! How can I help you today?",
214 "tool_calls": null
215 },
216 "finish_reason": "stop"
217 }]
218 })
219 .to_string(),
220 )
221 .create();
222
223 let client = Client::new(
224 "test-key".to_string(),
225 Model::OpenAI(OpenAIModel::GPT35Turbo),
226 )
227 .with_base_url(url);
228
229 let history = vec![Message {
230 role: "system".to_string(),
231 content: Some("You are a helpful assistant.".to_string()),
232 tool_calls: None,
233 tool_call_id: None,
234 }];
235
236 let tools = vec![]; let result = client
239 .send_prompt_with_tools(Some("Hello!".to_string()), history, tools)
240 .await
241 .unwrap();
242
243 assert_eq!(result.content, "Hello! How can I help you today?");
244 mock.assert();
245 }
246
247 #[tokio::test]
248 async fn test_send_prompt_with_tool_call_response() {
249 let mut server = mockito::Server::new_async().await;
250 let url = server.url();
251
252 let mock = server
253 .mock("POST", "/v1/chat/completions")
254 .with_status(200)
255 .with_header("content-type", "application/json")
256 .with_body(
257 serde_json::to_value(CompletionResponse {
258 id: "chatcmpl-123".to_string(),
259 choices: vec![Choice {
260 index: 0,
261 message: Message {
262 role: "assistant".to_string(),
263 content: Some("Hello! How can I help you today?".to_string()),
264 tool_calls: Some(vec![ToolCall {
265 id: "call_123".to_string(),
266 tool_type: "function".to_string(),
267 function: ToolDefinition {
268 name: "calculator".to_string(),
269 arguments: "{\"a\":5,\"b\":3,\"operation\":\"add\"}"
270 .to_string(),
271 },
272 }]),
273 tool_call_id: None,
274 },
275 finish_reason: "stop".to_string(),
276 }],
277 })
278 .unwrap()
279 .to_string(),
280 )
281 .create();
282
283 let client = Client::new(
284 "test-key".to_string(),
285 Model::OpenAI(OpenAIModel::GPT35Turbo),
286 )
287 .with_base_url(url);
288
289 let history = vec![Message {
290 role: "system".to_string(),
291 content: Some("You are a helpful assistant.".to_string()),
292 tool_calls: None,
293 tool_call_id: None,
294 }];
295
296 let tools = vec![Tool {
297 tool_type: "function".to_string(),
298 function: FunctionDefinition {
299 name: "calculator".to_string(),
300 description: "Calculate two numbers".to_string(),
301 parameters: json!({
302 "type": "object",
303 "properties": {
304 "a": {"type": "number"},
305 "b": {"type": "number"},
306 "operation": {"type": "string"}
307 },
308 "required": ["a", "b", "operation"]
309 }),
310 },
311 }];
312
313 let result = client
314 .send_prompt_with_tools(Some("Calculate 5 plus 3".to_string()), history, tools)
315 .await
316 .unwrap();
317
318 assert!(result.content.contains("calculator"));
320 assert!(result.content.contains("add"));
321 mock.assert();
322 }
323
324 #[tokio::test]
325 async fn test_model_string_conversion() {
326 assert_eq!(Model::OpenAI(OpenAIModel::GPT4).as_str(), "gpt-4");
327 assert_eq!(
328 Model::OpenAI(OpenAIModel::GPT35Turbo).as_str(),
329 "gpt-3.5-turbo"
330 );
331 assert_eq!(
332 Model::Anthropic(AnthropicModel::Claude3Sonnet).as_str(),
333 "claude-3-sonnet"
334 );
335 }
336
337 #[tokio::test]
338 async fn test_base_url_selection() {
339 let openai_client = Client::new(
340 "test-key".to_string(),
341 Model::OpenAI(OpenAIModel::GPT35Turbo),
342 );
343 assert_eq!(openai_client.get_base_url(), "https://api.openai.com");
344
345 let anthropic_client = Client::new(
346 "test-key".to_string(),
347 Model::Anthropic(AnthropicModel::Claude3Sonnet),
348 );
349 assert_eq!(anthropic_client.get_base_url(), "https://api.anthropic.com");
350 }
351}