Skip to main content

cognee_llm/adapters/
openai.rs

1//! OpenAI API adapter with JSON schema support for structured outputs.
2//!
3//! This adapter uses OpenAI's function calling or JSON mode to generate
4//! structured outputs based on JSON schemas derived from Rust types.
5
6use async_trait::async_trait;
7use reqwest::Client;
8use serde::Deserialize;
9use serde_json::{Value, json};
10use tracing::{debug, instrument, warn};
11
12#[allow(unused_imports)]
13use cognee_utils::tracing_keys::{COGNEE_LLM_MODEL, COGNEE_LLM_PROVIDER};
14
15use crate::error::{LlmError, LlmResult};
16use crate::llm_trait::Llm;
17use crate::transcriber::{Transcriber, TranscriptionOutput, validate_audio_format};
18use crate::types::{GenerationOptions, GenerationResponse, Message, MessageRole, TokenUsage};
19
20/// OpenAI API adapter.
21///
22/// Supports structured output generation via:
23/// - Strict JSON schema mode (response_format with type: "json_schema")
24/// - Function calling (for GPT-4 and GPT-3.5-turbo)
25/// - JSON mode (response_format with type: "json_object")
26/// - JSON schema validation (via function parameters)
27///
28/// # Example
29/// ```ignore
30/// use cognee_llm::adapters::OpenAIAdapter;
31/// use cognee_llm::Llm;
32///
33/// let adapter = OpenAIAdapter::new(
34///     "gpt-4-turbo-preview",
35///     "sk-...",
36///     None, // Use default base URL
37/// )?;
38///
39/// let result: MyStruct = adapter.create_structured_output(
40///     "Extract information from this text",
41///     "You are a helpful assistant",
42///     None,
43/// ).await?;
44/// ```
45#[derive(Clone)]
46pub struct OpenAIAdapter {
47    model: String,
48    api_key: String,
49    base_url: String,
50    client: Client,
51    structured_output_retries: usize,
52    /// Number of times to retry the HTTP request on transient network/server errors.
53    network_retries: usize,
54    /// Model name for audio transcription (e.g. `"whisper-1"`).
55    transcription_model: String,
56}
57
58impl OpenAIAdapter {
59    /// Default OpenAI API base URL
60    pub const DEFAULT_BASE_URL: &'static str = "https://api.openai.com/v1";
61    /// Default retry attempts for structured output parsing paths.
62    ///
63    /// Python parity: instructor's `acreate_structured_output` retries up to
64    /// `MAX_RETRIES = 5` times on a parse/validation failure. We match that
65    /// count so transient malformed responses get the same number of repair
66    /// chances before the cognify pipeline gives up.
67    pub const DEFAULT_STRUCTURED_OUTPUT_RETRIES: usize = 5;
68    /// Default retry attempts for transient network/server errors.
69    pub const DEFAULT_NETWORK_RETRIES: usize = 3;
70
71    /// Create a new OpenAI adapter.
72    ///
73    /// # Arguments
74    /// * `model` - Model identifier (e.g., "gpt-4", "gpt-3.5-turbo")
75    /// * `api_key` - OpenAI API key
76    /// * `base_url` - Optional custom base URL (defaults to OpenAI's API)
77    ///
78    /// # Returns
79    /// A new OpenAI adapter instance
80    pub fn new(
81        model: impl Into<String>,
82        api_key: impl Into<String>,
83        base_url: Option<String>,
84    ) -> LlmResult<Self> {
85        let client = Client::builder()
86            .timeout(std::time::Duration::from_secs(600))
87            .build()
88            .map_err(|e| LlmError::ConfigError(format!("Failed to create HTTP client: {e}")))?;
89
90        let transcription_model =
91            std::env::var("TRANSCRIPTION_MODEL").unwrap_or_else(|_| "whisper-1".to_string());
92
93        // Strip a leading litellm-style "openai/" provider prefix. Python's
94        // litellm accepts provider-qualified names (e.g. "openai/gpt-5-mini")
95        // and strips the provider before calling the OpenAI-native API, which
96        // itself rejects the prefix. Strip it here for parity so a
97        // provider-qualified config value works against real OpenAI.
98        let model: String = model.into();
99        let model = model
100            .strip_prefix("openai/")
101            .map(str::to_string)
102            .unwrap_or(model);
103
104        Ok(Self {
105            model,
106            api_key: api_key.into(),
107            base_url: base_url.unwrap_or_else(|| Self::DEFAULT_BASE_URL.to_string()),
108            client,
109            structured_output_retries: Self::DEFAULT_STRUCTURED_OUTPUT_RETRIES,
110            network_retries: Self::DEFAULT_NETWORK_RETRIES,
111            transcription_model,
112        })
113    }
114
115    /// Configure retry attempts for structured output extraction.
116    ///
117    /// Values lower than 1 are coerced to 1.
118    pub fn with_structured_output_retries(mut self, retries: u32) -> Self {
119        let retries = usize::try_from(retries).unwrap_or(usize::MAX);
120        self.structured_output_retries = retries.max(1);
121        self
122    }
123
124    /// Configure retry attempts for transient network and server errors (HTTP 429, 5xx).
125    ///
126    /// Each retry uses exponential backoff starting at 1 s, doubling up to 30 s.
127    pub fn with_network_retries(mut self, retries: u32) -> Self {
128        self.network_retries = usize::try_from(retries).unwrap_or(usize::MAX);
129        self
130    }
131
132    /// Configure the model used for audio transcription (default: `"whisper-1"`).
133    pub fn with_transcription_model(mut self, model: impl Into<String>) -> Self {
134        self.transcription_model = model.into();
135        self
136    }
137
138    /// Build the authorization header value
139    fn auth_header(&self) -> String {
140        format!("Bearer {}", self.api_key)
141    }
142
143    /// Whether to request non-thinking mode for local Qwen OpenAI-compatible endpoints.
144    fn should_disable_thinking(&self) -> bool {
145        self.model.to_lowercase().starts_with("qwen") && !self.base_url.contains("api.openai.com")
146    }
147
148    /// True for OpenAI reasoning-model families (`gpt-5*`, `o1*`, `o3*`, `o4*`)
149    /// that reject `temperature`/`top_p`/`frequency_penalty`/`presence_penalty`
150    /// overrides and require `max_completion_tokens` in place of `max_tokens`.
151    ///
152    /// Gated on the official `api.openai.com` base URL so custom OpenAI-compatible
153    /// proxies (Ollama, vLLM, …) keep accepting legacy parameters even when the
154    /// configured model name happens to match a reasoning-family prefix.
155    fn is_reasoning_model(&self) -> bool {
156        if !self.base_url.contains("api.openai.com") {
157            return false;
158        }
159        let m = self.model.to_lowercase();
160        m.starts_with("gpt-5") || m.starts_with("o1") || m.starts_with("o3") || m.starts_with("o4")
161    }
162
163    /// Insert `max_tokens` (or `max_completion_tokens` on reasoning models) into a
164    /// request body if `value` is `Some`.
165    fn write_max_tokens(&self, body: &mut Value, value: Option<u32>) {
166        if let Some(v) = value {
167            let key = if self.is_reasoning_model() {
168                "max_completion_tokens"
169            } else {
170                "max_tokens"
171            };
172            body[key] = json!(v);
173        }
174    }
175
176    /// Call the OpenAI chat completions API, retrying on transient network/server errors.
177    ///
178    /// Retries up to `self.network_retries` times with exponential backoff (1 s, 2 s, 4 s …
179    /// capped at 30 s) on:
180    /// - Network-level failures (connection refused, timeout, etc.)
181    /// - HTTP 429 (rate limit exceeded)
182    /// - HTTP 5xx (server errors)
183    ///
184    /// Errors on HTTP 400 and 401 are returned immediately without retrying.
185    #[instrument(
186        name = "llm.api_call",
187        level = "info",
188        skip(self, request_body),
189        fields(
190            url = tracing::field::Empty,
191            cognee.llm.model = self.model.as_str(),
192            cognee.llm.provider = "openai",
193        ),
194    )]
195    async fn call_api(&self, request_body: Value) -> LlmResult<OpenAIResponse> {
196        let url = format!("{}/chat/completions", self.base_url);
197        tracing::Span::current().record("url", url.as_str());
198        let debug_enabled = std::env::var("COGNEE_DEBUG_LLM_REQUEST")
199            .map(|v| cognee_utils::parse_env_bool(&v))
200            .unwrap_or(false);
201
202        if debug_enabled {
203            let pretty_request = serde_json::to_string_pretty(&request_body)
204                .unwrap_or_else(|_| request_body.to_string());
205            eprintln!("\n[COGNEE_DEBUG_LLM_REQUEST] POST {url}\n{pretty_request}\n");
206        }
207
208        let mut last_error = LlmError::NetworkError("No attempt made".to_string());
209
210        for attempt in 0..=self.network_retries {
211            debug!(attempt, "LLM API attempt");
212            if attempt > 0 {
213                let delay_ms = (1_000u64 * 2u64.saturating_pow(attempt as u32 - 1)).min(30_000);
214                warn!(
215                    attempt,
216                    network_retries = self.network_retries,
217                    delay_ms,
218                    error = %last_error,
219                    "LLM request failed, retrying",
220                );
221                tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await;
222            }
223
224            let response = match self
225                .client
226                .post(&url)
227                .header("Authorization", self.auth_header())
228                .header("Content-Type", "application/json")
229                .json(&request_body)
230                .send()
231                .await
232            {
233                Ok(r) => r,
234                Err(e) => {
235                    last_error = LlmError::NetworkError(e.to_string());
236                    continue;
237                }
238            };
239
240            let status = response.status();
241
242            if !status.is_success() {
243                let error_body = response
244                    .text()
245                    .await
246                    .unwrap_or_else(|_| "Unknown error".to_string());
247
248                let err = match status.as_u16() {
249                    401 => LlmError::AuthenticationError(error_body),
250                    429 => LlmError::RateLimitExceeded(error_body),
251                    400 => LlmError::InvalidResponse(format!("Bad request: {error_body}")),
252                    _ => LlmError::ApiError(format!("HTTP {status}: {error_body}")),
253                };
254
255                // Non-retryable: bad request or authentication failure.
256                if matches!(status.as_u16(), 400 | 401) {
257                    return Err(err);
258                }
259
260                last_error = err;
261                continue;
262            }
263
264            let response_body = response.text().await.map_err(|e| {
265                LlmError::DeserializationError(format!("Failed to read response body: {e}"))
266            })?;
267
268            if debug_enabled {
269                eprintln!("\n[COGNEE_DEBUG_LLM_RESPONSE] POST {url}\n{response_body}\n");
270            }
271
272            return serde_json::from_str::<OpenAIResponse>(&response_body).map_err(|e| {
273                LlmError::DeserializationError(format!(
274                    "Failed to parse response: {e}. Raw body: {response_body}"
275                ))
276            });
277        }
278
279        Err(LlmError::MaxRetriesExceeded(format!(
280            "LLM request failed after {} attempt(s): {}",
281            self.network_retries + 1,
282            last_error
283        )))
284    }
285
286    /// Convert our Message type to OpenAI's format
287    fn convert_messages(messages: &[Message]) -> Vec<Value> {
288        messages
289            .iter()
290            .map(|msg| {
291                json!({
292                    "role": match msg.role {
293                        MessageRole::System => "system",
294                        MessageRole::User => "user",
295                        MessageRole::Assistant => "assistant",
296                    },
297                    "content": msg.content
298                })
299            })
300            .collect()
301    }
302
303    /// Convert JSON Schema to an example JSON with placeholder values
304    /// This is clearer for LLMs than showing the full schema
305    fn schema_to_example(schema: &Value) -> String {
306        fn create_example(value: &Value, definitions: Option<&Value>) -> Value {
307            match value {
308                Value::Object(obj) => {
309                    // Handle $ref references
310                    if let Some(ref_str) = obj.get("$ref").and_then(|v| v.as_str())
311                        && let Some(def_name) = ref_str.strip_prefix("#/definitions/")
312                        && let Some(defs) = definitions
313                        && let Some(def) = defs.get(def_name)
314                    {
315                        return create_example(def, definitions);
316                    }
317
318                    // Get the type of this field
319                    let type_val = obj.get("type");
320
321                    // Handle arrays
322                    if let Some(Value::String(t)) = type_val
323                        && t == "array"
324                    {
325                        if let Some(items) = obj.get("items") {
326                            // Return array with one example item
327                            return json!([create_example(items, definitions)]);
328                        }
329                        return json!([]);
330                    }
331
332                    // Handle objects with properties
333                    if let Some(props) = obj.get("properties")
334                        && let Value::Object(props_obj) = props
335                    {
336                        let mut result = serde_json::Map::new();
337                        for (key, val) in props_obj {
338                            result.insert(key.clone(), create_example(val, definitions));
339                        }
340                        return Value::Object(result);
341                    }
342
343                    // Handle primitive types
344                    if let Some(Value::String(t)) = type_val {
345                        return match t.as_str() {
346                            "string" => json!("example"),
347                            "number" | "integer" => json!(0),
348                            "boolean" => json!(false),
349                            _ => json!(null),
350                        };
351                    }
352
353                    // Handle union types (e.g., ["string", "null"])
354                    if let Some(Value::Array(types)) = type_val {
355                        for t in types {
356                            if let Value::String(type_str) = t
357                                && type_str != "null"
358                            {
359                                return match type_str.as_str() {
360                                    "string" => json!("example"),
361                                    "number" | "integer" => json!(0),
362                                    "boolean" => json!(false),
363                                    _ => json!(null),
364                                };
365                            }
366                        }
367                    }
368
369                    json!(null)
370                }
371                _ => value.clone(),
372            }
373        }
374
375        let definitions = schema.get("definitions");
376        let example = create_example(schema, definitions);
377
378        serde_json::to_string_pretty(&example).unwrap_or_else(|_| "{}".to_string())
379    }
380}
381
382/// Rewrite a `schemars`-generated JSON schema so it satisfies OpenAI's
383/// **strict** structured-output requirements.
384///
385/// OpenAI's `response_format: {type: "json_schema", strict: true}` rejects any
386/// schema where an object lacks `"additionalProperties": false` or whose
387/// `"required"` array does not list *every* declared property. `schemars`
388/// (0.8, draft-07) emits neither guarantee — optional (`Option<T>`) fields are
389/// omitted from `required` and `additionalProperties` is left unset. When the
390/// strict request 400s, [`OpenAIAdapter::create_structured_output_with_messages_raw`]
391/// silently falls back to lenient JSON mode, where the model is free to drop
392/// required fields (e.g. a `Node` without its `type`), causing downstream
393/// deserialization failures.
394///
395/// This walks the schema (including `definitions`/`$defs`, `properties`,
396/// `items`, and the `anyOf`/`allOf`/`oneOf` combinators) and, for every object
397/// that declares `properties`, forces `additionalProperties: false` and sets
398/// `required` to the full set of property keys. The `Value` is cloned and
399/// returned unchanged for non-object schemas.
400fn to_strict_schema(schema: &Value) -> Value {
401    fn walk(value: &mut Value) {
402        match value {
403            Value::Object(obj) => {
404                if let Some(Value::Object(props)) = obj.get("properties") {
405                    // Every declared property must be required under strict mode.
406                    let keys: Vec<Value> = props.keys().map(|k| Value::String(k.clone())).collect();
407                    obj.insert("required".to_string(), Value::Array(keys));
408                    obj.insert("additionalProperties".to_string(), Value::Bool(false));
409                }
410                for (_k, v) in obj.iter_mut() {
411                    walk(v);
412                }
413            }
414            Value::Array(items) => {
415                for v in items.iter_mut() {
416                    walk(v);
417                }
418            }
419            _ => {}
420        }
421    }
422
423    let mut out = schema.clone();
424    walk(&mut out);
425    out
426}
427
428#[async_trait]
429impl Llm for OpenAIAdapter {
430    async fn generate(
431        &self,
432        messages: Vec<Message>,
433        options: Option<GenerationOptions>,
434    ) -> LlmResult<GenerationResponse> {
435        let opts = options.unwrap_or_default();
436
437        let mut request_body = json!({
438            "model": self.model,
439            "messages": Self::convert_messages(&messages),
440        });
441
442        // Add optional parameters. Reasoning models (gpt-5*/o1*/o3*/o4*)
443        // reject sampling overrides and only accept `max_completion_tokens`.
444        if !self.is_reasoning_model() {
445            if let Some(temp) = opts.temperature {
446                request_body["temperature"] = json!(temp);
447            }
448            if let Some(top_p) = opts.top_p {
449                request_body["top_p"] = json!(top_p);
450            }
451            if let Some(freq_penalty) = opts.frequency_penalty {
452                request_body["frequency_penalty"] = json!(freq_penalty);
453            }
454            if let Some(pres_penalty) = opts.presence_penalty {
455                request_body["presence_penalty"] = json!(pres_penalty);
456            }
457        }
458        self.write_max_tokens(&mut request_body, opts.max_tokens);
459        if let Some(stop) = opts.stop
460            && !stop.is_empty()
461        {
462            request_body["stop"] = json!(stop);
463        }
464
465        if self.should_disable_thinking() {
466            request_body["think"] = json!(false);
467            request_body["reasoning"] = json!({"effort": "none"});
468        }
469
470        let response = self.call_api(request_body).await?;
471
472        // Extract the first choice
473        let choice = response
474            .choices
475            .first()
476            .ok_or_else(|| LlmError::InvalidResponse("No choices in response".to_string()))?;
477
478        Ok(GenerationResponse {
479            content: choice.message.content.clone().unwrap_or_default(),
480            model: response.model,
481            finish_reason: choice.finish_reason.clone(),
482            usage: response.usage.map(|u| TokenUsage {
483                prompt_tokens: u.prompt_tokens,
484                completion_tokens: u.completion_tokens,
485                total_tokens: u.total_tokens,
486            }),
487        })
488    }
489
490    async fn create_structured_output_with_messages_raw(
491        &self,
492        messages: Vec<Message>,
493        json_schema: &Value,
494        options: Option<GenerationOptions>,
495    ) -> LlmResult<Value> {
496        let is_empty_or_non_json = |raw: &str| {
497            let trimmed = raw.trim();
498            trimmed.is_empty() || serde_json::from_str::<Value>(trimmed).is_err()
499        };
500
501        let parse_json =
502            |raw: &str| -> Result<Value, serde_json::Error> { serde_json::from_str(raw) };
503
504        let opts = options.unwrap_or_default();
505        let schema = json_schema;
506
507        // OpenAI strict mode requires `additionalProperties: false` and that
508        // every property appear in `required` on every object; the raw
509        // schemars schema satisfies neither, which would 400 and silently
510        // drop us into lenient mode (where required fields can go missing).
511        let strict_schema = to_strict_schema(schema);
512
513        // Try strict JSON schema mode first (new OpenAI API behavior).
514        let mut strict_schema_request = json!({
515            "model": self.model,
516            "messages": Self::convert_messages(&messages),
517            "response_format": {
518                "type": "json_schema",
519                "json_schema": {
520                    "name": "extract_structured_data",
521                    "schema": strict_schema,
522                    "strict": true
523                }
524            }
525        });
526
527        if !self.is_reasoning_model()
528            && let Some(temp) = opts.temperature
529        {
530            strict_schema_request["temperature"] = json!(temp);
531        }
532        self.write_max_tokens(&mut strict_schema_request, opts.max_tokens);
533        if self.should_disable_thinking() {
534            strict_schema_request["think"] = json!(false);
535            strict_schema_request["reasoning"] = json!({"effort": "none"});
536        }
537
538        for attempt in 0..self.structured_output_retries {
539            match self.call_api(strict_schema_request.clone()).await {
540                Ok(strict_response) => {
541                    let strict_choice = strict_response.choices.first().ok_or_else(|| {
542                        LlmError::InvalidResponse(
543                            "No choices in strict schema response".to_string(),
544                        )
545                    })?;
546
547                    if let Some(function_call) = &strict_choice.message.function_call {
548                        match parse_json(&function_call.arguments) {
549                            Ok(parsed) => return Ok(parsed),
550                            Err(e) => {
551                                if attempt + 1 < self.structured_output_retries
552                                    && is_empty_or_non_json(&function_call.arguments)
553                                {
554                                    continue;
555                                }
556                                if !is_empty_or_non_json(&function_call.arguments) {
557                                    return Err(LlmError::DeserializationError(format!(
558                                        "Failed to deserialize strict function call arguments: {}. Raw: {}",
559                                        e, function_call.arguments
560                                    )));
561                                }
562                                break;
563                            }
564                        }
565                    }
566
567                    if let Some(content) = strict_choice.message.content.as_ref() {
568                        match parse_json(content) {
569                            Ok(parsed) => return Ok(parsed),
570                            Err(e) => {
571                                if attempt + 1 < self.structured_output_retries
572                                    && is_empty_or_non_json(content)
573                                {
574                                    continue;
575                                }
576                                if !is_empty_or_non_json(content) {
577                                    return Err(LlmError::DeserializationError(format!(
578                                        "Failed to deserialize strict JSON content: {e}. Raw: {content}"
579                                    )));
580                                }
581                                break;
582                            }
583                        }
584                    }
585                }
586                Err(e) => {
587                    // Strict json_schema mode is unsupported by this
588                    // model/endpoint (or the schema was rejected). Fall back to
589                    // function calling / JSON mode below, but make the reason
590                    // visible — a silent fallback is how required fields end up
591                    // missing from the model's output.
592                    warn!(error = %e, "strict json_schema request failed; falling back to function/JSON mode");
593                    break;
594                }
595            }
596        }
597
598        // Try function calling first (works with OpenAI)
599        let mut request_body = json!({
600            "model": self.model,
601            "messages": Self::convert_messages(&messages),
602            "functions": [{
603                "name": "extract_structured_data",
604                "description": "Extract structured data from the input",
605                "parameters": schema
606            }],
607            "function_call": {"name": "extract_structured_data"}
608        });
609
610        if !self.is_reasoning_model()
611            && let Some(temp) = opts.temperature
612        {
613            request_body["temperature"] = json!(temp);
614        }
615        self.write_max_tokens(&mut request_body, opts.max_tokens);
616        if self.should_disable_thinking() {
617            request_body["think"] = json!(false);
618            request_body["reasoning"] = json!({"effort": "none"});
619        }
620
621        for attempt in 0..self.structured_output_retries {
622            let response = self.call_api(request_body.clone()).await?;
623
624            let choice = response
625                .choices
626                .first()
627                .ok_or_else(|| LlmError::InvalidResponse("No choices in response".to_string()))?;
628
629            if let Some(function_call) = &choice.message.function_call {
630                match parse_json(&function_call.arguments) {
631                    Ok(parsed) => return Ok(parsed),
632                    Err(e) => {
633                        if attempt + 1 < self.structured_output_retries
634                            && is_empty_or_non_json(&function_call.arguments)
635                        {
636                            continue;
637                        }
638                        if !is_empty_or_non_json(&function_call.arguments) {
639                            return Err(LlmError::DeserializationError(format!(
640                                "Failed to deserialize function call arguments: {}. Raw: {}",
641                                e, function_call.arguments
642                            )));
643                        }
644                        break;
645                    }
646                }
647            }
648
649            break;
650        }
651
652        // Fallback to JSON mode (works with Ollama and other providers)
653        let mut json_messages = Self::convert_messages(&messages);
654
655        let example = Self::schema_to_example(schema);
656
657        if let Some(last_msg) = json_messages.last_mut()
658            && last_msg["role"] == "user"
659        {
660            let original_content = last_msg["content"].as_str().unwrap_or("");
661            last_msg["content"] = json!(format!(
662                "{}\n\n\
663                    Extract the information from the text above and return it as JSON.\n\
664                    Use this structure as your template (but with actual data from the text):\n\
665                    {}",
666                original_content, example
667            ));
668        }
669
670        let mut json_request = json!({
671            "model": self.model,
672            "messages": json_messages,
673            "response_format": {"type": "json_object"}
674        });
675
676        if !self.is_reasoning_model()
677            && let Some(temp) = opts.temperature
678        {
679            json_request["temperature"] = json!(temp);
680        }
681        self.write_max_tokens(&mut json_request, opts.max_tokens);
682        if self.should_disable_thinking() {
683            json_request["think"] = json!(false);
684            json_request["reasoning"] = json!({"effort": "none"});
685        }
686
687        for attempt in 0..self.structured_output_retries {
688            let mut request_for_attempt = json_request.clone();
689
690            if attempt > 0 {
691                if let Some(messages) = request_for_attempt["messages"].as_array_mut()
692                    && let Some(last_msg) = messages.last_mut()
693                    && last_msg["role"] == "user"
694                {
695                    let original_content = last_msg["content"].as_str().unwrap_or("");
696                    last_msg["content"] = json!(format!(
697                        "{}\n\n/no_think\nReturn ONLY one valid JSON object matching the required schema. No reasoning, no markdown, no extra text.",
698                        original_content
699                    ));
700                }
701
702                if !self.is_reasoning_model() {
703                    request_for_attempt["temperature"] = json!(0.0);
704                }
705            }
706
707            let json_response = self.call_api(request_for_attempt).await?;
708
709            let json_choice = json_response.choices.first().ok_or_else(|| {
710                LlmError::InvalidResponse("No choices in JSON mode response".to_string())
711            })?;
712
713            let content = json_choice.message.content.as_ref().ok_or_else(|| {
714                LlmError::InvalidResponse("No content in JSON mode response".to_string())
715            })?;
716
717            match parse_json(content) {
718                Ok(parsed) => return Ok(parsed),
719                Err(e) => {
720                    if attempt + 1 < self.structured_output_retries && is_empty_or_non_json(content)
721                    {
722                        continue;
723                    }
724                    return Err(LlmError::DeserializationError(format!(
725                        "Failed to deserialize JSON content: {e}. Raw: {content}"
726                    )));
727                }
728            }
729        }
730
731        Err(LlmError::InvalidResponse(
732            "Structured output retries exhausted without a parseable response".to_string(),
733        ))
734    }
735
736    fn model(&self) -> &str {
737        &self.model
738    }
739
740    fn supports_streaming(&self) -> bool {
741        true
742    }
743
744    fn supports_function_calling(&self) -> bool {
745        true
746    }
747
748    fn max_context_length(&self) -> u32 {
749        // Context lengths for common OpenAI models
750        match self.model.as_str() {
751            m if m.starts_with("gpt-4-turbo") => 128_000,
752            m if m.starts_with("gpt-4-32k") => 32_768,
753            m if m.starts_with("gpt-4") => 8_192,
754            m if m.starts_with("gpt-3.5-turbo-16k") => 16_384,
755            m if m.starts_with("gpt-3.5-turbo") => 4_096,
756            _ => 4_096, // Conservative default
757        }
758    }
759
760    async fn transcribe_image(
761        &self,
762        image_bytes: &[u8],
763        mime_type: &str,
764        options: Option<GenerationOptions>,
765    ) -> LlmResult<String> {
766        use base64::Engine as _;
767
768        if !mime_type.starts_with("image/") {
769            return Err(LlmError::InvalidResponse(format!(
770                "Expected image/* MIME type, got: {mime_type}"
771            )));
772        }
773
774        let b64 = base64::engine::general_purpose::STANDARD.encode(image_bytes);
775        let data_uri = format!("data:{mime_type};base64,{b64}");
776
777        let vision_model = std::env::var("LLM_VISION_MODEL")
778            .ok()
779            .filter(|s| !s.is_empty())
780            .unwrap_or_else(|| self.model.clone());
781
782        let max_tokens = options.as_ref().and_then(|o| o.max_tokens).unwrap_or(300);
783
784        let mut request_body = json!({
785            "model": vision_model,
786            "messages": [{
787                "role": "user",
788                "content": [
789                    { "type": "text", "text": "What's in this image?" },
790                    { "type": "image_url", "image_url": { "url": data_uri } }
791                ]
792            }],
793        });
794        self.write_max_tokens(&mut request_body, Some(max_tokens));
795
796        let response = self.call_api(request_body).await?;
797
798        let choice = response.choices.first().ok_or_else(|| {
799            LlmError::InvalidResponse("No choices in vision response".to_string())
800        })?;
801
802        choice.message.content.clone().ok_or_else(|| {
803            LlmError::InvalidResponse("Vision response contained no content".to_string())
804        })
805    }
806
807    fn supports_vision(&self) -> bool {
808        let m = self.model.to_lowercase();
809        m.contains("gpt-4")
810            || m.contains("gpt-5")
811            || m.contains("vision")
812            || m.contains("o1")
813            || m.contains("o3")
814            || m.contains("o4")
815            || m.contains("llava")
816            || m.contains("moondream")
817            || m.contains("llama-3.2-vision")
818            || m.contains("gemma3")
819    }
820}
821
822// ---------------------------------------------------------------------------
823// Whisper transcription support
824// ---------------------------------------------------------------------------
825
826/// Response from the OpenAI Whisper `verbose_json` endpoint.
827#[derive(Debug, Deserialize)]
828struct WhisperResponse {
829    text: String,
830    language: Option<String>,
831    duration: Option<f32>,
832}
833
834/// Map a validated audio format extension to its MIME type.
835fn audio_mime_type(format: &str) -> &'static str {
836    match format {
837        "mp3" | "mpeg" | "mpga" => "audio/mpeg",
838        "mp4" | "m4a" => "audio/mp4",
839        "wav" => "audio/wav",
840        "webm" => "audio/webm",
841        // validate_audio_format ensures only the above values reach here
842        _ => "application/octet-stream",
843    }
844}
845
846impl OpenAIAdapter {
847    /// Call the Whisper transcription API with the same retry logic as `call_api`.
848    #[instrument(
849        name = "llm.transcription_api_call",
850        level = "info",
851        skip(self, form),
852        fields(
853            url = tracing::field::Empty,
854            cognee.llm.model = self.transcription_model.as_str(),
855            cognee.llm.provider = "openai",
856        ),
857    )]
858    async fn call_transcription_api(
859        &self,
860        form: reqwest::multipart::Form,
861    ) -> LlmResult<WhisperResponse> {
862        let url = format!("{}/audio/transcriptions", self.base_url);
863        tracing::Span::current().record("url", url.as_str());
864
865        // We cannot clone a multipart Form, so the first attempt uses the
866        // original form and retries are not possible for the multipart body.
867        // However, we keep the retry loop for network errors that occur
868        // *before* the body is consumed (connection refused, DNS failure).
869        // For simplicity and matching the guide's design, we rebuild the form
870        // if needed by storing the bytes. But since `Form` doesn't support
871        // Clone, we perform a single attempt with the form and rely on the
872        // caller to retry externally if needed.
873        //
874        // Actually, the simplest approach is to send the form once and
875        // handle retries at a higher level. But the guide says to mirror
876        // call_api's retry. Since reqwest::multipart::Form is not Clone,
877        // we accept `form` by value and do a single-shot request here,
878        // while the `transcribe_audio` impl handles retry by rebuilding
879        // the form on each attempt.
880
881        let response = self
882            .client
883            .post(&url)
884            .header("Authorization", self.auth_header())
885            .multipart(form)
886            .send()
887            .await
888            .map_err(|e| LlmError::NetworkError(e.to_string()))?;
889
890        let status = response.status();
891
892        if !status.is_success() {
893            let error_body = response
894                .text()
895                .await
896                .unwrap_or_else(|_| "Unknown error".to_string());
897
898            return Err(match status.as_u16() {
899                401 => LlmError::AuthenticationError(error_body),
900                429 => LlmError::RateLimitExceeded(error_body),
901                400 => LlmError::InvalidResponse(format!("Bad request: {error_body}")),
902                _ => LlmError::ApiError(format!("HTTP {status}: {error_body}")),
903            });
904        }
905
906        let response_body = response.text().await.map_err(|e| {
907            LlmError::DeserializationError(format!("Failed to read response body: {e}"))
908        })?;
909
910        serde_json::from_str::<WhisperResponse>(&response_body).map_err(|e| {
911            LlmError::DeserializationError(format!(
912                "Failed to parse Whisper response: {e}. Raw body: {response_body}"
913            ))
914        })
915    }
916
917    /// Build a `reqwest::multipart::Form` for a Whisper transcription request.
918    fn build_transcription_form(
919        &self,
920        audio: &[u8],
921        format: &str,
922        language_hint: Option<&str>,
923        prompt_hint: Option<&str>,
924    ) -> LlmResult<reqwest::multipart::Form> {
925        let mime = audio_mime_type(format);
926        let filename = format!("audio.{format}");
927
928        let file_part = reqwest::multipart::Part::bytes(audio.to_vec())
929            .file_name(filename)
930            .mime_str(mime)
931            .map_err(|e| {
932                LlmError::ConfigError(format!("Failed to set MIME type on multipart part: {e}"))
933            })?;
934
935        let mut form = reqwest::multipart::Form::new()
936            .part("file", file_part)
937            .text("model", self.transcription_model.clone())
938            .text("response_format", "verbose_json");
939
940        if let Some(lang) = language_hint {
941            form = form.text("language", lang.to_string());
942        }
943        if let Some(prompt) = prompt_hint {
944            form = form.text("prompt", prompt.to_string());
945        }
946
947        Ok(form)
948    }
949}
950
951#[async_trait]
952impl Transcriber for OpenAIAdapter {
953    async fn transcribe_audio(
954        &self,
955        audio: &[u8],
956        format: &str,
957        language_hint: Option<&str>,
958        prompt_hint: Option<&str>,
959    ) -> LlmResult<TranscriptionOutput> {
960        // Normalize and validate before any network I/O.
961        let format_lower = format.to_ascii_lowercase();
962        validate_audio_format(&format_lower)?;
963
964        let mut last_error = LlmError::NetworkError("No attempt made".to_string());
965
966        for attempt in 0..=self.network_retries {
967            debug!(attempt, "Transcription API attempt");
968            if attempt > 0 {
969                let delay_ms = (1_000u64 * 2u64.saturating_pow(attempt as u32 - 1)).min(30_000);
970                warn!(
971                    attempt,
972                    network_retries = self.network_retries,
973                    delay_ms,
974                    error = %last_error,
975                    "Transcription request failed, retrying",
976                );
977                tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await;
978            }
979
980            let form =
981                self.build_transcription_form(audio, &format_lower, language_hint, prompt_hint)?;
982
983            match self.call_transcription_api(form).await {
984                Ok(resp) => {
985                    return Ok(TranscriptionOutput {
986                        text: resp.text,
987                        language: resp.language,
988                        duration: resp.duration,
989                    });
990                }
991                Err(e) => {
992                    // Non-retryable errors: bad request or authentication failure.
993                    if matches!(
994                        e,
995                        LlmError::InvalidResponse(_) | LlmError::AuthenticationError(_)
996                    ) {
997                        return Err(e);
998                    }
999                    last_error = e;
1000                    continue;
1001                }
1002            }
1003        }
1004
1005        Err(LlmError::MaxRetriesExceeded(format!(
1006            "Transcription request failed after {} attempt(s): {}",
1007            self.network_retries + 1,
1008            last_error
1009        )))
1010    }
1011
1012    fn transcription_model(&self) -> &str {
1013        &self.transcription_model
1014    }
1015}
1016
1017// OpenAI API response types
1018#[derive(Debug, Deserialize)]
1019#[allow(dead_code)]
1020struct OpenAIResponse {
1021    id: String,
1022    object: String,
1023    created: i64,
1024    model: String,
1025    choices: Vec<OpenAIChoice>,
1026    usage: Option<OpenAIUsage>,
1027}
1028
1029#[derive(Debug, Deserialize)]
1030#[allow(dead_code)]
1031struct OpenAIChoice {
1032    index: u32,
1033    message: OpenAIMessage,
1034    finish_reason: Option<String>,
1035}
1036
1037#[derive(Debug, Deserialize)]
1038#[allow(dead_code)]
1039struct OpenAIMessage {
1040    role: String,
1041    content: Option<String>,
1042    reasoning: Option<String>,
1043    function_call: Option<OpenAIFunctionCall>,
1044}
1045
1046#[derive(Debug, Deserialize)]
1047#[allow(dead_code)]
1048struct OpenAIFunctionCall {
1049    name: String,
1050    arguments: String,
1051}
1052
1053#[derive(Debug, Deserialize)]
1054struct OpenAIUsage {
1055    prompt_tokens: u32,
1056    completion_tokens: u32,
1057    total_tokens: u32,
1058}
1059
1060#[cfg(test)]
1061mod tests {
1062    #![allow(
1063        clippy::unwrap_used,
1064        clippy::expect_used,
1065        reason = "test code — panics are acceptable"
1066    )]
1067    use super::*;
1068
1069    #[test]
1070    fn test_openai_provider_prefix_is_stripped() {
1071        // litellm-style "openai/<model>" must be sent as bare "<model>".
1072        let adapter = OpenAIAdapter::new("openai/gpt-5-mini", "test-key", None).unwrap();
1073        assert_eq!(adapter.model(), "gpt-5-mini");
1074        // Non-openai provider prefixes (custom endpoints) are left intact.
1075        let adapter = OpenAIAdapter::new("ollama/llama3", "test-key", None).unwrap();
1076        assert_eq!(adapter.model(), "ollama/llama3");
1077    }
1078
1079    #[test]
1080    fn test_openai_adapter_creation() {
1081        let adapter = OpenAIAdapter::new("gpt-4", "test-key", None);
1082        assert!(adapter.is_ok());
1083
1084        let adapter = adapter.unwrap();
1085        assert_eq!(adapter.model(), "gpt-4");
1086        assert_eq!(adapter.base_url, OpenAIAdapter::DEFAULT_BASE_URL);
1087        assert_eq!(
1088            adapter.structured_output_retries,
1089            OpenAIAdapter::DEFAULT_STRUCTURED_OUTPUT_RETRIES
1090        );
1091    }
1092
1093    #[test]
1094    fn test_configurable_structured_output_retries() {
1095        let adapter = OpenAIAdapter::new("gpt-4", "test-key", None)
1096            .unwrap()
1097            .with_structured_output_retries(5);
1098        assert_eq!(adapter.structured_output_retries, 5);
1099
1100        let adapter = OpenAIAdapter::new("gpt-4", "test-key", None)
1101            .unwrap()
1102            .with_structured_output_retries(0);
1103        assert_eq!(adapter.structured_output_retries, 1);
1104    }
1105
1106    #[test]
1107    fn test_openai_adapter_custom_base_url() {
1108        let adapter = OpenAIAdapter::new(
1109            "gpt-4",
1110            "test-key",
1111            Some("https://custom.api.com/v1".to_string()),
1112        );
1113        assert!(adapter.is_ok());
1114
1115        let adapter = adapter.unwrap();
1116        assert_eq!(adapter.base_url, "https://custom.api.com/v1");
1117    }
1118
1119    #[test]
1120    fn test_is_reasoning_model_matches_openai_families() {
1121        let cases = [
1122            ("gpt-5", true),
1123            ("gpt-5-mini", true),
1124            ("gpt-5-2025-06-01", true),
1125            ("o1", true),
1126            ("o1-mini", true),
1127            ("o3", true),
1128            ("o3-mini", true),
1129            ("o4-mini", true),
1130            ("GPT-5-Mini", true),
1131            ("gpt-4o-mini", false),
1132            ("gpt-4-turbo", false),
1133            ("gpt-3.5-turbo", false),
1134            ("o-foo", false),
1135        ];
1136        for (model, expected) in cases {
1137            let adapter = OpenAIAdapter::new(model, "test-key", None).unwrap();
1138            assert_eq!(
1139                adapter.is_reasoning_model(),
1140                expected,
1141                "is_reasoning_model({model})"
1142            );
1143        }
1144    }
1145
1146    #[test]
1147    fn test_is_reasoning_model_skipped_for_custom_base_url() {
1148        // Custom OpenAI-compatible endpoints (Ollama, vLLM, …) may have
1149        // model names that look like reasoning families but still accept
1150        // legacy sampling parameters — the gate is conservative.
1151        let adapter = OpenAIAdapter::new(
1152            "gpt-5-mini",
1153            "test-key",
1154            Some("http://localhost:11434/v1".to_string()),
1155        )
1156        .unwrap();
1157        assert!(!adapter.is_reasoning_model());
1158    }
1159
1160    #[test]
1161    fn test_write_max_tokens_renames_key_for_reasoning_models() {
1162        let mut body = json!({"model": "gpt-5-mini"});
1163        let reasoning = OpenAIAdapter::new("gpt-5-mini", "test-key", None).unwrap();
1164        reasoning.write_max_tokens(&mut body, Some(2048));
1165        assert!(body.get("max_tokens").is_none());
1166        assert_eq!(body["max_completion_tokens"], 2048);
1167
1168        let mut body = json!({"model": "gpt-4o-mini"});
1169        let classic = OpenAIAdapter::new("gpt-4o-mini", "test-key", None).unwrap();
1170        classic.write_max_tokens(&mut body, Some(2048));
1171        assert_eq!(body["max_tokens"], 2048);
1172        assert!(body.get("max_completion_tokens").is_none());
1173
1174        // None leaves body untouched.
1175        let mut body = json!({"model": "gpt-5-mini"});
1176        reasoning.write_max_tokens(&mut body, None);
1177        assert!(body.get("max_tokens").is_none());
1178        assert!(body.get("max_completion_tokens").is_none());
1179    }
1180
1181    #[test]
1182    fn test_message_conversion() {
1183        let messages = vec![
1184            Message {
1185                role: MessageRole::System,
1186                content: "You are helpful".to_string(),
1187            },
1188            Message {
1189                role: MessageRole::User,
1190                content: "Hello".to_string(),
1191            },
1192        ];
1193
1194        let converted = OpenAIAdapter::convert_messages(&messages);
1195        assert_eq!(converted.len(), 2);
1196        assert_eq!(converted[0]["role"], "system");
1197        assert_eq!(converted[0]["content"], "You are helpful");
1198        assert_eq!(converted[1]["role"], "user");
1199        assert_eq!(converted[1]["content"], "Hello");
1200    }
1201
1202    #[test]
1203    fn test_context_length() {
1204        let adapter = OpenAIAdapter::new("gpt-4-turbo-preview", "key", None).unwrap();
1205        assert_eq!(adapter.max_context_length(), 128_000);
1206
1207        let adapter = OpenAIAdapter::new("gpt-4", "key", None).unwrap();
1208        assert_eq!(adapter.max_context_length(), 8_192);
1209
1210        let adapter = OpenAIAdapter::new("gpt-3.5-turbo-16k", "key", None).unwrap();
1211        assert_eq!(adapter.max_context_length(), 16_384);
1212    }
1213
1214    #[test]
1215    fn test_supports_vision_gpt4o() {
1216        let adapter = OpenAIAdapter::new("gpt-4o", "key", None).unwrap();
1217        assert!(adapter.supports_vision());
1218    }
1219
1220    #[test]
1221    fn test_supports_vision_gpt4_turbo() {
1222        let adapter = OpenAIAdapter::new("gpt-4-turbo", "key", None).unwrap();
1223        assert!(adapter.supports_vision());
1224    }
1225
1226    #[test]
1227    fn test_supports_vision_gpt4o_mini() {
1228        let adapter = OpenAIAdapter::new("gpt-4o-mini", "key", None).unwrap();
1229        assert!(adapter.supports_vision());
1230    }
1231
1232    #[test]
1233    fn test_supports_vision_gpt35_is_false() {
1234        let adapter = OpenAIAdapter::new("gpt-3.5-turbo", "key", None).unwrap();
1235        assert!(!adapter.supports_vision());
1236    }
1237
1238    #[test]
1239    fn test_supports_vision_llava() {
1240        let adapter = OpenAIAdapter::new("llava:13b", "key", None).unwrap();
1241        assert!(adapter.supports_vision());
1242    }
1243
1244    #[test]
1245    fn test_supports_vision_o1() {
1246        let adapter = OpenAIAdapter::new("o1-preview", "key", None).unwrap();
1247        assert!(adapter.supports_vision());
1248    }
1249
1250    #[test]
1251    fn test_supports_vision_gemma3() {
1252        let adapter = OpenAIAdapter::new("gemma3:12b", "key", None).unwrap();
1253        assert!(adapter.supports_vision());
1254    }
1255
1256    #[tokio::test]
1257    async fn transcribe_image_rejects_non_image_mime() {
1258        let adapter = OpenAIAdapter::new("gpt-4o", "fake-key", None).unwrap();
1259        let result = adapter
1260            .transcribe_image(b"not-an-image", "text/plain", None)
1261            .await;
1262        assert!(result.is_err());
1263        assert!(
1264            matches!(result.unwrap_err(), LlmError::InvalidResponse(_)),
1265            "Expected InvalidResponse for non-image MIME type"
1266        );
1267    }
1268
1269    #[test]
1270    fn test_transcription_model_default() {
1271        // Clear the env var to test the default value.
1272        // SAFETY: This test is single-threaded and no other thread reads
1273        // TRANSCRIPTION_MODEL concurrently.
1274        unsafe { std::env::remove_var("TRANSCRIPTION_MODEL") };
1275        let adapter = OpenAIAdapter::new("gpt-4", "key", None).unwrap();
1276        assert_eq!(adapter.transcription_model(), "whisper-1");
1277    }
1278
1279    #[test]
1280    fn test_transcription_model_custom() {
1281        let adapter = OpenAIAdapter::new("gpt-4", "key", None)
1282            .unwrap()
1283            .with_transcription_model("whisper-large-v3");
1284        assert_eq!(adapter.transcription_model(), "whisper-large-v3");
1285    }
1286
1287    #[test]
1288    fn test_audio_mime_type_mapping() {
1289        assert_eq!(audio_mime_type("mp3"), "audio/mpeg");
1290        assert_eq!(audio_mime_type("mpeg"), "audio/mpeg");
1291        assert_eq!(audio_mime_type("mpga"), "audio/mpeg");
1292        assert_eq!(audio_mime_type("mp4"), "audio/mp4");
1293        assert_eq!(audio_mime_type("m4a"), "audio/mp4");
1294        assert_eq!(audio_mime_type("wav"), "audio/wav");
1295        assert_eq!(audio_mime_type("webm"), "audio/webm");
1296    }
1297
1298    #[test]
1299    fn test_to_strict_schema_marks_all_required_and_closes_objects() {
1300        // Mirrors the schemars-0.8 shape: an optional field omitted from
1301        // `required`, nested object behind `definitions`/`$ref`, and no
1302        // `additionalProperties` set anywhere.
1303        let schema = json!({
1304            "type": "object",
1305            "properties": {
1306                "nodes": { "type": "array", "items": { "$ref": "#/definitions/Node" } }
1307            },
1308            "required": ["nodes"],
1309            "definitions": {
1310                "Node": {
1311                    "type": "object",
1312                    "properties": {
1313                        "name": { "type": "string" },
1314                        "type": { "type": "string" },
1315                        "description": { "type": ["string", "null"] }
1316                    },
1317                    "required": ["name", "type"]
1318                }
1319            }
1320        });
1321
1322        let strict = to_strict_schema(&schema);
1323
1324        // Root object closed + all props required.
1325        assert_eq!(strict["additionalProperties"], json!(false));
1326        assert_eq!(strict["required"], json!(["nodes"]));
1327
1328        // Nested object inside definitions: every property now required
1329        // (including the previously-optional `description`) and closed.
1330        let node = &strict["definitions"]["Node"];
1331        assert_eq!(node["additionalProperties"], json!(false));
1332        let mut req: Vec<String> = node["required"]
1333            .as_array()
1334            .unwrap()
1335            .iter()
1336            .map(|v| v.as_str().unwrap().to_string())
1337            .collect();
1338        req.sort();
1339        assert_eq!(req, vec!["description", "name", "type"]);
1340    }
1341}