Skip to main content

punch_types/
image_gen.rs

1//! # Image Generation — forging visual strikes from text commands.
2//!
3//! This module provides types and traits for generating images from prompt descriptions,
4//! allowing fighters to conjure visual attacks on demand.
5
6use async_trait::async_trait;
7use serde::{Deserialize, Serialize};
8
9use crate::error::PunchResult;
10
11/// Style presets for image generation — the fighting stance of the visual output.
12#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
13#[serde(rename_all = "snake_case")]
14pub enum ImageStyle {
15    /// Realistic, natural-looking imagery.
16    Natural,
17    /// High-contrast, saturated, dramatic imagery.
18    Vivid,
19    /// Anime/manga-inspired art style.
20    Anime,
21    /// Photo-realistic rendering.
22    Photographic,
23    /// Digital art style with clean lines.
24    DigitalArt,
25    /// Comic book panels and halftone dots.
26    ComicBook,
27}
28
29/// Output format for generated images — the weapon's material form.
30#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
31#[serde(rename_all = "lowercase")]
32pub enum ImageFormat {
33    /// PNG format (lossless).
34    Png,
35    /// JPEG format (lossy, smaller size).
36    Jpeg,
37    /// WebP format (modern, efficient).
38    Webp,
39}
40
41/// A request to generate an image — the battle orders for visual creation.
42#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct ImageGenRequest {
44    /// The prompt describing the desired image.
45    pub prompt: String,
46    /// Width in pixels (default: 1024).
47    #[serde(default = "default_dimension")]
48    pub width: u32,
49    /// Height in pixels (default: 1024).
50    #[serde(default = "default_dimension")]
51    pub height: u32,
52    /// Specific model to use for generation.
53    pub model: Option<String>,
54    /// Visual style preset.
55    pub style: Option<ImageStyle>,
56    /// Negative prompt — what to avoid in the image.
57    pub negative_prompt: Option<String>,
58    /// Seed for reproducible generation.
59    pub seed: Option<u64>,
60}
61
62fn default_dimension() -> u32 {
63    1024
64}
65
66impl ImageGenRequest {
67    /// Create a new image generation request with default dimensions.
68    pub fn new(prompt: impl Into<String>) -> Self {
69        Self {
70            prompt: prompt.into(),
71            width: 1024,
72            height: 1024,
73            model: None,
74            style: None,
75            negative_prompt: None,
76            seed: None,
77        }
78    }
79
80    /// Set the image dimensions.
81    pub fn with_dimensions(mut self, width: u32, height: u32) -> Self {
82        self.width = width;
83        self.height = height;
84        self
85    }
86
87    /// Set the style preset.
88    pub fn with_style(mut self, style: ImageStyle) -> Self {
89        self.style = Some(style);
90        self
91    }
92
93    /// Set the model to use.
94    pub fn with_model(mut self, model: impl Into<String>) -> Self {
95        self.model = Some(model.into());
96        self
97    }
98
99    /// Set the negative prompt.
100    pub fn with_negative_prompt(mut self, negative_prompt: impl Into<String>) -> Self {
101        self.negative_prompt = Some(negative_prompt.into());
102        self
103    }
104
105    /// Set the seed for reproducible results.
106    pub fn with_seed(mut self, seed: u64) -> Self {
107        self.seed = Some(seed);
108        self
109    }
110}
111
112/// The result of an image generation — the visual strike delivered.
113#[derive(Debug, Clone, Serialize, Deserialize)]
114pub struct ImageGenResult {
115    /// Base64-encoded image data.
116    pub image_data: String,
117    /// Format of the generated image.
118    pub format: ImageFormat,
119    /// The prompt as revised/interpreted by the model.
120    pub revised_prompt: Option<String>,
121    /// Which model produced this image.
122    pub model_used: String,
123    /// Time taken to generate in milliseconds.
124    pub generation_ms: u64,
125}
126
127/// Trait for image generation backends — the forge that creates visual weapons.
128#[async_trait]
129pub trait ImageGenerator: Send + Sync {
130    /// Generate an image from the given request.
131    async fn generate(&self, request: ImageGenRequest) -> PunchResult<ImageGenResult>;
132
133    /// Return the list of models supported by this generator.
134    fn supported_models(&self) -> Vec<String>;
135}
136
137#[cfg(test)]
138mod tests {
139    use super::*;
140
141    #[test]
142    fn test_request_creation() {
143        let req = ImageGenRequest::new("a fierce warrior")
144            .with_style(ImageStyle::Vivid)
145            .with_model("dall-e-3")
146            .with_negative_prompt("blurry")
147            .with_seed(42);
148
149        assert_eq!(req.prompt, "a fierce warrior");
150        assert_eq!(req.style, Some(ImageStyle::Vivid));
151        assert_eq!(req.model, Some("dall-e-3".to_string()));
152        assert_eq!(req.negative_prompt, Some("blurry".to_string()));
153        assert_eq!(req.seed, Some(42));
154    }
155
156    #[test]
157    fn test_style_serialization() {
158        let style = ImageStyle::DigitalArt;
159        let json = serde_json::to_string(&style).expect("serialize style");
160        assert_eq!(json, "\"digital_art\"");
161
162        let deserialized: ImageStyle = serde_json::from_str(&json).expect("deserialize style");
163        assert_eq!(deserialized, ImageStyle::DigitalArt);
164    }
165
166    #[test]
167    fn test_result_construction() {
168        let result = ImageGenResult {
169            image_data: "iVBORw0KGgo=".to_string(),
170            format: ImageFormat::Png,
171            revised_prompt: Some("a fierce warrior in battle".to_string()),
172            model_used: "dall-e-3".to_string(),
173            generation_ms: 2500,
174        };
175
176        assert_eq!(result.model_used, "dall-e-3");
177        assert_eq!(result.generation_ms, 2500);
178        assert!(result.revised_prompt.is_some());
179    }
180
181    #[test]
182    fn test_default_dimensions() {
183        let req = ImageGenRequest::new("test");
184        assert_eq!(req.width, 1024);
185        assert_eq!(req.height, 1024);
186
187        let req = req.with_dimensions(512, 768);
188        assert_eq!(req.width, 512);
189        assert_eq!(req.height, 768);
190    }
191
192    #[test]
193    fn test_format_variants() {
194        let formats = vec![ImageFormat::Png, ImageFormat::Jpeg, ImageFormat::Webp];
195        for fmt in &formats {
196            let json = serde_json::to_string(fmt).expect("serialize format");
197            let deserialized: ImageFormat =
198                serde_json::from_str(&json).expect("deserialize format");
199            assert_eq!(&deserialized, fmt);
200        }
201
202        assert_eq!(
203            serde_json::to_string(&ImageFormat::Png).expect("png"),
204            "\"png\""
205        );
206        assert_eq!(
207            serde_json::to_string(&ImageFormat::Jpeg).expect("jpeg"),
208            "\"jpeg\""
209        );
210        assert_eq!(
211            serde_json::to_string(&ImageFormat::Webp).expect("webp"),
212            "\"webp\""
213        );
214    }
215}