1use async_trait::async_trait;
7use serde::{Deserialize, Serialize};
8
9use crate::error::PunchResult;
10
11#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
13#[serde(rename_all = "snake_case")]
14pub enum ImageStyle {
15 Natural,
17 Vivid,
19 Anime,
21 Photographic,
23 DigitalArt,
25 ComicBook,
27}
28
29#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
31#[serde(rename_all = "lowercase")]
32pub enum ImageFormat {
33 Png,
35 Jpeg,
37 Webp,
39}
40
41#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct ImageGenRequest {
44 pub prompt: String,
46 #[serde(default = "default_dimension")]
48 pub width: u32,
49 #[serde(default = "default_dimension")]
51 pub height: u32,
52 pub model: Option<String>,
54 pub style: Option<ImageStyle>,
56 pub negative_prompt: Option<String>,
58 pub seed: Option<u64>,
60}
61
62fn default_dimension() -> u32 {
63 1024
64}
65
66impl ImageGenRequest {
67 pub fn new(prompt: impl Into<String>) -> Self {
69 Self {
70 prompt: prompt.into(),
71 width: 1024,
72 height: 1024,
73 model: None,
74 style: None,
75 negative_prompt: None,
76 seed: None,
77 }
78 }
79
80 pub fn with_dimensions(mut self, width: u32, height: u32) -> Self {
82 self.width = width;
83 self.height = height;
84 self
85 }
86
87 pub fn with_style(mut self, style: ImageStyle) -> Self {
89 self.style = Some(style);
90 self
91 }
92
93 pub fn with_model(mut self, model: impl Into<String>) -> Self {
95 self.model = Some(model.into());
96 self
97 }
98
99 pub fn with_negative_prompt(mut self, negative_prompt: impl Into<String>) -> Self {
101 self.negative_prompt = Some(negative_prompt.into());
102 self
103 }
104
105 pub fn with_seed(mut self, seed: u64) -> Self {
107 self.seed = Some(seed);
108 self
109 }
110}
111
112#[derive(Debug, Clone, Serialize, Deserialize)]
114pub struct ImageGenResult {
115 pub image_data: String,
117 pub format: ImageFormat,
119 pub revised_prompt: Option<String>,
121 pub model_used: String,
123 pub generation_ms: u64,
125}
126
127#[async_trait]
129pub trait ImageGenerator: Send + Sync {
130 async fn generate(&self, request: ImageGenRequest) -> PunchResult<ImageGenResult>;
132
133 fn supported_models(&self) -> Vec<String>;
135}
136
137#[cfg(test)]
138mod tests {
139 use super::*;
140
141 #[test]
142 fn test_request_creation() {
143 let req = ImageGenRequest::new("a fierce warrior")
144 .with_style(ImageStyle::Vivid)
145 .with_model("dall-e-3")
146 .with_negative_prompt("blurry")
147 .with_seed(42);
148
149 assert_eq!(req.prompt, "a fierce warrior");
150 assert_eq!(req.style, Some(ImageStyle::Vivid));
151 assert_eq!(req.model, Some("dall-e-3".to_string()));
152 assert_eq!(req.negative_prompt, Some("blurry".to_string()));
153 assert_eq!(req.seed, Some(42));
154 }
155
156 #[test]
157 fn test_style_serialization() {
158 let style = ImageStyle::DigitalArt;
159 let json = serde_json::to_string(&style).expect("serialize style");
160 assert_eq!(json, "\"digital_art\"");
161
162 let deserialized: ImageStyle = serde_json::from_str(&json).expect("deserialize style");
163 assert_eq!(deserialized, ImageStyle::DigitalArt);
164 }
165
166 #[test]
167 fn test_result_construction() {
168 let result = ImageGenResult {
169 image_data: "iVBORw0KGgo=".to_string(),
170 format: ImageFormat::Png,
171 revised_prompt: Some("a fierce warrior in battle".to_string()),
172 model_used: "dall-e-3".to_string(),
173 generation_ms: 2500,
174 };
175
176 assert_eq!(result.model_used, "dall-e-3");
177 assert_eq!(result.generation_ms, 2500);
178 assert!(result.revised_prompt.is_some());
179 }
180
181 #[test]
182 fn test_default_dimensions() {
183 let req = ImageGenRequest::new("test");
184 assert_eq!(req.width, 1024);
185 assert_eq!(req.height, 1024);
186
187 let req = req.with_dimensions(512, 768);
188 assert_eq!(req.width, 512);
189 assert_eq!(req.height, 768);
190 }
191
192 #[test]
193 fn test_format_variants() {
194 let formats = vec![ImageFormat::Png, ImageFormat::Jpeg, ImageFormat::Webp];
195 for fmt in &formats {
196 let json = serde_json::to_string(fmt).expect("serialize format");
197 let deserialized: ImageFormat =
198 serde_json::from_str(&json).expect("deserialize format");
199 assert_eq!(&deserialized, fmt);
200 }
201
202 assert_eq!(
203 serde_json::to_string(&ImageFormat::Png).expect("png"),
204 "\"png\""
205 );
206 assert_eq!(
207 serde_json::to_string(&ImageFormat::Jpeg).expect("jpeg"),
208 "\"jpeg\""
209 );
210 assert_eq!(
211 serde_json::to_string(&ImageFormat::Webp).expect("webp"),
212 "\"webp\""
213 );
214 }
215}