1pub mod messages;
2
3use super::{
4 messages::AbstractMessage, LLMProvider, LLMToolUsage, MultiModelLLMProvider,
5 StructuredLLMProvider, Tool, ToolChoice, Toolkit,
6};
7use anyhow::Result;
8use log::{debug, info, warn};
9use messages::OpenAIMessage;
10use reqwest::blocking::Client;
11use schemars::{
12 schema::{ObjectValidation, RootSchema, Schema},
13 schema_for, JsonSchema,
14};
15use serde::{Deserialize, Serialize};
16
17pub struct OpenAIClient {
18 api_key: String,
19 client: Client,
20 model: OpenAIModel,
21}
22
23#[derive(Debug, Serialize, Deserialize, Clone, Copy)]
24pub enum OpenAIModel {
25 #[serde(rename = "gpt-4o")]
26 Gpt4o,
27 #[serde(rename = "o1-preview")]
28 O1Preview,
29}
30
31#[derive(Debug, Serialize, Deserialize)]
32pub struct CompletionRequest {
33 model: OpenAIModel,
34 messages: Vec<OpenAIMessage>,
35}
36
37impl CompletionRequest {
38 fn body(model: OpenAIModel, messages: Vec<OpenAIMessage>) -> Self {
39 Self { model, messages }
40 }
41}
42
43#[derive(Debug, Deserialize)]
44pub struct CompletionChoice {
45 finish_reason: String,
46 index: u64,
47 message: OpenAIMessage,
48}
49
50#[derive(Debug, Deserialize)]
51pub struct CompletionResponse {
52 id: String,
53 object: String,
54 created: u64, choices: Vec<CompletionChoice>,
56}
57
58fn set_additional_properties_false(root_schema: &mut RootSchema) {
59 if root_schema.schema.object.is_none() {
61 root_schema.schema.object = Some(Box::new(ObjectValidation::default()));
62 }
63 root_schema
64 .schema
65 .object
66 .as_mut()
67 .unwrap()
68 .additional_properties = Some(Box::new(Schema::Bool(false)));
69
70 if let Some(props) = &mut root_schema.schema.object {
72 for schema in props.properties.values_mut() {
73 if let Schema::Object(obj) = schema {
74 if obj.object.is_none() {
75 obj.object = Some(Box::new(ObjectValidation::default()));
76 }
77 obj.object.as_mut().unwrap().additional_properties =
78 Some(Box::new(Schema::Bool(false)));
79 }
80 }
81 }
82
83 for schema in root_schema.definitions.values_mut() {
85 if let Schema::Object(obj) = schema {
86 if obj.object.is_none() {
87 obj.object = Some(Box::new(ObjectValidation::default()));
88 }
89 obj.object.as_mut().unwrap().additional_properties =
90 Some(Box::new(Schema::Bool(false)));
91 }
92 }
93}
94
95impl LLMProvider<OpenAIMessage> for OpenAIClient {
96 fn get_completion(&self, messages: Vec<OpenAIMessage>) -> Result<Vec<OpenAIMessage>> {
97 debug!(
98 "Getting completion from OpenAI with {} messages",
99 messages.len()
100 );
101
102 let mut headers = reqwest::header::HeaderMap::new();
103 headers.insert(
104 "Authorization",
105 format!("Bearer {}", self.api_key)
106 .parse()
107 .expect("Invalid API key"),
108 );
109 headers.insert(
110 "Content-Type",
111 "application/json".parse().expect("Invalid content type"),
112 );
113
114 let request_body = CompletionRequest::body(OpenAIModel::Gpt4o, messages.clone());
115 debug!("Sending request to OpenAI API");
116
117 let result = self
118 .client
119 .post("https://api.openai.com/v1/chat/completions")
120 .headers(headers)
121 .json(&request_body)
122 .send()?;
123
124 if !result.status().is_success() {
125 let status = result.status();
126 let error_text = result.text()?;
127 warn!("OpenAI API error: {} - {}", status, error_text);
128 return Err(anyhow::anyhow!(
129 "Failed to get completion: {:?} {:?}",
130 status,
131 error_text
132 ));
133 }
134
135 let completion_response: CompletionResponse = result.json()?;
136
137 let last_message = completion_response.choices.first().ok_or(anyhow::anyhow!(
138 "No choices returned in the OpenAI response"
139 ))?;
140 debug!("Last message: {:?}", last_message.message);
141
142 Ok(messages
143 .into_iter()
144 .chain(vec![last_message.message.clone()])
145 .collect())
146 }
147
148 fn stream_completion(
149 &self,
150 messages: Vec<OpenAIMessage>,
151 ) -> Result<Box<dyn Iterator<Item = OpenAIMessage>>> {
152 todo!("Implement streaming for the OpenAI client")
153 }
154}
155
156impl MultiModelLLMProvider<OpenAIModel> for OpenAIClient {
157 fn with_model(&self, model: OpenAIModel) -> Self {
158 Self {
159 api_key: self.api_key.clone(),
160 client: self.client.clone(),
161 model,
162 }
163 }
164
165 fn get_model(&self) -> OpenAIModel {
166 self.model
167 }
168}
169
170impl LLMToolUsage<OpenAIMessage> for OpenAIClient {
171 fn do_work_with_tool(
172 &self,
173 messages: Vec<OpenAIMessage>,
174 tool: &dyn Tool,
175 ) -> Result<Vec<OpenAIMessage>> {
176 debug!("Executing tool '{}' with OpenAI", tool.name());
177
178 let mut headers = reqwest::header::HeaderMap::new();
179 headers.insert(
180 "Authorization",
181 format!("Bearer {}", self.api_key).parse().unwrap(),
182 );
183 headers.insert(
184 "Content-Type",
185 "application/json".parse().expect("Invalid content type"),
186 );
187
188 let request_body = serde_json::json!({
189 "model": self.model,
190 "messages": messages,
191 "tools": [{
192 "type": "function",
193 "function": {
194 "name": tool.name(),
195 "description": tool.description(),
196 "parameters": tool.schema()
197 }
198 }],
199 "tool_choice": {
200 "type": "function",
201 "function": { "name": tool.name() }
202 }
203 });
204
205 debug!("Sending tool execution request to OpenAI API");
206 let result = self
207 .client
208 .post("https://api.openai.com/v1/chat/completions")
209 .headers(headers)
210 .json(&request_body)
211 .send()?;
212
213 if !result.status().is_success() {
214 let status = result.status();
215 let error_text = result.text()?;
216 warn!(
217 "OpenAI API error during tool execution: {} - {}",
218 status, error_text
219 );
220 return Err(anyhow::anyhow!("Failed to use tool: {}", error_text));
221 }
222
223 let response: CompletionResponse = result.json()?;
224 println!("Raw response from tool use ask: {:#?}", response);
226
227 let message = response
228 .choices
229 .first()
230 .ok_or_else(|| anyhow::anyhow!("No choices returned in the OpenAI response"))?;
231 debug!("Last message: {:?}", message.message);
232
233 match &message.message {
234 OpenAIMessage::Assistant {
235 tool_calls: Some(tool_calls),
236 ..
237 } => {
238 let tool_call = tool_calls
239 .first()
240 .ok_or_else(|| anyhow::anyhow!("No tool calls in assistant message"))?;
241
242 let args = serde_json::from_str(&tool_call.function.arguments)?;
243 let result = tool.execute(args)?;
244
245 Ok(vec![OpenAIMessage::Tool {
246 content: serde_json::to_string(&result)?,
247 tool_call_id: tool_call.id.clone(),
248 }])
249 }
250 _ => Err(anyhow::anyhow!(
251 "Expected assistant message with tool calls"
252 )),
253 }
254 }
255
256 fn get_chat_with_tools(
257 &self,
258 messages: Vec<OpenAIMessage>,
259 tool_kit: &Toolkit,
260 force_tool_use: &ToolChoice,
261 ) -> Result<Vec<OpenAIMessage>> {
262 let mut headers = reqwest::header::HeaderMap::new();
263 headers.insert(
264 "Authorization",
265 format!("Bearer {}", self.api_key).parse().unwrap(),
266 );
267 headers.insert(
268 "Content-Type",
269 "application/json".parse().expect("Invalid content type"),
270 );
271
272 debug!("Messages: {:?}", messages);
273
274 let tool_defs: Vec<serde_json::Value> = tool_kit
275 .tools()
276 .iter()
277 .map(|tool| {
278 serde_json::json!({
279 "type": "function",
280 "function": {
281 "name": tool.name(),
282 "description": tool.description(),
283 "parameters": tool.schema()
284 }
285 })
286 })
287 .collect();
288
289 debug!("Tool definitions: {:?}", tool_defs);
290
291 let tool_choice = match force_tool_use {
292 ToolChoice::Specific(name) => serde_json::json!({
293 "type": "function",
294 "function": {
295 "name": name
296 }
297 }),
298 ToolChoice::Any => serde_json::json!("required"),
299 ToolChoice::SelfSelect => serde_json::json!("auto"),
300 };
301
302 let request_body = serde_json::json!({
303 "model": self.model,
304 "messages": messages,
305 "tools": tool_defs,
306 "tool_choice": tool_choice
307 });
308
309 let result = self
310 .client
311 .post("https://api.openai.com/v1/chat/completions")
312 .headers(headers)
313 .json(&request_body)
314 .send()?;
315
316 if !result.status().is_success() {
317 let status = result.status();
318 let error_text = result.text()?;
319 warn!(
320 "OpenAI API error during chat with tools: {} - {}",
321 status, error_text
322 );
323 return Err(anyhow::anyhow!("Failed to chat with tools: {}", error_text));
324 }
325
326 let response: CompletionResponse = result.json()?;
327
328 let message = response
329 .choices
330 .first()
331 .ok_or_else(|| anyhow::anyhow!("No choices returned in the OpenAI response"))?;
332 debug!("Last message: {:?}", message.message);
333
334 Ok(messages
336 .into_iter()
337 .chain(vec![message.message.clone()])
338 .collect())
339 }
340
341 fn get_work_result(
342 &self,
343 messages: Vec<OpenAIMessage>,
344 tool_kit: &Toolkit,
345 tool_choice: &ToolChoice,
346 ) -> Result<Vec<OpenAIMessage>> {
347 info!("Getting work result with tool choice: {:?}", tool_choice);
348
349 match tool_choice {
350 ToolChoice::Specific(name) => {
351 debug!("Using specific tool: {}", name);
352 self.do_work_with_tool(
353 messages,
354 tool_kit
355 .get(name)
356 .ok_or_else(|| anyhow::anyhow!("Tool not found: {}", name))?,
357 )
358 }
359 ToolChoice::Any => {
360 debug!("Getting chat with any tool allowed");
361 let response = self.get_chat_with_tools(messages, tool_kit, tool_choice)?;
362 debug!("Response from chat with tools: {:?}", response);
363
364 if let Some(OpenAIMessage::Assistant {
365 tool_calls: Some(tool_calls),
366 ..
367 }) = response.clone().last()
368 {
369 let mut result_messages = response;
370
371 for tool_call in tool_calls {
373 debug!("Processing tool call: {:?}", tool_call);
374 let tool = tool_kit.get(&tool_call.function.name).ok_or_else(|| {
375 anyhow::anyhow!("Tool not found: {}", tool_call.function.name)
376 })?;
377
378 let args = serde_json::from_str(&tool_call.function.arguments)?;
379 let result = tool.execute(args)?;
380
381 result_messages.push(OpenAIMessage::Tool {
382 content: serde_json::to_string(&result)?,
383 tool_call_id: tool_call.id.clone(),
384 });
385 }
386
387 debug!("Result messages: {:?}", result_messages);
388
389 let messages = self.get_work_result(result_messages, tool_kit, tool_choice)?;
390 Ok(messages)
391 } else {
392 Err(anyhow::anyhow!("No tool calls in assistant message"))
393 }
394 }
395 ToolChoice::SelfSelect => {
396 debug!("Letting model select tool usage");
397 let response = self.get_chat_with_tools(messages, tool_kit, tool_choice)?;
398 debug!("Response from chat with tools: {:?}", response);
399
400 if let Some(OpenAIMessage::Assistant {
401 tool_calls: Some(tool_calls),
402 ..
403 }) = response.clone().last()
404 {
405 let mut result_messages = response;
406
407 for tool_call in tool_calls {
409 debug!("Processing tool call: {:?}", tool_call);
410 let tool = tool_kit.get(&tool_call.function.name).ok_or_else(|| {
411 anyhow::anyhow!("Tool not found: {}", tool_call.function.name)
412 })?;
413
414 let args = serde_json::from_str(&tool_call.function.arguments)?;
415 let result = tool.execute(args)?;
416
417 result_messages.push(OpenAIMessage::Tool {
418 content: serde_json::to_string(&result)?,
419 tool_call_id: tool_call.id.clone(),
420 });
421 }
422
423 debug!("Result messages: {:?}", result_messages);
424
425 let messages = self.get_work_result(result_messages, tool_kit, tool_choice)?;
426 Ok(messages)
427 } else {
428 Ok(response) }
430 }
431 }
432 }
433}
434
435impl StructuredLLMProvider<OpenAIMessage> for OpenAIClient {
436 fn get_structured_response<
437 DesiredSchema: Serialize + serde::de::DeserializeOwned + JsonSchema,
438 >(
439 &self,
440 messages: Vec<OpenAIMessage>,
441 ) -> Result<DesiredSchema> {
442 let mut headers = reqwest::header::HeaderMap::new();
443 headers.insert(
444 "Authorization",
445 format!("Bearer {}", self.api_key)
446 .parse()
447 .expect("Invalid API key"),
448 );
449 headers.insert(
450 "Content-Type",
451 "application/json".parse().expect("Invalid content type"),
452 );
453
454 let mut schema = schema_for!(DesiredSchema);
455 set_additional_properties_false(&mut schema);
456
457 println!("{}", serde_json::to_string(&schema).unwrap());
458
459 let request_body = serde_json::json!({
460 "model": OpenAIModel::Gpt4o,
461 "messages": messages,
462 "response_format": {
463 "type": "json_schema",
464 "json_schema": {
465 "name": "desired_schema",
466 "strict": true,
467 "schema": schema
468 }
469 }
470 });
471
472 let result = self
473 .client
474 .post("https://api.openai.com/v1/chat/completions")
475 .headers(headers)
476 .json(&request_body)
477 .send()?;
478
479 if !result.status().is_success() {
480 return Err(anyhow::anyhow!(
481 "Failed to get structured response: {:?} {:?}",
482 result.status(),
483 result.text()
484 ));
485 }
486
487 let response: CompletionResponse = result.json()?;
488
489 let content = response.choices[0]
490 .message
491 .get_content()
492 .map_err(|_| anyhow::anyhow!("Failed to get message content"))?;
493
494 Ok(serde_json::from_str(&content)?)
495 }
496}
497
498impl Default for OpenAIClient {
499 fn default() -> Self {
500 Self::new()
501 }
502}
503
504impl OpenAIClient {
505 pub fn new() -> Self {
506 Self {
507 api_key: std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set"),
508 client: Client::new(),
509 model: OpenAIModel::Gpt4o,
510 }
511 }
512}