stability_rs/api/rest/generation/
mod.rs

1pub mod text_to_img;
2pub mod img_to_img;
3pub mod upscale;
4pub mod masking;
5
6use serde::{Serialize, Deserialize};
7use std::collections::HashMap;
8use base64::{engine::general_purpose, Engine as _};
9use crate::prelude::*;
10use crate::error::*;
11use crate::api::rest::client::*;
12use rand::Rng;
13use std::io::{Read, Write};
14use std::fs::File;
15use std::{fmt, io};
16
17
18const GENERATION_PATH: &str = "/generation";
19pub const MULTIPART_FORM_DATA_BOUNDARY: &str = "multipart/form-data; boundary=";
20
21
22#[derive(Debug, Deserialize, Serialize)]
23pub struct Image {
24    pub base64: String,
25    #[serde(rename = "finishReason")]
26    pub finish_reason: String,
27    pub seed: u32,
28}
29
30impl Image {
31    pub async fn save(&self, path: &str) -> Result<()> {
32        let mut png_file = tokio::fs::File::create(path).await?;
33        let mut buffer: Vec<u8> = Vec::new();
34        let _dec = general_purpose::STANDARD.decode_vec(&self.base64, &mut buffer)?;
35        png_file.write_all(buffer.as_mut_slice()).await?;
36        Ok(())
37    }
38}
39
40#[derive(Debug, Deserialize, Serialize)]
41pub struct ImageResponse {
42    pub artifacts: Vec<Image>,
43}
44
45
46    #[derive(Debug, Deserialize, Serialize)]
47    struct TextPrompt {
48        text: String,
49        weight: f32,
50    }
51
52    #[derive(Debug, Deserialize, Serialize)]
53    #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
54    pub enum ClipGuidancePreset {
55        FastBlue,
56        FastGreen,
57        Simple,
58        Slow,
59        Slower,
60        Slowest,
61        None,
62    }
63
64impl fmt::Display for ClipGuidancePreset {
65    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
66        match self {
67            ClipGuidancePreset::FastBlue => write!(f, "fast_blue"),
68            ClipGuidancePreset::FastGreen => write!(f, "fast_green"),
69            ClipGuidancePreset::Simple => write!(f, "simple"),
70            ClipGuidancePreset::Slow => write!(f, "slow"),
71            ClipGuidancePreset::Slower => write!(f, "slower"),
72            ClipGuidancePreset::Slowest => write!(f, "slowest"),
73            ClipGuidancePreset::None => write!(f, "none"),
74        }
75    }
76}
77
78    impl ClipGuidancePreset {
79        pub fn is_none(&self) -> bool {
80            match self {
81                ClipGuidancePreset::None => true,
82                _ => false,
83            }
84        }
85
86}
87
88    #[derive(Debug, Deserialize, Serialize)]
89    #[serde(rename_all = "kebab-case")]
90    pub enum StylePreset {
91        #[serde(rename = "3d-model")]
92        ThreeDModel,
93        Anime,
94        AnalogFilm,
95        Cinematic,
96        ComicBook,
97        DigitalArt,
98        Enhance,
99        FantasyArt,
100        Isometric,
101        LineArt,
102        LowPoly,
103        ModelingCompound,
104        NeonPunk,
105        Origami,
106        Photographic,
107        PixelArt,
108        TileTexture,
109    }
110
111impl fmt::Display for StylePreset {
112    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
113        match self {
114            StylePreset::ThreeDModel => write!(f, "3d-model"),
115            StylePreset::Anime => write!(f, "anime"),
116            StylePreset::AnalogFilm => write!(f, "analog-film"),
117            StylePreset::Cinematic => write!(f, "cinematic"),
118            StylePreset::ComicBook => write!(f, "comic-book"),
119            StylePreset::DigitalArt => write!(f, "digital-art"),
120            StylePreset::Enhance => write!(f, "enhance"),
121            StylePreset::FantasyArt => write!(f, "fantasy-art"),
122            StylePreset::Isometric => write!(f, "isometric"),
123            StylePreset::LineArt => write!(f, "line-art"),
124            StylePreset::LowPoly => write!(f, "low-poly"),
125            StylePreset::ModelingCompound => write!(f, "modeling-compound"),
126            StylePreset::NeonPunk => write!(f, "neon-punk"),
127            StylePreset::Origami => write!(f, "origami"),
128            StylePreset::Photographic => write!(f, "photographic"),
129            StylePreset::PixelArt => write!(f, "pixel-art"),
130            StylePreset::TileTexture => write!(f, "tile-texture"),
131        }
132    }
133
134}
135
136    #[derive(Debug,PartialEq,Deserialize, Serialize)]
137    #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
138    // Todo: Add more samplers K_DPMPP_SDE?
139    pub enum Sampler {
140        Ddim,
141        Ddpm,
142        #[serde(rename = "K_DPMPP_2M")]
143        KDpmpp2m,
144        #[serde(rename = "K_DPMPP_2S_ANCESTRAL")]
145        KDpmpp2sAncestral,
146        #[serde(rename = "K_DPMP_2")]
147        KDpm2,
148        #[serde(rename = "K_DPMP_2_ANCESTRAL")]
149        KDpm2Ancestral,
150        #[serde(rename = "K_EULER")]
151        KEuler,
152        #[serde(rename = "K_EULER_ANCESTRAL")]
153        KEAncestral,
154        #[serde(rename = "K_HEUN")]
155        KHeun,
156        #[serde(rename = "K_LMS")]
157        KLms,
158        None,
159    }
160
161impl fmt::Display for Sampler {
162    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
163        match self {
164            Sampler::Ddim => write!(f, "ddim"),
165            Sampler::Ddpm => write!(f, "ddpm"),
166            Sampler::KDpmpp2m => write!(f, "k_dpmpp_2m"),
167            Sampler::KDpmpp2sAncestral => write!(f, "k_dpmpp_2s_ancestral"),
168            Sampler::KDpm2 => write!(f, "k_dpm_2"),
169            Sampler::KDpm2Ancestral => write!(f, "k_dpm_2_ancestral"),
170            Sampler::KEuler => write!(f, "k_euler"),
171            Sampler::KEAncestral => write!(f, "k_euler_ancestral"),
172            Sampler::KHeun => write!(f, "k_heun"),
173            Sampler::KLms => write!(f, "k_lms"),
174            Sampler::None => write!(f, "none"),
175        }
176    }
177}
178
179    impl Sampler {
180        pub fn is_none(&self) -> bool {
181            match self {
182                Sampler::None => true,
183                _ => false,
184            }
185        }
186    }
187
188pub struct MultipartFormData {
189    pub boundary: String,
190    pub body: Vec<u8>,
191}
192
193impl MultipartFormData {
194    pub fn new() -> Self {
195        Self {
196            boundary: format!(
197                "-----------------------------{}", rand::thread_rng().gen::<u64>()),
198            body: Vec::new(),
199        }
200    }
201
202    pub fn add_text(&mut self, name: &str, value: &str) -> io::Result<()> {
203        write!(self.body, "--{}\r\n", self.boundary)?;
204        write!(self.body, "Content-Disposition: form-data; name=\"{}\"\r\n\r\n{}\r\n", name, value)?;
205        Ok(())
206    }
207
208    pub fn add_file(&mut self, name: &str, path: &str) -> io::Result<()> {
209        if !path.contains(".") {
210            return Err(io::Error::new(io::ErrorKind::Other, "Invalid file path"));
211        }
212        write!(self.body, "--{}\r\n", self.boundary)?;
213        write!(self.body, "Content-Disposition: form-data; name=\"{}\"; filename=\"{}\"\r\n", name, path)?;
214        write!(self.body, "Content-Type: image/{}\r\n\r\n", path.split_once(".").unwrap().1)?;
215        let mut file = File::open(path)?;
216        file.read_to_end(&mut self.body)?;
217        write!(self.body, "\r\n")?;
218
219
220        //write!(self.body, "--{}--\r\n", self.boundary)?;
221        Ok(())
222    }
223
224    pub fn end_body(&mut self) -> io::Result<()> {
225        write!(self.body, "--{}--\r\n", self.boundary)?;
226        Ok(())
227    }
228
229}
230
231