meridian/llms/
open_ai.rs

1pub mod messages;
2
3use super::{
4    messages::AbstractMessage, LLMProvider, LLMToolUsage, MultiModelLLMProvider,
5    StructuredLLMProvider, Tool, ToolChoice, Toolkit,
6};
7use anyhow::Result;
8use log::{debug, info, warn};
9use messages::OpenAIMessage;
10use reqwest::blocking::Client;
11use schemars::{
12    schema::{ObjectValidation, RootSchema, Schema},
13    schema_for, JsonSchema,
14};
15use serde::{Deserialize, Serialize};
16
17pub struct OpenAIClient {
18    api_key: String,
19    client: Client,
20    model: OpenAIModel,
21}
22
23#[derive(Debug, Serialize, Deserialize, Clone, Copy)]
24pub enum OpenAIModel {
25    #[serde(rename = "gpt-4o")]
26    Gpt4o,
27    #[serde(rename = "o1-preview")]
28    O1Preview,
29}
30
31#[derive(Debug, Serialize, Deserialize)]
32pub struct CompletionRequest {
33    model: OpenAIModel,
34    messages: Vec<OpenAIMessage>,
35}
36
37impl CompletionRequest {
38    fn body(model: OpenAIModel, messages: Vec<OpenAIMessage>) -> Self {
39        Self { model, messages }
40    }
41}
42
43#[derive(Debug, Deserialize)]
44pub struct CompletionChoice {
45    finish_reason: String,
46    index: u64,
47    message: OpenAIMessage,
48}
49
50#[derive(Debug, Deserialize)]
51pub struct CompletionResponse {
52    id: String,
53    object: String,
54    created: u64, // unix timestamp
55    choices: Vec<CompletionChoice>,
56}
57
58fn set_additional_properties_false(root_schema: &mut RootSchema) {
59    // Set root schema
60    if root_schema.schema.object.is_none() {
61        root_schema.schema.object = Some(Box::new(ObjectValidation::default()));
62    }
63    root_schema
64        .schema
65        .object
66        .as_mut()
67        .unwrap()
68        .additional_properties = Some(Box::new(Schema::Bool(false)));
69
70    // Set for properties
71    if let Some(props) = &mut root_schema.schema.object {
72        for schema in props.properties.values_mut() {
73            if let Schema::Object(obj) = schema {
74                if obj.object.is_none() {
75                    obj.object = Some(Box::new(ObjectValidation::default()));
76                }
77                obj.object.as_mut().unwrap().additional_properties =
78                    Some(Box::new(Schema::Bool(false)));
79            }
80        }
81    }
82
83    // Set for definitions
84    for schema in root_schema.definitions.values_mut() {
85        if let Schema::Object(obj) = schema {
86            if obj.object.is_none() {
87                obj.object = Some(Box::new(ObjectValidation::default()));
88            }
89            obj.object.as_mut().unwrap().additional_properties =
90                Some(Box::new(Schema::Bool(false)));
91        }
92    }
93}
94
95impl LLMProvider<OpenAIMessage> for OpenAIClient {
96    fn get_completion(&self, messages: Vec<OpenAIMessage>) -> Result<Vec<OpenAIMessage>> {
97        debug!(
98            "Getting completion from OpenAI with {} messages",
99            messages.len()
100        );
101
102        let mut headers = reqwest::header::HeaderMap::new();
103        headers.insert(
104            "Authorization",
105            format!("Bearer {}", self.api_key)
106                .parse()
107                .expect("Invalid API key"),
108        );
109        headers.insert(
110            "Content-Type",
111            "application/json".parse().expect("Invalid content type"),
112        );
113
114        let request_body = CompletionRequest::body(OpenAIModel::Gpt4o, messages.clone());
115        debug!("Sending request to OpenAI API");
116
117        let result = self
118            .client
119            .post("https://api.openai.com/v1/chat/completions")
120            .headers(headers)
121            .json(&request_body)
122            .send()?;
123
124        if !result.status().is_success() {
125            let status = result.status();
126            let error_text = result.text()?;
127            warn!("OpenAI API error: {} - {}", status, error_text);
128            return Err(anyhow::anyhow!(
129                "Failed to get completion: {:?} {:?}",
130                status,
131                error_text
132            ));
133        }
134
135        let completion_response: CompletionResponse = result.json()?;
136
137        let last_message = completion_response.choices.first().ok_or(anyhow::anyhow!(
138            "No choices returned in the OpenAI response"
139        ))?;
140        debug!("Last message: {:?}", last_message.message);
141
142        Ok(messages
143            .into_iter()
144            .chain(vec![last_message.message.clone()])
145            .collect())
146    }
147
148    fn stream_completion(
149        &self,
150        messages: Vec<OpenAIMessage>,
151    ) -> Result<Box<dyn Iterator<Item = OpenAIMessage>>> {
152        todo!("Implement streaming for the OpenAI client")
153    }
154}
155
156impl MultiModelLLMProvider<OpenAIModel> for OpenAIClient {
157    fn with_model(&self, model: OpenAIModel) -> Self {
158        Self {
159            api_key: self.api_key.clone(),
160            client: self.client.clone(),
161            model,
162        }
163    }
164
165    fn get_model(&self) -> OpenAIModel {
166        self.model
167    }
168}
169
170impl LLMToolUsage<OpenAIMessage> for OpenAIClient {
171    fn do_work_with_tool(
172        &self,
173        messages: Vec<OpenAIMessage>,
174        tool: &dyn Tool,
175    ) -> Result<Vec<OpenAIMessage>> {
176        debug!("Executing tool '{}' with OpenAI", tool.name());
177
178        let mut headers = reqwest::header::HeaderMap::new();
179        headers.insert(
180            "Authorization",
181            format!("Bearer {}", self.api_key).parse().unwrap(),
182        );
183        headers.insert(
184            "Content-Type",
185            "application/json".parse().expect("Invalid content type"),
186        );
187
188        let request_body = serde_json::json!({
189            "model": self.model,
190            "messages": messages,
191            "tools": [{
192                "type": "function",
193                "function": {
194                    "name": tool.name(),
195                    "description": tool.description(),
196                    "parameters": tool.schema()
197                }
198            }],
199            "tool_choice": {
200                "type": "function",
201                "function": { "name": tool.name() }
202            }
203        });
204
205        debug!("Sending tool execution request to OpenAI API");
206        let result = self
207            .client
208            .post("https://api.openai.com/v1/chat/completions")
209            .headers(headers)
210            .json(&request_body)
211            .send()?;
212
213        if !result.status().is_success() {
214            let status = result.status();
215            let error_text = result.text()?;
216            warn!(
217                "OpenAI API error during tool execution: {} - {}",
218                status, error_text
219            );
220            return Err(anyhow::anyhow!("Failed to use tool: {}", error_text));
221        }
222
223        let response: CompletionResponse = result.json()?;
224        // Debugging: Print the raw response
225        println!("Raw response from tool use ask: {:#?}", response);
226
227        let message = response
228            .choices
229            .first()
230            .ok_or_else(|| anyhow::anyhow!("No choices returned in the OpenAI response"))?;
231        debug!("Last message: {:?}", message.message);
232
233        match &message.message {
234            OpenAIMessage::Assistant {
235                tool_calls: Some(tool_calls),
236                ..
237            } => {
238                let tool_call = tool_calls
239                    .first()
240                    .ok_or_else(|| anyhow::anyhow!("No tool calls in assistant message"))?;
241
242                let args = serde_json::from_str(&tool_call.function.arguments)?;
243                let result = tool.execute(args)?;
244
245                Ok(vec![OpenAIMessage::Tool {
246                    content: serde_json::to_string(&result)?,
247                    tool_call_id: tool_call.id.clone(),
248                }])
249            }
250            _ => Err(anyhow::anyhow!(
251                "Expected assistant message with tool calls"
252            )),
253        }
254    }
255
256    fn get_chat_with_tools(
257        &self,
258        messages: Vec<OpenAIMessage>,
259        tool_kit: &Toolkit,
260        force_tool_use: &ToolChoice,
261    ) -> Result<Vec<OpenAIMessage>> {
262        let mut headers = reqwest::header::HeaderMap::new();
263        headers.insert(
264            "Authorization",
265            format!("Bearer {}", self.api_key).parse().unwrap(),
266        );
267        headers.insert(
268            "Content-Type",
269            "application/json".parse().expect("Invalid content type"),
270        );
271
272        debug!("Messages: {:?}", messages);
273
274        let tool_defs: Vec<serde_json::Value> = tool_kit
275            .tools()
276            .iter()
277            .map(|tool| {
278                serde_json::json!({
279                    "type": "function",
280                    "function": {
281                        "name": tool.name(),
282                        "description": tool.description(),
283                        "parameters": tool.schema()
284                    }
285                })
286            })
287            .collect();
288
289        debug!("Tool definitions: {:?}", tool_defs);
290
291        let tool_choice = match force_tool_use {
292            ToolChoice::Specific(name) => serde_json::json!({
293                "type": "function",
294                "function": {
295                    "name": name
296                }
297            }),
298            ToolChoice::Any => serde_json::json!("required"),
299            ToolChoice::SelfSelect => serde_json::json!("auto"),
300        };
301
302        let request_body = serde_json::json!({
303            "model": self.model,
304            "messages": messages,
305            "tools": tool_defs,
306            "tool_choice": tool_choice
307        });
308
309        let result = self
310            .client
311            .post("https://api.openai.com/v1/chat/completions")
312            .headers(headers)
313            .json(&request_body)
314            .send()?;
315
316        if !result.status().is_success() {
317            let status = result.status();
318            let error_text = result.text()?;
319            warn!(
320                "OpenAI API error during chat with tools: {} - {}",
321                status, error_text
322            );
323            return Err(anyhow::anyhow!("Failed to chat with tools: {}", error_text));
324        }
325
326        let response: CompletionResponse = result.json()?;
327
328        let message = response
329            .choices
330            .first()
331            .ok_or_else(|| anyhow::anyhow!("No choices returned in the OpenAI response"))?;
332        debug!("Last message: {:?}", message.message);
333
334        // Return all the messages
335        Ok(messages
336            .into_iter()
337            .chain(vec![message.message.clone()])
338            .collect())
339    }
340
341    fn get_work_result(
342        &self,
343        messages: Vec<OpenAIMessage>,
344        tool_kit: &Toolkit,
345        tool_choice: &ToolChoice,
346    ) -> Result<Vec<OpenAIMessage>> {
347        info!("Getting work result with tool choice: {:?}", tool_choice);
348
349        match tool_choice {
350            ToolChoice::Specific(name) => {
351                debug!("Using specific tool: {}", name);
352                self.do_work_with_tool(
353                    messages,
354                    tool_kit
355                        .get(name)
356                        .ok_or_else(|| anyhow::anyhow!("Tool not found: {}", name))?,
357                )
358            }
359            ToolChoice::Any => {
360                debug!("Getting chat with any tool allowed");
361                let response = self.get_chat_with_tools(messages, tool_kit, tool_choice)?;
362                debug!("Response from chat with tools: {:?}", response);
363
364                if let Some(OpenAIMessage::Assistant {
365                    tool_calls: Some(tool_calls),
366                    ..
367                }) = response.clone().last()
368                {
369                    let mut result_messages = response;
370
371                    // Process all tool calls
372                    for tool_call in tool_calls {
373                        debug!("Processing tool call: {:?}", tool_call);
374                        let tool = tool_kit.get(&tool_call.function.name).ok_or_else(|| {
375                            anyhow::anyhow!("Tool not found: {}", tool_call.function.name)
376                        })?;
377
378                        let args = serde_json::from_str(&tool_call.function.arguments)?;
379                        let result = tool.execute(args)?;
380
381                        result_messages.push(OpenAIMessage::Tool {
382                            content: serde_json::to_string(&result)?,
383                            tool_call_id: tool_call.id.clone(),
384                        });
385                    }
386
387                    debug!("Result messages: {:?}", result_messages);
388
389                    let messages = self.get_work_result(result_messages, tool_kit, tool_choice)?;
390                    Ok(messages)
391                } else {
392                    Err(anyhow::anyhow!("No tool calls in assistant message"))
393                }
394            }
395            ToolChoice::SelfSelect => {
396                debug!("Letting model select tool usage");
397                let response = self.get_chat_with_tools(messages, tool_kit, tool_choice)?;
398                debug!("Response from chat with tools: {:?}", response);
399
400                if let Some(OpenAIMessage::Assistant {
401                    tool_calls: Some(tool_calls),
402                    ..
403                }) = response.clone().last()
404                {
405                    let mut result_messages = response;
406
407                    // Process all tool calls
408                    for tool_call in tool_calls {
409                        debug!("Processing tool call: {:?}", tool_call);
410                        let tool = tool_kit.get(&tool_call.function.name).ok_or_else(|| {
411                            anyhow::anyhow!("Tool not found: {}", tool_call.function.name)
412                        })?;
413
414                        let args = serde_json::from_str(&tool_call.function.arguments)?;
415                        let result = tool.execute(args)?;
416
417                        result_messages.push(OpenAIMessage::Tool {
418                            content: serde_json::to_string(&result)?,
419                            tool_call_id: tool_call.id.clone(),
420                        });
421                    }
422
423                    debug!("Result messages: {:?}", result_messages);
424
425                    let messages = self.get_work_result(result_messages, tool_kit, tool_choice)?;
426                    Ok(messages)
427                } else {
428                    Ok(response) // For SelfSelect, we return the response even if no tools were used
429                }
430            }
431        }
432    }
433}
434
435impl StructuredLLMProvider<OpenAIMessage> for OpenAIClient {
436    fn get_structured_response<
437        DesiredSchema: Serialize + serde::de::DeserializeOwned + JsonSchema,
438    >(
439        &self,
440        messages: Vec<OpenAIMessage>,
441    ) -> Result<DesiredSchema> {
442        let mut headers = reqwest::header::HeaderMap::new();
443        headers.insert(
444            "Authorization",
445            format!("Bearer {}", self.api_key)
446                .parse()
447                .expect("Invalid API key"),
448        );
449        headers.insert(
450            "Content-Type",
451            "application/json".parse().expect("Invalid content type"),
452        );
453
454        let mut schema = schema_for!(DesiredSchema);
455        set_additional_properties_false(&mut schema);
456
457        println!("{}", serde_json::to_string(&schema).unwrap());
458
459        let request_body = serde_json::json!({
460            "model": OpenAIModel::Gpt4o,
461            "messages": messages,
462            "response_format": {
463                "type": "json_schema",
464                "json_schema": {
465                    "name": "desired_schema",
466                    "strict": true,
467                    "schema": schema
468                }
469            }
470        });
471
472        let result = self
473            .client
474            .post("https://api.openai.com/v1/chat/completions")
475            .headers(headers)
476            .json(&request_body)
477            .send()?;
478
479        if !result.status().is_success() {
480            return Err(anyhow::anyhow!(
481                "Failed to get structured response: {:?} {:?}",
482                result.status(),
483                result.text()
484            ));
485        }
486
487        let response: CompletionResponse = result.json()?;
488
489        let content = response.choices[0]
490            .message
491            .get_content()
492            .map_err(|_| anyhow::anyhow!("Failed to get message content"))?;
493
494        Ok(serde_json::from_str(&content)?)
495    }
496}
497
498impl Default for OpenAIClient {
499    fn default() -> Self {
500        Self::new()
501    }
502}
503
504impl OpenAIClient {
505    pub fn new() -> Self {
506        Self {
507            api_key: std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set"),
508            client: Client::new(),
509            model: OpenAIModel::Gpt4o,
510        }
511    }
512}