1use crate::error::{Error, Result};
2use crate::message::{Content, ContentPart, Message};
3use crate::provider::HTTPProvider;
4use crate::{Chat, LlmToolInfo, ModelInfo};
5use reqwest::{Method, Request, Url};
6use serde::{Deserialize, Serialize};
7use std::env;
8use tracing::{debug, error, info, instrument, trace, warn};
9
10#[derive(Debug, Clone)]
12pub struct OpenAIConfig {
13 pub api_key: String,
15 pub base_url: String,
17 pub organization: Option<String>,
19}
20
21impl Default for OpenAIConfig {
22 fn default() -> Self {
23 Self {
24 api_key: env::var("OPENAI_API_KEY").unwrap_or_default(),
25 base_url: "https://api.openai.com/v1".to_string(),
26 organization: env::var("OPENAI_ORGANIZATION").ok(),
27 }
28 }
29}
30
31#[derive(Debug, Clone)]
33pub struct OpenAIProvider {
34 config: OpenAIConfig,
36}
37
38impl OpenAIProvider {
39 #[instrument(level = "debug")]
51 pub fn new() -> Self {
52 info!("Creating new OpenAIProvider with default configuration");
53 let config = OpenAIConfig::default();
54 debug!("API key set: {}", !config.api_key.is_empty());
55 debug!("Base URL: {}", config.base_url);
56 debug!("Organization set: {}", config.organization.is_some());
57
58 Self { config }
59 }
60
61 #[instrument(skip(config), level = "debug")]
77 pub fn with_config(config: OpenAIConfig) -> Self {
78 info!("Creating new OpenAIProvider with custom configuration");
79 debug!("API key set: {}", !config.api_key.is_empty());
80 debug!("Base URL: {}", config.base_url);
81 debug!("Organization set: {}", config.organization.is_some());
82
83 Self { config }
84 }
85}
86
87impl Default for OpenAIProvider {
88 fn default() -> Self {
89 Self::new()
90 }
91}
92
93impl<M: ModelInfo + OpenAIModelInfo> HTTPProvider<M> for OpenAIProvider {
94 fn accept(&self, chat: Chat<M>) -> Result<Request> {
95 info!("Creating request for OpenAI model: {:?}", chat.model);
96 debug!("Messages in chat history: {}", chat.history.len());
97
98 let url_str = format!("{}/chat/completions", self.config.base_url);
99 debug!("Parsing URL: {}", url_str);
100 let url = match Url::parse(&url_str) {
101 Ok(url) => {
102 debug!("URL parsed successfully: {}", url);
103 url
104 }
105 Err(e) => {
106 error!("Failed to parse URL '{}': {}", url_str, e);
107 return Err(e.into());
108 }
109 };
110
111 let mut request = Request::new(Method::POST, url);
112 debug!("Created request: {} {}", request.method(), request.url());
113
114 debug!("Setting request headers");
116
117 let auth_header = match format!("Bearer {}", self.config.api_key).parse() {
119 Ok(header) => header,
120 Err(e) => {
121 error!("Invalid API key format: {}", e);
122 return Err(Error::Authentication("Invalid API key format".into()));
123 }
124 };
125
126 let content_type_header = match "application/json".parse() {
127 Ok(header) => header,
128 Err(e) => {
129 error!("Failed to set content type: {}", e);
130 return Err(Error::Other("Failed to set content type".into()));
131 }
132 };
133
134 request.headers_mut().insert("Authorization", auth_header);
135 request
136 .headers_mut()
137 .insert("Content-Type", content_type_header);
138
139 if let Some(org) = &self.config.organization {
141 match org.parse() {
142 Ok(header) => {
143 request.headers_mut().insert("OpenAI-Organization", header);
144 debug!("Added organization header");
145 }
146 Err(e) => {
147 warn!("Failed to set organization header: {}", e);
148 }
150 }
151 }
152
153 trace!("Request headers set: {:#?}", request.headers());
154
155 debug!("Creating request payload");
157 let payload = match self.create_request_payload(&chat) {
158 Ok(payload) => {
159 debug!("Request payload created successfully");
160 trace!("Model: {}", payload.model);
161 trace!("Max tokens: {:?}", payload.max_tokens);
162 trace!("Number of messages: {}", payload.messages.len());
163 payload
164 }
165 Err(e) => {
166 error!("Failed to create request payload: {}", e);
167 return Err(e);
168 }
169 };
170
171 debug!("Serializing request payload");
173 let body_bytes = match serde_json::to_vec(&payload) {
174 Ok(bytes) => {
175 debug!("Payload serialized successfully ({} bytes)", bytes.len());
176 bytes
177 }
178 Err(e) => {
179 error!("Failed to serialize payload: {}", e);
180 return Err(Error::Serialization(e));
181 }
182 };
183
184 *request.body_mut() = Some(body_bytes.into());
185 info!("Request created successfully");
186
187 Ok(request)
188 }
189
190 fn parse(&self, raw_response_text: String) -> Result<Message> {
191 info!("Parsing response from OpenAI API");
192 trace!("Raw response: {}", raw_response_text);
193
194 if let Ok(error_response) = serde_json::from_str::<OpenAIErrorResponse>(&raw_response_text)
196 {
197 if let Some(error) = error_response.error {
198 error!("OpenAI API returned an error: {}", error.message);
199 return Err(Error::ProviderUnavailable(error.message));
200 }
201 }
202
203 debug!("Deserializing response JSON");
205 let openai_response = match serde_json::from_str::<OpenAIResponse>(&raw_response_text) {
206 Ok(response) => {
207 debug!("Response deserialized successfully");
208 debug!("Response model: {}", response.model);
209 if !response.choices.is_empty() {
210 debug!("Number of choices: {}", response.choices.len());
211 debug!(
212 "First choice finish reason: {:?}",
213 response.choices[0].finish_reason
214 );
215 }
216 if let Some(usage) = &response.usage {
217 debug!(
218 "Token usage - prompt: {}, completion: {}, total: {}",
219 usage.prompt_tokens, usage.completion_tokens, usage.total_tokens
220 );
221 }
222 response
223 }
224 Err(e) => {
225 error!("Failed to deserialize response: {}", e);
226 error!("Raw response: {}", raw_response_text);
227 return Err(Error::Serialization(e));
228 }
229 };
230
231 debug!("Converting OpenAI response to Message");
233 let message = Message::from(&openai_response);
234
235 info!("Response parsed successfully");
236 trace!("Response message processed");
237
238 Ok(message)
239 }
240}
241
242pub trait OpenAIModelInfo {
244 fn openai_model_id(&self) -> String;
245}
246
247impl OpenAIProvider {
248 #[instrument(skip(self, chat), level = "debug")]
253 fn create_request_payload<M: ModelInfo + OpenAIModelInfo>(
254 &self,
255 chat: &Chat<M>,
256 ) -> Result<OpenAIRequest> {
257 info!("Creating request payload for chat with OpenAI model");
258 debug!("System prompt length: {}", chat.system_prompt.len());
259 debug!("Messages in history: {}", chat.history.len());
260 debug!("Max output tokens: {}", chat.max_output_tokens);
261
262 let model_id = chat.model.openai_model_id();
263 debug!("Using model ID: {}", model_id);
264
265 debug!("Converting messages to OpenAI format");
267 let mut messages: Vec<OpenAIMessage> = Vec::new();
268
269 if !chat.system_prompt.is_empty() {
271 debug!("Adding system prompt");
272 messages.push(OpenAIMessage {
273 role: "system".to_string(),
274 content: Some(chat.system_prompt.clone()),
275 function_call: None,
276 name: None,
277 tool_calls: None,
278 tool_call_id: None,
279 });
280 }
281
282 for msg in &chat.history {
284 debug!("Converting message with role: {}", msg.role_str());
285 messages.push(OpenAIMessage::from(msg));
286 }
287
288 debug!("Converted {} messages for the request", messages.len());
289
290 let tools = chat
292 .tools
293 .as_ref()
294 .map(|tools| tools.iter().map(OpenAITool::from).collect());
295
296 let tool_choice = if tools.is_some() {
298 Some("auto".to_string())
299 } else {
300 None
301 };
302
303 debug!("Creating OpenAIRequest");
305 let request = OpenAIRequest {
306 model: model_id,
307 messages,
308 temperature: None,
309 top_p: None,
310 n: None,
311 max_tokens: Some(chat.max_output_tokens),
312 presence_penalty: None,
313 frequency_penalty: None,
314 stream: None,
315 tools,
316 tool_choice,
317 };
318
319 info!("Request payload created successfully");
320 Ok(request)
321 }
322}
323
324#[derive(Debug, Clone, Serialize, Deserialize)]
326pub(crate) struct OpenAIMessage {
327 pub role: String,
329 #[serde(skip_serializing_if = "Option::is_none")]
331 pub content: Option<String>,
332 #[serde(skip_serializing_if = "Option::is_none")]
334 pub function_call: Option<OpenAIFunctionCall>,
335 #[serde(skip_serializing_if = "Option::is_none")]
337 pub name: Option<String>,
338 #[serde(skip_serializing_if = "Option::is_none")]
340 pub tool_calls: Option<Vec<OpenAIToolCall>>,
341 #[serde(skip_serializing_if = "Option::is_none")]
343 pub tool_call_id: Option<String>,
344}
345
346#[derive(Debug, Serialize, Deserialize)]
348pub(crate) struct OpenAIFunction {
349 pub name: String,
351 pub description: String,
353 pub parameters: serde_json::Value,
355}
356
357#[derive(Debug, Serialize, Deserialize)]
359pub(crate) struct OpenAITool {
360 pub r#type: String,
362 pub function: OpenAIFunction,
364}
365
366impl From<&LlmToolInfo> for OpenAITool {
367 fn from(value: &LlmToolInfo) -> Self {
368 OpenAITool {
369 r#type: "function".to_string(),
370 function: OpenAIFunction {
371 name: value.name.clone(),
372 description: value.description.clone(),
373 parameters: value.parameters.clone(),
374 },
375 }
376 }
377}
378
379#[derive(Debug, Clone, Serialize, Deserialize)]
381pub(crate) struct OpenAIFunctionCall {
382 pub name: String,
384 pub arguments: String,
386}
387
388#[derive(Debug, Clone, Serialize, Deserialize)]
390pub(crate) struct OpenAIToolCall {
391 pub id: String,
393 pub r#type: String,
395 pub function: OpenAIFunctionCall,
397}
398
399#[derive(Debug, Serialize, Deserialize)]
401pub(crate) struct OpenAIRequest {
402 pub model: String,
404 pub messages: Vec<OpenAIMessage>,
406 #[serde(skip_serializing_if = "Option::is_none")]
408 pub temperature: Option<f32>,
409 #[serde(skip_serializing_if = "Option::is_none")]
411 pub top_p: Option<f32>,
412 #[serde(skip_serializing_if = "Option::is_none")]
414 pub n: Option<usize>,
415 #[serde(skip_serializing_if = "Option::is_none")]
417 pub max_tokens: Option<usize>,
418 #[serde(skip_serializing_if = "Option::is_none")]
420 pub presence_penalty: Option<f32>,
421 #[serde(skip_serializing_if = "Option::is_none")]
423 pub frequency_penalty: Option<f32>,
424 #[serde(skip_serializing_if = "Option::is_none")]
426 pub stream: Option<bool>,
427 #[serde(skip_serializing_if = "Option::is_none")]
429 pub tools: Option<Vec<OpenAITool>>,
430 #[serde(skip_serializing_if = "Option::is_none")]
432 pub tool_choice: Option<String>,
433}
434
435#[derive(Debug, Serialize, Deserialize)]
437pub(crate) struct OpenAIResponse {
438 pub id: String,
440 pub object: String,
442 pub created: u64,
444 pub model: String,
446 pub choices: Vec<OpenAIChoice>,
448 pub usage: Option<OpenAIUsage>,
450}
451
452#[derive(Debug, Serialize, Deserialize)]
454pub(crate) struct OpenAIChoice {
455 pub index: usize,
457 pub message: OpenAIMessage,
459 pub finish_reason: Option<String>,
461}
462
463#[derive(Debug, Serialize, Deserialize)]
465pub(crate) struct OpenAIUsage {
466 pub prompt_tokens: u32,
468 pub completion_tokens: u32,
470 pub total_tokens: u32,
472}
473
474#[derive(Debug, Serialize, Deserialize)]
476pub(crate) struct OpenAIErrorResponse {
477 pub error: Option<OpenAIError>,
479}
480
481#[derive(Debug, Serialize, Deserialize)]
483pub(crate) struct OpenAIError {
484 pub message: String,
486 #[serde(rename = "type")]
488 pub error_type: String,
489 #[serde(skip_serializing_if = "Option::is_none")]
491 pub code: Option<String>,
492}
493
494impl From<&Message> for OpenAIMessage {
496 fn from(msg: &Message) -> Self {
497 let role = match msg {
498 Message::System { .. } => "system",
499 Message::User { .. } => "user",
500 Message::Assistant { .. } => "assistant",
501 Message::Tool { .. } => "tool",
502 }
503 .to_string();
504
505 let (content, name, function_call, tool_calls, tool_call_id) = match msg {
506 Message::System { content, .. } => (Some(content.clone()), None, None, None, None),
507 Message::User { content, name, .. } => {
508 let content_str = match content {
509 Content::Text(text) => Some(text.clone()),
510 Content::Parts(parts) => {
511 let combined_text = parts
513 .iter()
514 .filter_map(|part| match part {
515 ContentPart::Text { text } => Some(text.clone()),
516 _ => None,
517 })
518 .collect::<Vec<String>>()
519 .join("\n");
520
521 if combined_text.is_empty() {
522 None
523 } else {
524 Some(combined_text)
525 }
526 }
527 };
528 (content_str, name.clone(), None, None, None)
529 }
530 Message::Assistant {
531 content,
532 tool_calls,
533 ..
534 } => {
535 let content_str = match content {
536 Some(Content::Text(text)) => Some(text.clone()),
537 Some(Content::Parts(parts)) => {
538 let combined_text = parts
540 .iter()
541 .filter_map(|part| match part {
542 ContentPart::Text { text } => Some(text.clone()),
543 _ => None,
544 })
545 .collect::<Vec<String>>()
546 .join("\n");
547
548 if combined_text.is_empty() {
549 None
550 } else {
551 Some(combined_text)
552 }
553 }
554 None => None,
555 };
556
557 let openai_tool_calls = if !tool_calls.is_empty() {
559 let mut calls = Vec::with_capacity(tool_calls.len());
560
561 for tc in tool_calls {
562 calls.push(OpenAIToolCall {
563 id: tc.id.clone(),
564 r#type: tc.tool_type.clone(),
565 function: OpenAIFunctionCall {
566 name: tc.function.name.clone(),
567 arguments: tc.function.arguments.clone(),
568 },
569 });
570 }
571
572 Some(calls)
573 } else {
574 None
575 };
576
577 (content_str, None, None, openai_tool_calls, None)
578 }
579 Message::Tool {
580 tool_call_id,
581 content,
582 ..
583 } => (
584 Some(content.clone()),
585 None,
586 None,
587 None,
588 Some(tool_call_id.clone()),
589 ),
590 };
591
592 OpenAIMessage {
593 role,
594 content,
595 function_call,
596 name,
597 tool_calls,
598 tool_call_id,
599 }
600 }
601}
602
603impl From<&OpenAIResponse> for Message {
605 fn from(response: &OpenAIResponse) -> Self {
606 if response.choices.is_empty() {
608 return Message::assistant("No response generated");
609 }
610
611 let choice = &response.choices[0];
612 let message = &choice.message;
613
614 let mut msg = match message.role.as_str() {
616 "assistant" => {
617 let content = message
618 .content
619 .as_ref()
620 .map(|text| Content::Text(text.clone()));
621
622 if let Some(openai_tool_calls) = &message.tool_calls {
624 if !openai_tool_calls.is_empty() {
625 let mut tool_calls = Vec::with_capacity(openai_tool_calls.len());
626
627 for call in openai_tool_calls {
628 let tool_call = crate::message::ToolCall {
629 id: call.id.clone(),
630 tool_type: call.r#type.clone(),
631 function: crate::message::Function {
632 name: call.function.name.clone(),
633 arguments: call.function.arguments.clone(),
634 },
635 };
636 tool_calls.push(tool_call);
637 }
638
639 Message::Assistant {
640 content,
641 tool_calls,
642 metadata: Default::default(),
643 }
644 } else {
645 if let Some(Content::Text(text)) = content {
647 Message::assistant(text)
648 } else {
649 Message::Assistant {
650 content,
651 tool_calls: Vec::new(),
652 metadata: Default::default(),
653 }
654 }
655 }
656 } else if let Some(fc) = &message.function_call {
657 let tool_call = crate::message::ToolCall {
659 id: format!("legacy_function_{}", fc.name),
660 tool_type: "function".to_string(),
661 function: crate::message::Function {
662 name: fc.name.clone(),
663 arguments: fc.arguments.clone(),
664 },
665 };
666
667 Message::Assistant {
668 content,
669 tool_calls: vec![tool_call],
670 metadata: Default::default(),
671 }
672 } else {
673 if let Some(Content::Text(text)) = content {
675 Message::assistant(text)
676 } else {
677 Message::Assistant {
678 content,
679 tool_calls: Vec::new(),
680 metadata: Default::default(),
681 }
682 }
683 }
684 }
685 "user" => {
686 if let Some(name) = &message.name {
687 if let Some(content) = &message.content {
688 Message::user_with_name(name, content)
689 } else {
690 Message::user_with_name(name, "")
691 }
692 } else if let Some(content) = &message.content {
693 Message::user(content)
694 } else {
695 Message::user("")
696 }
697 }
698 "system" => {
699 if let Some(content) = &message.content {
700 Message::system(content)
701 } else {
702 Message::system("")
703 }
704 }
705 "tool" => {
706 if let Some(tool_call_id) = &message.tool_call_id {
707 if let Some(content) = &message.content {
708 Message::tool(tool_call_id, content)
709 } else {
710 Message::tool(tool_call_id, "")
711 }
712 } else {
713 if let Some(content) = &message.content {
715 Message::user(content)
716 } else {
717 Message::user("")
718 }
719 }
720 }
721 _ => {
722 if let Some(content) = &message.content {
724 Message::user(content)
725 } else {
726 Message::user("")
727 }
728 }
729 };
730
731 if let Some(usage) = &response.usage {
733 msg = msg.with_metadata(
734 "prompt_tokens",
735 serde_json::Value::Number(usage.prompt_tokens.into()),
736 );
737 msg = msg.with_metadata(
738 "completion_tokens",
739 serde_json::Value::Number(usage.completion_tokens.into()),
740 );
741 msg = msg.with_metadata(
742 "total_tokens",
743 serde_json::Value::Number(usage.total_tokens.into()),
744 );
745 }
746
747 msg
748 }
749}
750
751#[cfg(test)]
752mod tests {
753 use super::*;
754
755 #[test]
756 fn test_message_conversion() {
757 let msg = Message::user("Hello, world!");
759 let openai_msg = OpenAIMessage::from(&msg);
760
761 assert_eq!(openai_msg.role, "user");
762 assert_eq!(openai_msg.content, Some("Hello, world!".to_string()));
763
764 let msg = Message::system("You are a helpful assistant.");
766 let openai_msg = OpenAIMessage::from(&msg);
767
768 assert_eq!(openai_msg.role, "system");
769 assert_eq!(
770 openai_msg.content,
771 Some("You are a helpful assistant.".to_string())
772 );
773
774 let msg = Message::assistant("I can help with that.");
776 let openai_msg = OpenAIMessage::from(&msg);
777
778 assert_eq!(openai_msg.role, "assistant");
779 assert_eq!(
780 openai_msg.content,
781 Some("I can help with that.".to_string())
782 );
783
784 let tool_call = crate::message::ToolCall {
786 id: "tool_123".to_string(),
787 tool_type: "function".to_string(),
788 function: crate::message::Function {
789 name: "get_weather".to_string(),
790 arguments: "{\"location\":\"San Francisco\"}".to_string(),
791 },
792 };
793
794 let msg = Message::Assistant {
795 content: Some(Content::Text("I'll check the weather".to_string())),
796 tool_calls: vec![tool_call],
797 metadata: Default::default(),
798 };
799
800 let openai_msg = OpenAIMessage::from(&msg);
801
802 assert_eq!(openai_msg.role, "assistant");
803 assert_eq!(
804 openai_msg.content,
805 Some("I'll check the weather".to_string())
806 );
807 assert!(openai_msg.tool_calls.is_some());
808 let tool_calls = openai_msg.tool_calls.unwrap();
809 assert_eq!(tool_calls.len(), 1);
810 assert_eq!(tool_calls[0].id, "tool_123");
811 assert_eq!(tool_calls[0].function.name, "get_weather");
812 }
813
814 #[test]
815 fn test_error_response_parsing() {
816 let error_json = r#"{
817 "error": {
818 "message": "The model does not exist",
819 "type": "invalid_request_error",
820 "code": "model_not_found"
821 }
822 }"#;
823
824 let error_response: OpenAIErrorResponse = serde_json::from_str(error_json).unwrap();
825 assert!(error_response.error.is_some());
826 let error = error_response.error.unwrap();
827 assert_eq!(error.error_type, "invalid_request_error");
828 assert_eq!(error.code, Some("model_not_found".to_string()));
829 }
830}