Skip to main content

nenjo_models/
openai.rs

1//! OpenAI provider. Authenticates via Bearer token.
2
3use crate::ToolSpec;
4use crate::native::{
5    MediaOutputAsset, MediaOutputFormat, ModelNativeCapabilities, NativeCapabilitiesProvider,
6    NativeExecutionMode, NativeMediaRequest, NativeMediaResponse, NativeOperation,
7    NativeToolSpec as NativeMediaToolSpec, ProviderNativeCapabilities,
8};
9use crate::traits::{ChatMessage, ChatRequest, ChatResponse, ModelProvider, TokenUsage, ToolCall};
10use async_trait::async_trait;
11use reqwest::Client;
12use serde::{Deserialize, Serialize};
13use serde_json::{Value, json};
14
15pub struct OpenAiProvider {
16    api_key: Option<String>,
17    client: Client,
18}
19
20#[derive(Debug, Serialize)]
21struct NativeChatRequest {
22    model: String,
23    messages: Vec<NativeMessage>,
24    #[serde(skip_serializing_if = "Option::is_none")]
25    temperature: Option<f64>,
26    #[serde(skip_serializing_if = "Option::is_none")]
27    max_completion_tokens: Option<u32>,
28    #[serde(skip_serializing_if = "Option::is_none")]
29    tools: Option<Vec<NativeToolSpec>>,
30    #[serde(skip_serializing_if = "Option::is_none")]
31    tool_choice: Option<String>,
32}
33
34#[derive(Debug, Serialize)]
35struct NativeMessage {
36    role: String,
37    #[serde(skip_serializing_if = "Option::is_none")]
38    content: Option<String>,
39    #[serde(skip_serializing_if = "Option::is_none")]
40    tool_call_id: Option<String>,
41    #[serde(skip_serializing_if = "Option::is_none")]
42    tool_calls: Option<Vec<NativeToolCall>>,
43}
44
45#[derive(Debug, Serialize)]
46struct NativeToolSpec {
47    #[serde(rename = "type")]
48    kind: String,
49    function: NativeToolFunctionSpec,
50}
51
52#[derive(Debug, Serialize)]
53struct NativeToolFunctionSpec {
54    name: String,
55    description: String,
56    parameters: serde_json::Value,
57}
58
59#[derive(Debug, Serialize, Deserialize)]
60struct NativeToolCall {
61    #[serde(skip_serializing_if = "Option::is_none")]
62    id: Option<String>,
63    #[serde(rename = "type", skip_serializing_if = "Option::is_none")]
64    kind: Option<String>,
65    function: NativeFunctionCall,
66}
67
68#[derive(Debug, Serialize, Deserialize)]
69struct NativeFunctionCall {
70    name: String,
71    arguments: String,
72}
73
74#[derive(Debug, Deserialize)]
75struct NativeUsage {
76    #[serde(default)]
77    prompt_tokens: u64,
78    #[serde(default)]
79    completion_tokens: u64,
80}
81
82#[derive(Debug, Deserialize)]
83struct NativeChatResponse {
84    choices: Vec<NativeChoice>,
85    #[serde(default)]
86    usage: Option<NativeUsage>,
87}
88
89#[derive(Debug, Deserialize)]
90struct NativeChoice {
91    message: NativeResponseMessage,
92}
93
94#[derive(Debug, Deserialize)]
95struct NativeResponseMessage {
96    #[serde(default)]
97    content: Option<String>,
98    #[serde(default)]
99    tool_calls: Option<Vec<NativeToolCall>>,
100}
101
102#[derive(Debug, Serialize)]
103struct ImageGenerationRequest<'a> {
104    model: &'a str,
105    prompt: &'a str,
106    #[serde(skip_serializing_if = "Option::is_none")]
107    n: Option<u32>,
108    #[serde(skip_serializing_if = "Option::is_none")]
109    size: Option<&'a str>,
110    #[serde(skip_serializing_if = "Option::is_none")]
111    response_format: Option<&'static str>,
112    #[serde(skip_serializing_if = "Option::is_none")]
113    background: Option<&'a str>,
114    #[serde(skip_serializing_if = "Option::is_none")]
115    output_format: Option<&'a str>,
116    #[serde(skip_serializing_if = "Option::is_none")]
117    quality: Option<&'a str>,
118}
119
120#[derive(Debug, Deserialize)]
121struct ImageGenerationResponse {
122    data: Vec<ImageGenerationData>,
123}
124
125#[derive(Debug, Deserialize)]
126struct ImageGenerationData {
127    #[serde(default)]
128    url: Option<String>,
129    #[serde(default)]
130    b64_json: Option<String>,
131    #[serde(default)]
132    revised_prompt: Option<String>,
133}
134
135fn provider_option_str<'a>(options: &'a Value, key: &str) -> Option<&'a str> {
136    options.get(key).and_then(Value::as_str)
137}
138
139fn openai_generate_image_tool_spec() -> NativeMediaToolSpec {
140    let capability = NativeOperation::GenerateImage;
141    NativeMediaToolSpec {
142        capability,
143        tool_name: capability.tool_name().unwrap().to_string(),
144        description: "Generate an image with the configured OpenAI image model.".to_string(),
145        execution: NativeExecutionMode::Immediate,
146        parameters_schema: json!({
147            "type": "object",
148            "properties": {
149                "prompt": {"type": "string"},
150                "n": {"type": "integer", "minimum": 1},
151                "size": {
152                    "type": "string",
153                    "enum": ["1024x1024", "1024x1536", "1536x1024", "auto"]
154                },
155                "output_format": {"type": "string", "enum": ["url", "base64"]},
156                "provider_options": {
157                    "type": "object",
158                    "properties": {
159                        "background": {
160                            "type": "string",
161                            "enum": ["transparent", "opaque", "auto"]
162                        },
163                        "output_format": {
164                            "type": "string",
165                            "enum": ["png", "webp", "jpeg"]
166                        },
167                        "quality": {
168                            "type": "string",
169                            "enum": ["low", "medium", "high", "auto"]
170                        }
171                    },
172                    "additionalProperties": false
173                }
174            },
175            "required": ["prompt"]
176        }),
177    }
178}
179
180fn image_mime_type(output_format: Option<&str>) -> String {
181    match output_format {
182        Some("jpeg") => "image/jpeg",
183        Some("webp") => "image/webp",
184        _ => "image/png",
185    }
186    .to_string()
187}
188
189impl OpenAiProvider {
190    pub fn new(api_key: Option<&str>) -> Self {
191        Self {
192            api_key: api_key.map(ToString::to_string),
193            client: Client::builder()
194                .timeout(std::time::Duration::from_secs(120))
195                .connect_timeout(std::time::Duration::from_secs(10))
196                .build()
197                .unwrap_or_else(|_| Client::new()),
198        }
199    }
200
201    fn is_reasoning_model(model: &str) -> bool {
202        let m = model.to_lowercase();
203        m.starts_with("o1") || m.starts_with("o3") || m.starts_with("o4")
204    }
205
206    fn is_developer_role_model(model: &str) -> bool {
207        let m = model.to_lowercase();
208        Self::is_reasoning_model(&m)
209            || m.starts_with("gpt-5")
210            || m.starts_with("gpt-4.5")
211            || m.starts_with("gpt-4.1")
212    }
213
214    fn convert_tools(tools: Option<&[ToolSpec]>) -> Option<Vec<NativeToolSpec>> {
215        tools.map(|items| {
216            items
217                .iter()
218                .map(|tool| NativeToolSpec {
219                    kind: "function".to_string(),
220                    function: NativeToolFunctionSpec {
221                        name: crate::sanitize_tool_name(&tool.name),
222                        description: tool.description.clone(),
223                        parameters: tool.parameters.clone(),
224                    },
225                })
226                .collect()
227        })
228    }
229
230    fn convert_messages(messages: &[ChatMessage]) -> Vec<NativeMessage> {
231        messages
232            .iter()
233            .map(|m| {
234                if m.role == "assistant"
235                    && let Ok(value) = serde_json::from_str::<serde_json::Value>(&m.content)
236                    && let Some(tool_calls_value) = value.get("tool_calls")
237                    && let Ok(parsed_calls) =
238                        serde_json::from_value::<Vec<ToolCall>>(tool_calls_value.clone())
239                {
240                    let tool_calls = parsed_calls
241                        .into_iter()
242                        .map(|tc| NativeToolCall {
243                            id: Some(tc.id),
244                            kind: Some("function".to_string()),
245                            function: NativeFunctionCall {
246                                name: tc.name,
247                                arguments: tc.arguments,
248                            },
249                        })
250                        .collect::<Vec<_>>();
251                    let content = value
252                        .get("content")
253                        .and_then(serde_json::Value::as_str)
254                        .map(ToString::to_string);
255                    return NativeMessage {
256                        role: "assistant".to_string(),
257                        content,
258                        tool_call_id: None,
259                        tool_calls: Some(tool_calls),
260                    };
261                }
262
263                if m.role == "tool"
264                    && let Ok(value) = serde_json::from_str::<serde_json::Value>(&m.content)
265                {
266                    let tool_call_id = value
267                        .get("tool_call_id")
268                        .and_then(serde_json::Value::as_str)
269                        .map(ToString::to_string);
270                    let content = value
271                        .get("content")
272                        .and_then(serde_json::Value::as_str)
273                        .map(ToString::to_string);
274                    return NativeMessage {
275                        role: "tool".to_string(),
276                        content,
277                        tool_call_id,
278                        tool_calls: None,
279                    };
280                }
281
282                NativeMessage {
283                    role: m.role.clone(),
284                    content: Some(m.content.clone()),
285                    tool_call_id: None,
286                    tool_calls: None,
287                }
288            })
289            .collect()
290    }
291
292    fn parse_native_response(message: NativeResponseMessage) -> ChatResponse {
293        let tool_calls = message
294            .tool_calls
295            .unwrap_or_default()
296            .into_iter()
297            .map(|tc| ToolCall {
298                id: tc.id.unwrap_or_else(|| uuid::Uuid::new_v4().to_string()),
299                name: tc.function.name,
300                arguments: tc.function.arguments,
301            })
302            .collect::<Vec<_>>();
303
304        ChatResponse {
305            text: message.content,
306            tool_calls,
307            provider_tool_calls: vec![],
308            usage: TokenUsage::default(),
309        }
310    }
311
312    async fn generate_image(
313        &self,
314        request: crate::native::GenerateImageRequest,
315    ) -> anyhow::Result<NativeMediaResponse> {
316        let api_key = self.api_key.as_ref().ok_or_else(|| {
317            anyhow::anyhow!("OpenAI API key not set. Set OPENAI_API_KEY or edit config.toml.")
318        })?;
319
320        let response_format = match request.output_format {
321            MediaOutputFormat::Url => None,
322            MediaOutputFormat::Base64 => Some("b64_json"),
323        };
324        let body = ImageGenerationRequest {
325            model: &request.model,
326            prompt: &request.prompt,
327            n: request.n,
328            size: request.size.as_deref(),
329            response_format,
330            background: provider_option_str(&request.provider_options, "background"),
331            output_format: provider_option_str(&request.provider_options, "output_format"),
332            quality: provider_option_str(&request.provider_options, "quality"),
333        };
334        let mime_type = image_mime_type(body.output_format);
335
336        let response = self
337            .client
338            .post("https://api.openai.com/v1/images/generations")
339            .header("Authorization", format!("Bearer {api_key}"))
340            .json(&body)
341            .send()
342            .await?;
343
344        if !response.status().is_success() {
345            return Err(crate::api_error("OpenAI", response).await);
346        }
347
348        let images: ImageGenerationResponse = response.json().await?;
349        let mut assets = Vec::new();
350        let mut revised_prompts = Vec::new();
351
352        for image in images.data {
353            if let Some(prompt) = image.revised_prompt {
354                revised_prompts.push(prompt);
355            }
356            if let Some(url) = image.url {
357                assets.push(MediaOutputAsset::Url {
358                    url,
359                    mime_type: Some(mime_type.clone()),
360                });
361            } else if let Some(data) = image.b64_json {
362                assets.push(MediaOutputAsset::Base64 {
363                    data,
364                    mime_type: Some(mime_type.clone()),
365                });
366            }
367        }
368
369        if assets.is_empty() {
370            anyhow::bail!("OpenAI image generation returned no assets");
371        }
372
373        let metadata = if revised_prompts.is_empty() {
374            None
375        } else {
376            Some(serde_json::json!({ "revised_prompts": revised_prompts }))
377        };
378
379        Ok(NativeMediaResponse::Assets { assets, metadata })
380    }
381}
382
383#[async_trait]
384impl ModelProvider for OpenAiProvider {
385    async fn chat(
386        &self,
387        request: ChatRequest<'_>,
388        model: &str,
389        temperature: f64,
390    ) -> anyhow::Result<ChatResponse> {
391        let api_key = self.api_key.as_ref().ok_or_else(|| {
392            anyhow::anyhow!("OpenAI API key not set. Set OPENAI_API_KEY or edit config.toml.")
393        })?;
394
395        let is_reasoning = Self::is_reasoning_model(model);
396        let tools = Self::convert_tools(request.tools);
397        let native_request = NativeChatRequest {
398            model: model.to_string(),
399            messages: Self::convert_messages(request.messages),
400            // Reasoning models (o1/o3/o4) require temperature=1; omit it to use the default.
401            temperature: if is_reasoning {
402                None
403            } else {
404                Some(temperature)
405            },
406            max_completion_tokens: Some(if is_reasoning { 65536 } else { 16384 }),
407            tool_choice: tools.as_ref().map(|_| "auto".to_string()),
408            tools,
409        };
410
411        let response = self
412            .client
413            .post("https://api.openai.com/v1/chat/completions")
414            .header("Authorization", format!("Bearer {api_key}"))
415            .json(&native_request)
416            .send()
417            .await?;
418
419        if !response.status().is_success() {
420            return Err(crate::api_error("OpenAI", response).await);
421        }
422
423        let native_response: NativeChatResponse = response.json().await?;
424        let usage = native_response
425            .usage
426            .map(|u| TokenUsage {
427                input_tokens: u.prompt_tokens,
428                output_tokens: u.completion_tokens,
429            })
430            .unwrap_or_default();
431        let message = native_response
432            .choices
433            .into_iter()
434            .next()
435            .map(|c| c.message)
436            .ok_or_else(|| anyhow::anyhow!("No response from OpenAI"))?;
437        let mut result = Self::parse_native_response(message);
438        result.usage = usage;
439        Ok(result)
440    }
441
442    fn context_window(&self, model: &str) -> Option<usize> {
443        let m = model.to_lowercase();
444        Some(if m.contains("gpt-5") {
445            // GPT-5.x: 1M
446            1_000_000
447        } else if m.contains("o1") || m.contains("o3") || m.contains("o4") {
448            // Reasoning models: 200K
449            200_000
450        } else if m.contains("gpt-4o") {
451            // GPT-4o / GPT-4o-mini: 128K
452            128_000
453        } else {
454            128_000
455        })
456    }
457
458    fn supports_native_tools(&self) -> bool {
459        true
460    }
461
462    fn supports_developer_role(&self, model: &str) -> bool {
463        Self::is_developer_role_model(model)
464    }
465
466    fn native_capabilities(&self) -> Option<ProviderNativeCapabilities> {
467        Some(NativeCapabilitiesProvider::native_capabilities(self))
468    }
469
470    async fn submit_media(
471        &self,
472        request: NativeMediaRequest,
473    ) -> anyhow::Result<NativeMediaResponse> {
474        NativeCapabilitiesProvider::submit_media(self, request).await
475    }
476}
477
478#[async_trait]
479impl NativeCapabilitiesProvider for OpenAiProvider {
480    fn native_capabilities(&self) -> ProviderNativeCapabilities {
481        ProviderNativeCapabilities {
482            provider: "openai".to_string(),
483            model_tools: Vec::new(),
484            models: vec![ModelNativeCapabilities {
485                model_pattern: "gpt-image-*".to_string(),
486                tools: vec![openai_generate_image_tool_spec()],
487            }],
488        }
489    }
490
491    async fn submit_media(
492        &self,
493        request: NativeMediaRequest,
494    ) -> anyhow::Result<NativeMediaResponse> {
495        let operation = request.operation();
496        match request {
497            NativeMediaRequest::GenerateImage(request) => self.generate_image(request).await,
498            NativeMediaRequest::EditImage(_)
499            | NativeMediaRequest::GenerateVideo(_)
500            | NativeMediaRequest::EditVideo(_)
501            | NativeMediaRequest::ImageToVideo(_)
502            | NativeMediaRequest::ReferenceToVideo(_)
503            | NativeMediaRequest::ExtendVideo(_)
504            | NativeMediaRequest::GenerateSpeech(_)
505            | NativeMediaRequest::TranscribeAudio(_) => {
506                anyhow::bail!(
507                    "OpenAI native operation {operation:?} is declared but not implemented in this pass"
508                )
509            }
510        }
511    }
512}
513
514#[cfg(test)]
515mod tests {
516    use super::*;
517
518    #[test]
519    fn creates_with_key() {
520        let p = OpenAiProvider::new(Some("sk-proj-abc123"));
521        assert_eq!(p.api_key.as_deref(), Some("sk-proj-abc123"));
522    }
523
524    #[test]
525    fn developer_role_supported_for_newer_openai_models() {
526        let p = OpenAiProvider::new(None);
527        assert!(p.supports_developer_role("gpt-5.1"));
528        assert!(p.supports_developer_role("gpt-4.1"));
529        assert!(p.supports_developer_role("o3"));
530        assert!(!p.supports_developer_role("gpt-4o"));
531    }
532
533    #[test]
534    fn creates_without_key() {
535        let p = OpenAiProvider::new(None);
536        assert!(p.api_key.is_none());
537    }
538
539    #[test]
540    fn creates_with_empty_key() {
541        let p = OpenAiProvider::new(Some(""));
542        assert_eq!(p.api_key.as_deref(), Some(""));
543    }
544
545    #[tokio::test]
546    async fn chat_fails_without_key() {
547        let p = OpenAiProvider::new(None);
548        let messages = vec![ChatMessage::user("hello")];
549        let request = ChatRequest {
550            messages: &messages,
551            tools: None,
552            native_tools: None,
553        };
554        let result = p.chat(request, "gpt-4o", 0.7).await;
555        assert!(result.is_err());
556        assert!(result.unwrap_err().to_string().contains("API key not set"));
557    }
558
559    #[tokio::test]
560    async fn chat_with_system_fails_without_key() {
561        let p = OpenAiProvider::new(None);
562        let messages = vec![
563            ChatMessage::system("You are Nenjo"),
564            ChatMessage::user("test"),
565        ];
566        let request = ChatRequest {
567            messages: &messages,
568            tools: None,
569            native_tools: None,
570        };
571        let result = p.chat(request, "gpt-4o", 0.5).await;
572        assert!(result.is_err());
573    }
574
575    #[test]
576    fn native_capabilities_include_image_generation() {
577        let p = OpenAiProvider::new(None);
578        let capabilities = NativeCapabilitiesProvider::native_capabilities(&p);
579        assert_eq!(capabilities.provider, "openai");
580        assert!(capabilities.models.iter().any(|model| {
581            model
582                .operations()
583                .any(|op| op == NativeOperation::GenerateImage)
584        }));
585    }
586}