ferrox_openai_api/
completions.rs

1use crate::models::{CompletionRequest, CompletionResponse, Message, Model, Tool};
2use anyhow::Result;
3use serde::Serialize;
4use serde_json::json;
5
6#[derive(Clone)]
7pub struct Client {
8    api_key: String,
9    model: Model,
10    client: reqwest::Client,
11    base_url: Option<String>,
12}
13
14#[derive(Debug, Serialize)]
15pub struct StructuredResponse {
16    pub tool_call: bool,
17    pub content: String,
18}
19
20impl Client {
21    pub fn new(api_key: String, model: Model) -> Self {
22        Self {
23            api_key,
24            model,
25            client: reqwest::Client::new(),
26            base_url: None,
27        }
28    }
29
30    pub fn with_model(mut self, model: Model) -> Self {
31        self.model = model;
32        self
33    }
34
35    #[cfg(test)]
36    pub fn with_base_url(mut self, base_url: String) -> Self {
37        self.base_url = Some(base_url);
38        self
39    }
40
41    fn get_base_url(&self) -> String {
42        if let Some(url) = &self.base_url {
43            url.clone()
44        } else {
45            match self.model {
46                Model::OpenAI(_) => "https://api.openai.com".to_string(),
47                Model::Anthropic(_) => "https://api.anthropic.com".to_string(),
48            }
49        }
50    }
51
52    pub async fn send_prompt_with_tools(
53        &self,
54        prompt: Option<String>,
55        mut history: Vec<Message>,
56        mut tools: Vec<Tool>,
57    ) -> Result<StructuredResponse> {
58        println!("Sending prompt with tools");
59        // Add the user's prompt to the message history
60        if let Some(prompt) = prompt {
61            history.push(Message {
62                role: "user".to_string(),
63                content: Some(prompt),
64                tool_calls: None,
65                tool_call_id: None,
66            });
67        }
68
69        // Process array parameters in tools
70        for tool in &mut tools {
71            if let Some(properties) = tool.function.parameters.get_mut("properties") {
72                if let Some(obj) = properties.as_object_mut() {
73                    for (_, value) in obj.iter_mut() {
74                        if let Some(param_obj) = value.as_object_mut() {
75                            if param_obj.get("type").and_then(|t| t.as_str()) == Some("array") {
76                                // Add items field for array type if not present
77                                if !param_obj.contains_key("items") {
78                                    param_obj.insert(
79                                        "items".to_string(),
80                                        json!({
81                                            "type": "string"  // Default to string array
82                                        }),
83                                    );
84                                }
85                            }
86                        }
87                    }
88                }
89            }
90        }
91
92        let request = CompletionRequest {
93            model: self.model.as_str().to_string(),
94            messages: history,
95            temperature: Some(0.7),
96            tool_choice: match tools.is_empty() {
97                true => None,
98                false => Some("auto".to_string()),
99            },
100            parallel_tool_calls: match tools.is_empty() {
101                true => None,
102                false => Some(true),
103            },
104            tools: match tools.is_empty() {
105                true => None,
106                false => Some(tools),
107            },
108            ..Default::default()
109        };
110
111        let endpoint = match self.model {
112            Model::OpenAI(_) => "/v1/chat/completions",
113            Model::Anthropic(_) => "/v1/messages",
114        };
115
116        let response = self
117            .client
118            .post(format!("{}{}", self.get_base_url(), endpoint))
119            .header("Authorization", format!("Bearer {}", self.api_key))
120            .header("Content-Type", "application/json")
121            .header(
122                "anthropic-version",
123                if matches!(self.model, Model::Anthropic(_)) {
124                    "2023-06-01"
125                } else {
126                    ""
127                },
128            )
129            .json(&request)
130            .send()
131            .await?;
132
133        let text = response.text().await?;
134        let completion: CompletionResponse = serde_json::from_str(&text).unwrap();
135        // Handle both regular responses and tool calls
136        let first_choice = completion
137            .choices
138            .first()
139            .ok_or_else(|| anyhow::anyhow!("No completion choices returned from the API"))?;
140
141        match &first_choice.message.tool_calls {
142            Some(tool_calls) if !tool_calls.is_empty() => Ok(StructuredResponse {
143                tool_call: true,
144                content: serde_json::to_string(&tool_calls)?,
145            }),
146            _ => Ok(StructuredResponse {
147                tool_call: false,
148                content: first_choice
149                    .message
150                    .content
151                    .as_ref()
152                    .unwrap_or(&"".to_string())
153                    .clone(),
154            }),
155        }
156    }
157}
158
159#[cfg(test)]
160mod tests {
161    use super::*;
162    use crate::models::{
163        AnthropicModel, Choice, FunctionDefinition, OpenAIModel, ToolCall, ToolDefinition,
164    };
165    use mockito;
166    use serde_json::json;
167
168    #[tokio::test]
169    async fn test_new_client() {
170        let client = Client::new(
171            "test-key".to_string(),
172            Model::OpenAI(OpenAIModel::GPT35Turbo),
173        );
174        assert_eq!(client.api_key, "test-key");
175        assert!(matches!(
176            client.model,
177            Model::OpenAI(OpenAIModel::GPT35Turbo)
178        ));
179    }
180
181    #[tokio::test]
182    async fn test_with_model() {
183        let client = Client::new(
184            "test-key".to_string(),
185            Model::OpenAI(OpenAIModel::GPT35Turbo),
186        )
187        .with_model(Model::Anthropic(AnthropicModel::Claude3Sonnet));
188
189        assert!(matches!(
190            client.model,
191            Model::Anthropic(AnthropicModel::Claude3Sonnet)
192        ));
193    }
194
195    #[tokio::test]
196    async fn test_send_prompt_with_tools() {
197        let mut server = mockito::Server::new_async().await;
198        let url = server.url();
199
200        let mock = server
201            .mock("POST", "/v1/chat/completions")
202            .with_status(200)
203            .with_header("content-type", "application/json")
204            .with_body(
205                json!({
206                    "id": "chatcmpl-123",
207                    "object": "chat.completion",
208                    "created": 1677652288,
209                    "choices": [{
210                        "index": 0,
211                        "message": {
212                            "role": "assistant",
213                            "content": "Hello! How can I help you today?",
214                            "tool_calls": null
215                        },
216                        "finish_reason": "stop"
217                    }]
218                })
219                .to_string(),
220            )
221            .create();
222
223        let client = Client::new(
224            "test-key".to_string(),
225            Model::OpenAI(OpenAIModel::GPT35Turbo),
226        )
227        .with_base_url(url);
228
229        let history = vec![Message {
230            role: "system".to_string(),
231            content: Some("You are a helpful assistant.".to_string()),
232            tool_calls: None,
233            tool_call_id: None,
234        }];
235
236        let tools = vec![]; // Empty tools for this test
237
238        let result = client
239            .send_prompt_with_tools(Some("Hello!".to_string()), history, tools)
240            .await
241            .unwrap();
242
243        assert_eq!(result.content, "Hello! How can I help you today?");
244        mock.assert();
245    }
246
247    #[tokio::test]
248    async fn test_send_prompt_with_tool_call_response() {
249        let mut server = mockito::Server::new_async().await;
250        let url = server.url();
251
252        let mock = server
253            .mock("POST", "/v1/chat/completions")
254            .with_status(200)
255            .with_header("content-type", "application/json")
256            .with_body(
257                serde_json::to_value(CompletionResponse {
258                    id: "chatcmpl-123".to_string(),
259                    choices: vec![Choice {
260                        index: 0,
261                        message: Message {
262                            role: "assistant".to_string(),
263                            content: Some("Hello! How can I help you today?".to_string()),
264                            tool_calls: Some(vec![ToolCall {
265                                id: "call_123".to_string(),
266                                tool_type: "function".to_string(),
267                                function: ToolDefinition {
268                                    name: "calculator".to_string(),
269                                    arguments: "{\"a\":5,\"b\":3,\"operation\":\"add\"}"
270                                        .to_string(),
271                                },
272                            }]),
273                            tool_call_id: None,
274                        },
275                        finish_reason: "stop".to_string(),
276                    }],
277                })
278                .unwrap()
279                .to_string(),
280            )
281            .create();
282
283        let client = Client::new(
284            "test-key".to_string(),
285            Model::OpenAI(OpenAIModel::GPT35Turbo),
286        )
287        .with_base_url(url);
288
289        let history = vec![Message {
290            role: "system".to_string(),
291            content: Some("You are a helpful assistant.".to_string()),
292            tool_calls: None,
293            tool_call_id: None,
294        }];
295
296        let tools = vec![Tool {
297            tool_type: "function".to_string(),
298            function: FunctionDefinition {
299                name: "calculator".to_string(),
300                description: "Calculate two numbers".to_string(),
301                parameters: json!({
302                    "type": "object",
303                    "properties": {
304                        "a": {"type": "number"},
305                        "b": {"type": "number"},
306                        "operation": {"type": "string"}
307                    },
308                    "required": ["a", "b", "operation"]
309                }),
310            },
311        }];
312
313        let result = client
314            .send_prompt_with_tools(Some("Calculate 5 plus 3".to_string()), history, tools)
315            .await
316            .unwrap();
317
318        // The result should be the JSON string of the tool calls
319        assert!(result.content.contains("calculator"));
320        assert!(result.content.contains("add"));
321        mock.assert();
322    }
323
324    #[tokio::test]
325    async fn test_model_string_conversion() {
326        assert_eq!(Model::OpenAI(OpenAIModel::GPT4).as_str(), "gpt-4");
327        assert_eq!(
328            Model::OpenAI(OpenAIModel::GPT35Turbo).as_str(),
329            "gpt-3.5-turbo"
330        );
331        assert_eq!(
332            Model::Anthropic(AnthropicModel::Claude3Sonnet).as_str(),
333            "claude-3-sonnet"
334        );
335    }
336
337    #[tokio::test]
338    async fn test_base_url_selection() {
339        let openai_client = Client::new(
340            "test-key".to_string(),
341            Model::OpenAI(OpenAIModel::GPT35Turbo),
342        );
343        assert_eq!(openai_client.get_base_url(), "https://api.openai.com");
344
345        let anthropic_client = Client::new(
346            "test-key".to_string(),
347            Model::Anthropic(AnthropicModel::Claude3Sonnet),
348        );
349        assert_eq!(anthropic_client.get_base_url(), "https://api.anthropic.com");
350    }
351}