1use crate::error::{Error, Result};
2use crate::message::{Content, ContentPart, Message};
3use crate::provider::HTTPProvider;
4use crate::{Chat, LlmToolInfo, OpenAi};
5use reqwest::{Method, Request, Url};
6use serde::{Deserialize, Serialize};
7use std::env;
8use tracing::{debug, error, info, instrument, trace, warn};
9
10#[derive(Debug, Clone)]
12pub struct OpenAIConfig {
13 pub api_key: String,
15 pub base_url: String,
17 pub organization: Option<String>,
19}
20
21impl Default for OpenAIConfig {
22 fn default() -> Self {
23 Self {
24 api_key: env::var("OPENAI_API_KEY").unwrap_or_default(),
25 base_url: "https://api.openai.com/v1".to_string(),
26 organization: env::var("OPENAI_ORGANIZATION").ok(),
27 }
28 }
29}
30
31#[derive(Debug, Clone)]
33pub struct OpenAIProvider {
34 config: OpenAIConfig,
36}
37
38impl OpenAIProvider {
39 #[instrument(level = "debug")]
51 pub fn new() -> Self {
52 info!("Creating new OpenAIProvider with default configuration");
53 let config = OpenAIConfig::default();
54 debug!("API key set: {}", !config.api_key.is_empty());
55 debug!("Base URL: {}", config.base_url);
56 debug!("Organization set: {}", config.organization.is_some());
57
58 Self { config }
59 }
60
61 #[instrument(skip(config), level = "debug")]
77 pub fn with_config(config: OpenAIConfig) -> Self {
78 info!("Creating new OpenAIProvider with custom configuration");
79 debug!("API key set: {}", !config.api_key.is_empty());
80 debug!("Base URL: {}", config.base_url);
81 debug!("Organization set: {}", config.organization.is_some());
82
83 Self { config }
84 }
85}
86
87impl Default for OpenAIProvider {
88 fn default() -> Self {
89 Self::new()
90 }
91}
92
93impl HTTPProvider<OpenAi> for OpenAIProvider {
94 fn accept(&self, model: OpenAi, chat: &Chat) -> Result<Request> {
95 info!("Creating request for OpenAI model: {:?}", model);
96 debug!("Messages in chat history: {}", chat.history.len());
97
98 let url_str = format!("{}/chat/completions", self.config.base_url);
99 debug!("Parsing URL: {}", url_str);
100 let url = match Url::parse(&url_str) {
101 Ok(url) => {
102 debug!("URL parsed successfully: {}", url);
103 url
104 }
105 Err(e) => {
106 error!("Failed to parse URL '{}': {}", url_str, e);
107 return Err(e.into());
108 }
109 };
110
111 let mut request = Request::new(Method::POST, url);
112 debug!("Created request: {} {}", request.method(), request.url());
113
114 debug!("Setting request headers");
116
117 let auth_header = match format!("Bearer {}", self.config.api_key).parse() {
119 Ok(header) => header,
120 Err(e) => {
121 error!("Invalid API key format: {}", e);
122 return Err(Error::Authentication("Invalid API key format".into()));
123 }
124 };
125
126 let content_type_header = match "application/json".parse() {
127 Ok(header) => header,
128 Err(e) => {
129 error!("Failed to set content type: {}", e);
130 return Err(Error::Other("Failed to set content type".into()));
131 }
132 };
133
134 request.headers_mut().insert("Authorization", auth_header);
135 request
136 .headers_mut()
137 .insert("Content-Type", content_type_header);
138
139 if let Some(org) = &self.config.organization {
141 match org.parse() {
142 Ok(header) => {
143 request.headers_mut().insert("OpenAI-Organization", header);
144 debug!("Added organization header");
145 }
146 Err(e) => {
147 warn!("Failed to set organization header: {}", e);
148 }
150 }
151 }
152
153 trace!("Request headers set: {:#?}", request.headers());
154
155 debug!("Creating request payload");
157 let payload = match self.create_request_payload(model, chat) {
158 Ok(payload) => {
159 debug!("Request payload created successfully");
160 trace!("Model: {}", payload.model);
161 trace!("Max tokens: {:?}", payload.max_tokens);
162 trace!("Number of messages: {}", payload.messages.len());
163 payload
164 }
165 Err(e) => {
166 error!("Failed to create request payload: {}", e);
167 return Err(e);
168 }
169 };
170
171 debug!("Serializing request payload");
173 let body_bytes = match serde_json::to_vec(&payload) {
174 Ok(bytes) => {
175 debug!("Payload serialized successfully ({} bytes)", bytes.len());
176 bytes
177 }
178 Err(e) => {
179 error!("Failed to serialize payload: {}", e);
180 return Err(Error::Serialization(e));
181 }
182 };
183
184 *request.body_mut() = Some(body_bytes.into());
185 info!("Request created successfully");
186
187 Ok(request)
188 }
189
190 fn parse(&self, raw_response_text: String) -> Result<Message> {
191 info!("Parsing response from OpenAI API");
192 trace!("Raw response: {}", raw_response_text);
193
194 if let Ok(error_response) = serde_json::from_str::<OpenAIErrorResponse>(&raw_response_text)
196 {
197 if let Some(error) = error_response.error {
198 error!("OpenAI API returned an error: {}", error.message);
199 return Err(Error::ProviderUnavailable(error.message));
200 }
201 }
202
203 debug!("Deserializing response JSON");
205 let openai_response = match serde_json::from_str::<OpenAIResponse>(&raw_response_text) {
206 Ok(response) => {
207 debug!("Response deserialized successfully");
208 debug!("Response model: {}", response.model);
209 if !response.choices.is_empty() {
210 debug!("Number of choices: {}", response.choices.len());
211 debug!(
212 "First choice finish reason: {:?}",
213 response.choices[0].finish_reason
214 );
215 }
216 if let Some(usage) = &response.usage {
217 debug!(
218 "Token usage - prompt: {}, completion: {}, total: {}",
219 usage.prompt_tokens, usage.completion_tokens, usage.total_tokens
220 );
221 }
222 response
223 }
224 Err(e) => {
225 error!("Failed to deserialize response: {}", e);
226 error!("Raw response: {}", raw_response_text);
227 return Err(Error::Serialization(e));
228 }
229 };
230
231 debug!("Converting OpenAI response to Message");
233 let message = Message::from(&openai_response);
234
235 info!("Response parsed successfully");
236 trace!("Response message processed");
237
238 Ok(message)
239 }
240}
241
242pub trait OpenAIModelInfo {
244 fn openai_model_id(&self) -> String;
245}
246
247impl OpenAIProvider {
248 #[instrument(skip(self, chat), level = "debug")]
253 fn create_request_payload(&self, model: OpenAi, chat: &Chat) -> Result<OpenAIRequest> {
254 info!("Creating request payload for chat with OpenAI model");
255 debug!("System prompt length: {}", chat.system_prompt.len());
256 debug!("Messages in history: {}", chat.history.len());
257 debug!("Max output tokens: {}", chat.max_output_tokens);
258
259 let model_id = model.openai_model_id();
260 debug!("Using model ID: {}", model_id);
261
262 debug!("Converting messages to OpenAI format");
264 let mut messages: Vec<OpenAIMessage> = Vec::new();
265
266 if !chat.system_prompt.is_empty() {
268 debug!("Adding system prompt");
269 messages.push(OpenAIMessage {
270 role: "system".to_string(),
271 content: Some(chat.system_prompt.clone()),
272 function_call: None,
273 name: None,
274 tool_calls: None,
275 tool_call_id: None,
276 });
277 }
278
279 for msg in &chat.history {
281 debug!("Converting message with role: {}", msg.role_str());
282 messages.push(OpenAIMessage::from(msg));
283 }
284
285 debug!("Converted {} messages for the request", messages.len());
286
287 let tools = chat
289 .tools
290 .as_ref()
291 .map(|tools| tools.iter().map(OpenAITool::from).collect());
292
293 let tool_choice = if let Some(choice) = &chat.tool_choice {
295 match choice {
297 crate::tool::ToolChoice::Auto => Some(serde_json::json!("auto")),
298 crate::tool::ToolChoice::Any => Some(serde_json::json!("required")),
300 crate::tool::ToolChoice::None => Some(serde_json::json!("none")),
301 crate::tool::ToolChoice::Specific(name) => {
302 Some(serde_json::json!({
304 "type": "function",
305 "function": {
306 "name": name
307 }
308 }))
309 }
310 }
311 } else if tools.is_some() {
312 Some(serde_json::json!("auto"))
314 } else {
315 None
316 };
317
318 debug!("Creating OpenAIRequest");
320
321 let is_o_series = model_id.starts_with("o");
323
324 let request = OpenAIRequest {
325 model: model_id,
326 messages,
327 temperature: None,
328 top_p: None,
329 n: None,
330 max_tokens: if is_o_series {
332 None
333 } else {
334 Some(chat.max_output_tokens)
335 },
336 max_completion_tokens: if is_o_series {
337 Some(chat.max_output_tokens)
338 } else {
339 None
340 },
341 presence_penalty: None,
342 frequency_penalty: None,
343 stream: None,
344 tools,
345 tool_choice,
346 };
347
348 info!("Request payload created successfully");
349 Ok(request)
350 }
351}
352
353#[derive(Debug, Clone, Serialize, Deserialize)]
355pub(crate) struct OpenAIMessage {
356 pub role: String,
358 #[serde(skip_serializing_if = "Option::is_none")]
360 pub content: Option<String>,
361 #[serde(skip_serializing_if = "Option::is_none")]
363 pub function_call: Option<OpenAIFunctionCall>,
364 #[serde(skip_serializing_if = "Option::is_none")]
366 pub name: Option<String>,
367 #[serde(skip_serializing_if = "Option::is_none")]
369 pub tool_calls: Option<Vec<OpenAIToolCall>>,
370 #[serde(skip_serializing_if = "Option::is_none")]
372 pub tool_call_id: Option<String>,
373}
374
375#[derive(Debug, Serialize, Deserialize)]
377pub(crate) struct OpenAIFunction {
378 pub name: String,
380 pub description: String,
382 pub parameters: serde_json::Value,
384}
385
386#[derive(Debug, Serialize, Deserialize)]
388pub(crate) struct OpenAITool {
389 pub r#type: String,
391 pub function: OpenAIFunction,
393}
394
395impl From<&LlmToolInfo> for OpenAITool {
396 fn from(value: &LlmToolInfo) -> Self {
397 OpenAITool {
398 r#type: "function".to_string(),
399 function: OpenAIFunction {
400 name: value.name.clone(),
401 description: value.description.clone(),
402 parameters: value.parameters.clone(),
403 },
404 }
405 }
406}
407
408#[derive(Debug, Clone, Serialize, Deserialize)]
410pub(crate) struct OpenAIFunctionCall {
411 pub name: String,
413 pub arguments: String,
415}
416
417#[derive(Debug, Clone, Serialize, Deserialize)]
419pub(crate) struct OpenAIToolCall {
420 pub id: String,
422 pub r#type: String,
424 pub function: OpenAIFunctionCall,
426}
427
428#[derive(Debug, Serialize, Deserialize)]
430pub(crate) struct OpenAIRequest {
431 pub model: String,
433 pub messages: Vec<OpenAIMessage>,
435 #[serde(skip_serializing_if = "Option::is_none")]
437 pub temperature: Option<f32>,
438 #[serde(skip_serializing_if = "Option::is_none")]
440 pub top_p: Option<f32>,
441 #[serde(skip_serializing_if = "Option::is_none")]
443 pub n: Option<usize>,
444 #[serde(skip_serializing_if = "Option::is_none")]
446 pub max_tokens: Option<usize>,
447 #[serde(skip_serializing_if = "Option::is_none")]
449 pub max_completion_tokens: Option<usize>,
450 #[serde(skip_serializing_if = "Option::is_none")]
452 pub presence_penalty: Option<f32>,
453 #[serde(skip_serializing_if = "Option::is_none")]
455 pub frequency_penalty: Option<f32>,
456 #[serde(skip_serializing_if = "Option::is_none")]
458 pub stream: Option<bool>,
459 #[serde(skip_serializing_if = "Option::is_none")]
461 pub tools: Option<Vec<OpenAITool>>,
462 #[serde(skip_serializing_if = "Option::is_none")]
464 pub tool_choice: Option<serde_json::Value>,
465}
466
467#[derive(Debug, Serialize, Deserialize)]
469pub(crate) struct OpenAIResponse {
470 pub id: String,
472 pub object: String,
474 pub created: u64,
476 pub model: String,
478 pub choices: Vec<OpenAIChoice>,
480 pub usage: Option<OpenAIUsage>,
482}
483
484#[derive(Debug, Serialize, Deserialize)]
486pub(crate) struct OpenAIChoice {
487 pub index: usize,
489 pub message: OpenAIMessage,
491 pub finish_reason: Option<String>,
493}
494
495#[derive(Debug, Serialize, Deserialize)]
497pub(crate) struct OpenAIUsage {
498 pub prompt_tokens: u32,
500 pub completion_tokens: u32,
502 pub total_tokens: u32,
504}
505
506#[derive(Debug, Serialize, Deserialize)]
508pub(crate) struct OpenAIErrorResponse {
509 pub error: Option<OpenAIError>,
511}
512
513#[derive(Debug, Serialize, Deserialize)]
515pub(crate) struct OpenAIError {
516 pub message: String,
518 #[serde(rename = "type")]
520 pub error_type: String,
521 #[serde(skip_serializing_if = "Option::is_none")]
523 pub code: Option<String>,
524}
525
526impl From<&Message> for OpenAIMessage {
528 fn from(msg: &Message) -> Self {
529 let role = match msg {
530 Message::System { .. } => "system",
531 Message::User { .. } => "user",
532 Message::Assistant { .. } => "assistant",
533 Message::Tool { .. } => "tool",
534 }
535 .to_string();
536
537 let (content, name, function_call, tool_calls, tool_call_id) = match msg {
538 Message::System { content, .. } => (Some(content.clone()), None, None, None, None),
539 Message::User { content, name, .. } => {
540 let content_str = match content {
541 Content::Text(text) => Some(text.clone()),
542 Content::Parts(parts) => {
543 let combined_text = parts
545 .iter()
546 .filter_map(|part| match part {
547 ContentPart::Text { text } => Some(text.clone()),
548 _ => None,
549 })
550 .collect::<Vec<String>>()
551 .join("\n");
552
553 if combined_text.is_empty() {
554 None
555 } else {
556 Some(combined_text)
557 }
558 }
559 };
560 (content_str, name.clone(), None, None, None)
561 }
562 Message::Assistant {
563 content,
564 tool_calls,
565 ..
566 } => {
567 let content_str = match content {
568 Some(Content::Text(text)) => Some(text.clone()),
569 Some(Content::Parts(parts)) => {
570 let combined_text = parts
572 .iter()
573 .filter_map(|part| match part {
574 ContentPart::Text { text } => Some(text.clone()),
575 _ => None,
576 })
577 .collect::<Vec<String>>()
578 .join("\n");
579
580 if combined_text.is_empty() {
581 None
582 } else {
583 Some(combined_text)
584 }
585 }
586 None => None,
587 };
588
589 let openai_tool_calls = if !tool_calls.is_empty() {
591 let mut calls = Vec::with_capacity(tool_calls.len());
592
593 for tc in tool_calls {
594 calls.push(OpenAIToolCall {
595 id: tc.id.clone(),
596 r#type: tc.tool_type.clone(),
597 function: OpenAIFunctionCall {
598 name: tc.function.name.clone(),
599 arguments: tc.function.arguments.clone(),
600 },
601 });
602 }
603
604 Some(calls)
605 } else {
606 None
607 };
608
609 (content_str, None, None, openai_tool_calls, None)
610 }
611 Message::Tool {
612 tool_call_id,
613 content,
614 ..
615 } => (
616 Some(content.clone()),
617 None,
618 None,
619 None,
620 Some(tool_call_id.clone()),
621 ),
622 };
623
624 OpenAIMessage {
625 role,
626 content,
627 function_call,
628 name,
629 tool_calls,
630 tool_call_id,
631 }
632 }
633}
634
635impl From<&OpenAIResponse> for Message {
637 fn from(response: &OpenAIResponse) -> Self {
638 if response.choices.is_empty() {
640 return Message::assistant("No response generated");
641 }
642
643 let choice = &response.choices[0];
644 let message = &choice.message;
645
646 let mut msg = match message.role.as_str() {
648 "assistant" => {
649 let content = message
650 .content
651 .as_ref()
652 .map(|text| Content::Text(text.clone()));
653
654 if let Some(openai_tool_calls) = &message.tool_calls {
656 if !openai_tool_calls.is_empty() {
657 let mut tool_calls = Vec::with_capacity(openai_tool_calls.len());
658
659 for call in openai_tool_calls {
660 let tool_call = crate::message::ToolCall {
661 id: call.id.clone(),
662 tool_type: call.r#type.clone(),
663 function: crate::message::Function {
664 name: call.function.name.clone(),
665 arguments: call.function.arguments.clone(),
666 },
667 };
668 tool_calls.push(tool_call);
669 }
670
671 Message::Assistant {
672 content,
673 tool_calls,
674 metadata: Default::default(),
675 }
676 } else {
677 if let Some(Content::Text(text)) = content {
679 Message::assistant(text)
680 } else {
681 Message::Assistant {
682 content,
683 tool_calls: Vec::new(),
684 metadata: Default::default(),
685 }
686 }
687 }
688 } else if let Some(fc) = &message.function_call {
689 let tool_call = crate::message::ToolCall {
691 id: format!("legacy_function_{}", fc.name),
692 tool_type: "function".to_string(),
693 function: crate::message::Function {
694 name: fc.name.clone(),
695 arguments: fc.arguments.clone(),
696 },
697 };
698
699 Message::Assistant {
700 content,
701 tool_calls: vec![tool_call],
702 metadata: Default::default(),
703 }
704 } else {
705 if let Some(Content::Text(text)) = content {
707 Message::assistant(text)
708 } else {
709 Message::Assistant {
710 content,
711 tool_calls: Vec::new(),
712 metadata: Default::default(),
713 }
714 }
715 }
716 }
717 "user" => {
718 if let Some(name) = &message.name {
719 if let Some(content) = &message.content {
720 Message::user_with_name(name, content)
721 } else {
722 Message::user_with_name(name, "")
723 }
724 } else if let Some(content) = &message.content {
725 Message::user(content)
726 } else {
727 Message::user("")
728 }
729 }
730 "system" => {
731 if let Some(content) = &message.content {
732 Message::system(content)
733 } else {
734 Message::system("")
735 }
736 }
737 "tool" => {
738 if let Some(tool_call_id) = &message.tool_call_id {
739 if let Some(content) = &message.content {
740 Message::tool(tool_call_id, content)
741 } else {
742 Message::tool(tool_call_id, "")
743 }
744 } else {
745 if let Some(content) = &message.content {
747 Message::user(content)
748 } else {
749 Message::user("")
750 }
751 }
752 }
753 _ => {
754 if let Some(content) = &message.content {
756 Message::user(content)
757 } else {
758 Message::user("")
759 }
760 }
761 };
762
763 if let Some(usage) = &response.usage {
765 msg = msg.with_metadata(
766 "prompt_tokens",
767 serde_json::Value::Number(usage.prompt_tokens.into()),
768 );
769 msg = msg.with_metadata(
770 "completion_tokens",
771 serde_json::Value::Number(usage.completion_tokens.into()),
772 );
773 msg = msg.with_metadata(
774 "total_tokens",
775 serde_json::Value::Number(usage.total_tokens.into()),
776 );
777 }
778
779 msg
780 }
781}
782
783#[cfg(test)]
784mod tests {
785 use super::*;
786
787 #[test]
788 fn test_message_conversion() {
789 let msg = Message::user("Hello, world!");
791 let openai_msg = OpenAIMessage::from(&msg);
792
793 assert_eq!(openai_msg.role, "user");
794 assert_eq!(openai_msg.content, Some("Hello, world!".to_string()));
795
796 let msg = Message::system("You are a helpful assistant.");
798 let openai_msg = OpenAIMessage::from(&msg);
799
800 assert_eq!(openai_msg.role, "system");
801 assert_eq!(
802 openai_msg.content,
803 Some("You are a helpful assistant.".to_string())
804 );
805
806 let msg = Message::assistant("I can help with that.");
808 let openai_msg = OpenAIMessage::from(&msg);
809
810 assert_eq!(openai_msg.role, "assistant");
811 assert_eq!(
812 openai_msg.content,
813 Some("I can help with that.".to_string())
814 );
815
816 let tool_call = crate::message::ToolCall {
818 id: "tool_123".to_string(),
819 tool_type: "function".to_string(),
820 function: crate::message::Function {
821 name: "get_weather".to_string(),
822 arguments: "{\"location\":\"San Francisco\"}".to_string(),
823 },
824 };
825
826 let msg = Message::Assistant {
827 content: Some(Content::Text("I'll check the weather".to_string())),
828 tool_calls: vec![tool_call],
829 metadata: Default::default(),
830 };
831
832 let openai_msg = OpenAIMessage::from(&msg);
833
834 assert_eq!(openai_msg.role, "assistant");
835 assert_eq!(
836 openai_msg.content,
837 Some("I'll check the weather".to_string())
838 );
839 assert!(openai_msg.tool_calls.is_some());
840 let tool_calls = openai_msg.tool_calls.unwrap();
841 assert_eq!(tool_calls.len(), 1);
842 assert_eq!(tool_calls[0].id, "tool_123");
843 assert_eq!(tool_calls[0].function.name, "get_weather");
844 }
845
846 #[test]
847 fn test_error_response_parsing() {
848 let error_json = r#"{
849 "error": {
850 "message": "The model does not exist",
851 "type": "invalid_request_error",
852 "code": "model_not_found"
853 }
854 }"#;
855
856 let error_response: OpenAIErrorResponse = serde_json::from_str(error_json).unwrap();
857 assert!(error_response.error.is_some());
858 let error = error_response.error.unwrap();
859 assert_eq!(error.error_type, "invalid_request_error");
860 assert_eq!(error.code, Some("model_not_found".to_string()));
861 }
862}