1use crate::ToolSpec;
4use crate::native::{
5 MediaOutputAsset, MediaOutputFormat, ModelNativeCapabilities, NativeCapabilitiesProvider,
6 NativeExecutionMode, NativeMediaRequest, NativeMediaResponse, NativeOperation,
7 NativeToolSpec as NativeMediaToolSpec, ProviderNativeCapabilities,
8};
9use crate::traits::{ChatMessage, ChatRequest, ChatResponse, ModelProvider, TokenUsage, ToolCall};
10use async_trait::async_trait;
11use reqwest::Client;
12use serde::{Deserialize, Serialize};
13use serde_json::{Value, json};
14
15pub struct OpenAiProvider {
16 api_key: Option<String>,
17 client: Client,
18}
19
20#[derive(Debug, Serialize)]
21struct NativeChatRequest {
22 model: String,
23 messages: Vec<NativeMessage>,
24 #[serde(skip_serializing_if = "Option::is_none")]
25 temperature: Option<f64>,
26 #[serde(skip_serializing_if = "Option::is_none")]
27 max_completion_tokens: Option<u32>,
28 #[serde(skip_serializing_if = "Option::is_none")]
29 tools: Option<Vec<NativeToolSpec>>,
30 #[serde(skip_serializing_if = "Option::is_none")]
31 tool_choice: Option<String>,
32}
33
34#[derive(Debug, Serialize)]
35struct NativeMessage {
36 role: String,
37 #[serde(skip_serializing_if = "Option::is_none")]
38 content: Option<String>,
39 #[serde(skip_serializing_if = "Option::is_none")]
40 tool_call_id: Option<String>,
41 #[serde(skip_serializing_if = "Option::is_none")]
42 tool_calls: Option<Vec<NativeToolCall>>,
43}
44
45#[derive(Debug, Serialize)]
46struct NativeToolSpec {
47 #[serde(rename = "type")]
48 kind: String,
49 function: NativeToolFunctionSpec,
50}
51
52#[derive(Debug, Serialize)]
53struct NativeToolFunctionSpec {
54 name: String,
55 description: String,
56 parameters: serde_json::Value,
57}
58
59#[derive(Debug, Serialize, Deserialize)]
60struct NativeToolCall {
61 #[serde(skip_serializing_if = "Option::is_none")]
62 id: Option<String>,
63 #[serde(rename = "type", skip_serializing_if = "Option::is_none")]
64 kind: Option<String>,
65 function: NativeFunctionCall,
66}
67
68#[derive(Debug, Serialize, Deserialize)]
69struct NativeFunctionCall {
70 name: String,
71 arguments: String,
72}
73
74#[derive(Debug, Deserialize)]
75struct NativeUsage {
76 #[serde(default)]
77 prompt_tokens: u64,
78 #[serde(default)]
79 completion_tokens: u64,
80}
81
82#[derive(Debug, Deserialize)]
83struct NativeChatResponse {
84 choices: Vec<NativeChoice>,
85 #[serde(default)]
86 usage: Option<NativeUsage>,
87}
88
89#[derive(Debug, Deserialize)]
90struct NativeChoice {
91 message: NativeResponseMessage,
92}
93
94#[derive(Debug, Deserialize)]
95struct NativeResponseMessage {
96 #[serde(default)]
97 content: Option<String>,
98 #[serde(default)]
99 tool_calls: Option<Vec<NativeToolCall>>,
100}
101
102#[derive(Debug, Serialize)]
103struct ImageGenerationRequest<'a> {
104 model: &'a str,
105 prompt: &'a str,
106 #[serde(skip_serializing_if = "Option::is_none")]
107 n: Option<u32>,
108 #[serde(skip_serializing_if = "Option::is_none")]
109 size: Option<&'a str>,
110 #[serde(skip_serializing_if = "Option::is_none")]
111 response_format: Option<&'static str>,
112 #[serde(skip_serializing_if = "Option::is_none")]
113 background: Option<&'a str>,
114 #[serde(skip_serializing_if = "Option::is_none")]
115 output_format: Option<&'a str>,
116 #[serde(skip_serializing_if = "Option::is_none")]
117 quality: Option<&'a str>,
118}
119
120#[derive(Debug, Deserialize)]
121struct ImageGenerationResponse {
122 data: Vec<ImageGenerationData>,
123}
124
125#[derive(Debug, Deserialize)]
126struct ImageGenerationData {
127 #[serde(default)]
128 url: Option<String>,
129 #[serde(default)]
130 b64_json: Option<String>,
131 #[serde(default)]
132 revised_prompt: Option<String>,
133}
134
135fn provider_option_str<'a>(options: &'a Value, key: &str) -> Option<&'a str> {
136 options.get(key).and_then(Value::as_str)
137}
138
139fn openai_generate_image_tool_spec() -> NativeMediaToolSpec {
140 let capability = NativeOperation::GenerateImage;
141 NativeMediaToolSpec {
142 capability,
143 tool_name: capability.tool_name().unwrap().to_string(),
144 description: "Generate an image with the configured OpenAI image model.".to_string(),
145 execution: NativeExecutionMode::Immediate,
146 parameters_schema: json!({
147 "type": "object",
148 "properties": {
149 "prompt": {"type": "string"},
150 "n": {"type": "integer", "minimum": 1},
151 "size": {
152 "type": "string",
153 "enum": ["1024x1024", "1024x1536", "1536x1024", "auto"]
154 },
155 "output_format": {"type": "string", "enum": ["url", "base64"]},
156 "provider_options": {
157 "type": "object",
158 "properties": {
159 "background": {
160 "type": "string",
161 "enum": ["transparent", "opaque", "auto"]
162 },
163 "output_format": {
164 "type": "string",
165 "enum": ["png", "webp", "jpeg"]
166 },
167 "quality": {
168 "type": "string",
169 "enum": ["low", "medium", "high", "auto"]
170 }
171 },
172 "additionalProperties": false
173 }
174 },
175 "required": ["prompt"]
176 }),
177 }
178}
179
180fn image_mime_type(output_format: Option<&str>) -> String {
181 match output_format {
182 Some("jpeg") => "image/jpeg",
183 Some("webp") => "image/webp",
184 _ => "image/png",
185 }
186 .to_string()
187}
188
189impl OpenAiProvider {
190 pub fn new(api_key: Option<&str>) -> Self {
191 Self {
192 api_key: api_key.map(ToString::to_string),
193 client: Client::builder()
194 .timeout(std::time::Duration::from_secs(120))
195 .connect_timeout(std::time::Duration::from_secs(10))
196 .build()
197 .unwrap_or_else(|_| Client::new()),
198 }
199 }
200
201 fn is_reasoning_model(model: &str) -> bool {
202 let m = model.to_lowercase();
203 m.starts_with("o1") || m.starts_with("o3") || m.starts_with("o4")
204 }
205
206 fn is_developer_role_model(model: &str) -> bool {
207 let m = model.to_lowercase();
208 Self::is_reasoning_model(&m)
209 || m.starts_with("gpt-5")
210 || m.starts_with("gpt-4.5")
211 || m.starts_with("gpt-4.1")
212 }
213
214 fn convert_tools(tools: Option<&[ToolSpec]>) -> Option<Vec<NativeToolSpec>> {
215 tools.map(|items| {
216 items
217 .iter()
218 .map(|tool| NativeToolSpec {
219 kind: "function".to_string(),
220 function: NativeToolFunctionSpec {
221 name: crate::sanitize_tool_name(&tool.name),
222 description: tool.description.clone(),
223 parameters: tool.parameters.clone(),
224 },
225 })
226 .collect()
227 })
228 }
229
230 fn convert_messages(messages: &[ChatMessage]) -> Vec<NativeMessage> {
231 messages
232 .iter()
233 .map(|m| {
234 if m.role == "assistant"
235 && let Ok(value) = serde_json::from_str::<serde_json::Value>(&m.content)
236 && let Some(tool_calls_value) = value.get("tool_calls")
237 && let Ok(parsed_calls) =
238 serde_json::from_value::<Vec<ToolCall>>(tool_calls_value.clone())
239 {
240 let tool_calls = parsed_calls
241 .into_iter()
242 .map(|tc| NativeToolCall {
243 id: Some(tc.id),
244 kind: Some("function".to_string()),
245 function: NativeFunctionCall {
246 name: tc.name,
247 arguments: tc.arguments,
248 },
249 })
250 .collect::<Vec<_>>();
251 let content = value
252 .get("content")
253 .and_then(serde_json::Value::as_str)
254 .map(ToString::to_string);
255 return NativeMessage {
256 role: "assistant".to_string(),
257 content,
258 tool_call_id: None,
259 tool_calls: Some(tool_calls),
260 };
261 }
262
263 if m.role == "tool"
264 && let Ok(value) = serde_json::from_str::<serde_json::Value>(&m.content)
265 {
266 let tool_call_id = value
267 .get("tool_call_id")
268 .and_then(serde_json::Value::as_str)
269 .map(ToString::to_string);
270 let content = value
271 .get("content")
272 .and_then(serde_json::Value::as_str)
273 .map(ToString::to_string);
274 return NativeMessage {
275 role: "tool".to_string(),
276 content,
277 tool_call_id,
278 tool_calls: None,
279 };
280 }
281
282 NativeMessage {
283 role: m.role.clone(),
284 content: Some(m.content.clone()),
285 tool_call_id: None,
286 tool_calls: None,
287 }
288 })
289 .collect()
290 }
291
292 fn parse_native_response(message: NativeResponseMessage) -> ChatResponse {
293 let tool_calls = message
294 .tool_calls
295 .unwrap_or_default()
296 .into_iter()
297 .map(|tc| ToolCall {
298 id: tc.id.unwrap_or_else(|| uuid::Uuid::new_v4().to_string()),
299 name: tc.function.name,
300 arguments: tc.function.arguments,
301 })
302 .collect::<Vec<_>>();
303
304 ChatResponse {
305 text: message.content,
306 tool_calls,
307 provider_tool_calls: vec![],
308 usage: TokenUsage::default(),
309 }
310 }
311
312 async fn generate_image(
313 &self,
314 request: crate::native::GenerateImageRequest,
315 ) -> anyhow::Result<NativeMediaResponse> {
316 let api_key = self.api_key.as_ref().ok_or_else(|| {
317 anyhow::anyhow!("OpenAI API key not set. Set OPENAI_API_KEY or edit config.toml.")
318 })?;
319
320 let response_format = match request.output_format {
321 MediaOutputFormat::Url => None,
322 MediaOutputFormat::Base64 => Some("b64_json"),
323 };
324 let body = ImageGenerationRequest {
325 model: &request.model,
326 prompt: &request.prompt,
327 n: request.n,
328 size: request.size.as_deref(),
329 response_format,
330 background: provider_option_str(&request.provider_options, "background"),
331 output_format: provider_option_str(&request.provider_options, "output_format"),
332 quality: provider_option_str(&request.provider_options, "quality"),
333 };
334 let mime_type = image_mime_type(body.output_format);
335
336 let response = self
337 .client
338 .post("https://api.openai.com/v1/images/generations")
339 .header("Authorization", format!("Bearer {api_key}"))
340 .json(&body)
341 .send()
342 .await?;
343
344 if !response.status().is_success() {
345 return Err(crate::api_error("OpenAI", response).await);
346 }
347
348 let images: ImageGenerationResponse = response.json().await?;
349 let mut assets = Vec::new();
350 let mut revised_prompts = Vec::new();
351
352 for image in images.data {
353 if let Some(prompt) = image.revised_prompt {
354 revised_prompts.push(prompt);
355 }
356 if let Some(url) = image.url {
357 assets.push(MediaOutputAsset::Url {
358 url,
359 mime_type: Some(mime_type.clone()),
360 });
361 } else if let Some(data) = image.b64_json {
362 assets.push(MediaOutputAsset::Base64 {
363 data,
364 mime_type: Some(mime_type.clone()),
365 });
366 }
367 }
368
369 if assets.is_empty() {
370 anyhow::bail!("OpenAI image generation returned no assets");
371 }
372
373 let metadata = if revised_prompts.is_empty() {
374 None
375 } else {
376 Some(serde_json::json!({ "revised_prompts": revised_prompts }))
377 };
378
379 Ok(NativeMediaResponse::Assets { assets, metadata })
380 }
381}
382
383#[async_trait]
384impl ModelProvider for OpenAiProvider {
385 async fn chat(
386 &self,
387 request: ChatRequest<'_>,
388 model: &str,
389 temperature: f64,
390 ) -> anyhow::Result<ChatResponse> {
391 let api_key = self.api_key.as_ref().ok_or_else(|| {
392 anyhow::anyhow!("OpenAI API key not set. Set OPENAI_API_KEY or edit config.toml.")
393 })?;
394
395 let is_reasoning = Self::is_reasoning_model(model);
396 let tools = Self::convert_tools(request.tools);
397 let native_request = NativeChatRequest {
398 model: model.to_string(),
399 messages: Self::convert_messages(request.messages),
400 temperature: if is_reasoning {
402 None
403 } else {
404 Some(temperature)
405 },
406 max_completion_tokens: Some(if is_reasoning { 65536 } else { 16384 }),
407 tool_choice: tools.as_ref().map(|_| "auto".to_string()),
408 tools,
409 };
410
411 let response = self
412 .client
413 .post("https://api.openai.com/v1/chat/completions")
414 .header("Authorization", format!("Bearer {api_key}"))
415 .json(&native_request)
416 .send()
417 .await?;
418
419 if !response.status().is_success() {
420 return Err(crate::api_error("OpenAI", response).await);
421 }
422
423 let native_response: NativeChatResponse = response.json().await?;
424 let usage = native_response
425 .usage
426 .map(|u| TokenUsage {
427 input_tokens: u.prompt_tokens,
428 output_tokens: u.completion_tokens,
429 })
430 .unwrap_or_default();
431 let message = native_response
432 .choices
433 .into_iter()
434 .next()
435 .map(|c| c.message)
436 .ok_or_else(|| anyhow::anyhow!("No response from OpenAI"))?;
437 let mut result = Self::parse_native_response(message);
438 result.usage = usage;
439 Ok(result)
440 }
441
442 fn context_window(&self, model: &str) -> Option<usize> {
443 let m = model.to_lowercase();
444 Some(if m.contains("gpt-5") {
445 1_000_000
447 } else if m.contains("o1") || m.contains("o3") || m.contains("o4") {
448 200_000
450 } else if m.contains("gpt-4o") {
451 128_000
453 } else {
454 128_000
455 })
456 }
457
458 fn supports_native_tools(&self) -> bool {
459 true
460 }
461
462 fn supports_developer_role(&self, model: &str) -> bool {
463 Self::is_developer_role_model(model)
464 }
465
466 fn native_capabilities(&self) -> Option<ProviderNativeCapabilities> {
467 Some(NativeCapabilitiesProvider::native_capabilities(self))
468 }
469
470 async fn submit_media(
471 &self,
472 request: NativeMediaRequest,
473 ) -> anyhow::Result<NativeMediaResponse> {
474 NativeCapabilitiesProvider::submit_media(self, request).await
475 }
476}
477
478#[async_trait]
479impl NativeCapabilitiesProvider for OpenAiProvider {
480 fn native_capabilities(&self) -> ProviderNativeCapabilities {
481 ProviderNativeCapabilities {
482 provider: "openai".to_string(),
483 model_tools: Vec::new(),
484 models: vec![ModelNativeCapabilities {
485 model_pattern: "gpt-image-*".to_string(),
486 tools: vec![openai_generate_image_tool_spec()],
487 }],
488 }
489 }
490
491 async fn submit_media(
492 &self,
493 request: NativeMediaRequest,
494 ) -> anyhow::Result<NativeMediaResponse> {
495 let operation = request.operation();
496 match request {
497 NativeMediaRequest::GenerateImage(request) => self.generate_image(request).await,
498 NativeMediaRequest::EditImage(_)
499 | NativeMediaRequest::GenerateVideo(_)
500 | NativeMediaRequest::EditVideo(_)
501 | NativeMediaRequest::ImageToVideo(_)
502 | NativeMediaRequest::ReferenceToVideo(_)
503 | NativeMediaRequest::ExtendVideo(_)
504 | NativeMediaRequest::GenerateSpeech(_)
505 | NativeMediaRequest::TranscribeAudio(_) => {
506 anyhow::bail!(
507 "OpenAI native operation {operation:?} is declared but not implemented in this pass"
508 )
509 }
510 }
511 }
512}
513
514#[cfg(test)]
515mod tests {
516 use super::*;
517
518 #[test]
519 fn creates_with_key() {
520 let p = OpenAiProvider::new(Some("sk-proj-abc123"));
521 assert_eq!(p.api_key.as_deref(), Some("sk-proj-abc123"));
522 }
523
524 #[test]
525 fn developer_role_supported_for_newer_openai_models() {
526 let p = OpenAiProvider::new(None);
527 assert!(p.supports_developer_role("gpt-5.1"));
528 assert!(p.supports_developer_role("gpt-4.1"));
529 assert!(p.supports_developer_role("o3"));
530 assert!(!p.supports_developer_role("gpt-4o"));
531 }
532
533 #[test]
534 fn creates_without_key() {
535 let p = OpenAiProvider::new(None);
536 assert!(p.api_key.is_none());
537 }
538
539 #[test]
540 fn creates_with_empty_key() {
541 let p = OpenAiProvider::new(Some(""));
542 assert_eq!(p.api_key.as_deref(), Some(""));
543 }
544
545 #[tokio::test]
546 async fn chat_fails_without_key() {
547 let p = OpenAiProvider::new(None);
548 let messages = vec![ChatMessage::user("hello")];
549 let request = ChatRequest {
550 messages: &messages,
551 tools: None,
552 native_tools: None,
553 };
554 let result = p.chat(request, "gpt-4o", 0.7).await;
555 assert!(result.is_err());
556 assert!(result.unwrap_err().to_string().contains("API key not set"));
557 }
558
559 #[tokio::test]
560 async fn chat_with_system_fails_without_key() {
561 let p = OpenAiProvider::new(None);
562 let messages = vec![
563 ChatMessage::system("You are Nenjo"),
564 ChatMessage::user("test"),
565 ];
566 let request = ChatRequest {
567 messages: &messages,
568 tools: None,
569 native_tools: None,
570 };
571 let result = p.chat(request, "gpt-4o", 0.5).await;
572 assert!(result.is_err());
573 }
574
575 #[test]
576 fn native_capabilities_include_image_generation() {
577 let p = OpenAiProvider::new(None);
578 let capabilities = NativeCapabilitiesProvider::native_capabilities(&p);
579 assert_eq!(capabilities.provider, "openai");
580 assert!(capabilities.models.iter().any(|model| {
581 model
582 .operations()
583 .any(|op| op == NativeOperation::GenerateImage)
584 }));
585 }
586}