ai_sdk_openai/image/
model.rs1use ai_sdk_core::JsonValue;
2use ai_sdk_provider::{
3 image_model, ImageCallWarning, ImageData, ImageGenerateOptions, ImageGenerateResponse,
4 ImageModel, ImageProviderMetadata, JsonObject, Result,
5};
6use ai_sdk_provider_utils::merge_headers_reqwest;
7use async_trait::async_trait;
8use reqwest::Client;
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11
12use crate::openai_config::{OpenAIConfig, OpenAIUrlOptions};
13
14pub struct OpenAIImageModel {
16 model_id: String,
17 client: Client,
18 config: OpenAIConfig,
19}
20
21impl OpenAIImageModel {
22 pub fn new(model_id: impl Into<String>, config: impl Into<OpenAIConfig>) -> Self {
24 Self {
25 model_id: model_id.into(),
26 client: Client::new(),
27 config: config.into(),
28 }
29 }
30
31 fn has_default_response_format(&self) -> bool {
33 matches!(self.model_id.as_str(), "gpt-image-1" | "gpt-image-1-mini")
34 }
35}
36
37#[derive(Serialize)]
38struct ImageRequest {
39 model: String,
40 prompt: String,
41 #[serde(skip_serializing_if = "Option::is_none")]
42 n: Option<usize>,
43 #[serde(skip_serializing_if = "Option::is_none")]
44 size: Option<String>,
45 #[serde(skip_serializing_if = "Option::is_none")]
46 response_format: Option<String>,
47}
48
49#[derive(Deserialize, Debug)]
50struct ImageApiResponse {
51 data: Vec<ImageResponseData>,
52}
53
54#[derive(Deserialize, Debug)]
55struct ImageResponseData {
56 b64_json: String,
57 #[serde(default)]
58 revised_prompt: Option<String>,
59}
60
61#[async_trait]
62impl ImageModel for OpenAIImageModel {
63 fn provider(&self) -> &str {
64 "openai"
65 }
66
67 fn model_id(&self) -> &str {
68 &self.model_id
69 }
70
71 async fn max_images_per_call(&self) -> Option<usize> {
72 match self.model_id.as_str() {
73 "dall-e-3" => Some(1),
74 "dall-e-2" | "gpt-image-1" | "gpt-image-1-mini" => Some(10),
75 _ => Some(1),
76 }
77 }
78
79 async fn do_generate(&self, options: ImageGenerateOptions) -> Result<ImageGenerateResponse> {
80 let mut warnings = Vec::new();
81
82 if options.aspect_ratio.is_some() {
84 warnings.push(ImageCallWarning::UnsupportedSetting {
85 setting: "aspectRatio".into(),
86 details: Some(
87 "This model does not support aspect ratio. Use `size` instead.".into(),
88 ),
89 });
90 }
91
92 if options.seed.is_some() {
93 warnings.push(ImageCallWarning::UnsupportedSetting {
94 setting: "seed".into(),
95 details: None,
96 });
97 }
98
99 let url = (self.config.url)(OpenAIUrlOptions {
101 model_id: self.model_id.clone(),
102 path: "/images/generations".into(),
103 });
104
105 let request_body = ImageRequest {
107 model: self.model_id.clone(),
108 prompt: options.prompt,
109 n: options.n,
110 size: options.size,
111 response_format: if !self.has_default_response_format() {
112 Some("b64_json".into())
113 } else {
114 None
115 },
116 };
117
118 let response = self
119 .client
120 .post(&url)
121 .header("Content-Type", "application/json")
122 .headers(merge_headers_reqwest(
123 (self.config.headers)(),
124 options.headers.as_ref(),
125 ))
126 .json(&request_body)
127 .send()
128 .await?;
129
130 let status = response.status();
131 let response_headers: HashMap<String, String> = response
132 .headers()
133 .iter()
134 .map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string()))
135 .collect();
136
137 if !status.is_success() {
138 let error_text = response.text().await?;
139 return Err(format!("API error {}: {}", status, error_text).into());
140 }
141
142 let api_response: ImageApiResponse = response.json().await?;
143
144 let mut openai_metadata = JsonObject::new();
146 let images_metadata: Vec<_> = api_response
147 .data
148 .iter()
149 .map(|d| {
150 d.revised_prompt.as_ref().map(|p| {
151 let mut map = JsonObject::new();
152 map.insert("revisedPrompt".to_string(), JsonValue::String(p.clone()));
153 JsonValue::Object(map)
154 })
155 })
156 .map(|opt| opt.unwrap_or(JsonValue::Null))
157 .collect();
158
159 openai_metadata.insert("images".to_string(), JsonValue::Array(images_metadata));
160
161 let mut provider_metadata = HashMap::new();
162 provider_metadata.insert("openai".to_string(), openai_metadata);
163
164 Ok(ImageGenerateResponse {
165 images: api_response
166 .data
167 .into_iter()
168 .map(|d| ImageData::Base64(d.b64_json))
169 .collect(),
170 warnings,
171 provider_metadata: Some(ImageProviderMetadata {
172 metadata: provider_metadata,
173 }),
174 response: image_model::ResponseInfo {
175 timestamp: std::time::SystemTime::now(),
176 model_id: self.model_id.clone(),
177 headers: Some(response_headers),
178 },
179 })
180 }
181}