1use async_trait::async_trait;
7use reqwest::Client;
8use serde::Deserialize;
9use serde_json::{Value, json};
10use tracing::{debug, instrument, warn};
11
12#[allow(unused_imports)]
13use cognee_utils::tracing_keys::{COGNEE_LLM_MODEL, COGNEE_LLM_PROVIDER};
14
15use crate::error::{LlmError, LlmResult};
16use crate::llm_trait::Llm;
17use crate::transcriber::{Transcriber, TranscriptionOutput, validate_audio_format};
18use crate::types::{GenerationOptions, GenerationResponse, Message, MessageRole, TokenUsage};
19
20#[derive(Clone)]
46pub struct OpenAIAdapter {
47 model: String,
48 api_key: String,
49 base_url: String,
50 client: Client,
51 structured_output_retries: usize,
52 network_retries: usize,
54 transcription_model: String,
56}
57
58impl OpenAIAdapter {
59 pub const DEFAULT_BASE_URL: &'static str = "https://api.openai.com/v1";
61 pub const DEFAULT_STRUCTURED_OUTPUT_RETRIES: usize = 5;
68 pub const DEFAULT_NETWORK_RETRIES: usize = 3;
70
71 pub fn new(
81 model: impl Into<String>,
82 api_key: impl Into<String>,
83 base_url: Option<String>,
84 ) -> LlmResult<Self> {
85 let client = Client::builder()
86 .timeout(std::time::Duration::from_secs(600))
87 .build()
88 .map_err(|e| LlmError::ConfigError(format!("Failed to create HTTP client: {e}")))?;
89
90 let transcription_model =
91 std::env::var("TRANSCRIPTION_MODEL").unwrap_or_else(|_| "whisper-1".to_string());
92
93 let model: String = model.into();
99 let model = model
100 .strip_prefix("openai/")
101 .map(str::to_string)
102 .unwrap_or(model);
103
104 Ok(Self {
105 model,
106 api_key: api_key.into(),
107 base_url: base_url.unwrap_or_else(|| Self::DEFAULT_BASE_URL.to_string()),
108 client,
109 structured_output_retries: Self::DEFAULT_STRUCTURED_OUTPUT_RETRIES,
110 network_retries: Self::DEFAULT_NETWORK_RETRIES,
111 transcription_model,
112 })
113 }
114
115 pub fn with_structured_output_retries(mut self, retries: u32) -> Self {
119 let retries = usize::try_from(retries).unwrap_or(usize::MAX);
120 self.structured_output_retries = retries.max(1);
121 self
122 }
123
124 pub fn with_network_retries(mut self, retries: u32) -> Self {
128 self.network_retries = usize::try_from(retries).unwrap_or(usize::MAX);
129 self
130 }
131
132 pub fn with_transcription_model(mut self, model: impl Into<String>) -> Self {
134 self.transcription_model = model.into();
135 self
136 }
137
138 fn auth_header(&self) -> String {
140 format!("Bearer {}", self.api_key)
141 }
142
143 fn should_disable_thinking(&self) -> bool {
145 self.model.to_lowercase().starts_with("qwen") && !self.base_url.contains("api.openai.com")
146 }
147
148 fn is_reasoning_model(&self) -> bool {
156 if !self.base_url.contains("api.openai.com") {
157 return false;
158 }
159 let m = self.model.to_lowercase();
160 m.starts_with("gpt-5") || m.starts_with("o1") || m.starts_with("o3") || m.starts_with("o4")
161 }
162
163 fn write_max_tokens(&self, body: &mut Value, value: Option<u32>) {
166 if let Some(v) = value {
167 let key = if self.is_reasoning_model() {
168 "max_completion_tokens"
169 } else {
170 "max_tokens"
171 };
172 body[key] = json!(v);
173 }
174 }
175
176 #[instrument(
186 name = "llm.api_call",
187 level = "info",
188 skip(self, request_body),
189 fields(
190 url = tracing::field::Empty,
191 cognee.llm.model = self.model.as_str(),
192 cognee.llm.provider = "openai",
193 ),
194 )]
195 async fn call_api(&self, request_body: Value) -> LlmResult<OpenAIResponse> {
196 let url = format!("{}/chat/completions", self.base_url);
197 tracing::Span::current().record("url", url.as_str());
198 let debug_enabled = std::env::var("COGNEE_DEBUG_LLM_REQUEST")
199 .map(|v| cognee_utils::parse_env_bool(&v))
200 .unwrap_or(false);
201
202 if debug_enabled {
203 let pretty_request = serde_json::to_string_pretty(&request_body)
204 .unwrap_or_else(|_| request_body.to_string());
205 eprintln!("\n[COGNEE_DEBUG_LLM_REQUEST] POST {url}\n{pretty_request}\n");
206 }
207
208 let mut last_error = LlmError::NetworkError("No attempt made".to_string());
209
210 for attempt in 0..=self.network_retries {
211 debug!(attempt, "LLM API attempt");
212 if attempt > 0 {
213 let delay_ms = (1_000u64 * 2u64.saturating_pow(attempt as u32 - 1)).min(30_000);
214 warn!(
215 attempt,
216 network_retries = self.network_retries,
217 delay_ms,
218 error = %last_error,
219 "LLM request failed, retrying",
220 );
221 tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await;
222 }
223
224 let response = match self
225 .client
226 .post(&url)
227 .header("Authorization", self.auth_header())
228 .header("Content-Type", "application/json")
229 .json(&request_body)
230 .send()
231 .await
232 {
233 Ok(r) => r,
234 Err(e) => {
235 last_error = LlmError::NetworkError(e.to_string());
236 continue;
237 }
238 };
239
240 let status = response.status();
241
242 if !status.is_success() {
243 let error_body = response
244 .text()
245 .await
246 .unwrap_or_else(|_| "Unknown error".to_string());
247
248 let err = match status.as_u16() {
249 401 => LlmError::AuthenticationError(error_body),
250 429 => LlmError::RateLimitExceeded(error_body),
251 400 => LlmError::InvalidResponse(format!("Bad request: {error_body}")),
252 _ => LlmError::ApiError(format!("HTTP {status}: {error_body}")),
253 };
254
255 if matches!(status.as_u16(), 400 | 401) {
257 return Err(err);
258 }
259
260 last_error = err;
261 continue;
262 }
263
264 let response_body = response.text().await.map_err(|e| {
265 LlmError::DeserializationError(format!("Failed to read response body: {e}"))
266 })?;
267
268 if debug_enabled {
269 eprintln!("\n[COGNEE_DEBUG_LLM_RESPONSE] POST {url}\n{response_body}\n");
270 }
271
272 return serde_json::from_str::<OpenAIResponse>(&response_body).map_err(|e| {
273 LlmError::DeserializationError(format!(
274 "Failed to parse response: {e}. Raw body: {response_body}"
275 ))
276 });
277 }
278
279 Err(LlmError::MaxRetriesExceeded(format!(
280 "LLM request failed after {} attempt(s): {}",
281 self.network_retries + 1,
282 last_error
283 )))
284 }
285
286 fn convert_messages(messages: &[Message]) -> Vec<Value> {
288 messages
289 .iter()
290 .map(|msg| {
291 json!({
292 "role": match msg.role {
293 MessageRole::System => "system",
294 MessageRole::User => "user",
295 MessageRole::Assistant => "assistant",
296 },
297 "content": msg.content
298 })
299 })
300 .collect()
301 }
302
303 fn schema_to_example(schema: &Value) -> String {
306 fn create_example(value: &Value, definitions: Option<&Value>) -> Value {
307 match value {
308 Value::Object(obj) => {
309 if let Some(ref_str) = obj.get("$ref").and_then(|v| v.as_str())
311 && let Some(def_name) = ref_str.strip_prefix("#/definitions/")
312 && let Some(defs) = definitions
313 && let Some(def) = defs.get(def_name)
314 {
315 return create_example(def, definitions);
316 }
317
318 let type_val = obj.get("type");
320
321 if let Some(Value::String(t)) = type_val
323 && t == "array"
324 {
325 if let Some(items) = obj.get("items") {
326 return json!([create_example(items, definitions)]);
328 }
329 return json!([]);
330 }
331
332 if let Some(props) = obj.get("properties")
334 && let Value::Object(props_obj) = props
335 {
336 let mut result = serde_json::Map::new();
337 for (key, val) in props_obj {
338 result.insert(key.clone(), create_example(val, definitions));
339 }
340 return Value::Object(result);
341 }
342
343 if let Some(Value::String(t)) = type_val {
345 return match t.as_str() {
346 "string" => json!("example"),
347 "number" | "integer" => json!(0),
348 "boolean" => json!(false),
349 _ => json!(null),
350 };
351 }
352
353 if let Some(Value::Array(types)) = type_val {
355 for t in types {
356 if let Value::String(type_str) = t
357 && type_str != "null"
358 {
359 return match type_str.as_str() {
360 "string" => json!("example"),
361 "number" | "integer" => json!(0),
362 "boolean" => json!(false),
363 _ => json!(null),
364 };
365 }
366 }
367 }
368
369 json!(null)
370 }
371 _ => value.clone(),
372 }
373 }
374
375 let definitions = schema.get("definitions");
376 let example = create_example(schema, definitions);
377
378 serde_json::to_string_pretty(&example).unwrap_or_else(|_| "{}".to_string())
379 }
380}
381
382fn to_strict_schema(schema: &Value) -> Value {
401 fn walk(value: &mut Value) {
402 match value {
403 Value::Object(obj) => {
404 if let Some(Value::Object(props)) = obj.get("properties") {
405 let keys: Vec<Value> = props.keys().map(|k| Value::String(k.clone())).collect();
407 obj.insert("required".to_string(), Value::Array(keys));
408 obj.insert("additionalProperties".to_string(), Value::Bool(false));
409 }
410 for (_k, v) in obj.iter_mut() {
411 walk(v);
412 }
413 }
414 Value::Array(items) => {
415 for v in items.iter_mut() {
416 walk(v);
417 }
418 }
419 _ => {}
420 }
421 }
422
423 let mut out = schema.clone();
424 walk(&mut out);
425 out
426}
427
428#[async_trait]
429impl Llm for OpenAIAdapter {
430 async fn generate(
431 &self,
432 messages: Vec<Message>,
433 options: Option<GenerationOptions>,
434 ) -> LlmResult<GenerationResponse> {
435 let opts = options.unwrap_or_default();
436
437 let mut request_body = json!({
438 "model": self.model,
439 "messages": Self::convert_messages(&messages),
440 });
441
442 if !self.is_reasoning_model() {
445 if let Some(temp) = opts.temperature {
446 request_body["temperature"] = json!(temp);
447 }
448 if let Some(top_p) = opts.top_p {
449 request_body["top_p"] = json!(top_p);
450 }
451 if let Some(freq_penalty) = opts.frequency_penalty {
452 request_body["frequency_penalty"] = json!(freq_penalty);
453 }
454 if let Some(pres_penalty) = opts.presence_penalty {
455 request_body["presence_penalty"] = json!(pres_penalty);
456 }
457 }
458 self.write_max_tokens(&mut request_body, opts.max_tokens);
459 if let Some(stop) = opts.stop
460 && !stop.is_empty()
461 {
462 request_body["stop"] = json!(stop);
463 }
464
465 if self.should_disable_thinking() {
466 request_body["think"] = json!(false);
467 request_body["reasoning"] = json!({"effort": "none"});
468 }
469
470 let response = self.call_api(request_body).await?;
471
472 let choice = response
474 .choices
475 .first()
476 .ok_or_else(|| LlmError::InvalidResponse("No choices in response".to_string()))?;
477
478 Ok(GenerationResponse {
479 content: choice.message.content.clone().unwrap_or_default(),
480 model: response.model,
481 finish_reason: choice.finish_reason.clone(),
482 usage: response.usage.map(|u| TokenUsage {
483 prompt_tokens: u.prompt_tokens,
484 completion_tokens: u.completion_tokens,
485 total_tokens: u.total_tokens,
486 }),
487 })
488 }
489
490 async fn create_structured_output_with_messages_raw(
491 &self,
492 messages: Vec<Message>,
493 json_schema: &Value,
494 options: Option<GenerationOptions>,
495 ) -> LlmResult<Value> {
496 let is_empty_or_non_json = |raw: &str| {
497 let trimmed = raw.trim();
498 trimmed.is_empty() || serde_json::from_str::<Value>(trimmed).is_err()
499 };
500
501 let parse_json =
502 |raw: &str| -> Result<Value, serde_json::Error> { serde_json::from_str(raw) };
503
504 let opts = options.unwrap_or_default();
505 let schema = json_schema;
506
507 let strict_schema = to_strict_schema(schema);
512
513 let mut strict_schema_request = json!({
515 "model": self.model,
516 "messages": Self::convert_messages(&messages),
517 "response_format": {
518 "type": "json_schema",
519 "json_schema": {
520 "name": "extract_structured_data",
521 "schema": strict_schema,
522 "strict": true
523 }
524 }
525 });
526
527 if !self.is_reasoning_model()
528 && let Some(temp) = opts.temperature
529 {
530 strict_schema_request["temperature"] = json!(temp);
531 }
532 self.write_max_tokens(&mut strict_schema_request, opts.max_tokens);
533 if self.should_disable_thinking() {
534 strict_schema_request["think"] = json!(false);
535 strict_schema_request["reasoning"] = json!({"effort": "none"});
536 }
537
538 for attempt in 0..self.structured_output_retries {
539 match self.call_api(strict_schema_request.clone()).await {
540 Ok(strict_response) => {
541 let strict_choice = strict_response.choices.first().ok_or_else(|| {
542 LlmError::InvalidResponse(
543 "No choices in strict schema response".to_string(),
544 )
545 })?;
546
547 if let Some(function_call) = &strict_choice.message.function_call {
548 match parse_json(&function_call.arguments) {
549 Ok(parsed) => return Ok(parsed),
550 Err(e) => {
551 if attempt + 1 < self.structured_output_retries
552 && is_empty_or_non_json(&function_call.arguments)
553 {
554 continue;
555 }
556 if !is_empty_or_non_json(&function_call.arguments) {
557 return Err(LlmError::DeserializationError(format!(
558 "Failed to deserialize strict function call arguments: {}. Raw: {}",
559 e, function_call.arguments
560 )));
561 }
562 break;
563 }
564 }
565 }
566
567 if let Some(content) = strict_choice.message.content.as_ref() {
568 match parse_json(content) {
569 Ok(parsed) => return Ok(parsed),
570 Err(e) => {
571 if attempt + 1 < self.structured_output_retries
572 && is_empty_or_non_json(content)
573 {
574 continue;
575 }
576 if !is_empty_or_non_json(content) {
577 return Err(LlmError::DeserializationError(format!(
578 "Failed to deserialize strict JSON content: {e}. Raw: {content}"
579 )));
580 }
581 break;
582 }
583 }
584 }
585 }
586 Err(e) => {
587 warn!(error = %e, "strict json_schema request failed; falling back to function/JSON mode");
593 break;
594 }
595 }
596 }
597
598 let mut request_body = json!({
600 "model": self.model,
601 "messages": Self::convert_messages(&messages),
602 "functions": [{
603 "name": "extract_structured_data",
604 "description": "Extract structured data from the input",
605 "parameters": schema
606 }],
607 "function_call": {"name": "extract_structured_data"}
608 });
609
610 if !self.is_reasoning_model()
611 && let Some(temp) = opts.temperature
612 {
613 request_body["temperature"] = json!(temp);
614 }
615 self.write_max_tokens(&mut request_body, opts.max_tokens);
616 if self.should_disable_thinking() {
617 request_body["think"] = json!(false);
618 request_body["reasoning"] = json!({"effort": "none"});
619 }
620
621 for attempt in 0..self.structured_output_retries {
622 let response = self.call_api(request_body.clone()).await?;
623
624 let choice = response
625 .choices
626 .first()
627 .ok_or_else(|| LlmError::InvalidResponse("No choices in response".to_string()))?;
628
629 if let Some(function_call) = &choice.message.function_call {
630 match parse_json(&function_call.arguments) {
631 Ok(parsed) => return Ok(parsed),
632 Err(e) => {
633 if attempt + 1 < self.structured_output_retries
634 && is_empty_or_non_json(&function_call.arguments)
635 {
636 continue;
637 }
638 if !is_empty_or_non_json(&function_call.arguments) {
639 return Err(LlmError::DeserializationError(format!(
640 "Failed to deserialize function call arguments: {}. Raw: {}",
641 e, function_call.arguments
642 )));
643 }
644 break;
645 }
646 }
647 }
648
649 break;
650 }
651
652 let mut json_messages = Self::convert_messages(&messages);
654
655 let example = Self::schema_to_example(schema);
656
657 if let Some(last_msg) = json_messages.last_mut()
658 && last_msg["role"] == "user"
659 {
660 let original_content = last_msg["content"].as_str().unwrap_or("");
661 last_msg["content"] = json!(format!(
662 "{}\n\n\
663 Extract the information from the text above and return it as JSON.\n\
664 Use this structure as your template (but with actual data from the text):\n\
665 {}",
666 original_content, example
667 ));
668 }
669
670 let mut json_request = json!({
671 "model": self.model,
672 "messages": json_messages,
673 "response_format": {"type": "json_object"}
674 });
675
676 if !self.is_reasoning_model()
677 && let Some(temp) = opts.temperature
678 {
679 json_request["temperature"] = json!(temp);
680 }
681 self.write_max_tokens(&mut json_request, opts.max_tokens);
682 if self.should_disable_thinking() {
683 json_request["think"] = json!(false);
684 json_request["reasoning"] = json!({"effort": "none"});
685 }
686
687 for attempt in 0..self.structured_output_retries {
688 let mut request_for_attempt = json_request.clone();
689
690 if attempt > 0 {
691 if let Some(messages) = request_for_attempt["messages"].as_array_mut()
692 && let Some(last_msg) = messages.last_mut()
693 && last_msg["role"] == "user"
694 {
695 let original_content = last_msg["content"].as_str().unwrap_or("");
696 last_msg["content"] = json!(format!(
697 "{}\n\n/no_think\nReturn ONLY one valid JSON object matching the required schema. No reasoning, no markdown, no extra text.",
698 original_content
699 ));
700 }
701
702 if !self.is_reasoning_model() {
703 request_for_attempt["temperature"] = json!(0.0);
704 }
705 }
706
707 let json_response = self.call_api(request_for_attempt).await?;
708
709 let json_choice = json_response.choices.first().ok_or_else(|| {
710 LlmError::InvalidResponse("No choices in JSON mode response".to_string())
711 })?;
712
713 let content = json_choice.message.content.as_ref().ok_or_else(|| {
714 LlmError::InvalidResponse("No content in JSON mode response".to_string())
715 })?;
716
717 match parse_json(content) {
718 Ok(parsed) => return Ok(parsed),
719 Err(e) => {
720 if attempt + 1 < self.structured_output_retries && is_empty_or_non_json(content)
721 {
722 continue;
723 }
724 return Err(LlmError::DeserializationError(format!(
725 "Failed to deserialize JSON content: {e}. Raw: {content}"
726 )));
727 }
728 }
729 }
730
731 Err(LlmError::InvalidResponse(
732 "Structured output retries exhausted without a parseable response".to_string(),
733 ))
734 }
735
736 fn model(&self) -> &str {
737 &self.model
738 }
739
740 fn supports_streaming(&self) -> bool {
741 true
742 }
743
744 fn supports_function_calling(&self) -> bool {
745 true
746 }
747
748 fn max_context_length(&self) -> u32 {
749 match self.model.as_str() {
751 m if m.starts_with("gpt-4-turbo") => 128_000,
752 m if m.starts_with("gpt-4-32k") => 32_768,
753 m if m.starts_with("gpt-4") => 8_192,
754 m if m.starts_with("gpt-3.5-turbo-16k") => 16_384,
755 m if m.starts_with("gpt-3.5-turbo") => 4_096,
756 _ => 4_096, }
758 }
759
760 async fn transcribe_image(
761 &self,
762 image_bytes: &[u8],
763 mime_type: &str,
764 options: Option<GenerationOptions>,
765 ) -> LlmResult<String> {
766 use base64::Engine as _;
767
768 if !mime_type.starts_with("image/") {
769 return Err(LlmError::InvalidResponse(format!(
770 "Expected image/* MIME type, got: {mime_type}"
771 )));
772 }
773
774 let b64 = base64::engine::general_purpose::STANDARD.encode(image_bytes);
775 let data_uri = format!("data:{mime_type};base64,{b64}");
776
777 let vision_model = std::env::var("LLM_VISION_MODEL")
778 .ok()
779 .filter(|s| !s.is_empty())
780 .unwrap_or_else(|| self.model.clone());
781
782 let max_tokens = options.as_ref().and_then(|o| o.max_tokens).unwrap_or(300);
783
784 let mut request_body = json!({
785 "model": vision_model,
786 "messages": [{
787 "role": "user",
788 "content": [
789 { "type": "text", "text": "What's in this image?" },
790 { "type": "image_url", "image_url": { "url": data_uri } }
791 ]
792 }],
793 });
794 self.write_max_tokens(&mut request_body, Some(max_tokens));
795
796 let response = self.call_api(request_body).await?;
797
798 let choice = response.choices.first().ok_or_else(|| {
799 LlmError::InvalidResponse("No choices in vision response".to_string())
800 })?;
801
802 choice.message.content.clone().ok_or_else(|| {
803 LlmError::InvalidResponse("Vision response contained no content".to_string())
804 })
805 }
806
807 fn supports_vision(&self) -> bool {
808 let m = self.model.to_lowercase();
809 m.contains("gpt-4")
810 || m.contains("gpt-5")
811 || m.contains("vision")
812 || m.contains("o1")
813 || m.contains("o3")
814 || m.contains("o4")
815 || m.contains("llava")
816 || m.contains("moondream")
817 || m.contains("llama-3.2-vision")
818 || m.contains("gemma3")
819 }
820}
821
822#[derive(Debug, Deserialize)]
828struct WhisperResponse {
829 text: String,
830 language: Option<String>,
831 duration: Option<f32>,
832}
833
834fn audio_mime_type(format: &str) -> &'static str {
836 match format {
837 "mp3" | "mpeg" | "mpga" => "audio/mpeg",
838 "mp4" | "m4a" => "audio/mp4",
839 "wav" => "audio/wav",
840 "webm" => "audio/webm",
841 _ => "application/octet-stream",
843 }
844}
845
846impl OpenAIAdapter {
847 #[instrument(
849 name = "llm.transcription_api_call",
850 level = "info",
851 skip(self, form),
852 fields(
853 url = tracing::field::Empty,
854 cognee.llm.model = self.transcription_model.as_str(),
855 cognee.llm.provider = "openai",
856 ),
857 )]
858 async fn call_transcription_api(
859 &self,
860 form: reqwest::multipart::Form,
861 ) -> LlmResult<WhisperResponse> {
862 let url = format!("{}/audio/transcriptions", self.base_url);
863 tracing::Span::current().record("url", url.as_str());
864
865 let response = self
882 .client
883 .post(&url)
884 .header("Authorization", self.auth_header())
885 .multipart(form)
886 .send()
887 .await
888 .map_err(|e| LlmError::NetworkError(e.to_string()))?;
889
890 let status = response.status();
891
892 if !status.is_success() {
893 let error_body = response
894 .text()
895 .await
896 .unwrap_or_else(|_| "Unknown error".to_string());
897
898 return Err(match status.as_u16() {
899 401 => LlmError::AuthenticationError(error_body),
900 429 => LlmError::RateLimitExceeded(error_body),
901 400 => LlmError::InvalidResponse(format!("Bad request: {error_body}")),
902 _ => LlmError::ApiError(format!("HTTP {status}: {error_body}")),
903 });
904 }
905
906 let response_body = response.text().await.map_err(|e| {
907 LlmError::DeserializationError(format!("Failed to read response body: {e}"))
908 })?;
909
910 serde_json::from_str::<WhisperResponse>(&response_body).map_err(|e| {
911 LlmError::DeserializationError(format!(
912 "Failed to parse Whisper response: {e}. Raw body: {response_body}"
913 ))
914 })
915 }
916
917 fn build_transcription_form(
919 &self,
920 audio: &[u8],
921 format: &str,
922 language_hint: Option<&str>,
923 prompt_hint: Option<&str>,
924 ) -> LlmResult<reqwest::multipart::Form> {
925 let mime = audio_mime_type(format);
926 let filename = format!("audio.{format}");
927
928 let file_part = reqwest::multipart::Part::bytes(audio.to_vec())
929 .file_name(filename)
930 .mime_str(mime)
931 .map_err(|e| {
932 LlmError::ConfigError(format!("Failed to set MIME type on multipart part: {e}"))
933 })?;
934
935 let mut form = reqwest::multipart::Form::new()
936 .part("file", file_part)
937 .text("model", self.transcription_model.clone())
938 .text("response_format", "verbose_json");
939
940 if let Some(lang) = language_hint {
941 form = form.text("language", lang.to_string());
942 }
943 if let Some(prompt) = prompt_hint {
944 form = form.text("prompt", prompt.to_string());
945 }
946
947 Ok(form)
948 }
949}
950
951#[async_trait]
952impl Transcriber for OpenAIAdapter {
953 async fn transcribe_audio(
954 &self,
955 audio: &[u8],
956 format: &str,
957 language_hint: Option<&str>,
958 prompt_hint: Option<&str>,
959 ) -> LlmResult<TranscriptionOutput> {
960 let format_lower = format.to_ascii_lowercase();
962 validate_audio_format(&format_lower)?;
963
964 let mut last_error = LlmError::NetworkError("No attempt made".to_string());
965
966 for attempt in 0..=self.network_retries {
967 debug!(attempt, "Transcription API attempt");
968 if attempt > 0 {
969 let delay_ms = (1_000u64 * 2u64.saturating_pow(attempt as u32 - 1)).min(30_000);
970 warn!(
971 attempt,
972 network_retries = self.network_retries,
973 delay_ms,
974 error = %last_error,
975 "Transcription request failed, retrying",
976 );
977 tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await;
978 }
979
980 let form =
981 self.build_transcription_form(audio, &format_lower, language_hint, prompt_hint)?;
982
983 match self.call_transcription_api(form).await {
984 Ok(resp) => {
985 return Ok(TranscriptionOutput {
986 text: resp.text,
987 language: resp.language,
988 duration: resp.duration,
989 });
990 }
991 Err(e) => {
992 if matches!(
994 e,
995 LlmError::InvalidResponse(_) | LlmError::AuthenticationError(_)
996 ) {
997 return Err(e);
998 }
999 last_error = e;
1000 continue;
1001 }
1002 }
1003 }
1004
1005 Err(LlmError::MaxRetriesExceeded(format!(
1006 "Transcription request failed after {} attempt(s): {}",
1007 self.network_retries + 1,
1008 last_error
1009 )))
1010 }
1011
1012 fn transcription_model(&self) -> &str {
1013 &self.transcription_model
1014 }
1015}
1016
1017#[derive(Debug, Deserialize)]
1019#[allow(dead_code)]
1020struct OpenAIResponse {
1021 id: String,
1022 object: String,
1023 created: i64,
1024 model: String,
1025 choices: Vec<OpenAIChoice>,
1026 usage: Option<OpenAIUsage>,
1027}
1028
1029#[derive(Debug, Deserialize)]
1030#[allow(dead_code)]
1031struct OpenAIChoice {
1032 index: u32,
1033 message: OpenAIMessage,
1034 finish_reason: Option<String>,
1035}
1036
1037#[derive(Debug, Deserialize)]
1038#[allow(dead_code)]
1039struct OpenAIMessage {
1040 role: String,
1041 content: Option<String>,
1042 reasoning: Option<String>,
1043 function_call: Option<OpenAIFunctionCall>,
1044}
1045
1046#[derive(Debug, Deserialize)]
1047#[allow(dead_code)]
1048struct OpenAIFunctionCall {
1049 name: String,
1050 arguments: String,
1051}
1052
1053#[derive(Debug, Deserialize)]
1054struct OpenAIUsage {
1055 prompt_tokens: u32,
1056 completion_tokens: u32,
1057 total_tokens: u32,
1058}
1059
1060#[cfg(test)]
1061mod tests {
1062 #![allow(
1063 clippy::unwrap_used,
1064 clippy::expect_used,
1065 reason = "test code — panics are acceptable"
1066 )]
1067 use super::*;
1068
1069 #[test]
1070 fn test_openai_provider_prefix_is_stripped() {
1071 let adapter = OpenAIAdapter::new("openai/gpt-5-mini", "test-key", None).unwrap();
1073 assert_eq!(adapter.model(), "gpt-5-mini");
1074 let adapter = OpenAIAdapter::new("ollama/llama3", "test-key", None).unwrap();
1076 assert_eq!(adapter.model(), "ollama/llama3");
1077 }
1078
1079 #[test]
1080 fn test_openai_adapter_creation() {
1081 let adapter = OpenAIAdapter::new("gpt-4", "test-key", None);
1082 assert!(adapter.is_ok());
1083
1084 let adapter = adapter.unwrap();
1085 assert_eq!(adapter.model(), "gpt-4");
1086 assert_eq!(adapter.base_url, OpenAIAdapter::DEFAULT_BASE_URL);
1087 assert_eq!(
1088 adapter.structured_output_retries,
1089 OpenAIAdapter::DEFAULT_STRUCTURED_OUTPUT_RETRIES
1090 );
1091 }
1092
1093 #[test]
1094 fn test_configurable_structured_output_retries() {
1095 let adapter = OpenAIAdapter::new("gpt-4", "test-key", None)
1096 .unwrap()
1097 .with_structured_output_retries(5);
1098 assert_eq!(adapter.structured_output_retries, 5);
1099
1100 let adapter = OpenAIAdapter::new("gpt-4", "test-key", None)
1101 .unwrap()
1102 .with_structured_output_retries(0);
1103 assert_eq!(adapter.structured_output_retries, 1);
1104 }
1105
1106 #[test]
1107 fn test_openai_adapter_custom_base_url() {
1108 let adapter = OpenAIAdapter::new(
1109 "gpt-4",
1110 "test-key",
1111 Some("https://custom.api.com/v1".to_string()),
1112 );
1113 assert!(adapter.is_ok());
1114
1115 let adapter = adapter.unwrap();
1116 assert_eq!(adapter.base_url, "https://custom.api.com/v1");
1117 }
1118
1119 #[test]
1120 fn test_is_reasoning_model_matches_openai_families() {
1121 let cases = [
1122 ("gpt-5", true),
1123 ("gpt-5-mini", true),
1124 ("gpt-5-2025-06-01", true),
1125 ("o1", true),
1126 ("o1-mini", true),
1127 ("o3", true),
1128 ("o3-mini", true),
1129 ("o4-mini", true),
1130 ("GPT-5-Mini", true),
1131 ("gpt-4o-mini", false),
1132 ("gpt-4-turbo", false),
1133 ("gpt-3.5-turbo", false),
1134 ("o-foo", false),
1135 ];
1136 for (model, expected) in cases {
1137 let adapter = OpenAIAdapter::new(model, "test-key", None).unwrap();
1138 assert_eq!(
1139 adapter.is_reasoning_model(),
1140 expected,
1141 "is_reasoning_model({model})"
1142 );
1143 }
1144 }
1145
1146 #[test]
1147 fn test_is_reasoning_model_skipped_for_custom_base_url() {
1148 let adapter = OpenAIAdapter::new(
1152 "gpt-5-mini",
1153 "test-key",
1154 Some("http://localhost:11434/v1".to_string()),
1155 )
1156 .unwrap();
1157 assert!(!adapter.is_reasoning_model());
1158 }
1159
1160 #[test]
1161 fn test_write_max_tokens_renames_key_for_reasoning_models() {
1162 let mut body = json!({"model": "gpt-5-mini"});
1163 let reasoning = OpenAIAdapter::new("gpt-5-mini", "test-key", None).unwrap();
1164 reasoning.write_max_tokens(&mut body, Some(2048));
1165 assert!(body.get("max_tokens").is_none());
1166 assert_eq!(body["max_completion_tokens"], 2048);
1167
1168 let mut body = json!({"model": "gpt-4o-mini"});
1169 let classic = OpenAIAdapter::new("gpt-4o-mini", "test-key", None).unwrap();
1170 classic.write_max_tokens(&mut body, Some(2048));
1171 assert_eq!(body["max_tokens"], 2048);
1172 assert!(body.get("max_completion_tokens").is_none());
1173
1174 let mut body = json!({"model": "gpt-5-mini"});
1176 reasoning.write_max_tokens(&mut body, None);
1177 assert!(body.get("max_tokens").is_none());
1178 assert!(body.get("max_completion_tokens").is_none());
1179 }
1180
1181 #[test]
1182 fn test_message_conversion() {
1183 let messages = vec![
1184 Message {
1185 role: MessageRole::System,
1186 content: "You are helpful".to_string(),
1187 },
1188 Message {
1189 role: MessageRole::User,
1190 content: "Hello".to_string(),
1191 },
1192 ];
1193
1194 let converted = OpenAIAdapter::convert_messages(&messages);
1195 assert_eq!(converted.len(), 2);
1196 assert_eq!(converted[0]["role"], "system");
1197 assert_eq!(converted[0]["content"], "You are helpful");
1198 assert_eq!(converted[1]["role"], "user");
1199 assert_eq!(converted[1]["content"], "Hello");
1200 }
1201
1202 #[test]
1203 fn test_context_length() {
1204 let adapter = OpenAIAdapter::new("gpt-4-turbo-preview", "key", None).unwrap();
1205 assert_eq!(adapter.max_context_length(), 128_000);
1206
1207 let adapter = OpenAIAdapter::new("gpt-4", "key", None).unwrap();
1208 assert_eq!(adapter.max_context_length(), 8_192);
1209
1210 let adapter = OpenAIAdapter::new("gpt-3.5-turbo-16k", "key", None).unwrap();
1211 assert_eq!(adapter.max_context_length(), 16_384);
1212 }
1213
1214 #[test]
1215 fn test_supports_vision_gpt4o() {
1216 let adapter = OpenAIAdapter::new("gpt-4o", "key", None).unwrap();
1217 assert!(adapter.supports_vision());
1218 }
1219
1220 #[test]
1221 fn test_supports_vision_gpt4_turbo() {
1222 let adapter = OpenAIAdapter::new("gpt-4-turbo", "key", None).unwrap();
1223 assert!(adapter.supports_vision());
1224 }
1225
1226 #[test]
1227 fn test_supports_vision_gpt4o_mini() {
1228 let adapter = OpenAIAdapter::new("gpt-4o-mini", "key", None).unwrap();
1229 assert!(adapter.supports_vision());
1230 }
1231
1232 #[test]
1233 fn test_supports_vision_gpt35_is_false() {
1234 let adapter = OpenAIAdapter::new("gpt-3.5-turbo", "key", None).unwrap();
1235 assert!(!adapter.supports_vision());
1236 }
1237
1238 #[test]
1239 fn test_supports_vision_llava() {
1240 let adapter = OpenAIAdapter::new("llava:13b", "key", None).unwrap();
1241 assert!(adapter.supports_vision());
1242 }
1243
1244 #[test]
1245 fn test_supports_vision_o1() {
1246 let adapter = OpenAIAdapter::new("o1-preview", "key", None).unwrap();
1247 assert!(adapter.supports_vision());
1248 }
1249
1250 #[test]
1251 fn test_supports_vision_gemma3() {
1252 let adapter = OpenAIAdapter::new("gemma3:12b", "key", None).unwrap();
1253 assert!(adapter.supports_vision());
1254 }
1255
1256 #[tokio::test]
1257 async fn transcribe_image_rejects_non_image_mime() {
1258 let adapter = OpenAIAdapter::new("gpt-4o", "fake-key", None).unwrap();
1259 let result = adapter
1260 .transcribe_image(b"not-an-image", "text/plain", None)
1261 .await;
1262 assert!(result.is_err());
1263 assert!(
1264 matches!(result.unwrap_err(), LlmError::InvalidResponse(_)),
1265 "Expected InvalidResponse for non-image MIME type"
1266 );
1267 }
1268
1269 #[test]
1270 fn test_transcription_model_default() {
1271 unsafe { std::env::remove_var("TRANSCRIPTION_MODEL") };
1275 let adapter = OpenAIAdapter::new("gpt-4", "key", None).unwrap();
1276 assert_eq!(adapter.transcription_model(), "whisper-1");
1277 }
1278
1279 #[test]
1280 fn test_transcription_model_custom() {
1281 let adapter = OpenAIAdapter::new("gpt-4", "key", None)
1282 .unwrap()
1283 .with_transcription_model("whisper-large-v3");
1284 assert_eq!(adapter.transcription_model(), "whisper-large-v3");
1285 }
1286
1287 #[test]
1288 fn test_audio_mime_type_mapping() {
1289 assert_eq!(audio_mime_type("mp3"), "audio/mpeg");
1290 assert_eq!(audio_mime_type("mpeg"), "audio/mpeg");
1291 assert_eq!(audio_mime_type("mpga"), "audio/mpeg");
1292 assert_eq!(audio_mime_type("mp4"), "audio/mp4");
1293 assert_eq!(audio_mime_type("m4a"), "audio/mp4");
1294 assert_eq!(audio_mime_type("wav"), "audio/wav");
1295 assert_eq!(audio_mime_type("webm"), "audio/webm");
1296 }
1297
1298 #[test]
1299 fn test_to_strict_schema_marks_all_required_and_closes_objects() {
1300 let schema = json!({
1304 "type": "object",
1305 "properties": {
1306 "nodes": { "type": "array", "items": { "$ref": "#/definitions/Node" } }
1307 },
1308 "required": ["nodes"],
1309 "definitions": {
1310 "Node": {
1311 "type": "object",
1312 "properties": {
1313 "name": { "type": "string" },
1314 "type": { "type": "string" },
1315 "description": { "type": ["string", "null"] }
1316 },
1317 "required": ["name", "type"]
1318 }
1319 }
1320 });
1321
1322 let strict = to_strict_schema(&schema);
1323
1324 assert_eq!(strict["additionalProperties"], json!(false));
1326 assert_eq!(strict["required"], json!(["nodes"]));
1327
1328 let node = &strict["definitions"]["Node"];
1331 assert_eq!(node["additionalProperties"], json!(false));
1332 let mut req: Vec<String> = node["required"]
1333 .as_array()
1334 .unwrap()
1335 .iter()
1336 .map(|v| v.as_str().unwrap().to_string())
1337 .collect();
1338 req.sort();
1339 assert_eq!(req, vec!["description", "name", "type"]);
1340 }
1341}