1use std::fmt;
32
33use crate::error::LlmError;
34use crate::tool_desc::build_tool_description;
35use base64::{Engine, engine::general_purpose::STANDARD};
36use serde::{Deserialize, Serialize};
37
38#[derive(Debug, Clone)]
59pub struct OpenAiConfig {
60 pub api_key: String,
62 pub base_url: String,
65 pub model: String,
67 pub max_tokens: u32,
69 pub embedding_model: Option<String>,
71 pub reasoning_effort: Option<String>,
74}
75
76use crate::provider::{
77 ChatExtras, ChatResponse, ChatStream, GenerationOverrides, LlmProvider, Message, MessagePart,
78 Role, StatusTx, ToolDefinition, ToolUseRequest,
79};
80use crate::retry::send_with_retry;
81use crate::sse::openai_sse_to_stream;
82use crate::usage::UsageTracker;
83
84const MAX_RETRIES: u32 = 3;
85
86pub struct OpenAiProvider {
95 client: reqwest::Client,
96 api_key: String,
97 base_url: String,
98 model: String,
99 max_tokens: u32,
100 embedding_model: Option<String>,
101 reasoning_effort: Option<String>,
103 pub(crate) status_tx: Option<StatusTx>,
104 usage: UsageTracker,
105 generation_overrides: Option<GenerationOverrides>,
106 forward_output_schema: bool,
108 output_schema_hint_bytes: usize,
110 max_tool_description_bytes: usize,
112}
113
114impl fmt::Debug for OpenAiProvider {
115 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
116 f.debug_struct("OpenAiProvider")
117 .field("client", &"<reqwest::Client>")
118 .field("api_key", &"<redacted>")
119 .field("base_url", &self.base_url)
120 .field("model", &self.model)
121 .field("max_tokens", &self.max_tokens)
122 .field("embedding_model", &self.embedding_model)
123 .field("reasoning_effort", &self.reasoning_effort)
124 .field("status_tx", &self.status_tx.is_some())
125 .field("usage", &self.usage)
126 .field("generation_overrides", &self.generation_overrides)
127 .field("forward_output_schema", &self.forward_output_schema)
128 .field("output_schema_hint_bytes", &self.output_schema_hint_bytes)
129 .field(
130 "max_tool_description_bytes",
131 &self.max_tool_description_bytes,
132 )
133 .finish()
134 }
135}
136
137impl Clone for OpenAiProvider {
138 fn clone(&self) -> Self {
139 Self {
140 client: self.client.clone(),
141 api_key: self.api_key.clone(),
142 base_url: self.base_url.clone(),
143 model: self.model.clone(),
144 max_tokens: self.max_tokens,
145 embedding_model: self.embedding_model.clone(),
146 reasoning_effort: self.reasoning_effort.clone(),
147 status_tx: self.status_tx.clone(),
148 usage: UsageTracker::default(),
149 generation_overrides: self.generation_overrides.clone(),
150 forward_output_schema: self.forward_output_schema,
151 output_schema_hint_bytes: self.output_schema_hint_bytes,
152 max_tool_description_bytes: self.max_tool_description_bytes,
153 }
154 }
155}
156
157impl OpenAiProvider {
158 #[must_use]
160 pub fn new(cfg: OpenAiConfig) -> Self {
161 let mut base_url = cfg.base_url;
162 while base_url.ends_with('/') {
163 base_url.pop();
164 }
165 Self {
166 client: crate::http::llm_client(600),
167 api_key: cfg.api_key,
168 base_url,
169 model: cfg.model,
170 max_tokens: cfg.max_tokens,
171 embedding_model: cfg.embedding_model,
172 reasoning_effort: cfg.reasoning_effort,
173 status_tx: None,
174 usage: UsageTracker::default(),
175 generation_overrides: None,
176 forward_output_schema: false,
177 output_schema_hint_bytes: 1024,
178 max_tool_description_bytes: usize::MAX,
179 }
180 }
181
182 #[must_use]
184 pub fn with_generation_overrides(mut self, overrides: GenerationOverrides) -> Self {
185 self.generation_overrides = Some(overrides);
186 self
187 }
188
189 #[must_use]
193 pub fn with_output_schema_forwarding(
194 mut self,
195 enabled: bool,
196 hint_bytes: usize,
197 max_description_bytes: usize,
198 ) -> Self {
199 self.forward_output_schema = enabled;
200 self.output_schema_hint_bytes = hint_bytes;
201 self.max_tool_description_bytes = max_description_bytes;
202 self
203 }
204
205 #[must_use]
207 pub fn with_client(mut self, client: reqwest::Client) -> Self {
208 self.client = client;
209 self
210 }
211
212 #[must_use]
214 pub fn with_status_tx(mut self, tx: StatusTx) -> Self {
215 self.status_tx = Some(tx);
216 self
217 }
218
219 #[must_use]
224 pub fn cache_slug(&self) -> String {
225 let host = self
226 .base_url
227 .trim_start_matches("https://")
228 .trim_start_matches("http://")
229 .split('/')
230 .next()
231 .unwrap_or("openai")
232 .split(':')
233 .next()
234 .unwrap_or("openai");
235 let slug: String = host
236 .chars()
237 .map(|c| if c == '.' || c == '-' { '_' } else { c })
238 .filter(|c| c.is_ascii_alphanumeric() || *c == '_')
239 .collect();
240 if slug.is_empty() {
241 "openai".to_string()
242 } else {
243 slug
244 }
245 }
246
247 pub async fn list_models_remote(
253 &self,
254 ) -> Result<Vec<crate::model_cache::RemoteModelInfo>, LlmError> {
255 let url = format!("{}/models", self.base_url);
256 let resp = self
257 .client
258 .get(&url)
259 .bearer_auth(&self.api_key)
260 .send()
261 .await?;
262
263 if !resp.status().is_success() {
264 let status = resp.status();
265 let body = resp.text().await.unwrap_or_default();
266 tracing::debug!(status = %status, body = %body, "OpenAI list_models_remote error body");
267 return Err(LlmError::ApiError {
268 provider: "openai".into(),
269 status: status.as_u16(),
270 });
271 }
272
273 let page: serde_json::Value = resp.json().await?;
274 let models: Vec<crate::model_cache::RemoteModelInfo> = page
275 .get("data")
276 .and_then(|v| v.as_array())
277 .map(|arr| {
278 arr.iter()
279 .filter_map(|item| {
280 let id = item.get("id")?.as_str()?.to_string();
281 let created_at = item.get("created").and_then(serde_json::Value::as_i64);
282 Some(crate::model_cache::RemoteModelInfo {
283 display_name: id.clone(),
284 id,
285 context_window: None,
286 created_at,
287 })
288 })
289 .collect()
290 })
291 .unwrap_or_default();
292
293 let slug = self.cache_slug();
294 let cache = crate::model_cache::ModelCache::for_slug(&slug);
295 cache.save(&models)?;
296 Ok(models)
297 }
298
299 fn store_cache_usage(&self, usage: &OpenAiUsage) {
300 self.usage
301 .record_usage(usage.prompt_tokens, usage.completion_tokens);
302 let cached = usage
303 .prompt_tokens_details
304 .as_ref()
305 .map_or(0, |d| d.cached_tokens);
306 if cached > 0 {
307 self.usage.record_cache(0, cached);
308 }
309 let reasoning = usage
310 .completion_tokens_details
311 .as_ref()
312 .map_or(0, |d| d.reasoning_tokens);
313 if reasoning > 0 {
314 self.usage.record_reasoning(reasoning);
315 }
316 tracing::debug!(
317 prompt_tokens = usage.prompt_tokens,
318 cached_tokens = cached,
319 completion_tokens = usage.completion_tokens,
320 reasoning_tokens = reasoning,
321 "OpenAI API usage"
322 );
323 }
324
325 async fn send_request(&self, messages: &[Message]) -> Result<String, LlmError> {
326 let reasoning = self
327 .reasoning_effort
328 .as_deref()
329 .map(|effort| Reasoning { effort });
330
331 let (temperature, top_p, frequency_penalty, presence_penalty) =
332 if let Some(ref ov) = self.generation_overrides {
333 (
334 ov.temperature,
335 ov.top_p,
336 ov.frequency_penalty,
337 ov.presence_penalty,
338 )
339 } else {
340 (None, None, None, None)
341 };
342
343 let response = if has_image_parts(messages) {
344 let vision_messages = convert_messages_vision(messages);
345 let body = VisionChatRequest {
346 model: &self.model,
347 messages: vision_messages,
348 completion_tokens: CompletionTokens::for_model(&self.model, self.max_tokens),
349 stream: false,
350 reasoning,
351 temperature,
352 top_p,
353 frequency_penalty,
354 presence_penalty,
355 };
356 send_with_retry("OpenAI", MAX_RETRIES, self.status_tx.as_ref(), || {
357 self.openai_post(format!("{}/chat/completions", self.base_url))
358 .json(&body)
359 .send()
360 })
361 .await?
362 } else {
363 let api_messages = convert_messages(messages);
364 let body = ChatRequest {
365 model: &self.model,
366 messages: &api_messages,
367 completion_tokens: CompletionTokens::for_model(&self.model, self.max_tokens),
368 stream: false,
369 reasoning,
370 temperature,
371 top_p,
372 frequency_penalty,
373 presence_penalty,
374 };
375 send_with_retry("OpenAI", MAX_RETRIES, self.status_tx.as_ref(), || {
376 self.openai_post(format!("{}/chat/completions", self.base_url))
377 .json(&body)
378 .send()
379 })
380 .await?
381 };
382
383 let status = response.status();
384 let text = response.text().await.map_err(LlmError::Http)?;
385
386 if !status.is_success() {
387 tracing::error!("OpenAI API error {status}: {text}");
388 return Err(crate::http::map_error_response(status, &text, "openai"));
389 }
390
391 let resp: OpenAiChatResponse = serde_json::from_str(&text)?;
392
393 if let Some(ref usage) = resp.usage {
394 self.store_cache_usage(usage);
395 }
396
397 resp.choices
398 .first()
399 .map(|c| c.message.content.clone())
400 .ok_or(LlmError::EmptyResponse {
401 provider: "openai".into(),
402 })
403 }
404
405 async fn send_stream_request(
406 &self,
407 messages: &[Message],
408 ) -> Result<reqwest::Response, LlmError> {
409 let api_messages = convert_messages(messages);
410 let reasoning = self
411 .reasoning_effort
412 .as_deref()
413 .map(|effort| Reasoning { effort });
414
415 let (temperature, top_p, frequency_penalty, presence_penalty) =
416 if let Some(ref ov) = self.generation_overrides {
417 (
418 ov.temperature,
419 ov.top_p,
420 ov.frequency_penalty,
421 ov.presence_penalty,
422 )
423 } else {
424 (None, None, None, None)
425 };
426
427 let body = ChatRequest {
428 model: &self.model,
429 messages: &api_messages,
430 completion_tokens: CompletionTokens::for_model(&self.model, self.max_tokens),
431 stream: true,
432 reasoning,
433 temperature,
434 top_p,
435 frequency_penalty,
436 presence_penalty,
437 };
438
439 let response = send_with_retry("OpenAI", MAX_RETRIES, self.status_tx.as_ref(), || {
440 self.openai_post(format!("{}/chat/completions", self.base_url))
441 .json(&body)
442 .send()
443 })
444 .await?;
445
446 let status = response.status();
447
448 if !status.is_success() {
449 let text = response.text().await.map_err(LlmError::Http)?;
450 tracing::error!("OpenAI API streaming request error {status}: {text}");
451 return Err(crate::http::map_error_response(status, &text, "openai"));
452 }
453
454 Ok(response)
455 }
456
457 fn openai_post(&self, url: String) -> reqwest::RequestBuilder {
459 self.client
460 .post(url)
461 .bearer_auth(&self.api_key)
462 .header("Content-Type", "application/json")
463 }
464}
465
466impl LlmProvider for OpenAiProvider {
467 fn context_window(&self) -> Option<usize> {
468 if self.model.starts_with("gpt-4o") || self.model.starts_with("gpt-4") {
469 Some(128_000)
470 } else if self.model.starts_with("gpt-3.5") {
471 Some(16_385)
472 } else if self.model.starts_with("gpt-5") {
473 Some(1_000_000)
474 } else if starts_with_o_digit(&self.model) {
475 Some(200_000)
476 } else {
477 None
478 }
479 }
480
481 #[cfg_attr(
482 feature = "profiling",
483 tracing::instrument(
484 name = "llm.chat",
485 skip_all,
486 fields(provider = self.name(), model = self.model_identifier())
487 )
488 )]
489 async fn chat(&self, messages: &[Message]) -> Result<String, LlmError> {
490 self.send_request(messages).await
491 }
492
493 async fn chat_with_extras(
494 &self,
495 messages: &[Message],
496 ) -> Result<(String, ChatExtras), LlmError> {
497 Ok((self.send_request(messages).await?, ChatExtras::default()))
498 }
499
500 #[cfg_attr(
501 feature = "profiling",
502 tracing::instrument(
503 name = "llm.chat_stream",
504 skip_all,
505 fields(provider = self.name(), model = self.model_identifier())
506 )
507 )]
508 async fn chat_stream(&self, messages: &[Message]) -> Result<ChatStream, LlmError> {
509 let response = self.send_stream_request(messages).await?;
510 Ok(openai_sse_to_stream(response))
511 }
512
513 fn supports_streaming(&self) -> bool {
514 true
515 }
516
517 #[cfg_attr(
518 feature = "profiling",
519 tracing::instrument(
520 name = "llm.embed",
521 skip_all,
522 fields(provider = self.name(), model = self.model_identifier())
523 )
524 )]
525 async fn embed(&self, text: &str) -> Result<Vec<f32>, LlmError> {
526 use crate::embed::truncate_for_embed;
527
528 let model = self
529 .embedding_model
530 .as_deref()
531 .ok_or(LlmError::EmbedUnsupported {
532 provider: "openai".into(),
533 })?;
534
535 let text = truncate_for_embed(text);
536 let body = EmbeddingRequest {
537 input: &text,
538 model,
539 };
540
541 let response = self
542 .openai_post(format!("{}/embeddings", self.base_url))
543 .json(&body)
544 .send()
545 .await?;
546
547 let status = response.status();
548 let body_text = response.text().await.map_err(LlmError::Http)?;
549
550 if !status.is_success() {
551 tracing::error!("OpenAI embedding API error {status}: {body_text}");
552 if status == reqwest::StatusCode::BAD_REQUEST {
553 return Err(LlmError::InvalidInput {
554 provider: "openai".into(),
555 message: body_text,
556 });
557 }
558 return Err(LlmError::ApiError {
559 provider: "openai".into(),
560 status: status.as_u16(),
561 });
562 }
563
564 let resp: EmbeddingResponse = serde_json::from_str(&body_text)?;
565
566 resp.data
567 .first()
568 .map(|d| d.embedding.clone())
569 .ok_or(LlmError::EmptyResponse {
570 provider: "openai".into(),
571 })
572 }
573
574 async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, LlmError> {
575 use crate::embed::truncate_for_embed;
576
577 if texts.is_empty() {
578 return Ok(Vec::new());
579 }
580
581 let model = self
582 .embedding_model
583 .as_deref()
584 .ok_or(LlmError::EmbedUnsupported {
585 provider: "openai".into(),
586 })?;
587
588 let truncated: Vec<std::borrow::Cow<'_, str>> =
589 texts.iter().map(|t| truncate_for_embed(t)).collect();
590 let refs: Vec<&str> = truncated.iter().map(std::convert::AsRef::as_ref).collect();
591
592 let body = EmbeddingBatchRequest { model, input: refs };
593
594 let response = self
595 .openai_post(format!("{}/embeddings", self.base_url))
596 .json(&body)
597 .send()
598 .await?;
599
600 let status = response.status();
601 let body_text = response.text().await.map_err(LlmError::Http)?;
602
603 if !status.is_success() {
604 tracing::error!("OpenAI batch embedding API error {status}: {body_text}");
605 if status == reqwest::StatusCode::BAD_REQUEST {
606 return Err(LlmError::InvalidInput {
607 provider: "openai".into(),
608 message: body_text,
609 });
610 }
611 return Err(LlmError::ApiError {
612 provider: "openai".into(),
613 status: status.as_u16(),
614 });
615 }
616
617 let resp: EmbeddingResponse = serde_json::from_str(&body_text)?;
618
619 if resp.data.len() != texts.len() {
620 return Err(LlmError::Other(format!(
621 "OpenAI returned {} embeddings for {} inputs",
622 resp.data.len(),
623 texts.len()
624 )));
625 }
626
627 let mut data = resp.data;
629 data.sort_unstable_by_key(|d| d.index);
630
631 Ok(data.into_iter().map(|d| d.embedding).collect())
632 }
633
634 fn supports_embeddings(&self) -> bool {
635 self.embedding_model.is_some()
636 }
637
638 #[allow(clippy::unnecessary_literal_bound)]
639 fn name(&self) -> &str {
640 "openai"
641 }
642
643 fn model_identifier(&self) -> &str {
644 &self.model
645 }
646
647 fn list_models(&self) -> Vec<String> {
648 vec![self.model.clone()]
649 }
650
651 fn last_cache_usage(&self) -> Option<(u64, u64)> {
652 self.usage.last_cache_usage()
653 }
654
655 fn last_usage(&self) -> Option<(u64, u64)> {
656 self.usage.last_usage()
657 }
658
659 fn last_reasoning_tokens(&self) -> Option<u64> {
660 self.usage.last_reasoning()
661 }
662
663 fn debug_request_json(
664 &self,
665 messages: &[Message],
666 tools: &[ToolDefinition],
667 stream: bool,
668 ) -> serde_json::Value {
669 let reasoning = self
670 .reasoning_effort
671 .as_deref()
672 .map(|effort| Reasoning { effort });
673 let (temperature, top_p, frequency_penalty, presence_penalty) = self
674 .generation_overrides
675 .as_ref()
676 .map(|ov| {
677 (
678 ov.temperature,
679 ov.top_p,
680 ov.frequency_penalty,
681 ov.presence_penalty,
682 )
683 })
684 .unwrap_or_default();
685
686 if !tools.is_empty() {
687 let api_messages = convert_messages_structured(messages);
688 let descriptions: Vec<String> = tools
689 .iter()
690 .map(|t| {
691 build_tool_description(
692 &t.description,
693 t.output_schema.as_ref(),
694 self.forward_output_schema,
695 self.output_schema_hint_bytes,
696 self.max_tool_description_bytes,
697 t.name.as_str(),
698 )
699 })
700 .collect();
701 let api_tools: Vec<OpenAiTool<'_>> = tools
702 .iter()
703 .zip(descriptions.iter())
704 .map(|(t, desc)| OpenAiTool {
705 r#type: "function",
706 function: OpenAiFunction {
707 name: t.name.as_str(),
708 description: desc.as_str(),
709 parameters: prepare_tool_params(&t.parameters),
710 },
711 })
712 .collect();
713 let body = ToolChatRequest {
714 model: &self.model,
715 messages: &api_messages,
716 completion_tokens: CompletionTokens::for_model(&self.model, self.max_tokens),
717 tools: &api_tools,
718 reasoning,
719 temperature,
720 top_p,
721 frequency_penalty,
722 presence_penalty,
723 };
724 return serde_json::to_value(&body)
725 .unwrap_or_else(|e| serde_json::json!({ "serialization_error": e.to_string() }));
726 }
727
728 if has_image_parts(messages) {
729 let vision_messages = convert_messages_vision(messages);
730 let body = VisionChatRequest {
731 model: &self.model,
732 messages: vision_messages,
733 completion_tokens: CompletionTokens::for_model(&self.model, self.max_tokens),
734 stream,
735 reasoning,
736 temperature,
737 top_p,
738 frequency_penalty,
739 presence_penalty,
740 };
741 return serde_json::to_value(&body)
742 .unwrap_or_else(|e| serde_json::json!({ "serialization_error": e.to_string() }));
743 }
744
745 let api_messages = convert_messages(messages);
746 let body = ChatRequest {
747 model: &self.model,
748 messages: &api_messages,
749 completion_tokens: CompletionTokens::for_model(&self.model, self.max_tokens),
750 stream,
751 reasoning,
752 temperature,
753 top_p,
754 frequency_penalty,
755 presence_penalty,
756 };
757 serde_json::to_value(&body)
758 .unwrap_or_else(|e| serde_json::json!({ "serialization_error": e.to_string() }))
759 }
760
761 fn supports_structured_output(&self) -> bool {
762 true
763 }
764
765 async fn chat_typed<T>(&self, messages: &[Message]) -> Result<T, LlmError>
766 where
767 T: serde::de::DeserializeOwned + schemars::JsonSchema + 'static,
768 Self: Sized,
769 {
770 let (raw_schema, _) = crate::provider::cached_schema::<T>()?;
771 let mut schema_value = raw_schema;
772 inline_refs_openai(&mut schema_value, 8);
773 normalize_for_openai_strict(&mut schema_value, 16);
774 let type_name = crate::provider::short_type_name::<T>();
775
776 let api_messages = convert_messages(messages);
777 let body = TypedChatRequest {
778 model: &self.model,
779 messages: &api_messages,
780 completion_tokens: CompletionTokens::for_model(&self.model, self.max_tokens),
781 response_format: ResponseFormat {
782 r#type: "json_schema",
783 json_schema: JsonSchemaFormat {
784 name: type_name,
785 schema: schema_value,
786 strict: true,
787 },
788 },
789 };
790
791 let response = self
792 .openai_post(format!("{}/chat/completions", self.base_url))
793 .json(&body)
794 .send()
795 .await?;
796
797 let status = response.status();
798 let text = response.text().await.map_err(LlmError::Http)?;
799
800 if !status.is_success() {
801 return Err(crate::http::map_error_response(status, &text, "openai"));
802 }
803
804 let resp: OpenAiChatResponse = serde_json::from_str(&text)?;
805
806 if let Some(ref usage) = resp.usage {
807 self.store_cache_usage(usage);
808 }
809
810 let content = resp
811 .choices
812 .first()
813 .map(|c| c.message.content.as_str())
814 .ok_or(LlmError::EmptyResponse {
815 provider: "openai".into(),
816 })?;
817
818 serde_json::from_str::<T>(content).map_err(|e| LlmError::StructuredParse(e.to_string()))
819 }
820
821 fn supports_vision(&self) -> bool {
822 true
823 }
824
825 #[cfg_attr(
826 feature = "profiling",
827 tracing::instrument(
828 name = "llm.chat_with_tools",
829 skip_all,
830 fields(provider = self.name(), model = self.model_identifier(), tool_count = tools.len())
831 )
832 )]
833 async fn chat_with_tools(
834 &self,
835 messages: &[Message],
836 tools: &[ToolDefinition],
837 ) -> Result<ChatResponse, LlmError> {
838 let api_messages = convert_messages_structured(messages);
839 let reasoning = self
840 .reasoning_effort
841 .as_deref()
842 .map(|effort| Reasoning { effort });
843
844 let descriptions: Vec<String> = tools
845 .iter()
846 .map(|t| {
847 build_tool_description(
848 &t.description,
849 t.output_schema.as_ref(),
850 self.forward_output_schema,
851 self.output_schema_hint_bytes,
852 self.max_tool_description_bytes,
853 t.name.as_str(),
854 )
855 })
856 .collect();
857 let api_tools: Vec<OpenAiTool> = tools
858 .iter()
859 .zip(descriptions.iter())
860 .map(|(t, desc)| OpenAiTool {
861 r#type: "function",
862 function: OpenAiFunction {
863 name: t.name.as_str(),
864 description: desc.as_str(),
865 parameters: prepare_tool_params(&t.parameters),
866 },
867 })
868 .collect();
869
870 let (temperature, top_p, frequency_penalty, presence_penalty) = self
871 .generation_overrides
872 .as_ref()
873 .map(|ov| {
874 (
875 ov.temperature,
876 ov.top_p,
877 ov.frequency_penalty,
878 ov.presence_penalty,
879 )
880 })
881 .unwrap_or_default();
882 let body = ToolChatRequest {
883 model: &self.model,
884 messages: &api_messages,
885 completion_tokens: CompletionTokens::for_model(&self.model, self.max_tokens),
886 tools: &api_tools,
887 reasoning,
888 temperature,
889 top_p,
890 frequency_penalty,
891 presence_penalty,
892 };
893
894 let response = self
895 .openai_post(format!("{}/chat/completions", self.base_url))
896 .json(&body)
897 .send()
898 .await?;
899
900 let status = response.status();
901 let text = response.text().await.map_err(LlmError::Http)?;
902
903 if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
904 return Err(LlmError::RateLimited);
905 }
906
907 if status == reqwest::StatusCode::BAD_REQUEST {
908 tracing::warn!("OpenAI tool chat 400 bad request: {text}");
909 if crate::error::body_is_context_length_error(&text) {
910 return Err(LlmError::ContextLengthExceeded);
911 }
912 return Err(LlmError::InvalidInput {
913 provider: self.name().to_owned(),
914 message: text,
915 });
916 }
917
918 if !status.is_success() {
919 tracing::error!("OpenAI API error {status}: {text}");
920 return Err(LlmError::ApiError {
921 provider: "openai".into(),
922 status: status.as_u16(),
923 });
924 }
925
926 self.decode_tool_chat_response(&text, "openai")
927 }
928}
929
930impl OpenAiProvider {
931 pub(crate) fn decode_tool_chat_response(
936 &self,
937 text: &str,
938 provider_name: &str,
939 ) -> Result<ChatResponse, LlmError> {
940 let resp: ToolChatResponse = serde_json::from_str(text)?;
941
942 if let Some(ref usage) = resp.usage {
943 self.store_cache_usage(usage);
944 }
945
946 let choice = resp
947 .choices
948 .into_iter()
949 .next()
950 .ok_or(LlmError::EmptyResponse {
951 provider: provider_name.into(),
952 })?;
953
954 if let Some(tool_calls) = choice.message.tool_calls
955 && !tool_calls.is_empty()
956 {
957 let text = if choice.message.content.is_empty() {
958 None
959 } else {
960 Some(choice.message.content)
961 };
962 let calls = tool_calls
963 .into_iter()
964 .map(|tc| {
965 let input = serde_json::from_str(&tc.function.arguments)
966 .unwrap_or(serde_json::Value::Object(serde_json::Map::new()));
967 ToolUseRequest {
968 id: tc.id,
969 name: tc.function.name.into(),
970 input,
971 }
972 })
973 .collect();
974 return Ok(ChatResponse::ToolUse {
975 text,
976 tool_calls: calls,
977 thinking_blocks: vec![],
978 });
979 }
980
981 let content = if choice.finish_reason.as_deref() == Some("length") {
984 let truncation_marker = crate::provider::MAX_TOKENS_TRUNCATION_MARKER;
985 if choice.message.content.is_empty() {
986 format!(
987 "[Response truncated: {truncation_marker}. Please reduce the request scope.]"
988 )
989 } else {
990 format!(
991 "{}\n[Response truncated: {truncation_marker}.]",
992 choice.message.content
993 )
994 }
995 } else {
996 choice.message.content
997 };
998 Ok(ChatResponse::Text(content))
999 }
1000
1001 #[cfg(any(feature = "gonka", feature = "cocoon"))]
1011 pub(crate) fn build_typed_chat_body<T>(&self, messages: &[Message]) -> Result<Vec<u8>, LlmError>
1012 where
1013 T: serde::de::DeserializeOwned + schemars::JsonSchema + 'static,
1014 {
1015 let (raw_schema, _) = crate::provider::cached_schema::<T>()?;
1016 let mut schema_value = raw_schema;
1017 inline_refs_openai(&mut schema_value, 8);
1018 normalize_for_openai_strict(&mut schema_value, 16);
1019 let type_name = crate::provider::short_type_name::<T>();
1020
1021 let api_messages = convert_messages(messages);
1022 let body = TypedChatRequest {
1023 model: &self.model,
1024 messages: &api_messages,
1025 completion_tokens: CompletionTokens::for_model(&self.model, self.max_tokens),
1026 response_format: ResponseFormat {
1027 r#type: "json_schema",
1028 json_schema: JsonSchemaFormat {
1029 name: type_name,
1030 schema: schema_value,
1031 strict: true,
1032 },
1033 },
1034 };
1035
1036 serde_json::to_vec(&body).map_err(|e| LlmError::StructuredParse(e.to_string()))
1037 }
1038}
1039
1040#[derive(Serialize)]
1041#[serde(tag = "type", rename_all = "snake_case")]
1042enum OpenAiContentPart {
1043 Text { text: String },
1044 ImageUrl { image_url: ImageUrlDetail },
1045}
1046
1047#[derive(Serialize)]
1048struct ImageUrlDetail {
1049 url: String,
1050}
1051
1052#[derive(Serialize)]
1053struct VisionApiMessage {
1054 role: String,
1055 content: Vec<OpenAiContentPart>,
1056}
1057
1058#[derive(Serialize)]
1059struct VisionChatRequest<'a> {
1060 model: &'a str,
1061 messages: Vec<VisionApiMessage>,
1062 #[serde(flatten)]
1063 completion_tokens: CompletionTokens,
1064 #[serde(skip_serializing_if = "std::ops::Not::not")]
1065 stream: bool,
1066 #[serde(skip_serializing_if = "Option::is_none")]
1067 reasoning: Option<Reasoning<'a>>,
1068 #[serde(skip_serializing_if = "Option::is_none")]
1069 temperature: Option<f64>,
1070 #[serde(skip_serializing_if = "Option::is_none")]
1071 top_p: Option<f64>,
1072 #[serde(skip_serializing_if = "Option::is_none")]
1073 frequency_penalty: Option<f64>,
1074 #[serde(skip_serializing_if = "Option::is_none")]
1075 presence_penalty: Option<f64>,
1076}
1077
1078fn has_image_parts(messages: &[Message]) -> bool {
1079 messages
1080 .iter()
1081 .any(|m| m.parts.iter().any(|p| matches!(p, MessagePart::Image(_))))
1082}
1083
1084fn convert_messages_vision(messages: &[Message]) -> Vec<VisionApiMessage> {
1085 messages
1086 .iter()
1087 .map(|msg| {
1088 let role = match msg.role {
1089 Role::System => "system",
1090 Role::User => "user",
1091 Role::Assistant => "assistant",
1092 };
1093 let has_images = msg.parts.iter().any(|p| matches!(p, MessagePart::Image(_)));
1094 if has_images {
1095 let mut parts = Vec::new();
1096 let text_str: String = msg
1097 .parts
1098 .iter()
1099 .filter_map(MessagePart::as_plain_text)
1100 .collect::<Vec<_>>()
1101 .join("");
1102 if !text_str.is_empty() {
1103 parts.push(OpenAiContentPart::Text { text: text_str });
1104 }
1105 for part in &msg.parts {
1106 if let Some(img) = part.as_image() {
1107 let b64 = STANDARD.encode(&img.data);
1108 parts.push(OpenAiContentPart::ImageUrl {
1109 image_url: ImageUrlDetail {
1110 url: format!("data:{};base64,{b64}", img.mime_type),
1111 },
1112 });
1113 }
1114 }
1115 if parts.is_empty() {
1116 parts.push(OpenAiContentPart::Text {
1117 text: msg.to_llm_content().to_owned(),
1118 });
1119 }
1120 VisionApiMessage {
1121 role: role.to_owned(),
1122 content: parts,
1123 }
1124 } else {
1125 VisionApiMessage {
1126 role: role.to_owned(),
1127 content: vec![OpenAiContentPart::Text {
1128 text: msg.to_llm_content().to_owned(),
1129 }],
1130 }
1131 }
1132 })
1133 .collect()
1134}
1135
1136fn convert_messages(messages: &[Message]) -> Vec<ApiMessage<'_>> {
1137 messages
1138 .iter()
1139 .map(|msg| {
1140 let role = match msg.role {
1141 Role::System => "system",
1142 Role::User => "user",
1143 Role::Assistant => "assistant",
1144 };
1145 ApiMessage {
1146 role,
1147 content: msg.to_llm_content(),
1148 }
1149 })
1150 .collect()
1151}
1152
1153#[derive(Serialize)]
1154struct ChatRequest<'a> {
1155 model: &'a str,
1156 messages: &'a [ApiMessage<'a>],
1157 #[serde(flatten)]
1158 completion_tokens: CompletionTokens,
1159 #[serde(skip_serializing_if = "std::ops::Not::not")]
1160 stream: bool,
1161 #[serde(skip_serializing_if = "Option::is_none")]
1162 reasoning: Option<Reasoning<'a>>,
1163 #[serde(skip_serializing_if = "Option::is_none")]
1164 temperature: Option<f64>,
1165 #[serde(skip_serializing_if = "Option::is_none")]
1166 top_p: Option<f64>,
1167 #[serde(skip_serializing_if = "Option::is_none")]
1168 frequency_penalty: Option<f64>,
1169 #[serde(skip_serializing_if = "Option::is_none")]
1170 presence_penalty: Option<f64>,
1171}
1172
1173#[derive(Serialize)]
1174struct Reasoning<'a> {
1175 effort: &'a str,
1176}
1177
1178#[derive(Serialize)]
1179struct ApiMessage<'a> {
1180 role: &'a str,
1181 content: &'a str,
1182}
1183
1184#[derive(Deserialize)]
1185struct OpenAiChatResponse {
1186 choices: Vec<ChatChoice>,
1187 #[serde(default)]
1188 usage: Option<OpenAiUsage>,
1189}
1190
1191#[derive(Deserialize)]
1192struct OpenAiUsage {
1193 #[serde(default)]
1194 prompt_tokens: u64,
1195 #[serde(default)]
1196 completion_tokens: u64,
1197 #[serde(default)]
1198 prompt_tokens_details: Option<PromptTokensDetails>,
1199 #[serde(default)]
1200 completion_tokens_details: Option<CompletionTokensDetails>,
1201}
1202
1203#[derive(Deserialize)]
1204struct PromptTokensDetails {
1205 #[serde(default)]
1206 cached_tokens: u64,
1207}
1208
1209#[derive(Deserialize)]
1210struct CompletionTokensDetails {
1211 #[serde(default)]
1213 reasoning_tokens: u64,
1214}
1215
1216#[derive(Deserialize)]
1217struct ChatChoice {
1218 message: ChatMessage,
1219}
1220
1221#[derive(Deserialize)]
1222struct ChatMessage {
1223 content: String,
1224}
1225
1226#[derive(Serialize)]
1227struct OpenAiTool<'a> {
1228 r#type: &'a str,
1229 function: OpenAiFunction<'a>,
1230}
1231
1232#[derive(Serialize)]
1233struct OpenAiFunction<'a> {
1234 name: &'a str,
1235 description: &'a str,
1236 #[serde(skip_serializing_if = "Option::is_none")]
1237 parameters: Option<serde_json::Value>,
1238}
1239
1240#[derive(Serialize)]
1241struct ToolChatRequest<'a> {
1242 model: &'a str,
1243 messages: &'a [StructuredApiMessage],
1244 #[serde(flatten)]
1245 completion_tokens: CompletionTokens,
1246 tools: &'a [OpenAiTool<'a>],
1247 #[serde(skip_serializing_if = "Option::is_none")]
1248 reasoning: Option<Reasoning<'a>>,
1249 #[serde(skip_serializing_if = "Option::is_none")]
1250 temperature: Option<f64>,
1251 #[serde(skip_serializing_if = "Option::is_none")]
1252 top_p: Option<f64>,
1253 #[serde(skip_serializing_if = "Option::is_none")]
1254 frequency_penalty: Option<f64>,
1255 #[serde(skip_serializing_if = "Option::is_none")]
1256 presence_penalty: Option<f64>,
1257}
1258
1259#[derive(Serialize)]
1260struct StructuredApiMessage {
1261 role: String,
1262 #[serde(skip_serializing_if = "Option::is_none")]
1263 content: Option<String>,
1264 #[serde(skip_serializing_if = "Option::is_none")]
1265 tool_calls: Option<Vec<OpenAiToolCallOut>>,
1266 #[serde(skip_serializing_if = "Option::is_none")]
1267 tool_call_id: Option<String>,
1268}
1269
1270#[derive(Serialize)]
1271struct OpenAiToolCallOut {
1272 id: String,
1273 r#type: String,
1274 function: OpenAiFunctionCall,
1275}
1276
1277#[derive(Serialize)]
1278struct OpenAiFunctionCall {
1279 name: String,
1280 arguments: String,
1281}
1282
1283#[derive(Deserialize)]
1284struct ToolChatResponse {
1285 choices: Vec<ToolChatChoice>,
1286 #[serde(default)]
1287 usage: Option<OpenAiUsage>,
1288}
1289
1290#[derive(Deserialize)]
1291struct ToolChatChoice {
1292 message: ToolChatMessage,
1293 #[serde(default)]
1294 finish_reason: Option<String>,
1295}
1296
1297#[derive(Deserialize)]
1298struct ToolChatMessage {
1299 #[serde(default, deserialize_with = "deserialize_null_string_as_default")]
1300 content: String,
1301 #[serde(default)]
1302 tool_calls: Option<Vec<OpenAiToolCall>>,
1303}
1304
1305#[derive(Deserialize)]
1306struct OpenAiToolCall {
1307 id: String,
1308 function: OpenAiToolCallFunction,
1309}
1310
1311#[derive(Deserialize)]
1312struct OpenAiToolCallFunction {
1313 name: String,
1314 arguments: String,
1315}
1316
1317fn deserialize_null_string_as_default<'de, D>(deserializer: D) -> Result<String, D::Error>
1318where
1319 D: serde::Deserializer<'de>,
1320{
1321 Ok(Option::<String>::deserialize(deserializer)?.unwrap_or_default())
1322}
1323
1324fn convert_messages_structured(messages: &[Message]) -> Vec<StructuredApiMessage> {
1325 let mut result = Vec::new();
1326
1327 for msg in messages {
1328 let has_tool_parts = msg.parts.iter().any(|p| {
1329 matches!(
1330 p,
1331 MessagePart::ToolUse { .. } | MessagePart::ToolResult { .. }
1332 )
1333 });
1334
1335 if has_tool_parts {
1336 if msg.role == Role::Assistant {
1338 let text_content: String = msg
1339 .parts
1340 .iter()
1341 .filter_map(|p| p.as_plain_text())
1342 .collect::<Vec<_>>()
1343 .join("");
1344
1345 let tool_calls: Vec<OpenAiToolCallOut> = msg
1346 .parts
1347 .iter()
1348 .filter_map(|p| match p {
1349 MessagePart::ToolUse { id, name, input } => Some(OpenAiToolCallOut {
1350 id: id.clone(),
1351 r#type: "function".to_owned(),
1352 function: OpenAiFunctionCall {
1353 name: name.clone(),
1354 arguments: serde_json::to_string(input)
1355 .unwrap_or_else(|_| "{}".to_owned()),
1356 },
1357 }),
1358 _ => None,
1359 })
1360 .collect();
1361
1362 result.push(StructuredApiMessage {
1363 role: "assistant".to_owned(),
1364 content: if text_content.is_empty() {
1365 None
1366 } else {
1367 Some(text_content)
1368 },
1369 tool_calls: if tool_calls.is_empty() {
1370 None
1371 } else {
1372 Some(tool_calls)
1373 },
1374 tool_call_id: None,
1375 });
1376 } else {
1377 for part in &msg.parts {
1379 match part {
1380 MessagePart::ToolResult {
1381 tool_use_id,
1382 content,
1383 ..
1384 } => {
1385 result.push(StructuredApiMessage {
1386 role: "tool".to_owned(),
1387 content: Some(content.clone()),
1388 tool_calls: None,
1389 tool_call_id: Some(tool_use_id.clone()),
1390 });
1391 }
1392 other => {
1393 if let Some(text) = other.as_plain_text().filter(|t| !t.is_empty()) {
1394 result.push(StructuredApiMessage {
1395 role: "user".to_owned(),
1396 content: Some(text.to_owned()),
1397 tool_calls: None,
1398 tool_call_id: None,
1399 });
1400 }
1401 }
1402 }
1403 }
1404 }
1405 } else {
1406 let role = match msg.role {
1407 Role::System => "system",
1408 Role::User => "user",
1409 Role::Assistant => "assistant",
1410 };
1411 result.push(StructuredApiMessage {
1412 role: role.to_owned(),
1413 content: Some(msg.to_llm_content().to_owned()),
1414 tool_calls: None,
1415 tool_call_id: None,
1416 });
1417 }
1418 }
1419
1420 result
1421}
1422
1423#[derive(Serialize)]
1424struct EmbeddingRequest<'a> {
1425 input: &'a str,
1426 model: &'a str,
1427}
1428
1429#[derive(Deserialize)]
1430struct EmbeddingResponse {
1431 data: Vec<EmbeddingData>,
1432}
1433
1434#[derive(Deserialize)]
1435struct EmbeddingData {
1436 #[serde(default)]
1437 index: usize,
1438 embedding: Vec<f32>,
1439}
1440
1441#[derive(Serialize)]
1442struct EmbeddingBatchRequest<'a> {
1443 model: &'a str,
1444 input: Vec<&'a str>,
1445}
1446
1447#[derive(Serialize)]
1448struct TypedChatRequest<'a> {
1449 model: &'a str,
1450 messages: &'a [ApiMessage<'a>],
1451 #[serde(flatten)]
1452 completion_tokens: CompletionTokens,
1453 response_format: ResponseFormat<'a>,
1454}
1455
1456#[derive(Serialize)]
1457#[serde(untagged)]
1458enum CompletionTokens {
1459 MaxTokens { max_tokens: u32 },
1460 MaxCompletionTokens { max_completion_tokens: u32 },
1461}
1462
1463impl CompletionTokens {
1464 fn for_model(model: &str, max_tokens: u32) -> Self {
1465 if model.starts_with("gpt-5") || starts_with_o_digit(model) {
1468 Self::MaxCompletionTokens {
1469 max_completion_tokens: max_tokens,
1470 }
1471 } else {
1472 Self::MaxTokens { max_tokens }
1473 }
1474 }
1475}
1476
1477fn starts_with_o_digit(model: &str) -> bool {
1478 let mut chars = model.chars();
1479 chars.next() == Some('o') && chars.next().is_some_and(|c| c.is_ascii_digit())
1480}
1481
1482#[derive(Serialize)]
1483struct ResponseFormat<'a> {
1484 r#type: &'a str,
1485 json_schema: JsonSchemaFormat<'a>,
1486}
1487
1488#[derive(Serialize)]
1489struct JsonSchemaFormat<'a> {
1490 name: &'a str,
1491 schema: serde_json::Value,
1492 strict: bool,
1493}
1494
1495fn inline_refs_openai(schema: &mut serde_json::Value, depth: u8) {
1497 if depth == 0 {
1498 return;
1499 }
1500 let defs = if let Some(obj) = schema.as_object() {
1501 obj.get("$defs")
1502 .or_else(|| obj.get("definitions"))
1503 .cloned()
1504 .unwrap_or(serde_json::Value::Object(serde_json::Map::default()))
1505 } else {
1506 serde_json::Value::Object(serde_json::Map::default())
1507 };
1508 inline_refs_openai_inner(schema, &defs, depth);
1509 if let Some(obj) = schema.as_object_mut() {
1510 obj.remove("$defs");
1511 obj.remove("definitions");
1512 }
1513}
1514
1515fn inline_refs_openai_inner(schema: &mut serde_json::Value, defs: &serde_json::Value, depth: u8) {
1516 if depth == 0 {
1517 return;
1518 }
1519 if let Some(obj) = schema.as_object()
1520 && let Some(ref_val) = obj.get("$ref").and_then(|v| v.as_str())
1521 {
1522 let name = ref_val
1523 .trim_start_matches("#/$defs/")
1524 .trim_start_matches("#/definitions/");
1525 if let Some(resolved) = defs.get(name) {
1526 let mut resolved = resolved.clone();
1527 inline_refs_openai_inner(&mut resolved, defs, depth - 1);
1528 *schema = resolved;
1529 return;
1530 }
1531 *schema = serde_json::json!({"type": "object"});
1532 return;
1533 }
1534 if let Some(obj) = schema.as_object_mut() {
1535 for v in obj.values_mut() {
1536 inline_refs_openai_inner(v, defs, depth - 1);
1537 }
1538 } else if let Some(arr) = schema.as_array_mut() {
1539 for v in arr.iter_mut() {
1540 inline_refs_openai_inner(v, defs, depth - 1);
1541 }
1542 }
1543}
1544
1545fn is_empty_params_schema(schema: &serde_json::Value) -> bool {
1549 schema.get("type").and_then(|t| t.as_str()) == Some("object")
1550 && schema
1551 .get("properties")
1552 .and_then(|p| p.as_object())
1553 .is_none_or(serde_json::Map::is_empty)
1554}
1555
1556fn prepare_tool_params(params: &serde_json::Value) -> Option<serde_json::Value> {
1562 if is_empty_params_schema(params) {
1563 return None;
1564 }
1565 let mut schema = params.clone();
1566 inline_refs_openai(&mut schema, 8);
1567 normalize_for_openai_strict(&mut schema, 16);
1568 Some(schema)
1569}
1570
1571struct OpenAiStrictVisitor;
1572
1573impl crate::schema::SchemaVisitor for OpenAiStrictVisitor {
1574 fn visit(&mut self, schema: &mut serde_json::Value) -> bool {
1575 let Some(obj) = schema.as_object_mut() else {
1576 return false;
1577 };
1578 let remove_keys: &[&str] = &["$schema", "title", "format", "default", "examples", "$id"];
1579 for key in remove_keys {
1580 obj.remove(*key);
1581 }
1582 let is_object = obj.get("type").and_then(|t| t.as_str()) == Some("object");
1583 if is_object {
1584 obj.insert(
1585 "additionalProperties".to_owned(),
1586 serde_json::Value::Bool(false),
1587 );
1588 let prop_keys: Vec<String> = obj
1589 .get("properties")
1590 .and_then(|p| p.as_object())
1591 .map(|p| p.keys().cloned().collect())
1592 .unwrap_or_default();
1593 if !prop_keys.is_empty() {
1594 obj.insert(
1595 "required".to_owned(),
1596 serde_json::Value::Array(
1597 prop_keys
1598 .into_iter()
1599 .map(serde_json::Value::String)
1600 .collect(),
1601 ),
1602 );
1603 }
1604 }
1605 true
1606 }
1607}
1608
1609fn normalize_for_openai_strict(schema: &mut serde_json::Value, depth: u8) {
1616 crate::schema::walk_schema(schema, &mut OpenAiStrictVisitor, depth);
1617}
1618
1619#[cfg(test)]
1620mod tests;