1use async_trait::async_trait;
7use reqwest::Client;
8use serde::Deserialize;
9use serde_json::{Value, json};
10use tracing::{debug, instrument, warn};
11
12#[allow(unused_imports)]
13use cognee_utils::tracing_keys::{COGNEE_LLM_MODEL, COGNEE_LLM_PROVIDER};
14
15use crate::error::{LlmError, LlmResult};
16use crate::llm_trait::Llm;
17use crate::transcriber::{Transcriber, TranscriptionOutput, validate_audio_format};
18use crate::types::{GenerationOptions, GenerationResponse, Message, MessageRole, TokenUsage};
19
20#[derive(Clone)]
46pub struct OpenAIAdapter {
47 model: String,
48 api_key: String,
49 base_url: String,
50 client: Client,
51 structured_output_retries: usize,
52 network_retries: usize,
54 transcription_model: String,
56}
57
58impl OpenAIAdapter {
59 pub const DEFAULT_BASE_URL: &'static str = "https://api.openai.com/v1";
61 pub const DEFAULT_STRUCTURED_OUTPUT_RETRIES: usize = 5;
68 pub const DEFAULT_NETWORK_RETRIES: usize = 3;
70
71 pub fn new(
81 model: impl Into<String>,
82 api_key: impl Into<String>,
83 base_url: Option<String>,
84 ) -> LlmResult<Self> {
85 let client = Client::builder()
86 .timeout(std::time::Duration::from_secs(600))
87 .build()
88 .map_err(|e| LlmError::ConfigError(format!("Failed to create HTTP client: {e}")))?;
89
90 let transcription_model =
91 std::env::var("TRANSCRIPTION_MODEL").unwrap_or_else(|_| "whisper-1".to_string());
92
93 let model: String = model.into();
99 let model = model
100 .strip_prefix("openai/")
101 .map(str::to_string)
102 .unwrap_or(model);
103
104 Ok(Self {
105 model,
106 api_key: api_key.into(),
107 base_url: base_url
112 .map(|u| u.trim_end_matches('/').to_string())
113 .unwrap_or_else(|| Self::DEFAULT_BASE_URL.to_string()),
114 client,
115 structured_output_retries: Self::DEFAULT_STRUCTURED_OUTPUT_RETRIES,
116 network_retries: Self::DEFAULT_NETWORK_RETRIES,
117 transcription_model,
118 })
119 }
120
121 pub fn with_structured_output_retries(mut self, retries: u32) -> Self {
125 let retries = usize::try_from(retries).unwrap_or(usize::MAX);
126 self.structured_output_retries = retries.max(1);
127 self
128 }
129
130 pub fn with_network_retries(mut self, retries: u32) -> Self {
134 self.network_retries = usize::try_from(retries).unwrap_or(usize::MAX);
135 self
136 }
137
138 pub fn with_transcription_model(mut self, model: impl Into<String>) -> Self {
140 self.transcription_model = model.into();
141 self
142 }
143
144 fn auth_header(&self) -> String {
146 format!("Bearer {}", self.api_key)
147 }
148
149 fn should_disable_thinking(&self) -> bool {
151 self.model.to_lowercase().starts_with("qwen") && !self.base_url.contains("api.openai.com")
152 }
153
154 fn is_reasoning_model(&self) -> bool {
162 if !self.base_url.contains("api.openai.com") {
163 return false;
164 }
165 let m = self.model.to_lowercase();
166 m.starts_with("gpt-5") || m.starts_with("o1") || m.starts_with("o3") || m.starts_with("o4")
167 }
168
169 fn write_max_tokens(&self, body: &mut Value, value: Option<u32>) {
172 if let Some(v) = value {
173 let key = if self.is_reasoning_model() {
174 "max_completion_tokens"
175 } else {
176 "max_tokens"
177 };
178 body[key] = json!(v);
179 }
180 }
181
182 #[instrument(
192 name = "llm.api_call",
193 level = "info",
194 skip(self, request_body),
195 fields(
196 url = tracing::field::Empty,
197 cognee.llm.model = self.model.as_str(),
198 cognee.llm.provider = "openai",
199 ),
200 )]
201 async fn call_api(&self, request_body: Value) -> LlmResult<OpenAIResponse> {
202 let url = format!("{}/chat/completions", self.base_url);
203 tracing::Span::current().record("url", url.as_str());
204 let debug_enabled = std::env::var("COGNEE_DEBUG_LLM_REQUEST")
205 .map(|v| cognee_utils::parse_env_bool(&v))
206 .unwrap_or(false);
207
208 if debug_enabled {
209 let pretty_request = serde_json::to_string_pretty(&request_body)
210 .unwrap_or_else(|_| request_body.to_string());
211 eprintln!("\n[COGNEE_DEBUG_LLM_REQUEST] POST {url}\n{pretty_request}\n");
212 }
213
214 let mut last_error = LlmError::NetworkError("No attempt made".to_string());
215
216 for attempt in 0..=self.network_retries {
217 debug!(attempt, "LLM API attempt");
218 if attempt > 0 {
219 let delay_ms = (1_000u64 * 2u64.saturating_pow(attempt as u32 - 1)).min(30_000);
220 warn!(
221 attempt,
222 network_retries = self.network_retries,
223 delay_ms,
224 error = %last_error,
225 "LLM request failed, retrying",
226 );
227 tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await;
228 }
229
230 let response = match self
231 .client
232 .post(&url)
233 .header("Authorization", self.auth_header())
234 .header("Content-Type", "application/json")
235 .json(&request_body)
236 .send()
237 .await
238 {
239 Ok(r) => r,
240 Err(e) => {
241 last_error = LlmError::NetworkError(e.to_string());
242 continue;
243 }
244 };
245
246 let status = response.status();
247
248 if !status.is_success() {
249 let error_body = response
250 .text()
251 .await
252 .unwrap_or_else(|_| "Unknown error".to_string());
253
254 let err = match status.as_u16() {
255 401 => LlmError::AuthenticationError(error_body),
256 429 => LlmError::RateLimitExceeded(error_body),
257 400 => LlmError::InvalidResponse(format!("Bad request: {error_body}")),
258 _ => LlmError::ApiError(format!("HTTP {status}: {error_body}")),
259 };
260
261 if matches!(status.as_u16(), 400 | 401) {
263 return Err(err);
264 }
265
266 last_error = err;
267 continue;
268 }
269
270 let response_body = response.text().await.map_err(|e| {
271 LlmError::DeserializationError(format!("Failed to read response body: {e}"))
272 })?;
273
274 if debug_enabled {
275 eprintln!("\n[COGNEE_DEBUG_LLM_RESPONSE] POST {url}\n{response_body}\n");
276 }
277
278 return serde_json::from_str::<OpenAIResponse>(&response_body).map_err(|e| {
279 LlmError::DeserializationError(format!(
280 "Failed to parse response: {e}. Raw body: {response_body}"
281 ))
282 });
283 }
284
285 Err(LlmError::MaxRetriesExceeded(format!(
286 "LLM request failed after {} attempt(s): {}",
287 self.network_retries + 1,
288 last_error
289 )))
290 }
291
292 fn convert_messages(messages: &[Message]) -> Vec<Value> {
294 messages
295 .iter()
296 .map(|msg| {
297 json!({
298 "role": match msg.role {
299 MessageRole::System => "system",
300 MessageRole::User => "user",
301 MessageRole::Assistant => "assistant",
302 },
303 "content": msg.content
304 })
305 })
306 .collect()
307 }
308
309 fn schema_to_example(schema: &Value) -> String {
312 fn create_example(value: &Value, definitions: Option<&Value>) -> Value {
313 match value {
314 Value::Object(obj) => {
315 if let Some(ref_str) = obj.get("$ref").and_then(|v| v.as_str())
317 && let Some(def_name) = ref_str.strip_prefix("#/definitions/")
318 && let Some(defs) = definitions
319 && let Some(def) = defs.get(def_name)
320 {
321 return create_example(def, definitions);
322 }
323
324 let type_val = obj.get("type");
326
327 if let Some(Value::String(t)) = type_val
329 && t == "array"
330 {
331 if let Some(items) = obj.get("items") {
332 return json!([create_example(items, definitions)]);
334 }
335 return json!([]);
336 }
337
338 if let Some(props) = obj.get("properties")
340 && let Value::Object(props_obj) = props
341 {
342 let mut result = serde_json::Map::new();
343 for (key, val) in props_obj {
344 result.insert(key.clone(), create_example(val, definitions));
345 }
346 return Value::Object(result);
347 }
348
349 if let Some(Value::String(t)) = type_val {
351 return match t.as_str() {
352 "string" => json!("example"),
353 "number" | "integer" => json!(0),
354 "boolean" => json!(false),
355 _ => json!(null),
356 };
357 }
358
359 if let Some(Value::Array(types)) = type_val {
361 for t in types {
362 if let Value::String(type_str) = t
363 && type_str != "null"
364 {
365 return match type_str.as_str() {
366 "string" => json!("example"),
367 "number" | "integer" => json!(0),
368 "boolean" => json!(false),
369 _ => json!(null),
370 };
371 }
372 }
373 }
374
375 json!(null)
376 }
377 _ => value.clone(),
378 }
379 }
380
381 let definitions = schema.get("definitions");
382 let example = create_example(schema, definitions);
383
384 serde_json::to_string_pretty(&example).unwrap_or_else(|_| "{}".to_string())
385 }
386}
387
388fn to_strict_schema(schema: &Value) -> Value {
407 fn walk(value: &mut Value) {
408 match value {
409 Value::Object(obj) => {
410 if let Some(Value::Object(props)) = obj.get("properties") {
411 let keys: Vec<Value> = props.keys().map(|k| Value::String(k.clone())).collect();
413 obj.insert("required".to_string(), Value::Array(keys));
414 obj.insert("additionalProperties".to_string(), Value::Bool(false));
415 }
416 for (_k, v) in obj.iter_mut() {
417 walk(v);
418 }
419 }
420 Value::Array(items) => {
421 for v in items.iter_mut() {
422 walk(v);
423 }
424 }
425 _ => {}
426 }
427 }
428
429 let mut out = schema.clone();
430 walk(&mut out);
431 out
432}
433
434#[async_trait]
435impl Llm for OpenAIAdapter {
436 async fn generate(
437 &self,
438 messages: Vec<Message>,
439 options: Option<GenerationOptions>,
440 ) -> LlmResult<GenerationResponse> {
441 let opts = options.unwrap_or_default();
442
443 let mut request_body = json!({
444 "model": self.model,
445 "messages": Self::convert_messages(&messages),
446 });
447
448 if !self.is_reasoning_model() {
451 if let Some(temp) = opts.temperature {
452 request_body["temperature"] = json!(temp);
453 }
454 if let Some(top_p) = opts.top_p {
455 request_body["top_p"] = json!(top_p);
456 }
457 if let Some(freq_penalty) = opts.frequency_penalty {
458 request_body["frequency_penalty"] = json!(freq_penalty);
459 }
460 if let Some(pres_penalty) = opts.presence_penalty {
461 request_body["presence_penalty"] = json!(pres_penalty);
462 }
463 }
464 self.write_max_tokens(&mut request_body, opts.max_tokens);
465 if let Some(stop) = opts.stop
466 && !stop.is_empty()
467 {
468 request_body["stop"] = json!(stop);
469 }
470
471 if self.should_disable_thinking() {
472 request_body["think"] = json!(false);
473 request_body["reasoning"] = json!({"effort": "none"});
474 }
475
476 let response = self.call_api(request_body).await?;
477
478 let choice = response
480 .choices
481 .first()
482 .ok_or_else(|| LlmError::InvalidResponse("No choices in response".to_string()))?;
483
484 Ok(GenerationResponse {
485 content: choice.message.content.clone().unwrap_or_default(),
486 model: response.model,
487 finish_reason: choice.finish_reason.clone(),
488 usage: response.usage.map(|u| TokenUsage {
489 prompt_tokens: u.prompt_tokens,
490 completion_tokens: u.completion_tokens,
491 total_tokens: u.total_tokens,
492 }),
493 })
494 }
495
496 async fn create_structured_output_with_messages_raw(
497 &self,
498 messages: Vec<Message>,
499 json_schema: &Value,
500 options: Option<GenerationOptions>,
501 ) -> LlmResult<Value> {
502 let is_empty_or_non_json = |raw: &str| {
503 let trimmed = raw.trim();
504 trimmed.is_empty() || serde_json::from_str::<Value>(trimmed).is_err()
505 };
506
507 let parse_json =
508 |raw: &str| -> Result<Value, serde_json::Error> { serde_json::from_str(raw) };
509
510 let opts = options.unwrap_or_default();
511 let schema = json_schema;
512
513 let strict_schema = to_strict_schema(schema);
518
519 let mut strict_schema_request = json!({
521 "model": self.model,
522 "messages": Self::convert_messages(&messages),
523 "response_format": {
524 "type": "json_schema",
525 "json_schema": {
526 "name": "extract_structured_data",
527 "schema": strict_schema,
528 "strict": true
529 }
530 }
531 });
532
533 if !self.is_reasoning_model()
534 && let Some(temp) = opts.temperature
535 {
536 strict_schema_request["temperature"] = json!(temp);
537 }
538 self.write_max_tokens(&mut strict_schema_request, opts.max_tokens);
539 if self.should_disable_thinking() {
540 strict_schema_request["think"] = json!(false);
541 strict_schema_request["reasoning"] = json!({"effort": "none"});
542 }
543
544 for attempt in 0..self.structured_output_retries {
545 match self.call_api(strict_schema_request.clone()).await {
546 Ok(strict_response) => {
547 let strict_choice = strict_response.choices.first().ok_or_else(|| {
548 LlmError::InvalidResponse(
549 "No choices in strict schema response".to_string(),
550 )
551 })?;
552
553 if let Some(function_call) = &strict_choice.message.function_call {
554 match parse_json(&function_call.arguments) {
555 Ok(parsed) => return Ok(parsed),
556 Err(e) => {
557 if attempt + 1 < self.structured_output_retries
558 && is_empty_or_non_json(&function_call.arguments)
559 {
560 continue;
561 }
562 if !is_empty_or_non_json(&function_call.arguments) {
563 return Err(LlmError::DeserializationError(format!(
564 "Failed to deserialize strict function call arguments: {}. Raw: {}",
565 e, function_call.arguments
566 )));
567 }
568 break;
569 }
570 }
571 }
572
573 if let Some(content) = strict_choice.message.content.as_ref() {
574 match parse_json(content) {
575 Ok(parsed) => return Ok(parsed),
576 Err(e) => {
577 if attempt + 1 < self.structured_output_retries
578 && is_empty_or_non_json(content)
579 {
580 continue;
581 }
582 if !is_empty_or_non_json(content) {
583 return Err(LlmError::DeserializationError(format!(
584 "Failed to deserialize strict JSON content: {e}. Raw: {content}"
585 )));
586 }
587 break;
588 }
589 }
590 }
591 }
592 Err(e) => {
593 warn!(error = %e, "strict json_schema request failed; falling back to function/JSON mode");
599 break;
600 }
601 }
602 }
603
604 let mut request_body = json!({
606 "model": self.model,
607 "messages": Self::convert_messages(&messages),
608 "functions": [{
609 "name": "extract_structured_data",
610 "description": "Extract structured data from the input",
611 "parameters": schema
612 }],
613 "function_call": {"name": "extract_structured_data"}
614 });
615
616 if !self.is_reasoning_model()
617 && let Some(temp) = opts.temperature
618 {
619 request_body["temperature"] = json!(temp);
620 }
621 self.write_max_tokens(&mut request_body, opts.max_tokens);
622 if self.should_disable_thinking() {
623 request_body["think"] = json!(false);
624 request_body["reasoning"] = json!({"effort": "none"});
625 }
626
627 for attempt in 0..self.structured_output_retries {
628 let response = self.call_api(request_body.clone()).await?;
629
630 let choice = response
631 .choices
632 .first()
633 .ok_or_else(|| LlmError::InvalidResponse("No choices in response".to_string()))?;
634
635 if let Some(function_call) = &choice.message.function_call {
636 match parse_json(&function_call.arguments) {
637 Ok(parsed) => return Ok(parsed),
638 Err(e) => {
639 if attempt + 1 < self.structured_output_retries
640 && is_empty_or_non_json(&function_call.arguments)
641 {
642 continue;
643 }
644 if !is_empty_or_non_json(&function_call.arguments) {
645 return Err(LlmError::DeserializationError(format!(
646 "Failed to deserialize function call arguments: {}. Raw: {}",
647 e, function_call.arguments
648 )));
649 }
650 break;
651 }
652 }
653 }
654
655 break;
656 }
657
658 let mut json_messages = Self::convert_messages(&messages);
660
661 let example = Self::schema_to_example(schema);
662
663 if let Some(last_msg) = json_messages.last_mut()
664 && last_msg["role"] == "user"
665 {
666 let original_content = last_msg["content"].as_str().unwrap_or("");
667 last_msg["content"] = json!(format!(
668 "{}\n\n\
669 Extract the information from the text above and return it as JSON.\n\
670 Use this structure as your template (but with actual data from the text):\n\
671 {}",
672 original_content, example
673 ));
674 }
675
676 let mut json_request = json!({
677 "model": self.model,
678 "messages": json_messages,
679 "response_format": {"type": "json_object"}
680 });
681
682 if !self.is_reasoning_model()
683 && let Some(temp) = opts.temperature
684 {
685 json_request["temperature"] = json!(temp);
686 }
687 self.write_max_tokens(&mut json_request, opts.max_tokens);
688 if self.should_disable_thinking() {
689 json_request["think"] = json!(false);
690 json_request["reasoning"] = json!({"effort": "none"});
691 }
692
693 for attempt in 0..self.structured_output_retries {
694 let mut request_for_attempt = json_request.clone();
695
696 if attempt > 0 {
697 if let Some(messages) = request_for_attempt["messages"].as_array_mut()
698 && let Some(last_msg) = messages.last_mut()
699 && last_msg["role"] == "user"
700 {
701 let original_content = last_msg["content"].as_str().unwrap_or("");
702 last_msg["content"] = json!(format!(
703 "{}\n\n/no_think\nReturn ONLY one valid JSON object matching the required schema. No reasoning, no markdown, no extra text.",
704 original_content
705 ));
706 }
707
708 if !self.is_reasoning_model() {
709 request_for_attempt["temperature"] = json!(0.0);
710 }
711 }
712
713 let json_response = self.call_api(request_for_attempt).await?;
714
715 let json_choice = json_response.choices.first().ok_or_else(|| {
716 LlmError::InvalidResponse("No choices in JSON mode response".to_string())
717 })?;
718
719 let content = json_choice.message.content.as_ref().ok_or_else(|| {
720 LlmError::InvalidResponse("No content in JSON mode response".to_string())
721 })?;
722
723 match parse_json(content) {
724 Ok(parsed) => return Ok(parsed),
725 Err(e) => {
726 if attempt + 1 < self.structured_output_retries && is_empty_or_non_json(content)
727 {
728 continue;
729 }
730 return Err(LlmError::DeserializationError(format!(
731 "Failed to deserialize JSON content: {e}. Raw: {content}"
732 )));
733 }
734 }
735 }
736
737 Err(LlmError::InvalidResponse(
738 "Structured output retries exhausted without a parseable response".to_string(),
739 ))
740 }
741
742 fn model(&self) -> &str {
743 &self.model
744 }
745
746 fn supports_streaming(&self) -> bool {
747 true
748 }
749
750 fn supports_function_calling(&self) -> bool {
751 true
752 }
753
754 fn max_context_length(&self) -> u32 {
755 match self.model.as_str() {
757 m if m.starts_with("gpt-4-turbo") => 128_000,
758 m if m.starts_with("gpt-4-32k") => 32_768,
759 m if m.starts_with("gpt-4") => 8_192,
760 m if m.starts_with("gpt-3.5-turbo-16k") => 16_384,
761 m if m.starts_with("gpt-3.5-turbo") => 4_096,
762 _ => 4_096, }
764 }
765
766 async fn transcribe_image(
767 &self,
768 image_bytes: &[u8],
769 mime_type: &str,
770 options: Option<GenerationOptions>,
771 ) -> LlmResult<String> {
772 use base64::Engine as _;
773
774 if !mime_type.starts_with("image/") {
775 return Err(LlmError::InvalidResponse(format!(
776 "Expected image/* MIME type, got: {mime_type}"
777 )));
778 }
779
780 let b64 = base64::engine::general_purpose::STANDARD.encode(image_bytes);
781 let data_uri = format!("data:{mime_type};base64,{b64}");
782
783 let vision_model = std::env::var("LLM_VISION_MODEL")
784 .ok()
785 .filter(|s| !s.is_empty())
786 .unwrap_or_else(|| self.model.clone());
787
788 let max_tokens = options.as_ref().and_then(|o| o.max_tokens).unwrap_or(300);
789
790 let mut request_body = json!({
791 "model": vision_model,
792 "messages": [{
793 "role": "user",
794 "content": [
795 { "type": "text", "text": "What's in this image?" },
796 { "type": "image_url", "image_url": { "url": data_uri } }
797 ]
798 }],
799 });
800 self.write_max_tokens(&mut request_body, Some(max_tokens));
801
802 let response = self.call_api(request_body).await?;
803
804 let choice = response.choices.first().ok_or_else(|| {
805 LlmError::InvalidResponse("No choices in vision response".to_string())
806 })?;
807
808 choice.message.content.clone().ok_or_else(|| {
809 LlmError::InvalidResponse("Vision response contained no content".to_string())
810 })
811 }
812
813 fn supports_vision(&self) -> bool {
814 let m = self.model.to_lowercase();
815 m.contains("gpt-4")
816 || m.contains("gpt-5")
817 || m.contains("vision")
818 || m.contains("o1")
819 || m.contains("o3")
820 || m.contains("o4")
821 || m.contains("llava")
822 || m.contains("moondream")
823 || m.contains("llama-3.2-vision")
824 || m.contains("gemma3")
825 }
826}
827
828#[derive(Debug, Deserialize)]
834struct WhisperResponse {
835 text: String,
836 language: Option<String>,
837 duration: Option<f32>,
838}
839
840fn audio_mime_type(format: &str) -> &'static str {
842 match format {
843 "mp3" | "mpeg" | "mpga" => "audio/mpeg",
844 "mp4" | "m4a" => "audio/mp4",
845 "wav" => "audio/wav",
846 "webm" => "audio/webm",
847 _ => "application/octet-stream",
849 }
850}
851
852impl OpenAIAdapter {
853 #[instrument(
855 name = "llm.transcription_api_call",
856 level = "info",
857 skip(self, form),
858 fields(
859 url = tracing::field::Empty,
860 cognee.llm.model = self.transcription_model.as_str(),
861 cognee.llm.provider = "openai",
862 ),
863 )]
864 async fn call_transcription_api(
865 &self,
866 form: reqwest::multipart::Form,
867 ) -> LlmResult<WhisperResponse> {
868 let url = format!("{}/audio/transcriptions", self.base_url);
869 tracing::Span::current().record("url", url.as_str());
870
871 let response = self
888 .client
889 .post(&url)
890 .header("Authorization", self.auth_header())
891 .multipart(form)
892 .send()
893 .await
894 .map_err(|e| LlmError::NetworkError(e.to_string()))?;
895
896 let status = response.status();
897
898 if !status.is_success() {
899 let error_body = response
900 .text()
901 .await
902 .unwrap_or_else(|_| "Unknown error".to_string());
903
904 return Err(match status.as_u16() {
905 401 => LlmError::AuthenticationError(error_body),
906 429 => LlmError::RateLimitExceeded(error_body),
907 400 => LlmError::InvalidResponse(format!("Bad request: {error_body}")),
908 _ => LlmError::ApiError(format!("HTTP {status}: {error_body}")),
909 });
910 }
911
912 let response_body = response.text().await.map_err(|e| {
913 LlmError::DeserializationError(format!("Failed to read response body: {e}"))
914 })?;
915
916 serde_json::from_str::<WhisperResponse>(&response_body).map_err(|e| {
917 LlmError::DeserializationError(format!(
918 "Failed to parse Whisper response: {e}. Raw body: {response_body}"
919 ))
920 })
921 }
922
923 fn build_transcription_form(
925 &self,
926 audio: &[u8],
927 format: &str,
928 language_hint: Option<&str>,
929 prompt_hint: Option<&str>,
930 ) -> LlmResult<reqwest::multipart::Form> {
931 let mime = audio_mime_type(format);
932 let filename = format!("audio.{format}");
933
934 let file_part = reqwest::multipart::Part::bytes(audio.to_vec())
935 .file_name(filename)
936 .mime_str(mime)
937 .map_err(|e| {
938 LlmError::ConfigError(format!("Failed to set MIME type on multipart part: {e}"))
939 })?;
940
941 let mut form = reqwest::multipart::Form::new()
942 .part("file", file_part)
943 .text("model", self.transcription_model.clone())
944 .text("response_format", "verbose_json");
945
946 if let Some(lang) = language_hint {
947 form = form.text("language", lang.to_string());
948 }
949 if let Some(prompt) = prompt_hint {
950 form = form.text("prompt", prompt.to_string());
951 }
952
953 Ok(form)
954 }
955}
956
957#[async_trait]
958impl Transcriber for OpenAIAdapter {
959 async fn transcribe_audio(
960 &self,
961 audio: &[u8],
962 format: &str,
963 language_hint: Option<&str>,
964 prompt_hint: Option<&str>,
965 ) -> LlmResult<TranscriptionOutput> {
966 let format_lower = format.to_ascii_lowercase();
968 validate_audio_format(&format_lower)?;
969
970 let mut last_error = LlmError::NetworkError("No attempt made".to_string());
971
972 for attempt in 0..=self.network_retries {
973 debug!(attempt, "Transcription API attempt");
974 if attempt > 0 {
975 let delay_ms = (1_000u64 * 2u64.saturating_pow(attempt as u32 - 1)).min(30_000);
976 warn!(
977 attempt,
978 network_retries = self.network_retries,
979 delay_ms,
980 error = %last_error,
981 "Transcription request failed, retrying",
982 );
983 tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await;
984 }
985
986 let form =
987 self.build_transcription_form(audio, &format_lower, language_hint, prompt_hint)?;
988
989 match self.call_transcription_api(form).await {
990 Ok(resp) => {
991 return Ok(TranscriptionOutput {
992 text: resp.text,
993 language: resp.language,
994 duration: resp.duration,
995 });
996 }
997 Err(e) => {
998 if matches!(
1000 e,
1001 LlmError::InvalidResponse(_) | LlmError::AuthenticationError(_)
1002 ) {
1003 return Err(e);
1004 }
1005 last_error = e;
1006 continue;
1007 }
1008 }
1009 }
1010
1011 Err(LlmError::MaxRetriesExceeded(format!(
1012 "Transcription request failed after {} attempt(s): {}",
1013 self.network_retries + 1,
1014 last_error
1015 )))
1016 }
1017
1018 fn transcription_model(&self) -> &str {
1019 &self.transcription_model
1020 }
1021}
1022
1023#[derive(Debug, Deserialize)]
1025#[allow(dead_code)]
1026struct OpenAIResponse {
1027 id: String,
1028 object: String,
1029 created: i64,
1030 model: String,
1031 choices: Vec<OpenAIChoice>,
1032 usage: Option<OpenAIUsage>,
1033}
1034
1035#[derive(Debug, Deserialize)]
1036#[allow(dead_code)]
1037struct OpenAIChoice {
1038 index: u32,
1039 message: OpenAIMessage,
1040 finish_reason: Option<String>,
1041}
1042
1043#[derive(Debug, Deserialize)]
1044#[allow(dead_code)]
1045struct OpenAIMessage {
1046 role: String,
1047 content: Option<String>,
1048 reasoning: Option<String>,
1049 function_call: Option<OpenAIFunctionCall>,
1050}
1051
1052#[derive(Debug, Deserialize)]
1053#[allow(dead_code)]
1054struct OpenAIFunctionCall {
1055 name: String,
1056 arguments: String,
1057}
1058
1059#[derive(Debug, Deserialize)]
1060struct OpenAIUsage {
1061 prompt_tokens: u32,
1062 completion_tokens: u32,
1063 total_tokens: u32,
1064}
1065
1066#[cfg(test)]
1067mod tests {
1068 #![allow(
1069 clippy::unwrap_used,
1070 clippy::expect_used,
1071 reason = "test code — panics are acceptable"
1072 )]
1073 use super::*;
1074
1075 #[test]
1076 fn test_openai_provider_prefix_is_stripped() {
1077 let adapter = OpenAIAdapter::new("openai/gpt-5-mini", "test-key", None).unwrap();
1079 assert_eq!(adapter.model(), "gpt-5-mini");
1080 let adapter = OpenAIAdapter::new("ollama/llama3", "test-key", None).unwrap();
1082 assert_eq!(adapter.model(), "ollama/llama3");
1083 }
1084
1085 #[test]
1086 fn test_openai_adapter_creation() {
1087 let adapter = OpenAIAdapter::new("gpt-4", "test-key", None);
1088 assert!(adapter.is_ok());
1089
1090 let adapter = adapter.unwrap();
1091 assert_eq!(adapter.model(), "gpt-4");
1092 assert_eq!(adapter.base_url, OpenAIAdapter::DEFAULT_BASE_URL);
1093 assert_eq!(
1094 adapter.structured_output_retries,
1095 OpenAIAdapter::DEFAULT_STRUCTURED_OUTPUT_RETRIES
1096 );
1097 }
1098
1099 #[test]
1100 fn test_configurable_structured_output_retries() {
1101 let adapter = OpenAIAdapter::new("gpt-4", "test-key", None)
1102 .unwrap()
1103 .with_structured_output_retries(5);
1104 assert_eq!(adapter.structured_output_retries, 5);
1105
1106 let adapter = OpenAIAdapter::new("gpt-4", "test-key", None)
1107 .unwrap()
1108 .with_structured_output_retries(0);
1109 assert_eq!(adapter.structured_output_retries, 1);
1110 }
1111
1112 #[test]
1113 fn test_openai_adapter_custom_base_url() {
1114 let adapter = OpenAIAdapter::new(
1115 "gpt-4",
1116 "test-key",
1117 Some("https://custom.api.com/v1".to_string()),
1118 );
1119 assert!(adapter.is_ok());
1120
1121 let adapter = adapter.unwrap();
1122 assert_eq!(adapter.base_url, "https://custom.api.com/v1");
1123 }
1124
1125 #[test]
1126 fn test_base_url_trailing_slash_is_normalized() {
1127 let adapter = OpenAIAdapter::new(
1130 "gemini-2.0-flash",
1131 "test-key",
1132 Some("https://generativelanguage.googleapis.com/v1beta/openai/".to_string()),
1133 )
1134 .unwrap();
1135 assert_eq!(
1136 adapter.base_url,
1137 "https://generativelanguage.googleapis.com/v1beta/openai"
1138 );
1139 }
1140
1141 #[test]
1142 fn test_is_reasoning_model_matches_openai_families() {
1143 let cases = [
1144 ("gpt-5", true),
1145 ("gpt-5-mini", true),
1146 ("gpt-5-2025-06-01", true),
1147 ("o1", true),
1148 ("o1-mini", true),
1149 ("o3", true),
1150 ("o3-mini", true),
1151 ("o4-mini", true),
1152 ("GPT-5-Mini", true),
1153 ("gpt-4o-mini", false),
1154 ("gpt-4-turbo", false),
1155 ("gpt-3.5-turbo", false),
1156 ("o-foo", false),
1157 ];
1158 for (model, expected) in cases {
1159 let adapter = OpenAIAdapter::new(model, "test-key", None).unwrap();
1160 assert_eq!(
1161 adapter.is_reasoning_model(),
1162 expected,
1163 "is_reasoning_model({model})"
1164 );
1165 }
1166 }
1167
1168 #[test]
1169 fn test_is_reasoning_model_skipped_for_custom_base_url() {
1170 let adapter = OpenAIAdapter::new(
1174 "gpt-5-mini",
1175 "test-key",
1176 Some("http://localhost:11434/v1".to_string()),
1177 )
1178 .unwrap();
1179 assert!(!adapter.is_reasoning_model());
1180 }
1181
1182 #[test]
1183 fn test_write_max_tokens_renames_key_for_reasoning_models() {
1184 let mut body = json!({"model": "gpt-5-mini"});
1185 let reasoning = OpenAIAdapter::new("gpt-5-mini", "test-key", None).unwrap();
1186 reasoning.write_max_tokens(&mut body, Some(2048));
1187 assert!(body.get("max_tokens").is_none());
1188 assert_eq!(body["max_completion_tokens"], 2048);
1189
1190 let mut body = json!({"model": "gpt-4o-mini"});
1191 let classic = OpenAIAdapter::new("gpt-4o-mini", "test-key", None).unwrap();
1192 classic.write_max_tokens(&mut body, Some(2048));
1193 assert_eq!(body["max_tokens"], 2048);
1194 assert!(body.get("max_completion_tokens").is_none());
1195
1196 let mut body = json!({"model": "gpt-5-mini"});
1198 reasoning.write_max_tokens(&mut body, None);
1199 assert!(body.get("max_tokens").is_none());
1200 assert!(body.get("max_completion_tokens").is_none());
1201 }
1202
1203 #[test]
1204 fn test_message_conversion() {
1205 let messages = vec![
1206 Message {
1207 role: MessageRole::System,
1208 content: "You are helpful".to_string(),
1209 },
1210 Message {
1211 role: MessageRole::User,
1212 content: "Hello".to_string(),
1213 },
1214 ];
1215
1216 let converted = OpenAIAdapter::convert_messages(&messages);
1217 assert_eq!(converted.len(), 2);
1218 assert_eq!(converted[0]["role"], "system");
1219 assert_eq!(converted[0]["content"], "You are helpful");
1220 assert_eq!(converted[1]["role"], "user");
1221 assert_eq!(converted[1]["content"], "Hello");
1222 }
1223
1224 #[test]
1225 fn test_context_length() {
1226 let adapter = OpenAIAdapter::new("gpt-4-turbo-preview", "key", None).unwrap();
1227 assert_eq!(adapter.max_context_length(), 128_000);
1228
1229 let adapter = OpenAIAdapter::new("gpt-4", "key", None).unwrap();
1230 assert_eq!(adapter.max_context_length(), 8_192);
1231
1232 let adapter = OpenAIAdapter::new("gpt-3.5-turbo-16k", "key", None).unwrap();
1233 assert_eq!(adapter.max_context_length(), 16_384);
1234 }
1235
1236 #[test]
1237 fn test_supports_vision_gpt4o() {
1238 let adapter = OpenAIAdapter::new("gpt-4o", "key", None).unwrap();
1239 assert!(adapter.supports_vision());
1240 }
1241
1242 #[test]
1243 fn test_supports_vision_gpt4_turbo() {
1244 let adapter = OpenAIAdapter::new("gpt-4-turbo", "key", None).unwrap();
1245 assert!(adapter.supports_vision());
1246 }
1247
1248 #[test]
1249 fn test_supports_vision_gpt4o_mini() {
1250 let adapter = OpenAIAdapter::new("gpt-4o-mini", "key", None).unwrap();
1251 assert!(adapter.supports_vision());
1252 }
1253
1254 #[test]
1255 fn test_supports_vision_gpt35_is_false() {
1256 let adapter = OpenAIAdapter::new("gpt-3.5-turbo", "key", None).unwrap();
1257 assert!(!adapter.supports_vision());
1258 }
1259
1260 #[test]
1261 fn test_supports_vision_llava() {
1262 let adapter = OpenAIAdapter::new("llava:13b", "key", None).unwrap();
1263 assert!(adapter.supports_vision());
1264 }
1265
1266 #[test]
1267 fn test_supports_vision_o1() {
1268 let adapter = OpenAIAdapter::new("o1-preview", "key", None).unwrap();
1269 assert!(adapter.supports_vision());
1270 }
1271
1272 #[test]
1273 fn test_supports_vision_gemma3() {
1274 let adapter = OpenAIAdapter::new("gemma3:12b", "key", None).unwrap();
1275 assert!(adapter.supports_vision());
1276 }
1277
1278 #[tokio::test]
1279 async fn transcribe_image_rejects_non_image_mime() {
1280 let adapter = OpenAIAdapter::new("gpt-4o", "fake-key", None).unwrap();
1281 let result = adapter
1282 .transcribe_image(b"not-an-image", "text/plain", None)
1283 .await;
1284 assert!(result.is_err());
1285 assert!(
1286 matches!(result.unwrap_err(), LlmError::InvalidResponse(_)),
1287 "Expected InvalidResponse for non-image MIME type"
1288 );
1289 }
1290
1291 #[test]
1292 fn test_transcription_model_default() {
1293 unsafe { std::env::remove_var("TRANSCRIPTION_MODEL") };
1297 let adapter = OpenAIAdapter::new("gpt-4", "key", None).unwrap();
1298 assert_eq!(adapter.transcription_model(), "whisper-1");
1299 }
1300
1301 #[test]
1302 fn test_transcription_model_custom() {
1303 let adapter = OpenAIAdapter::new("gpt-4", "key", None)
1304 .unwrap()
1305 .with_transcription_model("whisper-large-v3");
1306 assert_eq!(adapter.transcription_model(), "whisper-large-v3");
1307 }
1308
1309 #[test]
1310 fn test_audio_mime_type_mapping() {
1311 assert_eq!(audio_mime_type("mp3"), "audio/mpeg");
1312 assert_eq!(audio_mime_type("mpeg"), "audio/mpeg");
1313 assert_eq!(audio_mime_type("mpga"), "audio/mpeg");
1314 assert_eq!(audio_mime_type("mp4"), "audio/mp4");
1315 assert_eq!(audio_mime_type("m4a"), "audio/mp4");
1316 assert_eq!(audio_mime_type("wav"), "audio/wav");
1317 assert_eq!(audio_mime_type("webm"), "audio/webm");
1318 }
1319
1320 #[test]
1321 fn test_to_strict_schema_marks_all_required_and_closes_objects() {
1322 let schema = json!({
1326 "type": "object",
1327 "properties": {
1328 "nodes": { "type": "array", "items": { "$ref": "#/definitions/Node" } }
1329 },
1330 "required": ["nodes"],
1331 "definitions": {
1332 "Node": {
1333 "type": "object",
1334 "properties": {
1335 "name": { "type": "string" },
1336 "type": { "type": "string" },
1337 "description": { "type": ["string", "null"] }
1338 },
1339 "required": ["name", "type"]
1340 }
1341 }
1342 });
1343
1344 let strict = to_strict_schema(&schema);
1345
1346 assert_eq!(strict["additionalProperties"], json!(false));
1348 assert_eq!(strict["required"], json!(["nodes"]));
1349
1350 let node = &strict["definitions"]["Node"];
1353 assert_eq!(node["additionalProperties"], json!(false));
1354 let mut req: Vec<String> = node["required"]
1355 .as_array()
1356 .unwrap()
1357 .iter()
1358 .map(|v| v.as_str().unwrap().to_string())
1359 .collect();
1360 req.sort();
1361 assert_eq!(req, vec!["description", "name", "type"]);
1362 }
1363}