Skip to main content

cognee_llm/adapters/
openai.rs

1//! OpenAI API adapter with JSON schema support for structured outputs.
2//!
3//! This adapter uses OpenAI's function calling or JSON mode to generate
4//! structured outputs based on JSON schemas derived from Rust types.
5
6use async_trait::async_trait;
7use reqwest::Client;
8use serde::Deserialize;
9use serde_json::{Value, json};
10use tracing::{debug, instrument, warn};
11
12#[allow(unused_imports)]
13use cognee_utils::tracing_keys::{COGNEE_LLM_MODEL, COGNEE_LLM_PROVIDER};
14
15use crate::error::{LlmError, LlmResult};
16use crate::llm_trait::Llm;
17use crate::transcriber::{Transcriber, TranscriptionOutput, validate_audio_format};
18use crate::types::{GenerationOptions, GenerationResponse, Message, MessageRole, TokenUsage};
19
20/// OpenAI API adapter.
21///
22/// Supports structured output generation via:
23/// - Strict JSON schema mode (response_format with type: "json_schema")
24/// - Function calling (for GPT-4 and GPT-3.5-turbo)
25/// - JSON mode (response_format with type: "json_object")
26/// - JSON schema validation (via function parameters)
27///
28/// # Example
29/// ```ignore
30/// use cognee_llm::adapters::OpenAIAdapter;
31/// use cognee_llm::Llm;
32///
33/// let adapter = OpenAIAdapter::new(
34///     "gpt-4-turbo-preview",
35///     "sk-...",
36///     None, // Use default base URL
37/// )?;
38///
39/// let result: MyStruct = adapter.create_structured_output(
40///     "Extract information from this text",
41///     "You are a helpful assistant",
42///     None,
43/// ).await?;
44/// ```
45#[derive(Clone)]
46pub struct OpenAIAdapter {
47    model: String,
48    api_key: String,
49    base_url: String,
50    client: Client,
51    structured_output_retries: usize,
52    /// Number of times to retry the HTTP request on transient network/server errors.
53    network_retries: usize,
54    /// Model name for audio transcription (e.g. `"whisper-1"`).
55    transcription_model: String,
56}
57
58impl OpenAIAdapter {
59    /// Default OpenAI API base URL
60    pub const DEFAULT_BASE_URL: &'static str = "https://api.openai.com/v1";
61    /// Default retry attempts for structured output parsing paths.
62    ///
63    /// Python parity: instructor's `acreate_structured_output` retries up to
64    /// `MAX_RETRIES = 5` times on a parse/validation failure. We match that
65    /// count so transient malformed responses get the same number of repair
66    /// chances before the cognify pipeline gives up.
67    pub const DEFAULT_STRUCTURED_OUTPUT_RETRIES: usize = 5;
68    /// Default retry attempts for transient network/server errors.
69    pub const DEFAULT_NETWORK_RETRIES: usize = 3;
70
71    /// Create a new OpenAI adapter.
72    ///
73    /// # Arguments
74    /// * `model` - Model identifier (e.g., "gpt-4", "gpt-3.5-turbo")
75    /// * `api_key` - OpenAI API key
76    /// * `base_url` - Optional custom base URL (defaults to OpenAI's API)
77    ///
78    /// # Returns
79    /// A new OpenAI adapter instance
80    pub fn new(
81        model: impl Into<String>,
82        api_key: impl Into<String>,
83        base_url: Option<String>,
84    ) -> LlmResult<Self> {
85        let client = Client::builder()
86            .timeout(std::time::Duration::from_secs(600))
87            .build()
88            .map_err(|e| LlmError::ConfigError(format!("Failed to create HTTP client: {e}")))?;
89
90        let transcription_model =
91            std::env::var("TRANSCRIPTION_MODEL").unwrap_or_else(|_| "whisper-1".to_string());
92
93        // Strip a leading litellm-style "openai/" provider prefix. Python's
94        // litellm accepts provider-qualified names (e.g. "openai/gpt-5-mini")
95        // and strips the provider before calling the OpenAI-native API, which
96        // itself rejects the prefix. Strip it here for parity so a
97        // provider-qualified config value works against real OpenAI.
98        let model: String = model.into();
99        let model = model
100            .strip_prefix("openai/")
101            .map(str::to_string)
102            .unwrap_or(model);
103
104        Ok(Self {
105            model,
106            api_key: api_key.into(),
107            // Normalise a trailing slash so request URLs built as
108            // `{base_url}/chat/completions` never produce a double slash. The
109            // Gemini OpenAI-compat base ends in `/v1beta/openai/`, and a
110            // user-supplied endpoint may too; both would otherwise 404.
111            base_url: base_url
112                .map(|u| u.trim_end_matches('/').to_string())
113                .unwrap_or_else(|| Self::DEFAULT_BASE_URL.to_string()),
114            client,
115            structured_output_retries: Self::DEFAULT_STRUCTURED_OUTPUT_RETRIES,
116            network_retries: Self::DEFAULT_NETWORK_RETRIES,
117            transcription_model,
118        })
119    }
120
121    /// Configure retry attempts for structured output extraction.
122    ///
123    /// Values lower than 1 are coerced to 1.
124    pub fn with_structured_output_retries(mut self, retries: u32) -> Self {
125        let retries = usize::try_from(retries).unwrap_or(usize::MAX);
126        self.structured_output_retries = retries.max(1);
127        self
128    }
129
130    /// Configure retry attempts for transient network and server errors (HTTP 429, 5xx).
131    ///
132    /// Each retry uses exponential backoff starting at 1 s, doubling up to 30 s.
133    pub fn with_network_retries(mut self, retries: u32) -> Self {
134        self.network_retries = usize::try_from(retries).unwrap_or(usize::MAX);
135        self
136    }
137
138    /// Configure the model used for audio transcription (default: `"whisper-1"`).
139    pub fn with_transcription_model(mut self, model: impl Into<String>) -> Self {
140        self.transcription_model = model.into();
141        self
142    }
143
144    /// Build the authorization header value
145    fn auth_header(&self) -> String {
146        format!("Bearer {}", self.api_key)
147    }
148
149    /// Whether to request non-thinking mode for local Qwen OpenAI-compatible endpoints.
150    fn should_disable_thinking(&self) -> bool {
151        self.model.to_lowercase().starts_with("qwen") && !self.base_url.contains("api.openai.com")
152    }
153
154    /// True for OpenAI reasoning-model families (`gpt-5*`, `o1*`, `o3*`, `o4*`)
155    /// that reject `temperature`/`top_p`/`frequency_penalty`/`presence_penalty`
156    /// overrides and require `max_completion_tokens` in place of `max_tokens`.
157    ///
158    /// Gated on the official `api.openai.com` base URL so custom OpenAI-compatible
159    /// proxies (Ollama, vLLM, …) keep accepting legacy parameters even when the
160    /// configured model name happens to match a reasoning-family prefix.
161    fn is_reasoning_model(&self) -> bool {
162        if !self.base_url.contains("api.openai.com") {
163            return false;
164        }
165        let m = self.model.to_lowercase();
166        m.starts_with("gpt-5") || m.starts_with("o1") || m.starts_with("o3") || m.starts_with("o4")
167    }
168
169    /// Insert `max_tokens` (or `max_completion_tokens` on reasoning models) into a
170    /// request body if `value` is `Some`.
171    fn write_max_tokens(&self, body: &mut Value, value: Option<u32>) {
172        if let Some(v) = value {
173            let key = if self.is_reasoning_model() {
174                "max_completion_tokens"
175            } else {
176                "max_tokens"
177            };
178            body[key] = json!(v);
179        }
180    }
181
182    /// Call the OpenAI chat completions API, retrying on transient network/server errors.
183    ///
184    /// Retries up to `self.network_retries` times with exponential backoff (1 s, 2 s, 4 s …
185    /// capped at 30 s) on:
186    /// - Network-level failures (connection refused, timeout, etc.)
187    /// - HTTP 429 (rate limit exceeded)
188    /// - HTTP 5xx (server errors)
189    ///
190    /// Errors on HTTP 400 and 401 are returned immediately without retrying.
191    #[instrument(
192        name = "llm.api_call",
193        level = "info",
194        skip(self, request_body),
195        fields(
196            url = tracing::field::Empty,
197            cognee.llm.model = self.model.as_str(),
198            cognee.llm.provider = "openai",
199        ),
200    )]
201    async fn call_api(&self, request_body: Value) -> LlmResult<OpenAIResponse> {
202        let url = format!("{}/chat/completions", self.base_url);
203        tracing::Span::current().record("url", url.as_str());
204        let debug_enabled = std::env::var("COGNEE_DEBUG_LLM_REQUEST")
205            .map(|v| cognee_utils::parse_env_bool(&v))
206            .unwrap_or(false);
207
208        if debug_enabled {
209            let pretty_request = serde_json::to_string_pretty(&request_body)
210                .unwrap_or_else(|_| request_body.to_string());
211            eprintln!("\n[COGNEE_DEBUG_LLM_REQUEST] POST {url}\n{pretty_request}\n");
212        }
213
214        let mut last_error = LlmError::NetworkError("No attempt made".to_string());
215
216        for attempt in 0..=self.network_retries {
217            debug!(attempt, "LLM API attempt");
218            if attempt > 0 {
219                let delay_ms = (1_000u64 * 2u64.saturating_pow(attempt as u32 - 1)).min(30_000);
220                warn!(
221                    attempt,
222                    network_retries = self.network_retries,
223                    delay_ms,
224                    error = %last_error,
225                    "LLM request failed, retrying",
226                );
227                tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await;
228            }
229
230            let response = match self
231                .client
232                .post(&url)
233                .header("Authorization", self.auth_header())
234                .header("Content-Type", "application/json")
235                .json(&request_body)
236                .send()
237                .await
238            {
239                Ok(r) => r,
240                Err(e) => {
241                    last_error = LlmError::NetworkError(e.to_string());
242                    continue;
243                }
244            };
245
246            let status = response.status();
247
248            if !status.is_success() {
249                let error_body = response
250                    .text()
251                    .await
252                    .unwrap_or_else(|_| "Unknown error".to_string());
253
254                let err = match status.as_u16() {
255                    401 => LlmError::AuthenticationError(error_body),
256                    429 => LlmError::RateLimitExceeded(error_body),
257                    400 => LlmError::InvalidResponse(format!("Bad request: {error_body}")),
258                    _ => LlmError::ApiError(format!("HTTP {status}: {error_body}")),
259                };
260
261                // Non-retryable: bad request or authentication failure.
262                if matches!(status.as_u16(), 400 | 401) {
263                    return Err(err);
264                }
265
266                last_error = err;
267                continue;
268            }
269
270            let response_body = response.text().await.map_err(|e| {
271                LlmError::DeserializationError(format!("Failed to read response body: {e}"))
272            })?;
273
274            if debug_enabled {
275                eprintln!("\n[COGNEE_DEBUG_LLM_RESPONSE] POST {url}\n{response_body}\n");
276            }
277
278            return serde_json::from_str::<OpenAIResponse>(&response_body).map_err(|e| {
279                LlmError::DeserializationError(format!(
280                    "Failed to parse response: {e}. Raw body: {response_body}"
281                ))
282            });
283        }
284
285        Err(LlmError::MaxRetriesExceeded(format!(
286            "LLM request failed after {} attempt(s): {}",
287            self.network_retries + 1,
288            last_error
289        )))
290    }
291
292    /// Convert our Message type to OpenAI's format
293    fn convert_messages(messages: &[Message]) -> Vec<Value> {
294        messages
295            .iter()
296            .map(|msg| {
297                json!({
298                    "role": match msg.role {
299                        MessageRole::System => "system",
300                        MessageRole::User => "user",
301                        MessageRole::Assistant => "assistant",
302                    },
303                    "content": msg.content
304                })
305            })
306            .collect()
307    }
308
309    /// Convert JSON Schema to an example JSON with placeholder values
310    /// This is clearer for LLMs than showing the full schema
311    fn schema_to_example(schema: &Value) -> String {
312        fn create_example(value: &Value, definitions: Option<&Value>) -> Value {
313            match value {
314                Value::Object(obj) => {
315                    // Handle $ref references
316                    if let Some(ref_str) = obj.get("$ref").and_then(|v| v.as_str())
317                        && let Some(def_name) = ref_str.strip_prefix("#/definitions/")
318                        && let Some(defs) = definitions
319                        && let Some(def) = defs.get(def_name)
320                    {
321                        return create_example(def, definitions);
322                    }
323
324                    // Get the type of this field
325                    let type_val = obj.get("type");
326
327                    // Handle arrays
328                    if let Some(Value::String(t)) = type_val
329                        && t == "array"
330                    {
331                        if let Some(items) = obj.get("items") {
332                            // Return array with one example item
333                            return json!([create_example(items, definitions)]);
334                        }
335                        return json!([]);
336                    }
337
338                    // Handle objects with properties
339                    if let Some(props) = obj.get("properties")
340                        && let Value::Object(props_obj) = props
341                    {
342                        let mut result = serde_json::Map::new();
343                        for (key, val) in props_obj {
344                            result.insert(key.clone(), create_example(val, definitions));
345                        }
346                        return Value::Object(result);
347                    }
348
349                    // Handle primitive types
350                    if let Some(Value::String(t)) = type_val {
351                        return match t.as_str() {
352                            "string" => json!("example"),
353                            "number" | "integer" => json!(0),
354                            "boolean" => json!(false),
355                            _ => json!(null),
356                        };
357                    }
358
359                    // Handle union types (e.g., ["string", "null"])
360                    if let Some(Value::Array(types)) = type_val {
361                        for t in types {
362                            if let Value::String(type_str) = t
363                                && type_str != "null"
364                            {
365                                return match type_str.as_str() {
366                                    "string" => json!("example"),
367                                    "number" | "integer" => json!(0),
368                                    "boolean" => json!(false),
369                                    _ => json!(null),
370                                };
371                            }
372                        }
373                    }
374
375                    json!(null)
376                }
377                _ => value.clone(),
378            }
379        }
380
381        let definitions = schema.get("definitions");
382        let example = create_example(schema, definitions);
383
384        serde_json::to_string_pretty(&example).unwrap_or_else(|_| "{}".to_string())
385    }
386}
387
388/// Rewrite a `schemars`-generated JSON schema so it satisfies OpenAI's
389/// **strict** structured-output requirements.
390///
391/// OpenAI's `response_format: {type: "json_schema", strict: true}` rejects any
392/// schema where an object lacks `"additionalProperties": false` or whose
393/// `"required"` array does not list *every* declared property. `schemars`
394/// (0.8, draft-07) emits neither guarantee — optional (`Option<T>`) fields are
395/// omitted from `required` and `additionalProperties` is left unset. When the
396/// strict request 400s, [`OpenAIAdapter::create_structured_output_with_messages_raw`]
397/// silently falls back to lenient JSON mode, where the model is free to drop
398/// required fields (e.g. a `Node` without its `type`), causing downstream
399/// deserialization failures.
400///
401/// This walks the schema (including `definitions`/`$defs`, `properties`,
402/// `items`, and the `anyOf`/`allOf`/`oneOf` combinators) and, for every object
403/// that declares `properties`, forces `additionalProperties: false` and sets
404/// `required` to the full set of property keys. The `Value` is cloned and
405/// returned unchanged for non-object schemas.
406fn to_strict_schema(schema: &Value) -> Value {
407    fn walk(value: &mut Value) {
408        match value {
409            Value::Object(obj) => {
410                if let Some(Value::Object(props)) = obj.get("properties") {
411                    // Every declared property must be required under strict mode.
412                    let keys: Vec<Value> = props.keys().map(|k| Value::String(k.clone())).collect();
413                    obj.insert("required".to_string(), Value::Array(keys));
414                    obj.insert("additionalProperties".to_string(), Value::Bool(false));
415                }
416                for (_k, v) in obj.iter_mut() {
417                    walk(v);
418                }
419            }
420            Value::Array(items) => {
421                for v in items.iter_mut() {
422                    walk(v);
423                }
424            }
425            _ => {}
426        }
427    }
428
429    let mut out = schema.clone();
430    walk(&mut out);
431    out
432}
433
434#[async_trait]
435impl Llm for OpenAIAdapter {
436    async fn generate(
437        &self,
438        messages: Vec<Message>,
439        options: Option<GenerationOptions>,
440    ) -> LlmResult<GenerationResponse> {
441        let opts = options.unwrap_or_default();
442
443        let mut request_body = json!({
444            "model": self.model,
445            "messages": Self::convert_messages(&messages),
446        });
447
448        // Add optional parameters. Reasoning models (gpt-5*/o1*/o3*/o4*)
449        // reject sampling overrides and only accept `max_completion_tokens`.
450        if !self.is_reasoning_model() {
451            if let Some(temp) = opts.temperature {
452                request_body["temperature"] = json!(temp);
453            }
454            if let Some(top_p) = opts.top_p {
455                request_body["top_p"] = json!(top_p);
456            }
457            if let Some(freq_penalty) = opts.frequency_penalty {
458                request_body["frequency_penalty"] = json!(freq_penalty);
459            }
460            if let Some(pres_penalty) = opts.presence_penalty {
461                request_body["presence_penalty"] = json!(pres_penalty);
462            }
463        }
464        self.write_max_tokens(&mut request_body, opts.max_tokens);
465        if let Some(stop) = opts.stop
466            && !stop.is_empty()
467        {
468            request_body["stop"] = json!(stop);
469        }
470
471        if self.should_disable_thinking() {
472            request_body["think"] = json!(false);
473            request_body["reasoning"] = json!({"effort": "none"});
474        }
475
476        let response = self.call_api(request_body).await?;
477
478        // Extract the first choice
479        let choice = response
480            .choices
481            .first()
482            .ok_or_else(|| LlmError::InvalidResponse("No choices in response".to_string()))?;
483
484        Ok(GenerationResponse {
485            content: choice.message.content.clone().unwrap_or_default(),
486            model: response.model,
487            finish_reason: choice.finish_reason.clone(),
488            usage: response.usage.map(|u| TokenUsage {
489                prompt_tokens: u.prompt_tokens,
490                completion_tokens: u.completion_tokens,
491                total_tokens: u.total_tokens,
492            }),
493        })
494    }
495
496    async fn create_structured_output_with_messages_raw(
497        &self,
498        messages: Vec<Message>,
499        json_schema: &Value,
500        options: Option<GenerationOptions>,
501    ) -> LlmResult<Value> {
502        let is_empty_or_non_json = |raw: &str| {
503            let trimmed = raw.trim();
504            trimmed.is_empty() || serde_json::from_str::<Value>(trimmed).is_err()
505        };
506
507        let parse_json =
508            |raw: &str| -> Result<Value, serde_json::Error> { serde_json::from_str(raw) };
509
510        let opts = options.unwrap_or_default();
511        let schema = json_schema;
512
513        // OpenAI strict mode requires `additionalProperties: false` and that
514        // every property appear in `required` on every object; the raw
515        // schemars schema satisfies neither, which would 400 and silently
516        // drop us into lenient mode (where required fields can go missing).
517        let strict_schema = to_strict_schema(schema);
518
519        // Try strict JSON schema mode first (new OpenAI API behavior).
520        let mut strict_schema_request = json!({
521            "model": self.model,
522            "messages": Self::convert_messages(&messages),
523            "response_format": {
524                "type": "json_schema",
525                "json_schema": {
526                    "name": "extract_structured_data",
527                    "schema": strict_schema,
528                    "strict": true
529                }
530            }
531        });
532
533        if !self.is_reasoning_model()
534            && let Some(temp) = opts.temperature
535        {
536            strict_schema_request["temperature"] = json!(temp);
537        }
538        self.write_max_tokens(&mut strict_schema_request, opts.max_tokens);
539        if self.should_disable_thinking() {
540            strict_schema_request["think"] = json!(false);
541            strict_schema_request["reasoning"] = json!({"effort": "none"});
542        }
543
544        for attempt in 0..self.structured_output_retries {
545            match self.call_api(strict_schema_request.clone()).await {
546                Ok(strict_response) => {
547                    let strict_choice = strict_response.choices.first().ok_or_else(|| {
548                        LlmError::InvalidResponse(
549                            "No choices in strict schema response".to_string(),
550                        )
551                    })?;
552
553                    if let Some(function_call) = &strict_choice.message.function_call {
554                        match parse_json(&function_call.arguments) {
555                            Ok(parsed) => return Ok(parsed),
556                            Err(e) => {
557                                if attempt + 1 < self.structured_output_retries
558                                    && is_empty_or_non_json(&function_call.arguments)
559                                {
560                                    continue;
561                                }
562                                if !is_empty_or_non_json(&function_call.arguments) {
563                                    return Err(LlmError::DeserializationError(format!(
564                                        "Failed to deserialize strict function call arguments: {}. Raw: {}",
565                                        e, function_call.arguments
566                                    )));
567                                }
568                                break;
569                            }
570                        }
571                    }
572
573                    if let Some(content) = strict_choice.message.content.as_ref() {
574                        match parse_json(content) {
575                            Ok(parsed) => return Ok(parsed),
576                            Err(e) => {
577                                if attempt + 1 < self.structured_output_retries
578                                    && is_empty_or_non_json(content)
579                                {
580                                    continue;
581                                }
582                                if !is_empty_or_non_json(content) {
583                                    return Err(LlmError::DeserializationError(format!(
584                                        "Failed to deserialize strict JSON content: {e}. Raw: {content}"
585                                    )));
586                                }
587                                break;
588                            }
589                        }
590                    }
591                }
592                Err(e) => {
593                    // Strict json_schema mode is unsupported by this
594                    // model/endpoint (or the schema was rejected). Fall back to
595                    // function calling / JSON mode below, but make the reason
596                    // visible — a silent fallback is how required fields end up
597                    // missing from the model's output.
598                    warn!(error = %e, "strict json_schema request failed; falling back to function/JSON mode");
599                    break;
600                }
601            }
602        }
603
604        // Try function calling first (works with OpenAI)
605        let mut request_body = json!({
606            "model": self.model,
607            "messages": Self::convert_messages(&messages),
608            "functions": [{
609                "name": "extract_structured_data",
610                "description": "Extract structured data from the input",
611                "parameters": schema
612            }],
613            "function_call": {"name": "extract_structured_data"}
614        });
615
616        if !self.is_reasoning_model()
617            && let Some(temp) = opts.temperature
618        {
619            request_body["temperature"] = json!(temp);
620        }
621        self.write_max_tokens(&mut request_body, opts.max_tokens);
622        if self.should_disable_thinking() {
623            request_body["think"] = json!(false);
624            request_body["reasoning"] = json!({"effort": "none"});
625        }
626
627        for attempt in 0..self.structured_output_retries {
628            let response = self.call_api(request_body.clone()).await?;
629
630            let choice = response
631                .choices
632                .first()
633                .ok_or_else(|| LlmError::InvalidResponse("No choices in response".to_string()))?;
634
635            if let Some(function_call) = &choice.message.function_call {
636                match parse_json(&function_call.arguments) {
637                    Ok(parsed) => return Ok(parsed),
638                    Err(e) => {
639                        if attempt + 1 < self.structured_output_retries
640                            && is_empty_or_non_json(&function_call.arguments)
641                        {
642                            continue;
643                        }
644                        if !is_empty_or_non_json(&function_call.arguments) {
645                            return Err(LlmError::DeserializationError(format!(
646                                "Failed to deserialize function call arguments: {}. Raw: {}",
647                                e, function_call.arguments
648                            )));
649                        }
650                        break;
651                    }
652                }
653            }
654
655            break;
656        }
657
658        // Fallback to JSON mode (works with Ollama and other providers)
659        let mut json_messages = Self::convert_messages(&messages);
660
661        let example = Self::schema_to_example(schema);
662
663        if let Some(last_msg) = json_messages.last_mut()
664            && last_msg["role"] == "user"
665        {
666            let original_content = last_msg["content"].as_str().unwrap_or("");
667            last_msg["content"] = json!(format!(
668                "{}\n\n\
669                    Extract the information from the text above and return it as JSON.\n\
670                    Use this structure as your template (but with actual data from the text):\n\
671                    {}",
672                original_content, example
673            ));
674        }
675
676        let mut json_request = json!({
677            "model": self.model,
678            "messages": json_messages,
679            "response_format": {"type": "json_object"}
680        });
681
682        if !self.is_reasoning_model()
683            && let Some(temp) = opts.temperature
684        {
685            json_request["temperature"] = json!(temp);
686        }
687        self.write_max_tokens(&mut json_request, opts.max_tokens);
688        if self.should_disable_thinking() {
689            json_request["think"] = json!(false);
690            json_request["reasoning"] = json!({"effort": "none"});
691        }
692
693        for attempt in 0..self.structured_output_retries {
694            let mut request_for_attempt = json_request.clone();
695
696            if attempt > 0 {
697                if let Some(messages) = request_for_attempt["messages"].as_array_mut()
698                    && let Some(last_msg) = messages.last_mut()
699                    && last_msg["role"] == "user"
700                {
701                    let original_content = last_msg["content"].as_str().unwrap_or("");
702                    last_msg["content"] = json!(format!(
703                        "{}\n\n/no_think\nReturn ONLY one valid JSON object matching the required schema. No reasoning, no markdown, no extra text.",
704                        original_content
705                    ));
706                }
707
708                if !self.is_reasoning_model() {
709                    request_for_attempt["temperature"] = json!(0.0);
710                }
711            }
712
713            let json_response = self.call_api(request_for_attempt).await?;
714
715            let json_choice = json_response.choices.first().ok_or_else(|| {
716                LlmError::InvalidResponse("No choices in JSON mode response".to_string())
717            })?;
718
719            let content = json_choice.message.content.as_ref().ok_or_else(|| {
720                LlmError::InvalidResponse("No content in JSON mode response".to_string())
721            })?;
722
723            match parse_json(content) {
724                Ok(parsed) => return Ok(parsed),
725                Err(e) => {
726                    if attempt + 1 < self.structured_output_retries && is_empty_or_non_json(content)
727                    {
728                        continue;
729                    }
730                    return Err(LlmError::DeserializationError(format!(
731                        "Failed to deserialize JSON content: {e}. Raw: {content}"
732                    )));
733                }
734            }
735        }
736
737        Err(LlmError::InvalidResponse(
738            "Structured output retries exhausted without a parseable response".to_string(),
739        ))
740    }
741
742    fn model(&self) -> &str {
743        &self.model
744    }
745
746    fn supports_streaming(&self) -> bool {
747        true
748    }
749
750    fn supports_function_calling(&self) -> bool {
751        true
752    }
753
754    fn max_context_length(&self) -> u32 {
755        // Context lengths for common OpenAI models
756        match self.model.as_str() {
757            m if m.starts_with("gpt-4-turbo") => 128_000,
758            m if m.starts_with("gpt-4-32k") => 32_768,
759            m if m.starts_with("gpt-4") => 8_192,
760            m if m.starts_with("gpt-3.5-turbo-16k") => 16_384,
761            m if m.starts_with("gpt-3.5-turbo") => 4_096,
762            _ => 4_096, // Conservative default
763        }
764    }
765
766    async fn transcribe_image(
767        &self,
768        image_bytes: &[u8],
769        mime_type: &str,
770        options: Option<GenerationOptions>,
771    ) -> LlmResult<String> {
772        use base64::Engine as _;
773
774        if !mime_type.starts_with("image/") {
775            return Err(LlmError::InvalidResponse(format!(
776                "Expected image/* MIME type, got: {mime_type}"
777            )));
778        }
779
780        let b64 = base64::engine::general_purpose::STANDARD.encode(image_bytes);
781        let data_uri = format!("data:{mime_type};base64,{b64}");
782
783        let vision_model = std::env::var("LLM_VISION_MODEL")
784            .ok()
785            .filter(|s| !s.is_empty())
786            .unwrap_or_else(|| self.model.clone());
787
788        let max_tokens = options.as_ref().and_then(|o| o.max_tokens).unwrap_or(300);
789
790        let mut request_body = json!({
791            "model": vision_model,
792            "messages": [{
793                "role": "user",
794                "content": [
795                    { "type": "text", "text": "What's in this image?" },
796                    { "type": "image_url", "image_url": { "url": data_uri } }
797                ]
798            }],
799        });
800        self.write_max_tokens(&mut request_body, Some(max_tokens));
801
802        let response = self.call_api(request_body).await?;
803
804        let choice = response.choices.first().ok_or_else(|| {
805            LlmError::InvalidResponse("No choices in vision response".to_string())
806        })?;
807
808        choice.message.content.clone().ok_or_else(|| {
809            LlmError::InvalidResponse("Vision response contained no content".to_string())
810        })
811    }
812
813    fn supports_vision(&self) -> bool {
814        let m = self.model.to_lowercase();
815        m.contains("gpt-4")
816            || m.contains("gpt-5")
817            || m.contains("vision")
818            || m.contains("o1")
819            || m.contains("o3")
820            || m.contains("o4")
821            || m.contains("llava")
822            || m.contains("moondream")
823            || m.contains("llama-3.2-vision")
824            || m.contains("gemma3")
825    }
826}
827
828// ---------------------------------------------------------------------------
829// Whisper transcription support
830// ---------------------------------------------------------------------------
831
832/// Response from the OpenAI Whisper `verbose_json` endpoint.
833#[derive(Debug, Deserialize)]
834struct WhisperResponse {
835    text: String,
836    language: Option<String>,
837    duration: Option<f32>,
838}
839
840/// Map a validated audio format extension to its MIME type.
841fn audio_mime_type(format: &str) -> &'static str {
842    match format {
843        "mp3" | "mpeg" | "mpga" => "audio/mpeg",
844        "mp4" | "m4a" => "audio/mp4",
845        "wav" => "audio/wav",
846        "webm" => "audio/webm",
847        // validate_audio_format ensures only the above values reach here
848        _ => "application/octet-stream",
849    }
850}
851
852impl OpenAIAdapter {
853    /// Call the Whisper transcription API with the same retry logic as `call_api`.
854    #[instrument(
855        name = "llm.transcription_api_call",
856        level = "info",
857        skip(self, form),
858        fields(
859            url = tracing::field::Empty,
860            cognee.llm.model = self.transcription_model.as_str(),
861            cognee.llm.provider = "openai",
862        ),
863    )]
864    async fn call_transcription_api(
865        &self,
866        form: reqwest::multipart::Form,
867    ) -> LlmResult<WhisperResponse> {
868        let url = format!("{}/audio/transcriptions", self.base_url);
869        tracing::Span::current().record("url", url.as_str());
870
871        // We cannot clone a multipart Form, so the first attempt uses the
872        // original form and retries are not possible for the multipart body.
873        // However, we keep the retry loop for network errors that occur
874        // *before* the body is consumed (connection refused, DNS failure).
875        // For simplicity and matching the guide's design, we rebuild the form
876        // if needed by storing the bytes. But since `Form` doesn't support
877        // Clone, we perform a single attempt with the form and rely on the
878        // caller to retry externally if needed.
879        //
880        // Actually, the simplest approach is to send the form once and
881        // handle retries at a higher level. But the guide says to mirror
882        // call_api's retry. Since reqwest::multipart::Form is not Clone,
883        // we accept `form` by value and do a single-shot request here,
884        // while the `transcribe_audio` impl handles retry by rebuilding
885        // the form on each attempt.
886
887        let response = self
888            .client
889            .post(&url)
890            .header("Authorization", self.auth_header())
891            .multipart(form)
892            .send()
893            .await
894            .map_err(|e| LlmError::NetworkError(e.to_string()))?;
895
896        let status = response.status();
897
898        if !status.is_success() {
899            let error_body = response
900                .text()
901                .await
902                .unwrap_or_else(|_| "Unknown error".to_string());
903
904            return Err(match status.as_u16() {
905                401 => LlmError::AuthenticationError(error_body),
906                429 => LlmError::RateLimitExceeded(error_body),
907                400 => LlmError::InvalidResponse(format!("Bad request: {error_body}")),
908                _ => LlmError::ApiError(format!("HTTP {status}: {error_body}")),
909            });
910        }
911
912        let response_body = response.text().await.map_err(|e| {
913            LlmError::DeserializationError(format!("Failed to read response body: {e}"))
914        })?;
915
916        serde_json::from_str::<WhisperResponse>(&response_body).map_err(|e| {
917            LlmError::DeserializationError(format!(
918                "Failed to parse Whisper response: {e}. Raw body: {response_body}"
919            ))
920        })
921    }
922
923    /// Build a `reqwest::multipart::Form` for a Whisper transcription request.
924    fn build_transcription_form(
925        &self,
926        audio: &[u8],
927        format: &str,
928        language_hint: Option<&str>,
929        prompt_hint: Option<&str>,
930    ) -> LlmResult<reqwest::multipart::Form> {
931        let mime = audio_mime_type(format);
932        let filename = format!("audio.{format}");
933
934        let file_part = reqwest::multipart::Part::bytes(audio.to_vec())
935            .file_name(filename)
936            .mime_str(mime)
937            .map_err(|e| {
938                LlmError::ConfigError(format!("Failed to set MIME type on multipart part: {e}"))
939            })?;
940
941        let mut form = reqwest::multipart::Form::new()
942            .part("file", file_part)
943            .text("model", self.transcription_model.clone())
944            .text("response_format", "verbose_json");
945
946        if let Some(lang) = language_hint {
947            form = form.text("language", lang.to_string());
948        }
949        if let Some(prompt) = prompt_hint {
950            form = form.text("prompt", prompt.to_string());
951        }
952
953        Ok(form)
954    }
955}
956
957#[async_trait]
958impl Transcriber for OpenAIAdapter {
959    async fn transcribe_audio(
960        &self,
961        audio: &[u8],
962        format: &str,
963        language_hint: Option<&str>,
964        prompt_hint: Option<&str>,
965    ) -> LlmResult<TranscriptionOutput> {
966        // Normalize and validate before any network I/O.
967        let format_lower = format.to_ascii_lowercase();
968        validate_audio_format(&format_lower)?;
969
970        let mut last_error = LlmError::NetworkError("No attempt made".to_string());
971
972        for attempt in 0..=self.network_retries {
973            debug!(attempt, "Transcription API attempt");
974            if attempt > 0 {
975                let delay_ms = (1_000u64 * 2u64.saturating_pow(attempt as u32 - 1)).min(30_000);
976                warn!(
977                    attempt,
978                    network_retries = self.network_retries,
979                    delay_ms,
980                    error = %last_error,
981                    "Transcription request failed, retrying",
982                );
983                tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await;
984            }
985
986            let form =
987                self.build_transcription_form(audio, &format_lower, language_hint, prompt_hint)?;
988
989            match self.call_transcription_api(form).await {
990                Ok(resp) => {
991                    return Ok(TranscriptionOutput {
992                        text: resp.text,
993                        language: resp.language,
994                        duration: resp.duration,
995                    });
996                }
997                Err(e) => {
998                    // Non-retryable errors: bad request or authentication failure.
999                    if matches!(
1000                        e,
1001                        LlmError::InvalidResponse(_) | LlmError::AuthenticationError(_)
1002                    ) {
1003                        return Err(e);
1004                    }
1005                    last_error = e;
1006                    continue;
1007                }
1008            }
1009        }
1010
1011        Err(LlmError::MaxRetriesExceeded(format!(
1012            "Transcription request failed after {} attempt(s): {}",
1013            self.network_retries + 1,
1014            last_error
1015        )))
1016    }
1017
1018    fn transcription_model(&self) -> &str {
1019        &self.transcription_model
1020    }
1021}
1022
1023// OpenAI API response types
1024#[derive(Debug, Deserialize)]
1025#[allow(dead_code)]
1026struct OpenAIResponse {
1027    id: String,
1028    object: String,
1029    created: i64,
1030    model: String,
1031    choices: Vec<OpenAIChoice>,
1032    usage: Option<OpenAIUsage>,
1033}
1034
1035#[derive(Debug, Deserialize)]
1036#[allow(dead_code)]
1037struct OpenAIChoice {
1038    index: u32,
1039    message: OpenAIMessage,
1040    finish_reason: Option<String>,
1041}
1042
1043#[derive(Debug, Deserialize)]
1044#[allow(dead_code)]
1045struct OpenAIMessage {
1046    role: String,
1047    content: Option<String>,
1048    reasoning: Option<String>,
1049    function_call: Option<OpenAIFunctionCall>,
1050}
1051
1052#[derive(Debug, Deserialize)]
1053#[allow(dead_code)]
1054struct OpenAIFunctionCall {
1055    name: String,
1056    arguments: String,
1057}
1058
1059#[derive(Debug, Deserialize)]
1060struct OpenAIUsage {
1061    prompt_tokens: u32,
1062    completion_tokens: u32,
1063    total_tokens: u32,
1064}
1065
1066#[cfg(test)]
1067mod tests {
1068    #![allow(
1069        clippy::unwrap_used,
1070        clippy::expect_used,
1071        reason = "test code — panics are acceptable"
1072    )]
1073    use super::*;
1074
1075    #[test]
1076    fn test_openai_provider_prefix_is_stripped() {
1077        // litellm-style "openai/<model>" must be sent as bare "<model>".
1078        let adapter = OpenAIAdapter::new("openai/gpt-5-mini", "test-key", None).unwrap();
1079        assert_eq!(adapter.model(), "gpt-5-mini");
1080        // Non-openai provider prefixes (custom endpoints) are left intact.
1081        let adapter = OpenAIAdapter::new("ollama/llama3", "test-key", None).unwrap();
1082        assert_eq!(adapter.model(), "ollama/llama3");
1083    }
1084
1085    #[test]
1086    fn test_openai_adapter_creation() {
1087        let adapter = OpenAIAdapter::new("gpt-4", "test-key", None);
1088        assert!(adapter.is_ok());
1089
1090        let adapter = adapter.unwrap();
1091        assert_eq!(adapter.model(), "gpt-4");
1092        assert_eq!(adapter.base_url, OpenAIAdapter::DEFAULT_BASE_URL);
1093        assert_eq!(
1094            adapter.structured_output_retries,
1095            OpenAIAdapter::DEFAULT_STRUCTURED_OUTPUT_RETRIES
1096        );
1097    }
1098
1099    #[test]
1100    fn test_configurable_structured_output_retries() {
1101        let adapter = OpenAIAdapter::new("gpt-4", "test-key", None)
1102            .unwrap()
1103            .with_structured_output_retries(5);
1104        assert_eq!(adapter.structured_output_retries, 5);
1105
1106        let adapter = OpenAIAdapter::new("gpt-4", "test-key", None)
1107            .unwrap()
1108            .with_structured_output_retries(0);
1109        assert_eq!(adapter.structured_output_retries, 1);
1110    }
1111
1112    #[test]
1113    fn test_openai_adapter_custom_base_url() {
1114        let adapter = OpenAIAdapter::new(
1115            "gpt-4",
1116            "test-key",
1117            Some("https://custom.api.com/v1".to_string()),
1118        );
1119        assert!(adapter.is_ok());
1120
1121        let adapter = adapter.unwrap();
1122        assert_eq!(adapter.base_url, "https://custom.api.com/v1");
1123    }
1124
1125    #[test]
1126    fn test_base_url_trailing_slash_is_normalized() {
1127        // The Gemini OpenAI-compat base ends in `/`; without normalisation the
1128        // request URL would be `.../openai//chat/completions` and 404.
1129        let adapter = OpenAIAdapter::new(
1130            "gemini-2.0-flash",
1131            "test-key",
1132            Some("https://generativelanguage.googleapis.com/v1beta/openai/".to_string()),
1133        )
1134        .unwrap();
1135        assert_eq!(
1136            adapter.base_url,
1137            "https://generativelanguage.googleapis.com/v1beta/openai"
1138        );
1139    }
1140
1141    #[test]
1142    fn test_is_reasoning_model_matches_openai_families() {
1143        let cases = [
1144            ("gpt-5", true),
1145            ("gpt-5-mini", true),
1146            ("gpt-5-2025-06-01", true),
1147            ("o1", true),
1148            ("o1-mini", true),
1149            ("o3", true),
1150            ("o3-mini", true),
1151            ("o4-mini", true),
1152            ("GPT-5-Mini", true),
1153            ("gpt-4o-mini", false),
1154            ("gpt-4-turbo", false),
1155            ("gpt-3.5-turbo", false),
1156            ("o-foo", false),
1157        ];
1158        for (model, expected) in cases {
1159            let adapter = OpenAIAdapter::new(model, "test-key", None).unwrap();
1160            assert_eq!(
1161                adapter.is_reasoning_model(),
1162                expected,
1163                "is_reasoning_model({model})"
1164            );
1165        }
1166    }
1167
1168    #[test]
1169    fn test_is_reasoning_model_skipped_for_custom_base_url() {
1170        // Custom OpenAI-compatible endpoints (Ollama, vLLM, …) may have
1171        // model names that look like reasoning families but still accept
1172        // legacy sampling parameters — the gate is conservative.
1173        let adapter = OpenAIAdapter::new(
1174            "gpt-5-mini",
1175            "test-key",
1176            Some("http://localhost:11434/v1".to_string()),
1177        )
1178        .unwrap();
1179        assert!(!adapter.is_reasoning_model());
1180    }
1181
1182    #[test]
1183    fn test_write_max_tokens_renames_key_for_reasoning_models() {
1184        let mut body = json!({"model": "gpt-5-mini"});
1185        let reasoning = OpenAIAdapter::new("gpt-5-mini", "test-key", None).unwrap();
1186        reasoning.write_max_tokens(&mut body, Some(2048));
1187        assert!(body.get("max_tokens").is_none());
1188        assert_eq!(body["max_completion_tokens"], 2048);
1189
1190        let mut body = json!({"model": "gpt-4o-mini"});
1191        let classic = OpenAIAdapter::new("gpt-4o-mini", "test-key", None).unwrap();
1192        classic.write_max_tokens(&mut body, Some(2048));
1193        assert_eq!(body["max_tokens"], 2048);
1194        assert!(body.get("max_completion_tokens").is_none());
1195
1196        // None leaves body untouched.
1197        let mut body = json!({"model": "gpt-5-mini"});
1198        reasoning.write_max_tokens(&mut body, None);
1199        assert!(body.get("max_tokens").is_none());
1200        assert!(body.get("max_completion_tokens").is_none());
1201    }
1202
1203    #[test]
1204    fn test_message_conversion() {
1205        let messages = vec![
1206            Message {
1207                role: MessageRole::System,
1208                content: "You are helpful".to_string(),
1209            },
1210            Message {
1211                role: MessageRole::User,
1212                content: "Hello".to_string(),
1213            },
1214        ];
1215
1216        let converted = OpenAIAdapter::convert_messages(&messages);
1217        assert_eq!(converted.len(), 2);
1218        assert_eq!(converted[0]["role"], "system");
1219        assert_eq!(converted[0]["content"], "You are helpful");
1220        assert_eq!(converted[1]["role"], "user");
1221        assert_eq!(converted[1]["content"], "Hello");
1222    }
1223
1224    #[test]
1225    fn test_context_length() {
1226        let adapter = OpenAIAdapter::new("gpt-4-turbo-preview", "key", None).unwrap();
1227        assert_eq!(adapter.max_context_length(), 128_000);
1228
1229        let adapter = OpenAIAdapter::new("gpt-4", "key", None).unwrap();
1230        assert_eq!(adapter.max_context_length(), 8_192);
1231
1232        let adapter = OpenAIAdapter::new("gpt-3.5-turbo-16k", "key", None).unwrap();
1233        assert_eq!(adapter.max_context_length(), 16_384);
1234    }
1235
1236    #[test]
1237    fn test_supports_vision_gpt4o() {
1238        let adapter = OpenAIAdapter::new("gpt-4o", "key", None).unwrap();
1239        assert!(adapter.supports_vision());
1240    }
1241
1242    #[test]
1243    fn test_supports_vision_gpt4_turbo() {
1244        let adapter = OpenAIAdapter::new("gpt-4-turbo", "key", None).unwrap();
1245        assert!(adapter.supports_vision());
1246    }
1247
1248    #[test]
1249    fn test_supports_vision_gpt4o_mini() {
1250        let adapter = OpenAIAdapter::new("gpt-4o-mini", "key", None).unwrap();
1251        assert!(adapter.supports_vision());
1252    }
1253
1254    #[test]
1255    fn test_supports_vision_gpt35_is_false() {
1256        let adapter = OpenAIAdapter::new("gpt-3.5-turbo", "key", None).unwrap();
1257        assert!(!adapter.supports_vision());
1258    }
1259
1260    #[test]
1261    fn test_supports_vision_llava() {
1262        let adapter = OpenAIAdapter::new("llava:13b", "key", None).unwrap();
1263        assert!(adapter.supports_vision());
1264    }
1265
1266    #[test]
1267    fn test_supports_vision_o1() {
1268        let adapter = OpenAIAdapter::new("o1-preview", "key", None).unwrap();
1269        assert!(adapter.supports_vision());
1270    }
1271
1272    #[test]
1273    fn test_supports_vision_gemma3() {
1274        let adapter = OpenAIAdapter::new("gemma3:12b", "key", None).unwrap();
1275        assert!(adapter.supports_vision());
1276    }
1277
1278    #[tokio::test]
1279    async fn transcribe_image_rejects_non_image_mime() {
1280        let adapter = OpenAIAdapter::new("gpt-4o", "fake-key", None).unwrap();
1281        let result = adapter
1282            .transcribe_image(b"not-an-image", "text/plain", None)
1283            .await;
1284        assert!(result.is_err());
1285        assert!(
1286            matches!(result.unwrap_err(), LlmError::InvalidResponse(_)),
1287            "Expected InvalidResponse for non-image MIME type"
1288        );
1289    }
1290
1291    #[test]
1292    fn test_transcription_model_default() {
1293        // Clear the env var to test the default value.
1294        // SAFETY: This test is single-threaded and no other thread reads
1295        // TRANSCRIPTION_MODEL concurrently.
1296        unsafe { std::env::remove_var("TRANSCRIPTION_MODEL") };
1297        let adapter = OpenAIAdapter::new("gpt-4", "key", None).unwrap();
1298        assert_eq!(adapter.transcription_model(), "whisper-1");
1299    }
1300
1301    #[test]
1302    fn test_transcription_model_custom() {
1303        let adapter = OpenAIAdapter::new("gpt-4", "key", None)
1304            .unwrap()
1305            .with_transcription_model("whisper-large-v3");
1306        assert_eq!(adapter.transcription_model(), "whisper-large-v3");
1307    }
1308
1309    #[test]
1310    fn test_audio_mime_type_mapping() {
1311        assert_eq!(audio_mime_type("mp3"), "audio/mpeg");
1312        assert_eq!(audio_mime_type("mpeg"), "audio/mpeg");
1313        assert_eq!(audio_mime_type("mpga"), "audio/mpeg");
1314        assert_eq!(audio_mime_type("mp4"), "audio/mp4");
1315        assert_eq!(audio_mime_type("m4a"), "audio/mp4");
1316        assert_eq!(audio_mime_type("wav"), "audio/wav");
1317        assert_eq!(audio_mime_type("webm"), "audio/webm");
1318    }
1319
1320    #[test]
1321    fn test_to_strict_schema_marks_all_required_and_closes_objects() {
1322        // Mirrors the schemars-0.8 shape: an optional field omitted from
1323        // `required`, nested object behind `definitions`/`$ref`, and no
1324        // `additionalProperties` set anywhere.
1325        let schema = json!({
1326            "type": "object",
1327            "properties": {
1328                "nodes": { "type": "array", "items": { "$ref": "#/definitions/Node" } }
1329            },
1330            "required": ["nodes"],
1331            "definitions": {
1332                "Node": {
1333                    "type": "object",
1334                    "properties": {
1335                        "name": { "type": "string" },
1336                        "type": { "type": "string" },
1337                        "description": { "type": ["string", "null"] }
1338                    },
1339                    "required": ["name", "type"]
1340                }
1341            }
1342        });
1343
1344        let strict = to_strict_schema(&schema);
1345
1346        // Root object closed + all props required.
1347        assert_eq!(strict["additionalProperties"], json!(false));
1348        assert_eq!(strict["required"], json!(["nodes"]));
1349
1350        // Nested object inside definitions: every property now required
1351        // (including the previously-optional `description`) and closed.
1352        let node = &strict["definitions"]["Node"];
1353        assert_eq!(node["additionalProperties"], json!(false));
1354        let mut req: Vec<String> = node["required"]
1355            .as_array()
1356            .unwrap()
1357            .iter()
1358            .map(|v| v.as_str().unwrap().to_string())
1359            .collect();
1360        req.sort();
1361        assert_eq!(req, vec!["description", "name", "type"]);
1362    }
1363}