1use async_openai::{
6 config::OpenAIConfig,
7 types::chat::{
8 ChatCompletionMessageToolCall, ChatCompletionMessageToolCalls,
9 ChatCompletionNamedToolChoice, ChatCompletionRequestAssistantMessageArgs,
10 ChatCompletionRequestMessage, ChatCompletionRequestMessageContentPartImage,
11 ChatCompletionRequestMessageContentPartText, ChatCompletionRequestSystemMessageArgs,
12 ChatCompletionRequestToolMessageArgs, ChatCompletionRequestUserMessageArgs,
13 ChatCompletionRequestUserMessageContent, ChatCompletionRequestUserMessageContentPart,
14 ChatCompletionStreamOptions, ChatCompletionTool, ChatCompletionToolChoiceOption,
15 ChatCompletionTools, CompletionUsage, CreateChatCompletionRequestArgs, FinishReason,
16 FunctionCall, FunctionName, FunctionObjectArgs, ImageDetail, ImageUrl, ToolChoiceOptions,
17 },
18 Client,
19};
20use async_trait::async_trait;
21use futures::StreamExt;
22use std::collections::HashMap;
23use tracing::debug;
24
25use crate::error::{LlmError, Result};
26use crate::traits::FunctionCall as TraitFunctionCall;
27use crate::traits::ImageData;
28use crate::traits::ToolCall;
29use crate::traits::{
30 ChatMessage, ChatRole, CompletionOptions, EmbeddingProvider, LLMProvider, LLMResponse,
31 StreamChunk, StreamUsage, ToolChoice, ToolDefinition,
32};
33
34pub struct OpenAIProvider {
42 client: Client<OpenAIConfig>,
43 model: String,
44 embedding_model: String,
45 max_context_length: usize,
46 embedding_dimension: usize,
47 raw_api_key: String,
49 raw_base_url: String,
51}
52
53impl OpenAIProvider {
54 pub fn new(api_key: impl Into<String>) -> Self {
56 let key = api_key.into();
57 let config = OpenAIConfig::new().with_api_key(&key);
58 Self::with_config_and_key(config, key, String::new())
59 }
60
61 pub fn with_config(config: OpenAIConfig) -> Self {
64 Self::with_config_and_key(config, String::new(), String::new())
67 }
68
69 fn with_config_and_key(config: OpenAIConfig, api_key: String, base_url: String) -> Self {
71 Self {
72 client: Client::with_config(config),
73 model: "gpt-5-mini".to_string(),
74 embedding_model: "text-embedding-3-small".to_string(),
75 max_context_length: 200000,
76 embedding_dimension: 1536,
77 raw_api_key: api_key,
78 raw_base_url: base_url,
79 }
80 }
81
82 pub fn compatible(api_key: impl Into<String>, base_url: impl Into<String>) -> Self {
84 let key = api_key.into();
85 let url = base_url.into();
86 let config = OpenAIConfig::new().with_api_key(&key).with_api_base(&url);
87 Self::with_config_and_key(config, key, url)
88 }
89
90 pub fn from_env() -> crate::error::Result<Self> {
102 let _ = dotenvy::dotenv();
103 let api_key = std::env::var("OPENAI_API_KEY")
104 .map_err(|_| crate::error::LlmError::ConfigError("OPENAI_API_KEY not set".into()))?;
105 let base_url = std::env::var("OPENAI_BASE_URL").unwrap_or_default();
106 let mut config = OpenAIConfig::new().with_api_key(&api_key);
107 if !base_url.is_empty() {
108 config = config.with_api_base(&base_url);
109 }
110 let mut provider = Self::with_config_and_key(config, api_key, base_url);
111 if let Ok(model) = std::env::var("OPENAI_MODEL") {
112 provider = provider.with_model(model);
113 }
114 Ok(provider)
115 }
116
117 pub fn with_model(mut self, model: impl Into<String>) -> Self {
119 self.model = model.into();
120 self.max_context_length = Self::context_length_for_model(&self.model);
121 self
122 }
123
124 pub fn with_embedding_model(mut self, model: impl Into<String>) -> Self {
126 self.embedding_model = model.into();
127 self.embedding_dimension = Self::dimension_for_model(&self.embedding_model);
128 self
129 }
130
131 fn context_length_for_model(model: &str) -> usize {
133 match model {
134 m if m.contains("gpt-5.2") || m.contains("gpt-5.1") => 200000,
136 m if m.contains("gpt-5-nano") => 128000,
137 m if m.contains("gpt-5-mini") || m.contains("gpt-5") => 200000,
138
139 m if m.contains("gpt-4.1") => 128000,
141
142 m if m.contains("o4") || m.contains("o3") => 200000,
144 m if m.contains("o1") => 200000,
145
146 m if m.contains("gpt-4o") => 128000,
148 m if m.contains("gpt-4-turbo") => 128000,
149 m if m.contains("gpt-4-32k") => 32768,
150 m if m.contains("gpt-4") => 8192,
151
152 m if m.contains("gpt-3.5-turbo-16k") => 16384,
154 m if m.contains("gpt-3.5") => 4096,
155
156 m if m.contains("codex") => 200000,
158
159 m if m.contains("gpt-realtime") || m.contains("gpt-audio") => 128000,
161
162 _ => 128000, }
164 }
165
166 fn dimension_for_model(model: &str) -> usize {
168 match model {
169 m if m.contains("text-embedding-3-large") => 3072,
170 m if m.contains("text-embedding-3-small") => 1536,
171 m if m.contains("text-embedding-ada") => 1536,
172 _ => 1536, }
174 }
175
176 fn extract_usage(
177 usage: Option<CompletionUsage>,
178 ) -> (usize, usize, usize, Option<usize>, Option<usize>) {
179 let usage = usage.unwrap_or(CompletionUsage {
180 prompt_tokens: 0,
181 completion_tokens: 0,
182 total_tokens: 0,
183 prompt_tokens_details: None,
184 completion_tokens_details: None,
185 });
186
187 let cache_hit_tokens = usage
188 .prompt_tokens_details
189 .as_ref()
190 .and_then(|d| d.cached_tokens)
191 .map(|t| t as usize);
192 let thinking_tokens = usage
193 .completion_tokens_details
194 .as_ref()
195 .and_then(|d| d.reasoning_tokens)
196 .map(|t| t as usize);
197
198 (
199 usage.prompt_tokens as usize,
200 usage.completion_tokens as usize,
201 usage.total_tokens as usize,
202 cache_hit_tokens,
203 thinking_tokens,
204 )
205 }
206
207 fn extract_stream_usage(usage: Option<CompletionUsage>) -> Option<StreamUsage> {
208 let (prompt_tokens, completion_tokens, _total_tokens, cache_hit_tokens, thinking_tokens) =
209 Self::extract_usage(usage);
210
211 if prompt_tokens == 0
212 && completion_tokens == 0
213 && cache_hit_tokens.is_none()
214 && thinking_tokens.is_none()
215 {
216 return None;
217 }
218
219 let mut usage = StreamUsage::new(prompt_tokens, completion_tokens);
220 if let Some(tokens) = cache_hit_tokens {
221 usage = usage.with_cache_hit_tokens(tokens);
222 }
223 if let Some(tokens) = thinking_tokens {
224 usage = usage.with_thinking_tokens(tokens);
225 }
226 Some(usage)
227 }
228
229 fn convert_messages(messages: &[ChatMessage]) -> Result<Vec<ChatCompletionRequestMessage>> {
241 messages
242 .iter()
243 .map(|msg| {
244 match msg.role {
245 ChatRole::System => ChatCompletionRequestSystemMessageArgs::default()
246 .content(msg.content.as_str())
247 .build()
248 .map(Into::into)
249 .map_err(|e| LlmError::InvalidRequest(e.to_string())),
250
251 ChatRole::User => {
252 let content = Self::build_user_content(msg);
253 ChatCompletionRequestUserMessageArgs::default()
254 .content(content)
255 .build()
256 .map(Into::into)
257 .map_err(|e| LlmError::InvalidRequest(e.to_string()))
258 }
259
260 ChatRole::Assistant => {
261 let mut builder = ChatCompletionRequestAssistantMessageArgs::default();
262 if !msg.content.is_empty() {
264 builder.content(msg.content.clone());
265 }
266 if let Some(ref tool_calls) = msg.tool_calls {
269 let openai_calls: Vec<ChatCompletionMessageToolCalls> = tool_calls
270 .iter()
271 .map(|tc| {
272 ChatCompletionMessageToolCalls::Function(
273 ChatCompletionMessageToolCall {
274 id: tc.id.clone(),
275 function: FunctionCall {
276 name: tc.function.name.clone(),
277 arguments: tc.function.arguments.clone(),
278 },
279 },
280 )
281 })
282 .collect();
283 builder.tool_calls(openai_calls);
284 }
285 builder
286 .build()
287 .map(Into::into)
288 .map_err(|e| LlmError::InvalidRequest(e.to_string()))
289 }
290
291 ChatRole::Tool => {
292 let tool_call_id = msg.tool_call_id.clone().ok_or_else(|| {
296 LlmError::InvalidRequest(
297 "Tool message missing required tool_call_id".into(),
298 )
299 })?;
300 ChatCompletionRequestToolMessageArgs::default()
301 .content(msg.content.clone())
302 .tool_call_id(tool_call_id)
303 .build()
304 .map(Into::into)
305 .map_err(|e| LlmError::InvalidRequest(e.to_string()))
306 }
307
308 ChatRole::Function => {
309 ChatCompletionRequestUserMessageArgs::default()
312 .content(msg.content.as_str())
313 .build()
314 .map(Into::into)
315 .map_err(|e| LlmError::InvalidRequest(e.to_string()))
316 }
317 }
318 })
319 .collect()
320 }
321
322 fn build_user_content(msg: &ChatMessage) -> ChatCompletionRequestUserMessageContent {
328 if msg.has_images() {
329 let mut parts: Vec<ChatCompletionRequestUserMessageContentPart> = Vec::new();
330
331 if !msg.content.is_empty() {
333 parts.push(ChatCompletionRequestUserMessageContentPart::Text(
334 ChatCompletionRequestMessageContentPartText {
335 text: msg.content.clone(),
336 },
337 ));
338 }
339
340 if let Some(ref images) = msg.images {
342 for img in images {
343 let detail = Self::parse_image_detail(img);
344 parts.push(ChatCompletionRequestUserMessageContentPart::ImageUrl(
345 ChatCompletionRequestMessageContentPartImage {
346 image_url: ImageUrl {
347 url: img.to_api_url(),
349 detail,
350 },
351 },
352 ));
353 }
354 }
355
356 ChatCompletionRequestUserMessageContent::Array(parts)
357 } else {
358 ChatCompletionRequestUserMessageContent::Text(msg.content.clone())
359 }
360 }
361
362 fn parse_image_detail(img: &ImageData) -> Option<ImageDetail> {
364 match img.detail.as_deref() {
365 Some("low") => Some(ImageDetail::Low),
366 Some("high") => Some(ImageDetail::High),
367 Some("auto") => Some(ImageDetail::Auto),
368 _ => None,
369 }
370 }
371}
372
373#[async_trait]
374impl LLMProvider for OpenAIProvider {
375 fn name(&self) -> &str {
376 "openai"
377 }
378
379 fn model(&self) -> &str {
380 &self.model
381 }
382
383 fn max_context_length(&self) -> usize {
384 self.max_context_length
385 }
386
387 async fn complete(&self, prompt: &str) -> Result<LLMResponse> {
388 self.complete_with_options(prompt, &CompletionOptions::default())
389 .await
390 }
391
392 async fn complete_with_options(
393 &self,
394 prompt: &str,
395 options: &CompletionOptions,
396 ) -> Result<LLMResponse> {
397 let mut messages = Vec::new();
398
399 if let Some(system) = &options.system_prompt {
400 messages.push(ChatMessage::system(system));
401 }
402 messages.push(ChatMessage::user(prompt));
403
404 self.chat(&messages, Some(options)).await
405 }
406
407 async fn chat(
408 &self,
409 messages: &[ChatMessage],
410 options: Option<&CompletionOptions>,
411 ) -> Result<LLMResponse> {
412 let openai_messages = Self::convert_messages(messages)?;
413 let options = options.cloned().unwrap_or_default();
414
415 let mut request_builder = CreateChatCompletionRequestArgs::default();
416 request_builder.model(&self.model).messages(openai_messages);
417
418 if let Some(max_tokens) = options.max_tokens {
419 request_builder.max_completion_tokens(max_tokens as u32);
423 }
424
425 if let Some(temp) = options.temperature {
426 if (temp - 1.0_f32).abs() > f32::EPSILON {
429 request_builder.temperature(temp);
430 }
431 }
432
433 if let Some(top_p) = options.top_p {
434 request_builder.top_p(top_p);
435 }
436
437 if let Some(stop) = options.stop {
438 request_builder.stop(stop);
439 }
440
441 if let Some(freq_penalty) = options.frequency_penalty {
442 request_builder.frequency_penalty(freq_penalty);
443 }
444
445 if let Some(pres_penalty) = options.presence_penalty {
446 request_builder.presence_penalty(pres_penalty);
447 }
448
449 let request = request_builder
450 .build()
451 .map_err(|e| LlmError::InvalidRequest(e.to_string()))?;
452
453 let response = self.client.chat().create(request).await?;
454
455 debug!(
457 "OpenAI response - usage: {:?}, model: {}",
458 response.usage, response.model
459 );
460
461 let choice = response
462 .choices
463 .first()
464 .ok_or_else(|| LlmError::ApiError("No choices in response".to_string()))?;
465
466 if let Some(FinishReason::ContentFilter) = choice.finish_reason {
468 return Err(LlmError::ApiError(
469 "Response blocked by OpenAI content filter (finish_reason=content_filter)".into(),
470 ));
471 }
472
473 let content = choice.message.content.clone().unwrap_or_default();
474
475 let (prompt_tokens, completion_tokens, total_tokens, cache_hit_tokens, thinking_tokens) =
476 Self::extract_usage(response.usage.clone());
477
478 debug!(
480 "OpenAI token usage - prompt: {}, completion: {}, total: {}, cached: {:?}, reasoning: {:?}",
481 prompt_tokens, completion_tokens, total_tokens,
482 cache_hit_tokens, thinking_tokens
483 );
484
485 let mut metadata = HashMap::new();
486 metadata.insert("response_id".to_string(), serde_json::json!(response.id));
487
488 Ok(LLMResponse {
489 content,
490 prompt_tokens,
491 completion_tokens,
492 total_tokens,
493 model: response.model,
494 finish_reason: choice.finish_reason.map(|r| format!("{:?}", r)),
495 tool_calls: Vec::new(),
496 metadata,
497 cache_hit_tokens,
498 cache_write_tokens: None,
499 thinking_tokens,
500 thinking_content: None,
501 })
502 }
503
504 async fn chat_with_tools(
512 &self,
513 messages: &[ChatMessage],
514 tools: &[ToolDefinition],
515 tool_choice: Option<ToolChoice>,
516 options: Option<&CompletionOptions>,
517 ) -> Result<LLMResponse> {
518 let openai_messages = Self::convert_messages(messages)?;
519 let opts = options.cloned().unwrap_or_default();
520
521 let openai_tools: Vec<ChatCompletionTools> = tools
522 .iter()
523 .map(|t| {
524 ChatCompletionTools::Function(ChatCompletionTool {
525 function: FunctionObjectArgs::default()
526 .name(&t.function.name)
527 .description(&t.function.description)
528 .parameters(t.function.parameters.clone())
529 .build()
530 .expect("Invalid tool definition"),
531 })
532 })
533 .collect();
534
535 let mut request_builder = CreateChatCompletionRequestArgs::default();
536 request_builder
537 .model(&self.model)
538 .messages(openai_messages)
539 .tools(openai_tools);
540
541 if let Some(tc) = tool_choice {
542 match tc {
543 ToolChoice::Auto(_) => {
544 request_builder.tool_choice(ChatCompletionToolChoiceOption::Mode(
545 ToolChoiceOptions::Auto,
546 ));
547 }
548 ToolChoice::Required(_) => {
549 request_builder.tool_choice(ChatCompletionToolChoiceOption::Mode(
550 ToolChoiceOptions::Required,
551 ));
552 }
553 ToolChoice::Function { ref function, .. } => {
554 request_builder.tool_choice(ChatCompletionToolChoiceOption::Function(
555 ChatCompletionNamedToolChoice {
556 function: FunctionName {
557 name: function.name.clone(),
558 },
559 },
560 ));
561 }
562 }
563 }
564
565 if let Some(max_tokens) = opts.max_tokens {
566 request_builder.max_completion_tokens(max_tokens as u32);
567 }
568
569 if let Some(temp) = opts.temperature {
570 if (temp - 1.0_f32).abs() > f32::EPSILON {
572 request_builder.temperature(temp);
573 }
574 }
575
576 let request = request_builder
577 .build()
578 .map_err(|e| LlmError::InvalidRequest(e.to_string()))?;
579
580 let response = self.client.chat().create(request).await?;
581
582 debug!(
583 "OpenAI chat_with_tools response id={} model={}",
584 response.id, response.model
585 );
586
587 let choice = response
588 .choices
589 .first()
590 .ok_or_else(|| LlmError::ApiError("No choices in response".to_string()))?;
591
592 if let Some(FinishReason::ContentFilter) = choice.finish_reason {
593 return Err(LlmError::ApiError(
594 "Response blocked by OpenAI content filter (finish_reason=content_filter)".into(),
595 ));
596 }
597
598 let tool_calls: Vec<ToolCall> = choice
600 .message
601 .tool_calls
602 .as_deref()
603 .unwrap_or_default()
604 .iter()
605 .filter_map(|tc| {
606 if let ChatCompletionMessageToolCalls::Function(f) = tc {
607 Some(ToolCall {
608 id: f.id.clone(),
609 call_type: "function".to_string(),
610 function: TraitFunctionCall {
611 name: f.function.name.clone(),
612 arguments: f.function.arguments.clone(),
613 },
614 thought_signature: None,
615 })
616 } else {
617 None
618 }
619 })
620 .collect();
621
622 let content = choice.message.content.clone().unwrap_or_default();
623
624 let (prompt_tokens, completion_tokens, total_tokens, cache_hit_tokens, thinking_tokens) =
625 Self::extract_usage(response.usage.clone());
626
627 let mut metadata = HashMap::new();
628 metadata.insert("response_id".to_string(), serde_json::json!(response.id));
629
630 Ok(LLMResponse {
631 content,
632 prompt_tokens,
633 completion_tokens,
634 total_tokens,
635 model: response.model,
636 finish_reason: choice.finish_reason.map(|r| format!("{:?}", r)),
637 tool_calls,
638 metadata,
639 cache_hit_tokens,
640 cache_write_tokens: None,
641 thinking_tokens,
642 thinking_content: None,
643 })
644 }
645
646 fn supports_function_calling(&self) -> bool {
647 true
648 }
649
650 async fn stream(
651 &self,
652 prompt: &str,
653 ) -> Result<futures::stream::BoxStream<'static, Result<String>>> {
654 let request = ChatCompletionRequestUserMessageArgs::default()
655 .content(prompt)
656 .build()
657 .map(Into::into)
658 .map_err(|e| LlmError::InvalidRequest(e.to_string()))?;
659
660 let request = CreateChatCompletionRequestArgs::default()
661 .model(&self.model)
662 .messages(vec![request])
663 .stream(true)
664 .build()
665 .map_err(|e| LlmError::InvalidRequest(e.to_string()))?;
666
667 let stream = self.client.chat().create_stream(request).await?;
668
669 let mapped_stream = stream.map(|res| match res {
670 Ok(response) => {
671 let content = response
672 .choices
673 .first()
674 .and_then(|c| c.delta.content.clone())
675 .unwrap_or_default();
676 Ok(content)
677 }
678 Err(e) => Err(LlmError::from(e)),
684 });
685
686 Ok(mapped_stream.boxed())
687 }
688
689 fn supports_streaming(&self) -> bool {
690 true
691 }
692
693 async fn chat_with_tools_stream(
694 &self,
695 messages: &[ChatMessage],
696 tools: &[ToolDefinition],
697 tool_choice: Option<ToolChoice>,
698 options: Option<&CompletionOptions>,
699 ) -> Result<futures::stream::BoxStream<'static, Result<StreamChunk>>> {
700 let openai_messages = Self::convert_messages(messages)?;
701 let options = options.cloned().unwrap_or_default();
702
703 let openai_tools: Vec<ChatCompletionTools> = tools
706 .iter()
707 .map(|tool| {
708 ChatCompletionTools::Function(ChatCompletionTool {
709 function: FunctionObjectArgs::default()
710 .name(&tool.function.name)
711 .description(&tool.function.description)
712 .parameters(tool.function.parameters.clone())
713 .build()
714 .expect("Invalid tool definition"),
715 })
716 })
717 .collect();
718
719 let mut request_builder = CreateChatCompletionRequestArgs::default();
721 request_builder
722 .model(&self.model)
723 .messages(openai_messages)
724 .tools(openai_tools)
725 .stream(true)
726 .stream_options(ChatCompletionStreamOptions {
727 include_usage: Some(true),
728 include_obfuscation: None,
729 }); if let Some(tc) = tool_choice {
734 match tc {
735 ToolChoice::Auto(_) => {
736 request_builder.tool_choice(ChatCompletionToolChoiceOption::Mode(
737 ToolChoiceOptions::Auto,
738 ));
739 }
740 ToolChoice::Required(_) => {
741 request_builder.tool_choice(ChatCompletionToolChoiceOption::Mode(
742 ToolChoiceOptions::Required,
743 ));
744 }
745 ToolChoice::Function { ref function, .. } => {
746 request_builder.tool_choice(ChatCompletionToolChoiceOption::Function(
748 ChatCompletionNamedToolChoice {
749 function: FunctionName {
750 name: function.name.clone(),
751 },
752 },
753 ));
754 }
755 }
756 }
757
758 if let Some(temp) = options.temperature {
759 if (temp - 1.0_f32).abs() > f32::EPSILON {
762 request_builder.temperature(temp);
763 }
764 }
765
766 if let Some(max_tokens) = options.max_tokens {
767 request_builder.max_completion_tokens(max_tokens as u32);
769 }
770
771 let request = request_builder
772 .build()
773 .map_err(|e| LlmError::InvalidRequest(e.to_string()))?;
774
775 let stream = self.client.chat().create_stream(request).await?;
776
777 let mapped_stream = stream.map(|result| {
779 match result {
780 Ok(response) => {
781 let stream_usage = Self::extract_stream_usage(response.usage.clone());
782
783 let choice = response.choices.first();
784 if let Some(choice) = choice {
785 if let Some(content) = &choice.delta.content {
787 return Ok(StreamChunk::Content(content.clone()));
788 }
789
790 if let Some(tool_call_chunks) = &choice.delta.tool_calls {
792 if let Some(chunk) = tool_call_chunks.first() {
793 return Ok(StreamChunk::ToolCallDelta {
794 index: chunk.index as usize,
795 id: chunk.id.clone(),
796 function_name: chunk
797 .function
798 .as_ref()
799 .and_then(|f| f.name.clone()),
800 function_arguments: chunk
801 .function
802 .as_ref()
803 .and_then(|f| f.arguments.clone()),
804 thought_signature: None,
805 });
806 }
807 }
808
809 if let Some(finish_reason) = &choice.finish_reason {
811 let reason = match finish_reason {
812 FinishReason::Stop => "stop",
813 FinishReason::Length => "length",
814 FinishReason::ToolCalls => "tool_calls",
815 FinishReason::ContentFilter => "content_filter",
816 FinishReason::FunctionCall => "function_call",
817 };
818 return Ok(StreamChunk::Finished {
819 reason: reason.to_string(),
820 ttft_ms: None,
821 usage: stream_usage,
822 });
823 }
824 }
825 if stream_usage.is_some() {
826 return Ok(StreamChunk::Finished {
827 reason: "stop".to_string(),
828 ttft_ms: None,
829 usage: stream_usage,
830 });
831 }
832 Ok(StreamChunk::Content(String::new()))
834 }
835 Err(e) => Err(LlmError::from(e)),
839 }
840 });
841
842 Ok(mapped_stream.boxed())
843 }
844
845 fn supports_tool_streaming(&self) -> bool {
846 true
847 }
848
849 fn supports_json_mode(&self) -> bool {
850 let m = &self.model;
853 m.contains("gpt-4")
854 || m.contains("gpt-3.5-turbo")
855 || m.contains("gpt-5")
856 || m.starts_with("o1")
857 || m.starts_with("o3")
858 || m.starts_with("o4")
859 }
860}
861
862#[async_trait]
863impl EmbeddingProvider for OpenAIProvider {
864 fn name(&self) -> &str {
865 "openai"
866 }
867
868 #[allow(clippy::misnamed_getters)]
872 fn model(&self) -> &str {
873 &self.embedding_model
874 }
875
876 fn dimension(&self) -> usize {
877 self.embedding_dimension
878 }
879
880 fn max_tokens(&self) -> usize {
881 8191 }
883
884 async fn embed(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
885 if texts.is_empty() {
886 return Ok(Vec::new());
887 }
888
889 let base_url = if self.raw_base_url.is_empty() {
892 "https://api.openai.com/v1".to_string()
893 } else {
894 self.raw_base_url.trim_end_matches('/').to_string()
895 };
896 let url = format!("{}/embeddings", base_url);
897
898 let request_body = serde_json::json!({
899 "model": self.embedding_model,
900 "input": texts,
901 "encoding_format": "float"
902 });
903
904 let http_client = reqwest::Client::new();
905 let mut req = http_client.post(&url).json(&request_body);
906 if !self.raw_api_key.is_empty() {
907 req = req.header("Authorization", format!("Bearer {}", self.raw_api_key));
908 }
909
910 let response = req
911 .send()
912 .await
913 .map_err(|e| LlmError::NetworkError(format!("Embedding request failed: {}", e)))?;
914
915 let status = response.status();
916 let body = response.text().await.map_err(|e| {
917 LlmError::NetworkError(format!("Failed to read embedding response: {}", e))
918 })?;
919
920 if !status.is_success() {
921 return Err(LlmError::ApiError(format!(
922 "Embedding API returned {} {}: {}",
923 status.as_u16(),
924 status.canonical_reason().unwrap_or(""),
925 &body[..body.len().min(500)]
926 )));
927 }
928
929 #[derive(serde::Deserialize)]
931 struct LenientEmbeddingResponse {
932 data: Vec<LenientEmbeddingObject>,
933 }
934 #[derive(serde::Deserialize)]
935 struct LenientEmbeddingObject {
936 embedding: Vec<f32>,
937 }
938
939 let parsed: LenientEmbeddingResponse = serde_json::from_str(&body).map_err(|e| {
940 LlmError::InvalidRequest(format!(
941 "Failed to parse embedding response: {} – body: {}",
942 e,
943 &body[..body.len().min(500)]
944 ))
945 })?;
946
947 Ok(parsed.data.into_iter().map(|o| o.embedding).collect())
948 }
949}
950
951#[cfg(test)]
952mod tests {
953 use super::*;
954
955 #[test]
956 fn test_context_length_detection() {
957 assert_eq!(OpenAIProvider::context_length_for_model("gpt-4o"), 128000);
958 assert_eq!(OpenAIProvider::context_length_for_model("gpt-4"), 8192);
959 assert_eq!(
960 OpenAIProvider::context_length_for_model("gpt-3.5-turbo"),
961 4096
962 );
963 }
964
965 #[test]
966 fn test_embedding_dimension_detection() {
967 assert_eq!(
968 OpenAIProvider::dimension_for_model("text-embedding-3-large"),
969 3072
970 );
971 assert_eq!(
972 OpenAIProvider::dimension_for_model("text-embedding-3-small"),
973 1536
974 );
975 }
976
977 #[test]
978 fn test_provider_builder() {
979 let provider = OpenAIProvider::new("test-key")
980 .with_model("gpt-4")
981 .with_embedding_model("text-embedding-3-large");
982
983 assert_eq!(LLMProvider::model(&provider), "gpt-4");
984 assert_eq!(provider.dimension(), 3072);
985 }
986
987 #[test]
988 fn test_message_conversion() {
989 let messages = vec![
990 ChatMessage::system("You are helpful"),
991 ChatMessage::user("Hello"),
992 ChatMessage::assistant("Hi there!"),
993 ];
994
995 let converted = OpenAIProvider::convert_messages(&messages).unwrap();
996 assert_eq!(converted.len(), 3);
997 }
998
999 #[test]
1002 fn test_context_length_gpt5_series() {
1003 assert_eq!(
1004 OpenAIProvider::context_length_for_model("gpt-5.2-turbo"),
1005 200000
1006 );
1007 assert_eq!(
1008 OpenAIProvider::context_length_for_model("gpt-5.1-preview"),
1009 200000
1010 );
1011 assert_eq!(
1012 OpenAIProvider::context_length_for_model("gpt-5-nano"),
1013 128000
1014 );
1015 assert_eq!(
1016 OpenAIProvider::context_length_for_model("gpt-5-mini"),
1017 200000
1018 );
1019 assert_eq!(OpenAIProvider::context_length_for_model("gpt-5"), 200000);
1020 }
1021
1022 #[test]
1023 fn test_context_length_o_series() {
1024 assert_eq!(OpenAIProvider::context_length_for_model("o4-mini"), 200000);
1025 assert_eq!(
1026 OpenAIProvider::context_length_for_model("o3-preview"),
1027 200000
1028 );
1029 assert_eq!(
1030 OpenAIProvider::context_length_for_model("o1-preview"),
1031 200000
1032 );
1033 }
1034
1035 #[test]
1036 fn test_context_length_gpt4_variants() {
1037 assert_eq!(
1038 OpenAIProvider::context_length_for_model("gpt-4-turbo-preview"),
1039 128000
1040 );
1041 assert_eq!(
1042 OpenAIProvider::context_length_for_model("gpt-4-32k-0613"),
1043 32768
1044 );
1045 assert_eq!(OpenAIProvider::context_length_for_model("gpt-4-0613"), 8192);
1046 }
1047
1048 #[test]
1049 fn test_context_length_gpt35_variants() {
1050 assert_eq!(
1051 OpenAIProvider::context_length_for_model("gpt-3.5-turbo-16k"),
1052 16384
1053 );
1054 assert_eq!(
1055 OpenAIProvider::context_length_for_model("gpt-3.5-turbo-1106"),
1056 4096
1057 );
1058 }
1059
1060 #[test]
1061 fn test_context_length_unknown_defaults_high() {
1062 assert_eq!(
1064 OpenAIProvider::context_length_for_model("unknown-future-model"),
1065 128000
1066 );
1067 }
1068
1069 #[test]
1070 fn test_dimension_ada_model() {
1071 assert_eq!(
1072 OpenAIProvider::dimension_for_model("text-embedding-ada-002"),
1073 1536
1074 );
1075 }
1076
1077 #[test]
1078 fn test_dimension_unknown_defaults() {
1079 assert_eq!(
1080 OpenAIProvider::dimension_for_model("unknown-embedding"),
1081 1536
1082 );
1083 }
1084
1085 #[test]
1086 fn test_provider_name() {
1087 let provider = OpenAIProvider::new("test-key");
1088 assert_eq!(LLMProvider::name(&provider), "openai");
1089 }
1090
1091 #[test]
1092 fn test_provider_max_context_length() {
1093 let provider = OpenAIProvider::new("test-key").with_model("gpt-4");
1094 assert_eq!(provider.max_context_length(), 8192);
1095 }
1096
1097 #[test]
1098 fn test_provider_dimension() {
1099 let provider =
1100 OpenAIProvider::new("test-key").with_embedding_model("text-embedding-3-large");
1101 assert_eq!(provider.dimension(), 3072);
1102 }
1103
1104 #[test]
1105 fn test_provider_embedding_model() {
1106 let provider =
1107 OpenAIProvider::new("test-key").with_embedding_model("text-embedding-3-small");
1108 assert_eq!(
1109 EmbeddingProvider::model(&provider),
1110 "text-embedding-3-small"
1111 );
1112 }
1113
1114 #[test]
1115 fn test_message_conversion_tool_role() {
1116 let messages = vec![ChatMessage::tool_result("call_abc", "result data")];
1118 let converted = OpenAIProvider::convert_messages(&messages).unwrap();
1119 assert_eq!(converted.len(), 1);
1120 match &converted[0] {
1121 ChatCompletionRequestMessage::Tool(m) => {
1122 assert_eq!(m.tool_call_id, "call_abc");
1123 }
1124 other => panic!("Expected Tool message, got {:?}", other),
1125 }
1126 }
1127
1128 #[test]
1129 fn test_tool_message_missing_id_returns_err() {
1130 let mut msg = ChatMessage::user("orphan");
1131 msg.role = ChatRole::Tool;
1132 msg.tool_call_id = None;
1133 let r = OpenAIProvider::convert_messages(&[msg]);
1134 assert!(
1135 r.is_err(),
1136 "Expected Err for tool message without tool_call_id"
1137 );
1138 }
1139
1140 #[test]
1141 fn test_assistant_with_tool_calls_serialized() {
1142 let calls = vec![ToolCall {
1144 id: "call_xyz".to_string(),
1145 call_type: "function".to_string(),
1146 function: TraitFunctionCall {
1147 name: "get_weather".to_string(),
1148 arguments: r#"{"city":"Paris"}"#.to_string(),
1149 },
1150 thought_signature: None,
1151 }];
1152 let msg = ChatMessage::assistant_with_tools("", calls);
1153 let converted = OpenAIProvider::convert_messages(&[msg]).unwrap();
1154 assert_eq!(converted.len(), 1);
1155 match &converted[0] {
1156 ChatCompletionRequestMessage::Assistant(m) => {
1157 let tcs = m.tool_calls.as_ref().expect("tool_calls must be present");
1158 assert_eq!(tcs.len(), 1);
1159 if let ChatCompletionMessageToolCalls::Function(f) = &tcs[0] {
1160 assert_eq!(f.id, "call_xyz");
1161 assert_eq!(f.function.name, "get_weather");
1162 } else {
1163 panic!("Expected Function tool call");
1164 }
1165 }
1166 other => panic!("Expected Assistant message, got {:?}", other),
1167 }
1168 }
1169
1170 #[test]
1171 fn test_supports_streaming() {
1172 let provider = OpenAIProvider::new("test-key");
1173 assert!(provider.supports_streaming());
1174 }
1175
1176 #[test]
1177 fn test_supports_json_mode_gpt4() {
1178 let provider = OpenAIProvider::new("test-key").with_model("gpt-4o");
1179 assert!(provider.supports_json_mode());
1180 }
1181
1182 #[test]
1183 fn test_supports_json_mode_gpt35() {
1184 let provider = OpenAIProvider::new("test-key").with_model("gpt-3.5-turbo");
1185 assert!(provider.supports_json_mode());
1186 }
1187
1188 #[test]
1189 fn test_supports_json_mode_default_is_false() {
1190 let provider = OpenAIProvider::new("test-key").with_model("davinci-002");
1192 assert!(!provider.supports_json_mode());
1193 }
1194
1195 #[test]
1198 fn test_build_user_content_text_only() {
1199 let msg = ChatMessage::user("Hello");
1200 let content = OpenAIProvider::build_user_content(&msg);
1201 match content {
1202 ChatCompletionRequestUserMessageContent::Text(t) => assert_eq!(t, "Hello"),
1203 _ => panic!("Expected text content"),
1204 }
1205 }
1206
1207 #[test]
1208 fn test_build_user_content_with_image() {
1209 use crate::traits::ImageData;
1210 let img = ImageData::new("base64data", "image/png");
1211 let msg = ChatMessage::user_with_images("Describe this", vec![img]);
1212 let content = OpenAIProvider::build_user_content(&msg);
1213 match content {
1214 ChatCompletionRequestUserMessageContent::Array(parts) => {
1215 assert_eq!(parts.len(), 2, "Should have text + image parts");
1216 assert!(
1217 matches!(
1218 parts[0],
1219 ChatCompletionRequestUserMessageContentPart::Text(_)
1220 ),
1221 "First part should be text"
1222 );
1223 assert!(
1224 matches!(
1225 parts[1],
1226 ChatCompletionRequestUserMessageContentPart::ImageUrl(_)
1227 ),
1228 "Second part should be image_url"
1229 );
1230 }
1231 _ => panic!("Expected array content for vision message"),
1232 }
1233 }
1234
1235 #[test]
1236 fn test_build_user_content_image_data_uri() {
1237 use crate::traits::ImageData;
1238 let img = ImageData::new("abc123", "image/jpeg");
1239 let msg = ChatMessage::user_with_images("What's here?", vec![img]);
1240 let content = OpenAIProvider::build_user_content(&msg);
1241 if let ChatCompletionRequestUserMessageContent::Array(parts) = content {
1242 if let ChatCompletionRequestUserMessageContentPart::ImageUrl(img_part) = &parts[1] {
1243 assert_eq!(
1244 img_part.image_url.url, "data:image/jpeg;base64,abc123",
1245 "Data URI should be correct"
1246 );
1247 } else {
1248 panic!("Expected ImageUrl part");
1249 }
1250 } else {
1251 panic!("Expected array content");
1252 }
1253 }
1254
1255 #[test]
1256 fn test_build_user_content_image_with_detail() {
1257 use crate::traits::ImageData;
1258 let img = ImageData::new("data", "image/png").with_detail("high");
1259 let _msg = ChatMessage::user_with_images("Analyze", vec![img]);
1260 let detail = OpenAIProvider::parse_image_detail(
1261 &ImageData::new("x", "image/png").with_detail("high"),
1262 );
1263 assert!(matches!(detail, Some(ImageDetail::High)));
1264 }
1265
1266 #[test]
1267 fn test_parse_image_detail_low() {
1268 use crate::traits::ImageData;
1269 let img = ImageData::new("x", "image/png").with_detail("low");
1270 let d = OpenAIProvider::parse_image_detail(&img);
1271 assert!(matches!(d, Some(ImageDetail::Low)));
1272 }
1273
1274 #[test]
1275 fn test_parse_image_detail_auto() {
1276 use crate::traits::ImageData;
1277 let img = ImageData::new("x", "image/png").with_detail("auto");
1278 let d = OpenAIProvider::parse_image_detail(&img);
1279 assert!(matches!(d, Some(ImageDetail::Auto)));
1280 }
1281
1282 #[test]
1283 fn test_parse_image_detail_none() {
1284 use crate::traits::ImageData;
1285 let img = ImageData::new("x", "image/png");
1286 let d = OpenAIProvider::parse_image_detail(&img);
1287 assert!(d.is_none());
1288 }
1289
1290 #[test]
1291 fn test_convert_messages_with_image_produces_array_content() {
1292 use crate::traits::ImageData;
1293 let img = ImageData::new("iVBORw0KGgo", "image/png");
1294 let messages = vec![
1295 ChatMessage::system("You are a vision assistant"),
1296 ChatMessage::user_with_images("What is in this image?", vec![img]),
1297 ];
1298 let converted = OpenAIProvider::convert_messages(&messages).unwrap();
1299 assert_eq!(converted.len(), 2);
1300 let json = serde_json::to_value(&converted[1]).unwrap();
1303 let content = &json["content"];
1304 assert!(
1305 content.is_array(),
1306 "Vision user message content must be a JSON array, got: {:?}",
1307 content
1308 );
1309 let parts = content.as_array().unwrap();
1310 assert_eq!(parts.len(), 2, "Should have text + image parts");
1311 assert_eq!(parts[0]["type"], "text");
1312 assert_eq!(parts[1]["type"], "image_url");
1313 assert!(parts[1]["image_url"]["url"]
1314 .as_str()
1315 .unwrap()
1316 .starts_with("data:image/png;base64,"));
1317 }
1318
1319 #[test]
1320 fn test_convert_messages_without_image_produces_text_content() {
1321 let messages = vec![ChatMessage::user("Just text")];
1322 let converted = OpenAIProvider::convert_messages(&messages).unwrap();
1323 let json = serde_json::to_value(&converted[0]).unwrap();
1324 let content = &json["content"];
1325 assert!(
1326 content.is_string(),
1327 "Plain text user message content must be a JSON string"
1328 );
1329 assert_eq!(content.as_str().unwrap(), "Just text");
1330 }
1331
1332 #[test]
1338 fn test_chat_completion_tools_function_wrapping() {
1339 use crate::traits::FunctionDefinition;
1340 let tool_def = ToolDefinition {
1341 tool_type: "function".to_string(),
1342 function: FunctionDefinition {
1343 name: "get_weather".to_string(),
1344 description: "Get the current weather".to_string(),
1345 parameters: serde_json::json!({
1346 "type": "object",
1347 "properties": {
1348 "location": { "type": "string" }
1349 },
1350 "required": ["location"]
1351 }),
1352 strict: None,
1353 },
1354 };
1355
1356 let openai_tool = ChatCompletionTools::Function(ChatCompletionTool {
1358 function: FunctionObjectArgs::default()
1359 .name(&tool_def.function.name)
1360 .description(&tool_def.function.description)
1361 .parameters(tool_def.function.parameters.clone())
1362 .build()
1363 .unwrap(),
1364 });
1365
1366 let json = serde_json::to_value(&openai_tool).unwrap();
1368 assert_eq!(json["type"], "function");
1369 assert_eq!(json["function"]["name"], "get_weather");
1370 assert_eq!(json["function"]["description"], "Get the current weather");
1371 }
1372
1373 #[test]
1376 fn test_tool_choice_auto_serialization() {
1377 let choice = ChatCompletionToolChoiceOption::Mode(ToolChoiceOptions::Auto);
1378 let json = serde_json::to_value(&choice).unwrap();
1379 assert_eq!(json, "auto");
1380 }
1381
1382 #[test]
1383 fn test_tool_choice_required_serialization() {
1384 let choice = ChatCompletionToolChoiceOption::Mode(ToolChoiceOptions::Required);
1385 let json = serde_json::to_value(&choice).unwrap();
1386 assert_eq!(json, "required");
1387 }
1388
1389 #[test]
1390 fn test_tool_choice_none_serialization() {
1391 let choice = ChatCompletionToolChoiceOption::Mode(ToolChoiceOptions::None);
1392 let json = serde_json::to_value(&choice).unwrap();
1393 assert_eq!(json, "none");
1394 }
1395
1396 #[test]
1400 fn test_max_completion_tokens_in_request_serialization() {
1401 let request = CreateChatCompletionRequestArgs::default()
1402 .model("o3-mini")
1403 .messages(vec![ChatCompletionRequestUserMessageArgs::default()
1404 .content("Hello")
1405 .build()
1406 .unwrap()
1407 .into()])
1408 .max_completion_tokens(1024u32)
1409 .build()
1410 .unwrap();
1411
1412 let json = serde_json::to_value(&request).unwrap();
1413 assert_eq!(
1414 json["max_completion_tokens"], 1024,
1415 "max_completion_tokens should be set in request"
1416 );
1417 assert!(
1418 json["max_tokens"].is_null(),
1419 "deprecated max_tokens should NOT be set"
1420 );
1421 }
1422
1423 #[test]
1426 fn test_max_completion_tokens_works_for_all_models() {
1427 for model in &[
1428 "gpt-4o",
1429 "gpt-3.5-turbo",
1430 "o1-preview",
1431 "o3-mini",
1432 "gpt-4.1-nano",
1433 ] {
1434 let request = CreateChatCompletionRequestArgs::default()
1435 .model(*model)
1436 .messages(vec![ChatCompletionRequestUserMessageArgs::default()
1437 .content("Test")
1438 .build()
1439 .unwrap()
1440 .into()])
1441 .max_completion_tokens(512u32)
1442 .build()
1443 .unwrap();
1444
1445 let json = serde_json::to_value(&request).unwrap();
1446 assert_eq!(
1447 json["max_completion_tokens"], 512,
1448 "max_completion_tokens should be set for model {}",
1449 model
1450 );
1451 }
1452 }
1453
1454 #[test]
1457 fn test_cache_hit_token_extraction() {
1458 use async_openai::types::chat::PromptTokensDetails;
1459
1460 let usage = CompletionUsage {
1461 prompt_tokens: 100,
1462 completion_tokens: 50,
1463 total_tokens: 150,
1464 prompt_tokens_details: Some(PromptTokensDetails {
1465 cached_tokens: Some(80),
1466 audio_tokens: None,
1467 }),
1468 completion_tokens_details: None,
1469 };
1470
1471 let cache_hit_tokens = usage
1472 .prompt_tokens_details
1473 .as_ref()
1474 .and_then(|d| d.cached_tokens)
1475 .map(|t| t as usize);
1476
1477 assert_eq!(cache_hit_tokens, Some(80));
1478 }
1479
1480 #[test]
1483 fn test_reasoning_token_extraction() {
1484 use async_openai::types::chat::CompletionTokensDetails;
1485
1486 let usage = CompletionUsage {
1487 prompt_tokens: 50,
1488 completion_tokens: 200,
1489 total_tokens: 250,
1490 prompt_tokens_details: None,
1491 completion_tokens_details: Some(CompletionTokensDetails {
1492 reasoning_tokens: Some(150),
1493 audio_tokens: None,
1494 accepted_prediction_tokens: None,
1495 rejected_prediction_tokens: None,
1496 }),
1497 };
1498
1499 let thinking_tokens = usage
1500 .completion_tokens_details
1501 .as_ref()
1502 .and_then(|d| d.reasoning_tokens)
1503 .map(|t| t as usize);
1504
1505 assert_eq!(thinking_tokens, Some(150));
1506 }
1507
1508 #[test]
1510 fn test_token_details_none_is_safe() {
1511 let usage = CompletionUsage {
1512 prompt_tokens: 10,
1513 completion_tokens: 20,
1514 total_tokens: 30,
1515 prompt_tokens_details: None,
1516 completion_tokens_details: None,
1517 };
1518
1519 let cache_hit = usage
1520 .prompt_tokens_details
1521 .as_ref()
1522 .and_then(|d| d.cached_tokens)
1523 .map(|t| t as usize);
1524
1525 let reasoning = usage
1526 .completion_tokens_details
1527 .as_ref()
1528 .and_then(|d| d.reasoning_tokens)
1529 .map(|t| t as usize);
1530
1531 assert_eq!(cache_hit, None);
1532 assert_eq!(reasoning, None);
1533 }
1534
1535 #[test]
1538 fn test_finish_reason_variants() {
1539 let cases = vec![
1540 (FinishReason::Stop, "Stop"),
1541 (FinishReason::Length, "Length"),
1542 (FinishReason::ToolCalls, "ToolCalls"),
1543 (FinishReason::ContentFilter, "ContentFilter"),
1544 (FinishReason::FunctionCall, "FunctionCall"),
1545 ];
1546
1547 for (reason, expected_debug) in cases {
1548 let formatted = format!("{:?}", reason);
1549 assert_eq!(
1550 formatted, expected_debug,
1551 "FinishReason::{} should format as {:?}",
1552 expected_debug, expected_debug
1553 );
1554 }
1555 }
1556
1557 #[test]
1560 fn test_json_deserialize_error_conversion() {
1561 use crate::error::LlmError;
1562 let serde_err = serde_json::from_str::<serde_json::Value>("invalid json {{").unwrap_err();
1564 let openai_err = async_openai::error::OpenAIError::JSONDeserialize(
1565 serde_err,
1566 "invalid json {{".to_string(),
1567 );
1568 let llm_err = LlmError::from(openai_err);
1569 assert!(
1570 matches!(llm_err, LlmError::SerializationError(_)),
1571 "JSONDeserialize error should convert to SerializationError"
1572 );
1573 }
1574
1575 #[test]
1579 fn test_chat_completion_tool_serialization() {
1580 let tool = ChatCompletionTool {
1581 function: FunctionObjectArgs::default()
1582 .name("my_func")
1583 .description("A test function")
1584 .parameters(serde_json::json!({"type": "object"}))
1585 .build()
1586 .unwrap(),
1587 };
1588 let wrapped = ChatCompletionTools::Function(tool);
1589 let json = serde_json::to_value(&wrapped).unwrap();
1590
1591 assert_eq!(json["type"], "function");
1593 assert_eq!(json["function"]["name"], "my_func");
1594 }
1595}