1use crate::{
6 builder::LLMBackend,
7 chat::Tool,
8 chat::{ChatMessage, ChatProvider, ChatRole, MessageType, StructuredOutputFormat},
9 completion::{CompletionProvider, CompletionRequest, CompletionResponse},
10 embedding::EmbeddingProvider,
11 error::LLMError,
12 models::{ModelListRawEntry, ModelListRequest, ModelListResponse, ModelsProvider},
13 LLMProvider,
14};
15use crate::{
16 builder::LLMBuilder,
17 chat::{ChatResponse, ToolChoice},
18 FunctionCall, ToolCall,
19};
20use async_trait::async_trait;
21use chrono::{DateTime, Utc};
22use either::*;
23use futures::stream::Stream;
24use reqwest::{Client, Url};
25use serde::{Deserialize, Serialize};
26use serde_json::Value;
27use std::sync::Arc;
28
29pub struct OpenAI {
33 pub api_key: String,
34 pub base_url: Url,
35 pub model: String,
36 pub max_tokens: Option<u32>,
37 pub temperature: Option<f32>,
38 pub system: Option<String>,
39 pub timeout_seconds: Option<u64>,
40 pub stream: Option<bool>,
41 pub top_p: Option<f32>,
42 pub top_k: Option<u32>,
43 pub tools: Option<Vec<Tool>>,
44 pub tool_choice: Option<ToolChoice>,
45 pub embedding_encoding_format: Option<String>,
47 pub embedding_dimensions: Option<u32>,
48 pub reasoning_effort: Option<String>,
49 pub json_schema: Option<StructuredOutputFormat>,
51 pub voice: Option<String>,
52 pub enable_web_search: Option<bool>,
53 pub web_search_context_size: Option<String>,
54 pub web_search_user_location_type: Option<String>,
55 pub web_search_user_location_approximate_country: Option<String>,
56 pub web_search_user_location_approximate_city: Option<String>,
57 pub web_search_user_location_approximate_region: Option<String>,
58 client: Client,
59}
60
61#[derive(Serialize, Debug)]
63struct OpenAIChatMessage<'a> {
64 #[allow(dead_code)]
65 role: &'a str,
66 #[serde(
67 skip_serializing_if = "Option::is_none",
68 with = "either::serde_untagged_optional"
69 )]
70 content: Option<Either<Vec<MessageContent<'a>>, String>>,
71 #[serde(skip_serializing_if = "Option::is_none")]
72 tool_calls: Option<Vec<OpenAIFunctionCall<'a>>>,
73 #[serde(skip_serializing_if = "Option::is_none")]
74 tool_call_id: Option<String>,
75}
76
77#[derive(Serialize, Debug)]
78struct OpenAIFunctionPayload<'a> {
79 name: &'a str,
80 arguments: &'a str,
81}
82
83#[derive(Serialize, Debug)]
84struct OpenAIFunctionCall<'a> {
85 id: &'a str,
86 #[serde(rename = "type")]
87 content_type: &'a str,
88 function: OpenAIFunctionPayload<'a>,
89}
90
91#[derive(Serialize, Debug)]
92struct MessageContent<'a> {
93 #[serde(rename = "type", skip_serializing_if = "Option::is_none")]
94 message_type: Option<&'a str>,
95 #[serde(skip_serializing_if = "Option::is_none")]
96 text: Option<&'a str>,
97 #[serde(skip_serializing_if = "Option::is_none")]
98 image_url: Option<ImageUrlContent<'a>>,
99 #[serde(skip_serializing_if = "Option::is_none", rename = "tool_call_id")]
100 tool_call_id: Option<&'a str>,
101 #[serde(skip_serializing_if = "Option::is_none", rename = "content")]
102 tool_output: Option<&'a str>,
103}
104
105#[derive(Serialize, Debug)]
107struct ImageUrlContent<'a> {
108 url: &'a str,
109}
110
111#[derive(Serialize)]
112struct OpenAIEmbeddingRequest {
113 model: String,
114 input: Vec<String>,
115 #[serde(skip_serializing_if = "Option::is_none")]
116 encoding_format: Option<String>,
117 #[serde(skip_serializing_if = "Option::is_none")]
118 dimensions: Option<u32>,
119}
120
121#[derive(Serialize, Debug)]
123struct OpenAIChatRequest<'a> {
124 model: &'a str,
125 messages: Vec<OpenAIChatMessage<'a>>,
126 #[serde(skip_serializing_if = "Option::is_none")]
127 max_tokens: Option<u32>,
128 #[serde(skip_serializing_if = "Option::is_none")]
129 temperature: Option<f32>,
130 stream: bool,
131 #[serde(skip_serializing_if = "Option::is_none")]
132 top_p: Option<f32>,
133 #[serde(skip_serializing_if = "Option::is_none")]
134 top_k: Option<u32>,
135 #[serde(skip_serializing_if = "Option::is_none")]
136 tools: Option<Vec<Tool>>,
137 #[serde(skip_serializing_if = "Option::is_none")]
138 tool_choice: Option<ToolChoice>,
139 #[serde(skip_serializing_if = "Option::is_none")]
140 reasoning_effort: Option<String>,
141 #[serde(skip_serializing_if = "Option::is_none")]
142 response_format: Option<OpenAIResponseFormat>,
143 #[serde(skip_serializing_if = "Option::is_none")]
144 web_search_options: Option<OpenAIWebSearchOptions>,
145}
146
147impl std::fmt::Display for ToolCall {
148 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
149 write!(
150 f,
151 "{{\n \"id\": \"{}\",\n \"type\": \"{}\",\n \"function\": {}\n}}",
152 self.id, self.call_type, self.function
153 )
154 }
155}
156
157impl std::fmt::Display for FunctionCall {
158 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
159 write!(
160 f,
161 "{{\n \"name\": \"{}\",\n \"arguments\": {}\n}}",
162 self.name, self.arguments
163 )
164 }
165}
166
167#[derive(Deserialize, Debug)]
169struct OpenAIChatResponse {
170 choices: Vec<OpenAIChatChoice>,
171}
172
173#[derive(Deserialize, Debug)]
175struct OpenAIChatChoice {
176 message: OpenAIChatMsg,
177}
178
179#[derive(Deserialize, Debug)]
181struct OpenAIChatMsg {
182 #[allow(dead_code)]
183 role: String,
184 content: Option<String>,
185 tool_calls: Option<Vec<ToolCall>>,
186}
187
188#[derive(Deserialize, Debug)]
189struct OpenAIEmbeddingData {
190 embedding: Vec<f32>,
191}
192#[derive(Deserialize, Debug)]
193struct OpenAIEmbeddingResponse {
194 data: Vec<OpenAIEmbeddingData>,
195}
196
197#[derive(Deserialize, Debug)]
199struct OpenAIChatStreamResponse {
200 choices: Vec<OpenAIChatStreamChoice>,
201}
202
203#[derive(Deserialize, Debug)]
205struct OpenAIChatStreamChoice {
206 delta: OpenAIChatStreamDelta,
207}
208
209#[derive(Deserialize, Debug)]
211struct OpenAIChatStreamDelta {
212 content: Option<String>,
213}
214
215#[derive(Deserialize, Debug, Serialize)]
219enum OpenAIResponseType {
220 #[serde(rename = "text")]
221 Text,
222 #[serde(rename = "json_schema")]
223 JsonSchema,
224 #[serde(rename = "json_object")]
225 JsonObject,
226}
227
228#[derive(Deserialize, Debug, Serialize)]
229struct OpenAIResponseFormat {
230 #[serde(rename = "type")]
231 response_type: OpenAIResponseType,
232 #[serde(skip_serializing_if = "Option::is_none")]
233 json_schema: Option<StructuredOutputFormat>,
234}
235
236#[derive(Deserialize, Debug, Serialize)]
237struct OpenAIWebSearchOptions {
238 #[serde(skip_serializing_if = "Option::is_none")]
239 user_location: Option<UserLocation>,
240 #[serde(skip_serializing_if = "Option::is_none")]
241 search_context_size: Option<String>,
242}
243
244#[derive(Deserialize, Debug, Serialize)]
245struct UserLocation {
246 #[serde(rename = "type")]
247 location_type: String,
248 #[serde(skip_serializing_if = "Option::is_none")]
249 approximate: Option<ApproximateLocation>,
250}
251
252#[derive(Deserialize, Debug, Serialize)]
253struct ApproximateLocation {
254 country: String,
255 city: String,
256 region: String,
257}
258
259impl From<StructuredOutputFormat> for OpenAIResponseFormat {
260 fn from(structured_response_format: StructuredOutputFormat) -> Self {
262 match structured_response_format.schema {
265 None => OpenAIResponseFormat {
266 response_type: OpenAIResponseType::JsonSchema,
267 json_schema: Some(structured_response_format),
268 },
269 Some(mut schema) => {
270 schema = if schema.get("additionalProperties").is_none() {
273 schema["additionalProperties"] = serde_json::json!(false);
274 schema
275 } else {
276 schema
277 };
278
279 OpenAIResponseFormat {
280 response_type: OpenAIResponseType::JsonSchema,
281 json_schema: Some(StructuredOutputFormat {
282 name: structured_response_format.name,
283 description: structured_response_format.description,
284 schema: Some(schema),
285 strict: structured_response_format.strict,
286 }),
287 }
288 }
289 }
290 }
291}
292
293impl ChatResponse for OpenAIChatResponse {
294 fn text(&self) -> Option<String> {
295 self.choices.first().and_then(|c| c.message.content.clone())
296 }
297
298 fn tool_calls(&self) -> Option<Vec<ToolCall>> {
299 self.choices
300 .first()
301 .and_then(|c| c.message.tool_calls.clone())
302 }
303}
304
305impl std::fmt::Display for OpenAIChatResponse {
306 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
307 match (
308 &self.choices.first().unwrap().message.content,
309 &self.choices.first().unwrap().message.tool_calls,
310 ) {
311 (Some(content), Some(tool_calls)) => {
312 for tool_call in tool_calls {
313 write!(f, "{}", tool_call)?;
314 }
315 write!(f, "{}", content)
316 }
317 (Some(content), None) => write!(f, "{}", content),
318 (None, Some(tool_calls)) => {
319 for tool_call in tool_calls {
320 write!(f, "{}", tool_call)?;
321 }
322 Ok(())
323 }
324 (None, None) => write!(f, ""),
325 }
326 }
327}
328
329impl OpenAI {
330 #[allow(clippy::too_many_arguments)]
350 pub fn new(
351 api_key: impl Into<String>,
352 base_url: Option<String>,
353 model: Option<String>,
354 max_tokens: Option<u32>,
355 temperature: Option<f32>,
356 timeout_seconds: Option<u64>,
357 system: Option<String>,
358 stream: Option<bool>,
359 top_p: Option<f32>,
360 top_k: Option<u32>,
361 embedding_encoding_format: Option<String>,
362 embedding_dimensions: Option<u32>,
363 tools: Option<Vec<Tool>>,
364 tool_choice: Option<ToolChoice>,
365 reasoning_effort: Option<String>,
366 json_schema: Option<StructuredOutputFormat>,
367 voice: Option<String>,
368 enable_web_search: Option<bool>,
369 web_search_context_size: Option<String>,
370 web_search_user_location_type: Option<String>,
371 web_search_user_location_approximate_country: Option<String>,
372 web_search_user_location_approximate_city: Option<String>,
373 web_search_user_location_approximate_region: Option<String>,
374 ) -> Self {
375 let mut builder = Client::builder();
376 if let Some(sec) = timeout_seconds {
377 builder = builder.timeout(std::time::Duration::from_secs(sec));
378 }
379 Self {
380 api_key: api_key.into(),
381 base_url: Url::parse(
382 &base_url.unwrap_or_else(|| "https://api.openai.com/v1/".to_owned()),
383 )
384 .expect("Failed to prase base Url"),
385 model: model.unwrap_or("gpt-3.5-turbo".to_string()),
386 max_tokens,
387 temperature,
388 system,
389 timeout_seconds,
390 stream,
391 top_p,
392 top_k,
393 tools,
394 tool_choice,
395 embedding_encoding_format,
396 embedding_dimensions,
397 client: builder.build().expect("Failed to build reqwest Client"),
398 reasoning_effort,
399 json_schema,
400 voice,
401 enable_web_search,
402 web_search_context_size,
403 web_search_user_location_type,
404 web_search_user_location_approximate_country,
405 web_search_user_location_approximate_city,
406 web_search_user_location_approximate_region,
407 }
408 }
409}
410
411#[async_trait]
412impl ChatProvider for OpenAI {
413 async fn chat_with_tools(
423 &self,
424 messages: &[ChatMessage],
425 tools: Option<&[Tool]>,
426 ) -> Result<Box<dyn ChatResponse>, LLMError> {
427 if self.api_key.is_empty() {
428 return Err(LLMError::AuthError("Missing OpenAI API key".to_string()));
429 }
430
431 let messages = messages.to_vec();
433
434 let mut openai_msgs: Vec<OpenAIChatMessage> = vec![];
435
436 for msg in messages {
437 if let MessageType::ToolResult(ref results) = msg.message_type {
438 for result in results {
439 openai_msgs.push(
440 OpenAIChatMessage {
442 role: "tool",
443 tool_call_id: Some(result.id.clone()),
444 tool_calls: None,
445 content: Some(Right(result.function.arguments.clone())),
446 },
447 );
448 }
449 } else {
450 openai_msgs.push(chat_message_to_api_message(msg))
451 }
452 }
453
454 if let Some(system) = &self.system {
455 openai_msgs.insert(
456 0,
457 OpenAIChatMessage {
458 role: "system",
459 content: Some(Left(vec![MessageContent {
460 message_type: Some("text"),
461 text: Some(system),
462 image_url: None,
463 tool_call_id: None,
464 tool_output: None,
465 }])),
466 tool_calls: None,
467 tool_call_id: None,
468 },
469 );
470 }
471
472 let response_format: Option<OpenAIResponseFormat> =
473 self.json_schema.clone().map(|s| s.into());
474
475 let request_tools = tools.map(|t| t.to_vec()).or_else(|| self.tools.clone());
476
477 let request_tool_choice = if request_tools.is_some() {
478 self.tool_choice.clone()
479 } else {
480 None
481 };
482
483 let web_search_options = if self.enable_web_search.unwrap_or(false) {
484 let loc_type_opt = self
485 .web_search_user_location_type
486 .as_ref()
487 .filter(|t| matches!(t.as_str(), "exact" | "approximate"));
488
489 let country = self.web_search_user_location_approximate_country.as_ref();
490 let city = self.web_search_user_location_approximate_city.as_ref();
491 let region = self.web_search_user_location_approximate_region.as_ref();
492
493 let approximate = if [country, city, region].iter().any(|v| v.is_some()) {
494 Some(ApproximateLocation {
495 country: country.cloned().unwrap_or_default(),
496 city: city.cloned().unwrap_or_default(),
497 region: region.cloned().unwrap_or_default(),
498 })
499 } else {
500 None
501 };
502
503 let user_location = loc_type_opt.map(|loc_type| UserLocation {
504 location_type: loc_type.clone(),
505 approximate,
506 });
507
508 Some(OpenAIWebSearchOptions {
509 search_context_size: self.web_search_context_size.clone(),
510 user_location,
511 })
512 } else {
513 None
514 };
515
516 let body = OpenAIChatRequest {
517 model: &self.model,
518 messages: openai_msgs,
519 max_tokens: self.max_tokens,
520 temperature: self.temperature,
521 stream: self.stream.unwrap_or(false),
522 top_p: self.top_p,
523 top_k: self.top_k,
524 tools: request_tools,
525 tool_choice: request_tool_choice,
526 reasoning_effort: self.reasoning_effort.clone(),
527 response_format,
528 web_search_options,
529 };
530
531 let url = self
532 .base_url
533 .join("chat/completions")
534 .map_err(|e| LLMError::HttpError(e.to_string()))?;
535
536 let mut request = self.client.post(url).bearer_auth(&self.api_key).json(&body);
537
538 if log::log_enabled!(log::Level::Trace) {
539 if let Ok(json) = serde_json::to_string(&body) {
540 log::trace!("OpenAI request payload: {}", json);
541 }
542 }
543
544 if let Some(timeout) = self.timeout_seconds {
545 request = request.timeout(std::time::Duration::from_secs(timeout));
546 }
547
548 let response = request.send().await?;
549
550 log::debug!("OpenAI HTTP status: {}", response.status());
551
552 if !response.status().is_success() {
553 let status = response.status();
554 let error_text = response.text().await?;
555 return Err(LLMError::ResponseFormatError {
556 message: format!("OpenAI API returned error status: {}", status),
557 raw_response: error_text,
558 });
559 }
560
561 let resp_text = response.text().await?;
563 let json_resp: Result<OpenAIChatResponse, serde_json::Error> =
564 serde_json::from_str(&resp_text);
565
566 match json_resp {
567 Ok(response) => Ok(Box::new(response)),
568 Err(e) => Err(LLMError::ResponseFormatError {
569 message: format!("Failed to decode OpenAI API response: {}", e),
570 raw_response: resp_text,
571 }),
572 }
573 }
574
575 async fn chat(&self, messages: &[ChatMessage]) -> Result<Box<dyn ChatResponse>, LLMError> {
576 self.chat_with_tools(messages, None).await
577 }
578
579 async fn chat_stream(
589 &self,
590 messages: &[ChatMessage],
591 ) -> Result<std::pin::Pin<Box<dyn Stream<Item = Result<String, LLMError>> + Send>>, LLMError>
592 {
593 if self.api_key.is_empty() {
594 return Err(LLMError::AuthError("Missing OpenAI API key".to_string()));
595 }
596
597 let messages = messages.to_vec();
598 let mut openai_msgs: Vec<OpenAIChatMessage> = vec![];
599
600 for msg in messages {
601 if let MessageType::ToolResult(ref results) = msg.message_type {
602 for result in results {
603 openai_msgs.push(OpenAIChatMessage {
604 role: "tool",
605 tool_call_id: Some(result.id.clone()),
606 tool_calls: None,
607 content: Some(Right(result.function.arguments.clone())),
608 });
609 }
610 } else {
611 openai_msgs.push(chat_message_to_api_message(msg))
612 }
613 }
614
615 if let Some(system) = &self.system {
616 openai_msgs.insert(
617 0,
618 OpenAIChatMessage {
619 role: "system",
620 content: Some(Left(vec![MessageContent {
621 message_type: Some("text"),
622 text: Some(system),
623 image_url: None,
624 tool_call_id: None,
625 tool_output: None,
626 }])),
627 tool_calls: None,
628 tool_call_id: None,
629 },
630 );
631 }
632
633 let body = OpenAIChatRequest {
634 model: &self.model,
635 messages: openai_msgs,
636 max_tokens: self.max_tokens,
637 temperature: self.temperature,
638 stream: true,
639 top_p: self.top_p,
640 top_k: self.top_k,
641 tools: self.tools.clone(),
642 tool_choice: self.tool_choice.clone(),
643 reasoning_effort: self.reasoning_effort.clone(),
644 response_format: None,
645 web_search_options: None,
646 };
647
648 let url = self
649 .base_url
650 .join("chat/completions")
651 .map_err(|e| LLMError::HttpError(e.to_string()))?;
652
653 let mut request = self.client.post(url).bearer_auth(&self.api_key).json(&body);
654
655 if let Some(timeout) = self.timeout_seconds {
656 request = request.timeout(std::time::Duration::from_secs(timeout));
657 }
658
659 let response = request.send().await?;
660
661 if !response.status().is_success() {
662 let status = response.status();
663 let error_text = response.text().await?;
664 return Err(LLMError::ResponseFormatError {
665 message: format!("OpenAI API returned error status: {}", status),
666 raw_response: error_text,
667 });
668 }
669
670 Ok(crate::chat::create_sse_stream(response, parse_sse_chunk))
671 }
672}
673
674fn chat_message_to_api_message(chat_msg: ChatMessage) -> OpenAIChatMessage<'static> {
676 OpenAIChatMessage {
678 role: match chat_msg.role {
679 ChatRole::User => "user",
680 ChatRole::Assistant => "assistant",
681 },
682 tool_call_id: None,
683 content: match &chat_msg.message_type {
684 MessageType::Text => Some(Right(chat_msg.content.clone())),
685 MessageType::Image(_) => unreachable!(),
687 MessageType::Pdf(_) => unimplemented!(),
688 MessageType::ImageURL(url) => {
689 let owned_url = url.clone();
691 let url_str = Box::leak(owned_url.into_boxed_str());
693 Some(Left(vec![MessageContent {
694 message_type: Some("image_url"),
695 text: None,
696 image_url: Some(ImageUrlContent { url: url_str }),
697 tool_output: None,
698 tool_call_id: None,
699 }]))
700 }
701 MessageType::ToolUse(_) => None,
702 MessageType::ToolResult(_) => None,
703 },
704 tool_calls: match &chat_msg.message_type {
705 MessageType::ToolUse(calls) => {
706 let owned_calls: Vec<OpenAIFunctionCall<'static>> = calls
707 .iter()
708 .map(|c| {
709 let owned_id = c.id.clone();
710 let owned_name = c.function.name.clone();
711 let owned_args = c.function.arguments.clone();
712
713 let id_str = Box::leak(owned_id.into_boxed_str());
717 let name_str = Box::leak(owned_name.into_boxed_str());
718 let args_str = Box::leak(owned_args.into_boxed_str());
719
720 OpenAIFunctionCall {
721 id: id_str,
722 content_type: "function",
723 function: OpenAIFunctionPayload {
724 name: name_str,
725 arguments: args_str,
726 },
727 }
728 })
729 .collect();
730 Some(owned_calls)
731 }
732 _ => None,
733 },
734 }
735}
736
737#[async_trait]
738impl CompletionProvider for OpenAI {
739 async fn complete(&self, _req: &CompletionRequest) -> Result<CompletionResponse, LLMError> {
743 Ok(CompletionResponse {
744 text: "OpenAI completion not implemented.".into(),
745 })
746 }
747}
748
749#[cfg(feature = "openai")]
750#[async_trait]
751impl EmbeddingProvider for OpenAI {
752 async fn embed(&self, input: Vec<String>) -> Result<Vec<Vec<f32>>, LLMError> {
753 if self.api_key.is_empty() {
754 return Err(LLMError::AuthError("Missing OpenAI API key".into()));
755 }
756
757 let emb_format = self
758 .embedding_encoding_format
759 .clone()
760 .unwrap_or_else(|| "float".to_string());
761
762 let body = OpenAIEmbeddingRequest {
763 model: self.model.clone(),
764 input,
765 encoding_format: Some(emb_format),
766 dimensions: self.embedding_dimensions,
767 };
768
769 let url = self
770 .base_url
771 .join("embeddings")
772 .map_err(|e| LLMError::HttpError(e.to_string()))?;
773
774 let resp = self
775 .client
776 .post(url)
777 .bearer_auth(&self.api_key)
778 .json(&body)
779 .send()
780 .await?
781 .error_for_status()?;
782
783 let json_resp: OpenAIEmbeddingResponse = resp.json().await?;
784
785 let embeddings = json_resp.data.into_iter().map(|d| d.embedding).collect();
786 Ok(embeddings)
787 }
788}
789
790#[derive(Clone, Debug, Deserialize)]
791pub struct OpenAIModelEntry {
792 pub id: String,
793 pub created: Option<u64>,
794 #[serde(flatten)]
795 pub extra: Value,
796}
797
798impl ModelListRawEntry for OpenAIModelEntry {
799 fn get_id(&self) -> String {
800 self.id.clone()
801 }
802
803 fn get_created_at(&self) -> DateTime<Utc> {
804 self.created
805 .map(|t| chrono::DateTime::from_timestamp(t as i64, 0).unwrap_or_default())
806 .unwrap_or_default()
807 }
808
809 fn get_raw(&self) -> Value {
810 self.extra.clone()
811 }
812}
813
814#[derive(Clone, Debug, Deserialize)]
815pub struct OpenAIModelListResponse {
816 pub data: Vec<OpenAIModelEntry>,
817}
818
819impl ModelListResponse for OpenAIModelListResponse {
820 fn get_models(&self) -> Vec<String> {
821 self.data.iter().map(|e| e.id.clone()).collect()
822 }
823
824 fn get_models_raw(&self) -> Vec<Box<dyn ModelListRawEntry>> {
825 self.data
826 .iter()
827 .map(|e| Box::new(e.clone()) as Box<dyn ModelListRawEntry>)
828 .collect()
829 }
830
831 fn get_backend(&self) -> LLMBackend {
832 LLMBackend::OpenAI
833 }
834}
835
836#[async_trait]
837impl ModelsProvider for OpenAI {
838 async fn list_models(
839 &self,
840 _request: Option<&ModelListRequest>,
841 ) -> Result<Box<dyn ModelListResponse>, LLMError> {
842 let url = self
843 .base_url
844 .join("models")
845 .map_err(|e| LLMError::HttpError(e.to_string()))?;
846
847 let resp = self
848 .client
849 .get(url)
850 .bearer_auth(&self.api_key)
851 .send()
852 .await?
853 .error_for_status()?;
854
855 let result = resp.json::<OpenAIModelListResponse>().await?;
856
857 Ok(Box::new(result))
858 }
859}
860
861impl LLMProvider for OpenAI {
862 fn tools(&self) -> Option<&[Tool]> {
863 self.tools.as_deref()
864 }
865}
866
867fn parse_sse_chunk(chunk: &str) -> Result<Option<String>, LLMError> {
879 let mut collected_content = String::new();
880
881 for line in chunk.lines() {
882 let line = line.trim();
883
884 if let Some(data) = line.strip_prefix("data: ") {
885 if data == "[DONE]" {
886 if collected_content.is_empty() {
887 return Ok(None);
888 } else {
889 return Ok(Some(collected_content));
890 }
891 }
892
893 match serde_json::from_str::<OpenAIChatStreamResponse>(data) {
894 Ok(response) => {
895 if let Some(choice) = response.choices.first() {
896 if let Some(content) = &choice.delta.content {
897 collected_content.push_str(content);
898 }
899 }
900 }
901 Err(_) => continue,
902 }
903 }
904 }
905
906 if collected_content.is_empty() {
907 Ok(None)
908 } else {
909 Ok(Some(collected_content))
910 }
911}
912
913impl LLMBuilder<OpenAI> {
914 pub fn voice(mut self, voice: impl Into<String>) -> Self {
916 self.voice = Some(voice.into());
917 self
918 }
919
920 pub fn openai_enable_web_search(mut self, enable: bool) -> Self {
922 self.openai_enable_web_search = Some(enable);
923 self
924 }
925
926 pub fn openai_web_search_context_size(mut self, context_size: impl Into<String>) -> Self {
928 self.openai_web_search_context_size = Some(context_size.into());
929 self
930 }
931
932 pub fn openai_web_search_user_location_type(
934 mut self,
935 location_type: impl Into<String>,
936 ) -> Self {
937 self.openai_web_search_user_location_type = Some(location_type.into());
938 self
939 }
940
941 pub fn openai_web_search_user_location_approximate_country(
943 mut self,
944 country: impl Into<String>,
945 ) -> Self {
946 self.openai_web_search_user_location_approximate_country = Some(country.into());
947 self
948 }
949
950 pub fn openai_web_search_user_location_approximate_city(
952 mut self,
953 city: impl Into<String>,
954 ) -> Self {
955 self.openai_web_search_user_location_approximate_city = Some(city.into());
956 self
957 }
958
959 pub fn openai_web_search_user_location_approximate_region(
961 mut self,
962 region: impl Into<String>,
963 ) -> Self {
964 self.openai_web_search_user_location_approximate_region = Some(region.into());
965 self
966 }
967
968 pub fn build(self) -> Result<Arc<dyn LLMProvider>, LLMError> {
969 let (tools, tool_choice) = self.validate_tool_config()?;
970 let key = self.api_key.ok_or_else(|| {
971 LLMError::InvalidRequest("No API key provided for OpenAI".to_string())
972 })?;
973 let openai = OpenAI::new(
974 key,
975 self.base_url,
976 self.model,
977 self.max_tokens,
978 self.temperature,
979 self.timeout_seconds,
980 self.system,
981 self.stream,
982 self.top_p,
983 self.top_k,
984 self.embedding_encoding_format,
985 self.embedding_dimensions,
986 tools,
987 tool_choice,
988 self.reasoning_effort,
989 self.json_schema,
990 self.voice,
991 self.openai_enable_web_search,
992 self.openai_web_search_context_size,
993 self.openai_web_search_user_location_type,
994 self.openai_web_search_user_location_approximate_country,
995 self.openai_web_search_user_location_approximate_city,
996 self.openai_web_search_user_location_approximate_region,
997 );
998
999 Ok(Arc::new(openai))
1000 }
1001}