1use std::time::Duration;
6
7#[cfg(feature = "openai")]
8use crate::{
9 chat::Tool,
10 chat::{ChatMessage, ChatProvider, ChatRole, MessageType, StructuredOutputFormat},
11 completion::{CompletionProvider, CompletionRequest, CompletionResponse},
12 embedding::EmbeddingProvider,
13 error::LLMError,
14 stt::SpeechToTextProvider,
15 tts::TextToSpeechProvider,
16 LLMProvider,
17};
18use crate::{
19 chat::{ChatResponse, ToolChoice},
20 FunctionCall, ToolCall,
21};
22use async_trait::async_trait;
23use either::*;
24use futures::stream::Stream;
25use reqwest::{Client, Url};
26use serde::{Deserialize, Serialize};
27
28pub struct OpenAI {
32 pub api_key: String,
33 pub base_url: Url,
34 pub model: String,
35 pub max_tokens: Option<u32>,
36 pub temperature: Option<f32>,
37 pub system: Option<String>,
38 pub timeout_seconds: Option<u64>,
39 pub stream: Option<bool>,
40 pub top_p: Option<f32>,
41 pub top_k: Option<u32>,
42 pub tools: Option<Vec<Tool>>,
43 pub tool_choice: Option<ToolChoice>,
44 pub embedding_encoding_format: Option<String>,
46 pub embedding_dimensions: Option<u32>,
47 pub reasoning_effort: Option<String>,
48 pub json_schema: Option<StructuredOutputFormat>,
50 pub voice: Option<String>,
51 client: Client,
52}
53
54#[derive(Serialize, Debug)]
56struct OpenAIChatMessage<'a> {
57 #[allow(dead_code)]
58 role: &'a str,
59 #[serde(
60 skip_serializing_if = "Option::is_none",
61 with = "either::serde_untagged_optional"
62 )]
63 content: Option<Either<Vec<MessageContent<'a>>, String>>,
64 #[serde(skip_serializing_if = "Option::is_none")]
65 tool_calls: Option<Vec<OpenAIFunctionCall<'a>>>,
66 #[serde(skip_serializing_if = "Option::is_none")]
67 tool_call_id: Option<String>,
68}
69
70#[derive(Serialize, Debug)]
71struct OpenAIFunctionPayload<'a> {
72 name: &'a str,
73 arguments: &'a str,
74}
75
76#[derive(Serialize, Debug)]
77struct OpenAIFunctionCall<'a> {
78 id: &'a str,
79 #[serde(rename = "type")]
80 content_type: &'a str,
81 function: OpenAIFunctionPayload<'a>,
82}
83
84#[derive(Serialize, Debug)]
85struct MessageContent<'a> {
86 #[serde(rename = "type", skip_serializing_if = "Option::is_none")]
87 message_type: Option<&'a str>,
88 #[serde(skip_serializing_if = "Option::is_none")]
89 text: Option<&'a str>,
90 #[serde(skip_serializing_if = "Option::is_none")]
91 image_url: Option<ImageUrlContent<'a>>,
92 #[serde(skip_serializing_if = "Option::is_none", rename = "tool_call_id")]
93 tool_call_id: Option<&'a str>,
94 #[serde(skip_serializing_if = "Option::is_none", rename = "content")]
95 tool_output: Option<&'a str>,
96}
97
98#[derive(Serialize, Debug)]
100struct ImageUrlContent<'a> {
101 url: &'a str,
102}
103
104#[derive(Serialize)]
105struct OpenAIEmbeddingRequest {
106 model: String,
107 input: Vec<String>,
108 #[serde(skip_serializing_if = "Option::is_none")]
109 encoding_format: Option<String>,
110 #[serde(skip_serializing_if = "Option::is_none")]
111 dimensions: Option<u32>,
112}
113
114#[derive(Serialize, Debug)]
116struct OpenAIChatRequest<'a> {
117 model: &'a str,
118 messages: Vec<OpenAIChatMessage<'a>>,
119 #[serde(skip_serializing_if = "Option::is_none")]
120 max_tokens: Option<u32>,
121 #[serde(skip_serializing_if = "Option::is_none")]
122 temperature: Option<f32>,
123 stream: bool,
124 #[serde(skip_serializing_if = "Option::is_none")]
125 top_p: Option<f32>,
126 #[serde(skip_serializing_if = "Option::is_none")]
127 top_k: Option<u32>,
128 #[serde(skip_serializing_if = "Option::is_none")]
129 tools: Option<Vec<Tool>>,
130 #[serde(skip_serializing_if = "Option::is_none")]
131 tool_choice: Option<ToolChoice>,
132 #[serde(skip_serializing_if = "Option::is_none")]
133 reasoning_effort: Option<String>,
134 #[serde(skip_serializing_if = "Option::is_none")]
135 response_format: Option<OpenAIResponseFormat>,
136}
137
138impl std::fmt::Display for ToolCall {
139 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
140 write!(
141 f,
142 "{{\n \"id\": \"{}\",\n \"type\": \"{}\",\n \"function\": {}\n}}",
143 self.id, self.call_type, self.function
144 )
145 }
146}
147
148impl std::fmt::Display for FunctionCall {
149 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
150 write!(
151 f,
152 "{{\n \"name\": \"{}\",\n \"arguments\": {}\n}}",
153 self.name, self.arguments
154 )
155 }
156}
157
158#[derive(Deserialize, Debug)]
160struct OpenAIChatResponse {
161 choices: Vec<OpenAIChatChoice>,
162}
163
164#[derive(Deserialize, Debug)]
166struct OpenAIChatChoice {
167 message: OpenAIChatMsg,
168}
169
170#[derive(Deserialize, Debug)]
172struct OpenAIChatMsg {
173 #[allow(dead_code)]
174 role: String,
175 content: Option<String>,
176 tool_calls: Option<Vec<ToolCall>>,
177}
178
179#[derive(Deserialize, Debug)]
180struct OpenAIEmbeddingData {
181 embedding: Vec<f32>,
182}
183#[derive(Deserialize, Debug)]
184struct OpenAIEmbeddingResponse {
185 data: Vec<OpenAIEmbeddingData>,
186}
187
188#[derive(Deserialize, Debug)]
190struct OpenAIChatStreamResponse {
191 choices: Vec<OpenAIChatStreamChoice>,
192}
193
194#[derive(Deserialize, Debug)]
196struct OpenAIChatStreamChoice {
197 delta: OpenAIChatStreamDelta,
198}
199
200#[derive(Deserialize, Debug)]
202struct OpenAIChatStreamDelta {
203 content: Option<String>,
204}
205
206#[derive(Deserialize, Debug, Serialize)]
210enum OpenAIResponseType {
211 #[serde(rename = "text")]
212 Text,
213 #[serde(rename = "json_schema")]
214 JsonSchema,
215 #[serde(rename = "json_object")]
216 JsonObject,
217}
218
219#[derive(Deserialize, Debug, Serialize)]
220struct OpenAIResponseFormat {
221 #[serde(rename = "type")]
222 response_type: OpenAIResponseType,
223 #[serde(skip_serializing_if = "Option::is_none")]
224 json_schema: Option<StructuredOutputFormat>,
225}
226
227impl From<StructuredOutputFormat> for OpenAIResponseFormat {
228 fn from(structured_response_format: StructuredOutputFormat) -> Self {
230 match structured_response_format.schema {
233 None => OpenAIResponseFormat {
234 response_type: OpenAIResponseType::JsonSchema,
235 json_schema: Some(structured_response_format),
236 },
237 Some(mut schema) => {
238 schema = if schema.get("additionalProperties").is_none() {
241 schema["additionalProperties"] = serde_json::json!(false);
242 schema
243 } else {
244 schema
245 };
246
247 OpenAIResponseFormat {
248 response_type: OpenAIResponseType::JsonSchema,
249 json_schema: Some(StructuredOutputFormat {
250 name: structured_response_format.name,
251 description: structured_response_format.description,
252 schema: Some(schema),
253 strict: structured_response_format.strict,
254 }),
255 }
256 }
257 }
258 }
259}
260
261impl ChatResponse for OpenAIChatResponse {
262 fn text(&self) -> Option<String> {
263 self.choices.first().and_then(|c| c.message.content.clone())
264 }
265
266 fn tool_calls(&self) -> Option<Vec<ToolCall>> {
267 self.choices
268 .first()
269 .and_then(|c| c.message.tool_calls.clone())
270 }
271}
272
273impl std::fmt::Display for OpenAIChatResponse {
274 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
275 match (
276 &self.choices.first().unwrap().message.content,
277 &self.choices.first().unwrap().message.tool_calls,
278 ) {
279 (Some(content), Some(tool_calls)) => {
280 for tool_call in tool_calls {
281 write!(f, "{}", tool_call)?;
282 }
283 write!(f, "{}", content)
284 }
285 (Some(content), None) => write!(f, "{}", content),
286 (None, Some(tool_calls)) => {
287 for tool_call in tool_calls {
288 write!(f, "{}", tool_call)?;
289 }
290 Ok(())
291 }
292 (None, None) => write!(f, ""),
293 }
294 }
295}
296
297impl OpenAI {
298 #[allow(clippy::too_many_arguments)]
318 pub fn new(
319 api_key: impl Into<String>,
320 base_url: Option<String>,
321 model: Option<String>,
322 max_tokens: Option<u32>,
323 temperature: Option<f32>,
324 timeout_seconds: Option<u64>,
325 system: Option<String>,
326 stream: Option<bool>,
327 top_p: Option<f32>,
328 top_k: Option<u32>,
329 embedding_encoding_format: Option<String>,
330 embedding_dimensions: Option<u32>,
331 tools: Option<Vec<Tool>>,
332 tool_choice: Option<ToolChoice>,
333 reasoning_effort: Option<String>,
334 json_schema: Option<StructuredOutputFormat>,
335 voice: Option<String>,
336 ) -> Self {
337 let mut builder = Client::builder();
338 if let Some(sec) = timeout_seconds {
339 builder = builder.timeout(std::time::Duration::from_secs(sec));
340 }
341 Self {
342 api_key: api_key.into(),
343 base_url: Url::parse(
344 &base_url.unwrap_or_else(|| "https://api.openai.com/v1/".to_owned()),
345 )
346 .expect("Failed to prase base Url"),
347 model: model.unwrap_or("gpt-3.5-turbo".to_string()),
348 max_tokens,
349 temperature,
350 system,
351 timeout_seconds,
352 stream,
353 top_p,
354 top_k,
355 tools,
356 tool_choice,
357 embedding_encoding_format,
358 embedding_dimensions,
359 client: builder.build().expect("Failed to build reqwest Client"),
360 reasoning_effort,
361 json_schema,
362 voice,
363 }
364 }
365}
366
367#[async_trait]
368impl ChatProvider for OpenAI {
369 async fn chat_with_tools(
379 &self,
380 messages: &[ChatMessage],
381 tools: Option<&[Tool]>,
382 ) -> Result<Box<dyn ChatResponse>, LLMError> {
383 if self.api_key.is_empty() {
384 return Err(LLMError::AuthError("Missing OpenAI API key".to_string()));
385 }
386
387 let messages = messages.to_vec();
389
390 let mut openai_msgs: Vec<OpenAIChatMessage> = vec![];
391
392 for msg in messages {
393 if let MessageType::ToolResult(ref results) = msg.message_type {
394 for result in results {
395 openai_msgs.push(
396 OpenAIChatMessage {
398 role: "tool",
399 tool_call_id: Some(result.id.clone()),
400 tool_calls: None,
401 content: Some(Right(result.function.arguments.clone())),
402 },
403 );
404 }
405 } else {
406 openai_msgs.push(chat_message_to_api_message(msg))
407 }
408 }
409
410 if let Some(system) = &self.system {
411 openai_msgs.insert(
412 0,
413 OpenAIChatMessage {
414 role: "system",
415 content: Some(Left(vec![MessageContent {
416 message_type: Some("text"),
417 text: Some(system),
418 image_url: None,
419 tool_call_id: None,
420 tool_output: None,
421 }])),
422 tool_calls: None,
423 tool_call_id: None,
424 },
425 );
426 }
427
428 let response_format: Option<OpenAIResponseFormat> =
429 self.json_schema.clone().map(|s| s.into());
430
431 let request_tools = tools.map(|t| t.to_vec()).or_else(|| self.tools.clone());
432
433 let request_tool_choice = if request_tools.is_some() {
434 self.tool_choice.clone()
435 } else {
436 None
437 };
438
439 let body = OpenAIChatRequest {
440 model: &self.model,
441 messages: openai_msgs,
442 max_tokens: self.max_tokens,
443 temperature: self.temperature,
444 stream: self.stream.unwrap_or(false),
445 top_p: self.top_p,
446 top_k: self.top_k,
447 tools: request_tools,
448 tool_choice: request_tool_choice,
449 reasoning_effort: self.reasoning_effort.clone(),
450 response_format,
451 };
452
453 let url = self
454 .base_url
455 .join("chat/completions")
456 .map_err(|e| LLMError::HttpError(e.to_string()))?;
457
458 let mut request = self.client.post(url).bearer_auth(&self.api_key).json(&body);
459
460 if log::log_enabled!(log::Level::Trace) {
461 if let Ok(json) = serde_json::to_string(&body) {
462 log::trace!("OpenAI request payload: {}", json);
463 }
464 }
465
466 if let Some(timeout) = self.timeout_seconds {
467 request = request.timeout(std::time::Duration::from_secs(timeout));
468 }
469
470 let response = request.send().await?;
471
472 log::debug!("OpenAI HTTP status: {}", response.status());
473
474 if !response.status().is_success() {
475 let status = response.status();
476 let error_text = response.text().await?;
477 return Err(LLMError::ResponseFormatError {
478 message: format!("OpenAI API returned error status: {}", status),
479 raw_response: error_text,
480 });
481 }
482
483 let resp_text = response.text().await?;
485 let json_resp: Result<OpenAIChatResponse, serde_json::Error> =
486 serde_json::from_str(&resp_text);
487
488 match json_resp {
489 Ok(response) => Ok(Box::new(response)),
490 Err(e) => Err(LLMError::ResponseFormatError {
491 message: format!("Failed to decode OpenAI API response: {}", e),
492 raw_response: resp_text,
493 }),
494 }
495 }
496
497 async fn chat(&self, messages: &[ChatMessage]) -> Result<Box<dyn ChatResponse>, LLMError> {
498 self.chat_with_tools(messages, None).await
499 }
500
501 async fn chat_stream(
511 &self,
512 messages: &[ChatMessage],
513 ) -> Result<std::pin::Pin<Box<dyn Stream<Item = Result<String, LLMError>> + Send>>, LLMError> {
514 if self.api_key.is_empty() {
515 return Err(LLMError::AuthError("Missing OpenAI API key".to_string()));
516 }
517
518 let messages = messages.to_vec();
519 let mut openai_msgs: Vec<OpenAIChatMessage> = vec![];
520
521 for msg in messages {
522 if let MessageType::ToolResult(ref results) = msg.message_type {
523 for result in results {
524 openai_msgs.push(OpenAIChatMessage {
525 role: "tool",
526 tool_call_id: Some(result.id.clone()),
527 tool_calls: None,
528 content: Some(Right(result.function.arguments.clone())),
529 });
530 }
531 } else {
532 openai_msgs.push(chat_message_to_api_message(msg))
533 }
534 }
535
536 if let Some(system) = &self.system {
537 openai_msgs.insert(
538 0,
539 OpenAIChatMessage {
540 role: "system",
541 content: Some(Left(vec![MessageContent {
542 message_type: Some("text"),
543 text: Some(system),
544 image_url: None,
545 tool_call_id: None,
546 tool_output: None,
547 }])),
548 tool_calls: None,
549 tool_call_id: None,
550 },
551 );
552 }
553
554 let body = OpenAIChatRequest {
555 model: &self.model,
556 messages: openai_msgs,
557 max_tokens: self.max_tokens,
558 temperature: self.temperature,
559 stream: true,
560 top_p: self.top_p,
561 top_k: self.top_k,
562 tools: self.tools.clone(),
563 tool_choice: self.tool_choice.clone(),
564 reasoning_effort: self.reasoning_effort.clone(),
565 response_format: None,
566 };
567
568 let url = self
569 .base_url
570 .join("chat/completions")
571 .map_err(|e| LLMError::HttpError(e.to_string()))?;
572
573 let mut request = self.client.post(url).bearer_auth(&self.api_key).json(&body);
574
575 if let Some(timeout) = self.timeout_seconds {
576 request = request.timeout(std::time::Duration::from_secs(timeout));
577 }
578
579 let response = request.send().await?;
580
581 if !response.status().is_success() {
582 let status = response.status();
583 let error_text = response.text().await?;
584 return Err(LLMError::ResponseFormatError {
585 message: format!("OpenAI API returned error status: {}", status),
586 raw_response: error_text,
587 });
588 }
589
590 Ok(crate::chat::create_sse_stream(response, parse_sse_chunk))
591 }
592}
593
594fn chat_message_to_api_message(chat_msg: ChatMessage) -> OpenAIChatMessage<'static> {
596 OpenAIChatMessage {
598 role: match chat_msg.role {
599 ChatRole::User => "user",
600 ChatRole::Assistant => "assistant",
601 },
602 tool_call_id: None,
603 content: match &chat_msg.message_type {
604 MessageType::Text => Some(Right(chat_msg.content.clone())),
605 MessageType::Image(_) => unreachable!(),
607 MessageType::Pdf(_) => unimplemented!(),
608 MessageType::ImageURL(url) => {
609 let owned_url = url.clone();
611 let url_str = Box::leak(owned_url.into_boxed_str());
613 Some(Left(vec![MessageContent {
614 message_type: Some("image_url"),
615 text: None,
616 image_url: Some(ImageUrlContent { url: url_str }),
617 tool_output: None,
618 tool_call_id: None,
619 }]))
620 }
621 MessageType::ToolUse(_) => None,
622 MessageType::ToolResult(_) => None,
623 },
624 tool_calls: match &chat_msg.message_type {
625 MessageType::ToolUse(calls) => {
626 let owned_calls: Vec<OpenAIFunctionCall<'static>> = calls
627 .iter()
628 .map(|c| {
629 let owned_id = c.id.clone();
630 let owned_name = c.function.name.clone();
631 let owned_args = c.function.arguments.clone();
632
633 let id_str = Box::leak(owned_id.into_boxed_str());
637 let name_str = Box::leak(owned_name.into_boxed_str());
638 let args_str = Box::leak(owned_args.into_boxed_str());
639
640 OpenAIFunctionCall {
641 id: id_str,
642 content_type: "function",
643 function: OpenAIFunctionPayload {
644 name: name_str,
645 arguments: args_str,
646 },
647 }
648 })
649 .collect();
650 Some(owned_calls)
651 }
652 _ => None,
653 },
654 }
655}
656
657#[async_trait]
658impl CompletionProvider for OpenAI {
659 async fn complete(&self, _req: &CompletionRequest) -> Result<CompletionResponse, LLMError> {
663 Ok(CompletionResponse {
664 text: "OpenAI completion not implemented.".into(),
665 })
666 }
667}
668
669#[async_trait]
670impl SpeechToTextProvider for OpenAI {
671 async fn transcribe(&self, audio: Vec<u8>) -> Result<String, LLMError> {
682 let url = self
683 .base_url
684 .join("audio/transcriptions")
685 .map_err(|e| LLMError::HttpError(e.to_string()))?;
686
687 let part = reqwest::multipart::Part::bytes(audio).file_name("audio.m4a");
688 let form = reqwest::multipart::Form::new()
689 .text("model", self.model.clone())
690 .text("response_format", "text")
691 .part("file", part);
692
693 let mut req = self
694 .client
695 .post(url)
696 .bearer_auth(&self.api_key)
697 .multipart(form);
698
699 if let Some(t) = self.timeout_seconds {
700 req = req.timeout(Duration::from_secs(t));
701 }
702
703 let resp = req.send().await?;
704 let text = resp.text().await?;
705 let raw = text.clone();
706 Ok(raw)
707 }
708
709 async fn transcribe_file(&self, file_path: &str) -> Result<String, LLMError> {
720 let url = self
721 .base_url
722 .join("audio/transcriptions")
723 .map_err(|e| LLMError::HttpError(e.to_string()))?;
724
725 let form = reqwest::multipart::Form::new()
726 .text("model", self.model.clone())
727 .text("response_format", "text")
728 .file("file", file_path)
729 .await
730 .map_err(|e| LLMError::HttpError(e.to_string()))?;
731
732 let mut req = self
733 .client
734 .post(url)
735 .bearer_auth(&self.api_key)
736 .multipart(form);
737
738 if let Some(t) = self.timeout_seconds {
739 req = req.timeout(Duration::from_secs(t));
740 }
741
742 let resp = req.send().await?;
743 let text = resp.text().await?;
744 let raw = text.clone();
745 Ok(raw)
746 }
747}
748
749#[cfg(feature = "openai")]
750#[async_trait]
751impl EmbeddingProvider for OpenAI {
752 async fn embed(&self, input: Vec<String>) -> Result<Vec<Vec<f32>>, LLMError> {
753 if self.api_key.is_empty() {
754 return Err(LLMError::AuthError("Missing OpenAI API key".into()));
755 }
756
757 let emb_format = self
758 .embedding_encoding_format
759 .clone()
760 .unwrap_or_else(|| "float".to_string());
761
762 let body = OpenAIEmbeddingRequest {
763 model: self.model.clone(),
764 input,
765 encoding_format: Some(emb_format),
766 dimensions: self.embedding_dimensions,
767 };
768
769 let url = self
770 .base_url
771 .join("embeddings")
772 .map_err(|e| LLMError::HttpError(e.to_string()))?;
773
774 let resp = self
775 .client
776 .post(url)
777 .bearer_auth(&self.api_key)
778 .json(&body)
779 .send()
780 .await?
781 .error_for_status()?;
782
783 let json_resp: OpenAIEmbeddingResponse = resp.json().await?;
784
785 let embeddings = json_resp.data.into_iter().map(|d| d.embedding).collect();
786 Ok(embeddings)
787 }
788}
789
790impl LLMProvider for OpenAI {
791 fn tools(&self) -> Option<&[Tool]> {
792 self.tools.as_deref()
793 }
794}
795
796#[async_trait]
797impl TextToSpeechProvider for OpenAI {
798 async fn speech(&self, text: &str) -> Result<Vec<u8>, LLMError> {
806 if self.api_key.is_empty() {
807 return Err(LLMError::AuthError("Missing OpenAI API key".into()));
808 }
809
810 let url = self
811 .base_url
812 .join("audio/speech")
813 .map_err(|e| LLMError::HttpError(e.to_string()))?;
814
815 #[derive(Serialize)]
816 struct SpeechRequest {
817 model: String,
818 input: String,
819 voice: String,
820 }
821
822 let body = SpeechRequest {
823 model: self.model.clone(),
824 input: text.to_string(),
825 voice: self.voice.clone().unwrap_or("alloy".to_string()),
826 };
827
828 let mut req = self.client.post(url).bearer_auth(&self.api_key).json(&body);
829
830 if let Some(t) = self.timeout_seconds {
831 req = req.timeout(Duration::from_secs(t));
832 }
833
834 let resp = req.send().await?;
835
836 if !resp.status().is_success() {
837 let status = resp.status();
838 let error_text = resp.text().await?;
839 return Err(LLMError::ResponseFormatError {
840 message: format!("OpenAI API returned error status: {}", status),
841 raw_response: error_text,
842 });
843 }
844
845 Ok(resp.bytes().await?.to_vec())
846 }
847}
848
849fn parse_sse_chunk(chunk: &str) -> Result<Option<String>, LLMError> {
861 for line in chunk.lines() {
862 let line = line.trim();
863
864 if line.starts_with("data: ") {
865 let data = &line[6..];
866
867 if data == "[DONE]" {
868 return Ok(None);
869 }
870
871 match serde_json::from_str::<OpenAIChatStreamResponse>(data) {
872 Ok(response) => {
873 if let Some(choice) = response.choices.first() {
874 if let Some(content) = &choice.delta.content {
875 return Ok(Some(content.clone()));
876 }
877 }
878 return Ok(None);
879 }
880 Err(_) => continue,
881 }
882 }
883 }
884
885 Ok(None)
886}