ai_sdk_openai/image/
model.rs

1use ai_sdk_core::JsonValue;
2use ai_sdk_provider::{
3    image_model, ImageCallWarning, ImageData, ImageGenerateOptions, ImageGenerateResponse,
4    ImageModel, ImageProviderMetadata, JsonObject, Result,
5};
6use ai_sdk_provider_utils::merge_headers_reqwest;
7use async_trait::async_trait;
8use reqwest::Client;
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11
12use crate::openai_config::{OpenAIConfig, OpenAIUrlOptions};
13
14/// OpenAI implementation of image model.
15pub struct OpenAIImageModel {
16    model_id: String,
17    client: Client,
18    config: OpenAIConfig,
19}
20
21impl OpenAIImageModel {
22    /// Creates a new image model with the specified model ID and API key.
23    pub fn new(model_id: impl Into<String>, config: impl Into<OpenAIConfig>) -> Self {
24        Self {
25            model_id: model_id.into(),
26            client: Client::new(),
27            config: config.into(),
28        }
29    }
30
31    /// Check if this model has a default response format
32    fn has_default_response_format(&self) -> bool {
33        matches!(self.model_id.as_str(), "gpt-image-1" | "gpt-image-1-mini")
34    }
35}
36
37#[derive(Serialize)]
38struct ImageRequest {
39    model: String,
40    prompt: String,
41    #[serde(skip_serializing_if = "Option::is_none")]
42    n: Option<usize>,
43    #[serde(skip_serializing_if = "Option::is_none")]
44    size: Option<String>,
45    #[serde(skip_serializing_if = "Option::is_none")]
46    response_format: Option<String>,
47}
48
49#[derive(Deserialize, Debug)]
50struct ImageApiResponse {
51    data: Vec<ImageResponseData>,
52}
53
54#[derive(Deserialize, Debug)]
55struct ImageResponseData {
56    b64_json: String,
57    #[serde(default)]
58    revised_prompt: Option<String>,
59}
60
61#[async_trait]
62impl ImageModel for OpenAIImageModel {
63    fn provider(&self) -> &str {
64        "openai"
65    }
66
67    fn model_id(&self) -> &str {
68        &self.model_id
69    }
70
71    async fn max_images_per_call(&self) -> Option<usize> {
72        match self.model_id.as_str() {
73            "dall-e-3" => Some(1),
74            "dall-e-2" | "gpt-image-1" | "gpt-image-1-mini" => Some(10),
75            _ => Some(1),
76        }
77    }
78
79    async fn do_generate(&self, options: ImageGenerateOptions) -> Result<ImageGenerateResponse> {
80        let mut warnings = Vec::new();
81
82        // Check unsupported settings
83        if options.aspect_ratio.is_some() {
84            warnings.push(ImageCallWarning::UnsupportedSetting {
85                setting: "aspectRatio".into(),
86                details: Some(
87                    "This model does not support aspect ratio. Use `size` instead.".into(),
88                ),
89            });
90        }
91
92        if options.seed.is_some() {
93            warnings.push(ImageCallWarning::UnsupportedSetting {
94                setting: "seed".into(),
95                details: None,
96            });
97        }
98
99        // let url = format!("{}/images/generations", (self.config.url)());
100        let url = (self.config.url)(OpenAIUrlOptions {
101            model_id: self.model_id.clone(),
102            path: "/images/generations".into(),
103        });
104
105        // Build request body
106        let request_body = ImageRequest {
107            model: self.model_id.clone(),
108            prompt: options.prompt,
109            n: options.n,
110            size: options.size,
111            response_format: if !self.has_default_response_format() {
112                Some("b64_json".into())
113            } else {
114                None
115            },
116        };
117
118        let response = self
119            .client
120            .post(&url)
121            .header("Content-Type", "application/json")
122            .headers(merge_headers_reqwest(
123                (self.config.headers)(),
124                options.headers.as_ref(),
125            ))
126            .json(&request_body)
127            .send()
128            .await?;
129
130        let status = response.status();
131        let response_headers: HashMap<String, String> = response
132            .headers()
133            .iter()
134            .map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string()))
135            .collect();
136
137        if !status.is_success() {
138            let error_text = response.text().await?;
139            return Err(format!("API error {}: {}", status, error_text).into());
140        }
141
142        let api_response: ImageApiResponse = response.json().await?;
143
144        // Build provider metadata
145        let mut openai_metadata = JsonObject::new();
146        let images_metadata: Vec<_> = api_response
147            .data
148            .iter()
149            .map(|d| {
150                d.revised_prompt.as_ref().map(|p| {
151                    let mut map = JsonObject::new();
152                    map.insert("revisedPrompt".to_string(), JsonValue::String(p.clone()));
153                    JsonValue::Object(map)
154                })
155            })
156            .map(|opt| opt.unwrap_or(JsonValue::Null))
157            .collect();
158
159        openai_metadata.insert("images".to_string(), JsonValue::Array(images_metadata));
160
161        let mut provider_metadata = HashMap::new();
162        provider_metadata.insert("openai".to_string(), openai_metadata);
163
164        Ok(ImageGenerateResponse {
165            images: api_response
166                .data
167                .into_iter()
168                .map(|d| ImageData::Base64(d.b64_json))
169                .collect(),
170            warnings,
171            provider_metadata: Some(ImageProviderMetadata {
172                metadata: provider_metadata,
173            }),
174            response: image_model::ResponseInfo {
175                timestamp: std::time::SystemTime::now(),
176                model_id: self.model_id.clone(),
177                headers: Some(response_headers),
178            },
179        })
180    }
181}