1#![deny(missing_docs)]
2use std::{collections::HashMap, future::Future};
5
6use data_encoding::BASE64;
7use serde::{de::DeserializeOwned, Deserialize, Serialize};
8use thiserror::Error;
9
10pub use image;
11
12#[derive(Error, Debug)]
14pub enum ClientError {
15 #[error("invalid url; make sure it starts with http")]
17 InvalidUrl,
18 #[error("Not authenticated")]
20 NotAuthenticated,
21 #[error("error: {message}")]
23 Error {
24 message: String,
26 },
27 #[error("invalid response body (expected {expected:?})")]
29 InvalidResponse {
30 expected: String,
32 },
33 #[error("internal server error")]
35 InternalServerError,
36 #[error("Config not available")]
39 ConfigNotAvailable,
40
41 #[error("reqwest error")]
43 ReqwestError(#[from] reqwest::Error),
44 #[error("serde json error")]
46 SerdeJsonError(#[from] serde_json::Error),
47 #[error("base64 decode error")]
49 Base64DecodeError(#[from] data_encoding::DecodeError),
50 #[error("image error")]
52 ImageError(#[from] image::ImageError),
53 #[error("tokio join error")]
55 TokioJoinError(#[from] tokio::task::JoinError),
56 #[error("chrono parse error")]
58 ChronoParseError {
59 timestamp: String,
61 error: chrono::format::ParseError,
63 },
64}
65impl ClientError {
66 fn invalid_response(expected: &str) -> Self {
67 Self::InvalidResponse {
68 expected: expected.to_string(),
69 }
70 }
71}
72pub type Result<T> = core::result::Result<T, ClientError>;
74
75pub enum Authentication<'a> {
77 None,
79 ApiAuth(&'a str, &'a str),
81 GradioAuth(&'a str, &'a str),
83}
84
85pub struct Client {
87 client: RequestClient,
88}
89impl Client {
90 pub async fn new(url: &str, authentication: Authentication<'_>) -> Result<Self> {
92 let mut client = RequestClient::new(url).await?;
93
94 match authentication {
95 Authentication::None => {}
96 Authentication::ApiAuth(username, password) => {
97 client.set_authentication_token(
98 BASE64.encode(format!("{username}:{password}").as_bytes()),
99 );
100 }
101 Authentication::GradioAuth(username, password) => {
102 client
103 .post_raw("login")
104 .form(&HashMap::<&str, &str>::from_iter([
105 ("username", username),
106 ("password", password),
107 ]))
108 .send()
109 .await?;
110 }
111 }
112
113 Ok(Self { client })
114 }
115
116 pub fn generate_from_text(
118 &self,
119 request: &TextToImageGenerationRequest,
120 ) -> impl Future<Output = Result<GenerationResult>> {
121 let client = self.client.clone();
122
123 #[derive(Serialize)]
124 struct Request {
125 batch_size: i32,
126 cfg_scale: f32,
127 denoising_strength: f32,
128 eta: f32,
129 height: u32,
130 n_iter: u32,
131 negative_prompt: String,
132 prompt: String,
133 restore_faces: bool,
134 s_churn: f32,
135 s_noise: f32,
136 s_tmax: f32,
137 s_tmin: f32,
138 sampler_index: String,
139 seed: i64,
140 seed_resize_from_h: i32,
141 seed_resize_from_w: i32,
142 steps: u32,
143 styles: Vec<String>,
144 subseed: i64,
145 subseed_strength: f32,
146 tiling: bool,
147 width: u32,
148 override_settings: OverrideSettings,
149 override_settings_restore_afterwards: bool,
150
151 enable_hr: bool,
152 firstphase_height: u32,
153 firstphase_width: u32,
154 }
155
156 let json_request = {
157 let d = Request {
158 enable_hr: false,
159 denoising_strength: 0.0,
160 firstphase_width: 0,
161 firstphase_height: 0,
162 prompt: String::new(),
163 styles: vec![],
164 seed: -1,
165 subseed: -1,
166 subseed_strength: 0.0,
167 seed_resize_from_h: -1,
168 seed_resize_from_w: -1,
169 batch_size: 1,
170 n_iter: 1,
171 steps: 20,
172 cfg_scale: 7.0,
173 width: 512,
174 height: 512,
175 restore_faces: false,
176 tiling: false,
177 negative_prompt: String::new(),
178 eta: 0.0,
179 s_churn: 0.0,
180 s_tmax: 0.0,
181 s_tmin: 0.0,
182 s_noise: 1.0,
183 sampler_index: Sampler::EulerA.to_string(),
184 override_settings: OverrideSettings::default(),
185 override_settings_restore_afterwards: false,
186 };
187 let r = &request;
188 let b = &request.base;
189 Request {
190 batch_size: b.batch_size.map(|i| i as i32).unwrap_or(d.batch_size),
191 cfg_scale: b.cfg_scale.unwrap_or(d.cfg_scale),
192 denoising_strength: b.denoising_strength.unwrap_or(d.denoising_strength),
193 enable_hr: r.enable_hr.unwrap_or(d.enable_hr),
194 eta: b.eta.unwrap_or(d.eta),
195 firstphase_height: r.firstphase_height.unwrap_or(d.firstphase_height),
196 firstphase_width: r.firstphase_width.unwrap_or(d.firstphase_width),
197 height: b.height.unwrap_or(d.height),
198 n_iter: b.batch_count.unwrap_or(d.n_iter),
199 negative_prompt: b.negative_prompt.clone().unwrap_or(d.negative_prompt),
200 prompt: b.prompt.to_owned(),
201 restore_faces: b.restore_faces.unwrap_or(d.restore_faces),
202 s_churn: b.s_churn.unwrap_or(d.s_churn),
203 s_noise: b.s_noise.unwrap_or(d.s_noise),
204 s_tmax: b.s_tmax.unwrap_or(d.s_tmax),
205 s_tmin: b.s_tmin.unwrap_or(d.s_tmin),
206 sampler_index: b.sampler.map(|s| s.to_string()).unwrap_or(d.sampler_index),
207 seed: b.seed.unwrap_or(d.seed),
208 seed_resize_from_h: b
209 .seed_resize_from_h
210 .map(|i| i as i32)
211 .unwrap_or(d.seed_resize_from_h),
212 seed_resize_from_w: b
213 .seed_resize_from_w
214 .map(|i| i as i32)
215 .unwrap_or(d.seed_resize_from_w),
216 steps: b.steps.unwrap_or(d.steps),
217 styles: b.styles.clone().unwrap_or(d.styles),
218 subseed: b.subseed.unwrap_or(d.subseed),
219 subseed_strength: b.subseed_strength.unwrap_or(d.subseed_strength),
220 tiling: b.tiling.unwrap_or(d.tiling),
221 width: b.width.unwrap_or(d.width),
222 override_settings: OverrideSettings::from_model(r.base.model.as_ref()),
223 override_settings_restore_afterwards: false,
224 }
225 };
226
227 async move {
228 let tiling = json_request.tiling;
229 Self::issue_generation_task(
230 client,
231 "sdapi/v1/txt2img".to_string(),
232 json_request,
233 tiling,
234 )
235 .await
236 }
237 }
238
239 pub fn generate_from_image_and_text(
241 &self,
242 request: &ImageToImageGenerationRequest,
243 ) -> impl Future<Output = Result<GenerationResult>> {
244 let client = self.client.clone();
245
246 #[derive(Serialize)]
247 struct Request {
248 batch_size: i32,
249 cfg_scale: f32,
250 denoising_strength: f32,
251 eta: f32,
252 height: u32,
253 n_iter: u32,
254 negative_prompt: String,
255 prompt: String,
256 restore_faces: bool,
257 s_churn: f32,
258 s_noise: f32,
259 s_tmax: f32,
260 s_tmin: f32,
261 sampler_index: String,
262 seed: i64,
263 seed_resize_from_h: i32,
264 seed_resize_from_w: i32,
265 steps: u32,
266 styles: Vec<String>,
267 subseed: i64,
268 subseed_strength: f32,
269 tiling: bool,
270 width: u32,
271 override_settings: OverrideSettings,
272 override_settings_restore_afterwards: bool,
273
274 init_images: Vec<String>,
275 resize_mode: u32,
276 mask: Option<String>,
277 mask_blur: u32,
278 inpainting_fill: u32,
279 inpaint_full_res: bool,
280 inpaint_full_res_padding: u32,
281 inpainting_mask_invert: u32,
282 include_init_images: bool,
283 }
284
285 let json_request = (|| {
286 let d = Request {
287 denoising_strength: 0.0,
288 prompt: String::new(),
289 styles: vec![],
290 seed: -1,
291 subseed: -1,
292 subseed_strength: 0.0,
293 seed_resize_from_h: -1,
294 seed_resize_from_w: -1,
295 batch_size: 1,
296 n_iter: 1,
297 steps: 20,
298 cfg_scale: 7.0,
299 width: 512,
300 height: 512,
301 restore_faces: false,
302 tiling: false,
303 negative_prompt: String::new(),
304 eta: 0.0,
305 s_churn: 0.0,
306 s_tmax: 0.0,
307 s_tmin: 0.0,
308 s_noise: 1.0,
309 sampler_index: Sampler::EulerA.to_string(),
310 override_settings: OverrideSettings::default(),
311 override_settings_restore_afterwards: true,
312
313 init_images: vec![],
314 resize_mode: 0,
315 mask: None,
316 mask_blur: 4,
317 inpainting_fill: 0,
318 inpaint_full_res: true,
319 inpaint_full_res_padding: 0,
320 inpainting_mask_invert: 0,
321 include_init_images: false,
322 };
323 let r = &request;
324 let b = &request.base;
325 Ok::<_, ClientError>(Request {
326 batch_size: b.batch_size.map(|i| i as i32).unwrap_or(d.batch_size),
327 cfg_scale: b.cfg_scale.unwrap_or(d.cfg_scale),
328 denoising_strength: b.denoising_strength.unwrap_or(d.denoising_strength),
329 eta: b.eta.unwrap_or(d.eta),
330 height: b.height.unwrap_or(d.height),
331 n_iter: b.batch_count.unwrap_or(d.n_iter),
332 negative_prompt: b.negative_prompt.clone().unwrap_or(d.negative_prompt),
333 prompt: b.prompt.to_owned(),
334 restore_faces: b.restore_faces.unwrap_or(d.restore_faces),
335 s_churn: b.s_churn.unwrap_or(d.s_churn),
336 s_noise: b.s_noise.unwrap_or(d.s_noise),
337 s_tmax: b.s_tmax.unwrap_or(d.s_tmax),
338 s_tmin: b.s_tmin.unwrap_or(d.s_tmin),
339 sampler_index: b.sampler.map(|s| s.to_string()).unwrap_or(d.sampler_index),
340 seed: b.seed.unwrap_or(d.seed),
341 seed_resize_from_h: b
342 .seed_resize_from_h
343 .map(|i| i as i32)
344 .unwrap_or(d.seed_resize_from_h),
345 seed_resize_from_w: b
346 .seed_resize_from_w
347 .map(|i| i as i32)
348 .unwrap_or(d.seed_resize_from_w),
349 steps: b.steps.unwrap_or(d.steps),
350 styles: b.styles.clone().unwrap_or(d.styles),
351 subseed: b.subseed.unwrap_or(d.subseed),
352 subseed_strength: b.subseed_strength.unwrap_or(d.subseed_strength),
353 tiling: b.tiling.unwrap_or(d.tiling),
354 width: b.width.unwrap_or(d.width),
355 override_settings: OverrideSettings::from_model(r.base.model.as_ref()),
356 override_settings_restore_afterwards: false,
357
358 init_images: r
359 .images
360 .iter()
361 .map(encode_image_to_base64)
362 .collect::<core::result::Result<Vec<_>, _>>()?,
363 resize_mode: r.resize_mode.unwrap_or_default().into(),
364 mask: r.mask.as_ref().map(encode_image_to_base64).transpose()?,
365 mask_blur: r.mask_blur.unwrap_or(d.mask_blur),
366 inpainting_fill: match r.inpainting_fill_mode.unwrap_or_default() {
367 InpaintingFillMode::Fill => 0,
368 InpaintingFillMode::Original => 1,
369 InpaintingFillMode::LatentNoise => 2,
370 InpaintingFillMode::LatentNothing => 3,
371 },
372 inpaint_full_res: r.inpaint_full_resolution,
373 inpaint_full_res_padding: r
374 .inpaint_full_resolution_padding
375 .unwrap_or(d.inpaint_full_res_padding),
376 inpainting_mask_invert: r.inpainting_mask_invert as _,
377 include_init_images: false,
378 })
379 })();
380
381 async move {
382 let json_request = json_request?;
383 let tiling = json_request.tiling;
384 Self::issue_generation_task(
385 client,
386 "sdapi/v1/img2img".to_string(),
387 json_request,
388 tiling,
389 )
390 .await
391 }
392 }
393
394 pub async fn progress(&self) -> Result<GenerationProgress> {
401 #[derive(Deserialize)]
402 struct State {
403 job_timestamp: String,
404 }
405
406 #[derive(Deserialize)]
407 struct Response {
408 eta_relative: f32,
409 progress: f32,
410 current_image: Option<String>,
411 state: Option<State>,
412 }
413
414 let response: Response = self.client.get("sdapi/v1/progress").await?;
415 Ok(GenerationProgress {
416 eta_seconds: response.eta_relative.max(0.0),
417 progress_factor: response.progress.clamp(0.0, 1.0),
418 current_image: response
419 .current_image
420 .map(|i| decode_image_from_base64(&i))
421 .transpose()?,
422 job_timestamp: response
423 .state
424 .map(|s| {
425 chrono::NaiveDateTime::parse_from_str(&s.job_timestamp, "%Y%m%d%H%M%S").map_err(
426 |error| ClientError::ChronoParseError {
427 timestamp: s.job_timestamp.clone(),
428 error,
429 },
430 )
431 })
432 .transpose()?
433 .map(|dt| dt.and_local_timezone(chrono::Local).unwrap()),
434 })
435 }
436
437 pub async fn postprocess(
439 &self,
440 image: &image::DynamicImage,
441 request: &PostprocessRequest,
442 ) -> Result<image::DynamicImage> {
443 #[derive(Serialize)]
444 struct RequestRaw<'a> {
445 image: &'a str,
446 resize_mode: u32,
447 upscaler_1: &'a str,
448 upscaler_2: &'a str,
449 upscaling_resize: f32,
450
451 #[serde(skip_serializing_if = "Option::is_none")]
452 codeformer_visibility: Option<f32>,
453 #[serde(skip_serializing_if = "Option::is_none")]
454 codeformer_weight: Option<f32>,
455 #[serde(skip_serializing_if = "Option::is_none")]
456 extras_upscaler_2_visibility: Option<f32>,
457 #[serde(skip_serializing_if = "Option::is_none")]
458 gfpgan_visibility: Option<f32>,
459 #[serde(skip_serializing_if = "Option::is_none")]
460 upscale_first: Option<bool>,
461 }
462
463 #[derive(Deserialize)]
464 struct ResponseRaw {
465 image: String,
466 }
467
468 let response: ResponseRaw = self
469 .client
470 .post(
471 "sdapi/v1/extra-single-image",
472 &RequestRaw {
473 image: &encode_image_to_base64(image)?,
474 resize_mode: request.resize_mode.into(),
475 upscaler_1: &request.upscaler_1.to_string(),
476 upscaler_2: &request.upscaler_2.to_string(),
477 upscaling_resize: request.scale_factor,
478
479 codeformer_visibility: request.codeformer_visibility,
480 codeformer_weight: request.codeformer_weight,
481 extras_upscaler_2_visibility: request.upscaler_2_visibility,
482 gfpgan_visibility: request.gfpgan_visibility,
483 upscale_first: request.upscale_first,
484 },
485 )
486 .await?;
487
488 decode_image_from_base64(&response.image)
489 }
490
491 pub async fn interrogate(
493 &self,
494 image: &image::DynamicImage,
495 interrogator: Interrogator,
496 ) -> Result<String> {
497 #[derive(Serialize)]
498 struct RequestRaw<'a> {
499 image: &'a str,
500 model: &'a str,
501 }
502
503 #[derive(Deserialize)]
504 struct ResponseRaw {
505 caption: String,
506 }
507
508 let response: ResponseRaw = self
509 .client
510 .post(
511 "sdapi/v1/interrogate",
512 &RequestRaw {
513 image: &encode_image_to_base64(image)?,
514 model: match interrogator {
515 Interrogator::Clip => "clip",
516 Interrogator::DeepDanbooru => "deepdanbooru",
517 },
518 },
519 )
520 .await?;
521
522 Ok(response.caption)
523 }
524
525 pub async fn png_info(&self, image_bytes: &[u8]) -> Result<String> {
527 #[derive(Serialize)]
528 struct RequestRaw<'a> {
529 image: &'a str,
530 }
531
532 #[derive(Deserialize)]
533 struct ResponseRaw {
534 info: String,
535 }
536
537 let response: ResponseRaw = self
538 .client
539 .post(
540 "sdapi/v1/png-info",
541 &RequestRaw {
542 image: &BASE64.encode(image_bytes),
543 },
544 )
545 .await?;
546
547 Ok(response.info)
548 }
549
550 pub async fn embeddings(&self) -> Result<Embeddings> {
552 #[derive(Deserialize)]
553 struct EmbeddingRaw {
554 step: Option<u32>,
555 sd_checkpoint: Option<String>,
556 sd_checkpoint_name: Option<String>,
557 shape: u32,
558 vectors: u32,
559 }
560
561 #[derive(Deserialize)]
562 struct ResponseRaw {
563 loaded: HashMap<String, EmbeddingRaw>,
564 skipped: HashMap<String, EmbeddingRaw>,
565 }
566
567 fn convert_embeddings(hm: HashMap<String, EmbeddingRaw>) -> HashMap<String, Embedding> {
568 hm.into_iter()
569 .map(|(k, v)| {
570 (
571 k,
572 Embedding {
573 step: v.step,
574 sd_checkpoint: v.sd_checkpoint,
575 sd_checkpoint_name: v.sd_checkpoint_name,
576 shape: v.shape,
577 vectors: v.vectors,
578 },
579 )
580 })
581 .collect()
582 }
583
584 let response: ResponseRaw = self.client.get("sdapi/v1/embeddings").await?;
585
586 Ok(Embeddings {
587 loaded: convert_embeddings(response.loaded),
588 skipped: convert_embeddings(response.skipped),
589 })
590 }
591
592 pub async fn options(&self) -> Result<Options> {
594 #[derive(Deserialize)]
595 struct OptionsRaw {
596 s_churn: f32,
597 s_noise: f32,
598 s_tmin: f32,
599 sd_model_checkpoint: String,
600 }
601
602 self.client
603 .get::<OptionsRaw>("sdapi/v1/options")
604 .await
605 .map(|r| Options {
606 model: r.sd_model_checkpoint,
607 s_churn: r.s_churn,
608 s_noise: r.s_noise,
609 s_tmin: r.s_tmin,
610 })
611 }
612
613 pub async fn samplers(&self) -> Result<Vec<Sampler>> {
615 #[derive(Serialize, Deserialize)]
616 struct SamplerRaw {
617 aliases: Vec<String>,
618 name: String,
619 }
620
621 self.client
622 .get::<Vec<SamplerRaw>>("sdapi/v1/samplers")
623 .await
624 .map(|r| {
625 r.into_iter()
626 .filter_map(|s| Sampler::try_from(s.name.as_str()).ok())
627 .collect::<Vec<_>>()
628 })
629 }
630
631 pub async fn upscalers(&self) -> Result<Vec<Upscaler>> {
633 #[derive(Serialize, Deserialize)]
634 struct UpscalerRaw {
635 model_name: Option<String>,
636 model_path: Option<String>,
637 model_url: Option<String>,
638 name: String,
639 }
640
641 self.client
642 .get::<Vec<UpscalerRaw>>("sdapi/v1/upscalers")
643 .await
644 .map(|r| {
645 r.into_iter()
646 .flat_map(|r| {
647 Upscaler::try_from(r.model_name.as_deref().unwrap_or(r.name.as_str())).ok()
648 })
649 .collect()
650 })
651 }
652
653 pub async fn models(&self) -> Result<Vec<Model>> {
655 #[derive(Serialize, Deserialize)]
656 struct ModelRaw {
657 config: Option<String>,
658 filename: String,
659 hash: Option<String>,
660 #[serde(default)]
661 sha256: Option<String>,
662 model_name: String,
663 title: String,
664 }
665
666 self.client
667 .get::<Vec<ModelRaw>>("sdapi/v1/sd-models")
668 .await
669 .map(|r| {
670 r.into_iter()
671 .map(|r| Model {
672 title: r.title,
673 name: r.model_name,
674 hash_short: r.hash,
675 hash_sha256: r.sha256,
676 })
677 .collect()
678 })
679 }
680
681 pub async fn hypernetworks(&self) -> Result<Vec<String>> {
683 #[derive(Serialize, Deserialize)]
684 struct HypernetworkRaw {
685 name: String,
686 path: String,
687 }
688
689 self.client
690 .get::<Vec<HypernetworkRaw>>("sdapi/v1/hypernetworks")
691 .await
692 .map(|r| r.into_iter().map(|s| s.name).collect())
693 }
694
695 pub async fn face_restorers(&self) -> Result<Vec<String>> {
697 #[derive(Serialize, Deserialize)]
698 struct FaceRestorerRaw {
699 cmd_dir: Option<String>,
700 name: String,
701 }
702
703 self.client
704 .get::<Vec<FaceRestorerRaw>>("sdapi/v1/face-restorers")
705 .await
706 .map(|r| r.into_iter().map(|s| s.name).collect())
707 }
708
709 pub async fn realesrgan_models(&self) -> Result<Vec<String>> {
711 #[derive(Serialize, Deserialize)]
712 struct RealEsrganModelRaw {
713 name: String,
714 path: Option<String>,
715 scale: i64,
716 }
717
718 self.client
719 .get::<Vec<RealEsrganModelRaw>>("sdapi/v1/realesrgan-models")
720 .await
721 .map(|r| r.into_iter().map(|s| s.name).collect())
722 }
723
724 pub async fn prompt_styles(&self) -> Result<Vec<PromptStyle>> {
726 #[derive(Serialize, Deserialize)]
727 struct PromptStyleRaw {
728 name: String,
729 negative_prompt: Option<String>,
730 prompt: Option<String>,
731 }
732
733 self.client
734 .get::<Vec<PromptStyleRaw>>("sdapi/v1/prompt-styles")
735 .await
736 .map(|r| {
737 r.into_iter()
738 .map(|r| PromptStyle {
739 name: r.name,
740 prompt: r.prompt,
741 negative_prompt: r.negative_prompt,
742 })
743 .collect()
744 })
745 }
746
747 pub async fn artist_categories(&self) -> Result<Vec<String>> {
749 self.client
750 .get::<Vec<String>>("sdapi/v1/artist-categories")
751 .await
752 }
753
754 pub async fn artists(&self) -> Result<Vec<Artist>> {
756 #[derive(Serialize, Deserialize)]
757 struct ArtistRaw {
758 category: String,
759 name: String,
760 score: f32,
761 }
762
763 self.client
764 .get::<Vec<ArtistRaw>>("sdapi/v1/artists")
765 .await
766 .map(|r| {
767 r.into_iter()
768 .map(|r| Artist {
769 name: r.name,
770 category: r.category,
771 })
772 .collect()
773 })
774 }
775}
776impl Client {
777 async fn issue_generation_task<R: Serialize + Send + Sync + 'static>(
778 client: RequestClient,
779 url: String,
780 request: R,
781 tiling: bool,
782 ) -> Result<GenerationResult> {
783 #[derive(Deserialize)]
784 struct Response {
785 images: Vec<String>,
786 info: String,
787 }
788
789 #[derive(Deserialize)]
790 pub struct InfoResponse {
791 all_negative_prompts: Vec<String>,
792 all_prompts: Vec<String>,
793
794 all_seeds: Vec<i64>,
795 seed_resize_from_h: i32,
796 seed_resize_from_w: i32,
797
798 all_subseeds: Vec<i64>,
799 subseed_strength: f32,
800
801 cfg_scale: f32,
802 clip_skip: usize,
803 denoising_strength: f32,
804 face_restoration_model: Option<String>,
805 is_using_inpainting_conditioning: bool,
806 job_timestamp: String,
807 restore_faces: bool,
808 sd_model_hash: String,
809 styles: Vec<String>,
810
811 width: u32,
812 height: u32,
813 sampler_name: String,
814 steps: u32,
815 }
816
817 let response: Response = client.post(&url, &request).await?;
818 let pngs = response
819 .images
820 .iter()
821 .map(|b64| {
822 BASE64
823 .decode(b64.as_bytes())
824 .map_err(|e| ClientError::from(e))
825 })
826 .collect::<Result<Vec<_>>>()?;
827 let info = {
828 let raw: InfoResponse = serde_json::from_str(&response.info)?;
829 GenerationInfo {
830 prompts: raw.all_prompts,
831 negative_prompts: raw.all_negative_prompts,
832 seeds: raw.all_seeds,
833 subseeds: raw.all_subseeds,
834 subseed_strength: raw.subseed_strength,
835 width: raw.width,
836 height: raw.height,
837 sampler: Sampler::try_from(raw.sampler_name.as_str()).unwrap(),
838 steps: raw.steps,
839 tiling,
840
841 cfg_scale: raw.cfg_scale,
842 denoising_strength: raw.denoising_strength,
843 restore_faces: raw.restore_faces,
844 seed_resize_from_w: Some(raw.seed_resize_from_w)
845 .filter(|v| *v > 0)
846 .map(|v| v as u32),
847 seed_resize_from_h: Some(raw.seed_resize_from_h)
848 .filter(|v| *v > 0)
849 .map(|v| v as u32),
850 styles: raw.styles,
851
852 clip_skip: raw.clip_skip,
853 face_restoration_model: raw.face_restoration_model,
854 is_using_inpainting_conditioning: raw.is_using_inpainting_conditioning,
855 job_timestamp: chrono::NaiveDateTime::parse_from_str(
856 &raw.job_timestamp,
857 "%Y%m%d%H%M%S",
858 )
859 .map_err(|error| ClientError::ChronoParseError {
860 timestamp: raw.job_timestamp.clone(),
861 error,
862 })?
863 .and_local_timezone(chrono::Local)
864 .unwrap(),
865 model_hash: raw.sd_model_hash,
866 }
867 };
868
869 Ok(GenerationResult { pngs, info })
870 }
871}
872
873fn decode_image_from_base64(b64: &str) -> Result<image::DynamicImage> {
874 Ok(image::load_from_memory(&BASE64.decode(b64.as_bytes())?)?)
875}
876
877fn encode_image_to_base64(image: &image::DynamicImage) -> image::ImageResult<String> {
878 let mut bytes: Vec<u8> = Vec::new();
879 let mut cursor = std::io::Cursor::new(&mut bytes);
880 image.write_to(&mut cursor, image::ImageOutputFormat::Png)?;
881 Ok(BASE64.encode(&bytes))
882}
883
884pub struct GenerationProgress {
889 pub eta_seconds: f32,
891 pub progress_factor: f32,
893 pub current_image: Option<image::DynamicImage>,
895 pub job_timestamp: Option<chrono::DateTime<chrono::Local>>,
897}
898impl GenerationProgress {
899 pub fn is_finished(&self) -> bool {
901 self.progress_factor >= 1.0
902 }
903}
904
905#[derive(Default, Clone, Debug)]
911pub struct BaseGenerationRequest {
912 pub prompt: String,
914 pub negative_prompt: Option<String>,
916
917 pub batch_size: Option<u32>,
919 pub batch_count: Option<u32>,
921
922 pub width: Option<u32>,
924 pub height: Option<u32>,
926
927 pub cfg_scale: Option<f32>,
930 pub denoising_strength: Option<f32>,
932 pub eta: Option<f32>,
934 pub sampler: Option<Sampler>,
936 pub steps: Option<u32>,
938 pub model: Option<Model>,
940
941 pub tiling: Option<bool>,
943 pub restore_faces: Option<bool>,
945
946 pub s_churn: Option<f32>,
948 pub s_noise: Option<f32>,
950 pub s_tmax: Option<f32>,
952 pub s_tmin: Option<f32>,
954
955 pub seed: Option<i64>,
958 pub seed_resize_from_w: Option<u32>,
960 pub seed_resize_from_h: Option<u32>,
962 pub subseed: Option<i64>,
964 pub subseed_strength: Option<f32>,
966
967 pub styles: Option<Vec<String>>,
969}
970
971#[derive(Default, Clone, Debug)]
976pub struct TextToImageGenerationRequest {
977 pub base: BaseGenerationRequest,
979
980 pub firstphase_width: Option<u32>,
982 pub firstphase_height: Option<u32>,
984
985 pub enable_hr: Option<bool>,
987}
988
989#[derive(Default, Clone, Debug)]
994pub struct ImageToImageGenerationRequest {
995 pub base: BaseGenerationRequest,
997
998 pub images: Vec<image::DynamicImage>,
1000
1001 pub resize_mode: Option<ResizeMode>,
1003
1004 pub mask: Option<image::DynamicImage>,
1006
1007 pub mask_blur: Option<u32>,
1009
1010 pub inpainting_fill_mode: Option<InpaintingFillMode>,
1012
1013 pub inpaint_full_resolution: bool,
1015
1016 pub inpaint_full_resolution_padding: Option<u32>,
1018
1019 pub inpainting_mask_invert: bool,
1022}
1023
1024pub struct GenerationResult {
1026 pub pngs: Vec<Vec<u8>>,
1031 pub info: GenerationInfo,
1033}
1034impl GenerationResult {
1035 pub fn images(&self) -> Result<Vec<image::DynamicImage>> {
1039 self.pngs
1040 .iter()
1041 .map(|p| image::load_from_memory(&p).map_err(|e| ClientError::from(e)))
1042 .collect()
1043 }
1044}
1045
1046#[derive(Debug, Clone)]
1048pub struct GenerationInfo {
1049 pub prompts: Vec<String>,
1051 pub negative_prompts: Vec<String>,
1053 pub seeds: Vec<i64>,
1055 pub subseeds: Vec<i64>,
1057 pub subseed_strength: f32,
1059 pub width: u32,
1061 pub height: u32,
1063 pub sampler: Sampler,
1065 pub steps: u32,
1067 pub tiling: bool,
1069
1070 pub cfg_scale: f32,
1073 pub denoising_strength: f32,
1075
1076 pub restore_faces: bool,
1078
1079 pub seed_resize_from_w: Option<u32>,
1081 pub seed_resize_from_h: Option<u32>,
1083
1084 pub styles: Vec<String>,
1086
1087 pub clip_skip: usize,
1089 pub face_restoration_model: Option<String>,
1091 pub is_using_inpainting_conditioning: bool,
1093 pub job_timestamp: chrono::DateTime<chrono::Local>,
1095 pub model_hash: String,
1098}
1099
1100#[derive(Default)]
1102pub struct PostprocessRequest {
1103 pub resize_mode: ResizeMode,
1106 pub upscaler_1: Upscaler,
1108 pub upscaler_2: Upscaler,
1110 pub scale_factor: f32,
1112
1113 pub codeformer_visibility: Option<f32>,
1115 pub codeformer_weight: Option<f32>,
1117 pub upscaler_2_visibility: Option<f32>,
1119 pub gfpgan_visibility: Option<f32>,
1121 pub upscale_first: Option<bool>,
1123}
1124
1125macro_rules! define_user_friendly_enum {
1126 ($enum_name:ident, $doc:literal, {$(($name:ident, $friendly_name:literal)),*}) => {
1127 #[doc = $doc]
1128 #[derive(Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
1129 pub enum $enum_name {
1130 $(
1131 #[doc = $friendly_name]
1132 $name
1133 ),*
1134 }
1135 impl std::fmt::Display for $enum_name {
1136 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1137 write!(f, "{}", match self {
1138 $(
1139 Self::$name => $friendly_name
1140 ),*
1141 })
1142 }
1143 }
1144 impl std::fmt::Debug for $enum_name {
1145 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1146 std::fmt::Display::fmt(self, f)
1147 }
1148 }
1149 impl TryFrom<&str> for $enum_name {
1150 type Error = ();
1151
1152 fn try_from(s: &str) -> core::result::Result<Self, ()> {
1153 match s {
1154 $(
1155 $friendly_name => Ok(Self::$name),
1156 )*
1157 _ => Err(()),
1158 }
1159 }
1160 }
1161 impl $enum_name {
1162 pub const VALUES: &[Self] = &[
1164 $(Self::$name),*
1165 ];
1166 }
1167 }
1168}
1169
1170define_user_friendly_enum!(
1171 Sampler,
1172 "The sampler to use for the generation.",
1173 {
1174 (EulerA, "Euler a"),
1175 (Euler, "Euler"),
1176 (Lms, "LMS"),
1177 (Heun, "Heun"),
1178 (Dpm2, "DPM2"),
1179 (Dpm2A, "DPM2 a"),
1180 (DpmPP2SA, "DPM++ 2S a"),
1181 (DpmPP2M, "DPM++ 2M"),
1182 (DpmPPSDE, "DPM++ SDE"),
1183 (DpmFast, "DPM fast"),
1184 (DpmAdaptive, "DPM adaptive"),
1185 (LmsKarras, "LMS Karras"),
1186 (Dpm2Karras, "DPM2 Karras"),
1187 (Dpm2AKarras, "DPM2 a Karras"),
1188 (DpmPP2SAKarras, "DPM++ 2S a Karras"),
1189 (DpmPP2MKarras, "DPM++ 2M Karras"),
1190 (DpmPPSDEKarras, "DPM++ SDE Karras"),
1191 (Ddim, "DDIM"),
1192 (Plms, "PLMS")
1193 }
1194);
1195
1196define_user_friendly_enum!(
1197 Interrogator,
1198 "Supported interrogators for [Client::interrogate]",
1199 {
1200 (Clip, "CLIP"),
1201 (DeepDanbooru, "DeepDanbooru")
1202 }
1203);
1204
1205define_user_friendly_enum!(
1206 ResizeMode,
1207 "How to resize the image for image-to-image generation",
1208 {
1209 (Resize, "Just resize"),
1210 (CropAndResize, "Crop and resize"),
1211 (ResizeAndFill, "Resize and fill")
1212 }
1213);
1214impl Default for ResizeMode {
1215 fn default() -> Self {
1216 Self::Resize
1217 }
1218}
1219impl From<ResizeMode> for u32 {
1220 fn from(mode: ResizeMode) -> Self {
1221 match mode {
1222 ResizeMode::Resize => 0,
1223 ResizeMode::CropAndResize => 1,
1224 ResizeMode::ResizeAndFill => 2,
1225 }
1226 }
1227}
1228
1229define_user_friendly_enum!(
1230 InpaintingFillMode,
1231 "How the area to be inpainted will be initialized",
1232 {
1233 (Fill, "Fill"),
1234 (Original, "Original"),
1235 (LatentNoise, "Latent noise"),
1236 (LatentNothing, "Latent nothing")
1237 }
1238);
1239impl Default for InpaintingFillMode {
1240 fn default() -> Self {
1241 Self::Original
1242 }
1243}
1244
1245define_user_friendly_enum!(
1246 Upscaler,
1247 "Upscaler",
1248 {
1249 (None, "None"),
1250 (Lanczos, "Lanczos"),
1251 (Nearest, "Nearest"),
1252 (Ldsr, "LDSR"),
1253 (ScuNetGan, "ScuNET GAN"),
1254 (ScuNetPSNR, "ScuNET PSNR"),
1255 (SwinIR4x, "SwinIR 4x"),
1256 (ESRGAN4x, "ESRGAN_4x")
1257 }
1258);
1259impl Default for Upscaler {
1260 fn default() -> Self {
1261 Self::None
1262 }
1263}
1264
1265#[derive(Debug, Clone)]
1267pub struct Options {
1268 pub model: String,
1270
1271 pub s_churn: f32,
1273 pub s_noise: f32,
1275 pub s_tmin: f32,
1277}
1278
1279#[derive(Debug, Clone)]
1281pub struct Model {
1282 pub title: String,
1284 pub name: String,
1286 pub hash_short: Option<String>,
1288 pub hash_sha256: Option<String>,
1290}
1291
1292#[derive(Debug, Clone)]
1294pub struct PromptStyle {
1295 pub name: String,
1297 pub prompt: Option<String>,
1299 pub negative_prompt: Option<String>,
1301}
1302
1303#[derive(Debug, Clone)]
1305pub struct Artist {
1306 pub name: String,
1308 pub category: String,
1310}
1311
1312#[derive(Debug, Clone)]
1314pub struct Embedding {
1315 pub step: Option<u32>,
1317 pub sd_checkpoint: Option<String>,
1319 pub sd_checkpoint_name: Option<String>,
1323 pub shape: u32,
1325 pub vectors: u32,
1327}
1328
1329#[derive(Debug, Clone)]
1331pub struct Embeddings {
1332 pub loaded: HashMap<String, Embedding>,
1334 pub skipped: HashMap<String, Embedding>,
1336}
1337impl Embeddings {
1338 pub fn all(&self) -> impl Iterator<Item = (&String, &Embedding)> {
1340 self.loaded.iter().chain(self.skipped.iter())
1341 }
1342}
1343
1344#[allow(dead_code)]
1345mod config {
1346 use crate::{ClientError, RequestClient, Result};
1347 use std::collections::HashMap;
1348
1349 #[derive(Debug)]
1350 enum ConfigComponent {
1351 Dropdown { choices: Vec<String>, id: String },
1352 Radio { choices: Vec<String>, id: String },
1353 }
1354
1355 #[derive(Debug)]
1359 struct Config(HashMap<u32, ConfigComponent>);
1360 impl Config {
1361 async fn new(client: &RequestClient) -> Result<Self> {
1362 Ok(Self(
1363 client
1364 .get::<HashMap<String, serde_json::Value>>("config")
1365 .await?
1366 .get("components")
1367 .ok_or_else(|| ClientError::invalid_response("components"))?
1368 .as_array()
1369 .ok_or_else(|| ClientError::invalid_response("components to be an array"))?
1370 .iter()
1371 .filter_map(|v| v.as_object())
1372 .filter_map(|o| {
1373 let id = o.get("id")?.as_u64()? as u32;
1374 let comp_type = o.get("type")?.as_str()?;
1375 let props = o.get("props")?.as_object()?;
1376 match comp_type {
1377 "dropdown" => Some((
1378 id,
1379 ConfigComponent::Dropdown {
1380 choices: extract_string_array(props.get("choices")?)?,
1381 id: props.get("elem_id")?.as_str()?.to_owned(),
1382 },
1383 )),
1384 "radio" => Some((
1385 id,
1386 ConfigComponent::Radio {
1387 choices: extract_string_array(props.get("choices")?)?,
1388 id: props.get("elem_id")?.as_str()?.to_owned(),
1389 },
1390 )),
1391 _ => None,
1392 }
1393 })
1394 .collect(),
1395 ))
1396 }
1397
1398 fn embeddings(&self) -> Result<Vec<String>> {
1400 self.get_dropdown("train_embedding")
1401 }
1402
1403 fn values(&self) -> impl Iterator<Item = &ConfigComponent> {
1404 self.0.values()
1405 }
1406 fn get_dropdown(&self, target_id: &str) -> Result<Vec<String>> {
1407 self.values()
1408 .find_map(|comp| match comp {
1409 ConfigComponent::Dropdown { id, choices, .. } if id == target_id => {
1410 Some(choices.clone())
1411 }
1412 _ => None,
1413 })
1414 .ok_or_else(|| {
1415 ClientError::invalid_response(&format!("no {target_id} dropdown component"))
1416 })
1417 }
1418 fn get_radio(&self, target_id: &str) -> Result<Vec<String>> {
1419 self.values()
1420 .find_map(|comp| match comp {
1421 ConfigComponent::Radio { id, choices, .. } if id == target_id => {
1422 Some(choices.clone())
1423 }
1424 _ => None,
1425 })
1426 .ok_or_else(|| {
1427 ClientError::invalid_response(&format!("no {target_id} radio component"))
1428 })
1429 }
1430 }
1431
1432 fn extract_string_array(value: &serde_json::Value) -> Option<Vec<String>> {
1433 Some(
1434 value
1435 .as_array()?
1436 .iter()
1437 .flat_map(|s| Some(s.as_str()?.to_owned()))
1438 .collect(),
1439 )
1440 }
1441}
1442
1443#[derive(Clone)]
1444struct RequestClient {
1445 url: String,
1446 client: reqwest::Client,
1447 authentication_token: Option<String>,
1448}
1449impl RequestClient {
1450 async fn new(url: &str) -> Result<Self> {
1451 if !url.starts_with("http") {
1452 return Err(ClientError::InvalidUrl);
1453 }
1454
1455 let url = url.strip_suffix('/').unwrap_or(url).to_owned();
1456 let client = reqwest::ClientBuilder::new().cookie_store(true).build()?;
1457
1458 Ok(Self {
1459 url,
1460 client,
1461 authentication_token: None,
1462 })
1463 }
1464
1465 fn set_authentication_token(&mut self, token: String) {
1466 self.authentication_token = Some(format!("Basic {token}"));
1467 }
1468
1469 fn url(&self, endpoint: &str) -> String {
1470 format!("{}/{}", self.url, endpoint)
1471 }
1472
1473 fn check_for_authentication<R: DeserializeOwned>(body: String) -> Result<R> {
1474 if body.trim() == "Internal Server Error" {
1475 return Err(ClientError::InternalServerError);
1476 }
1477
1478 match serde_json::from_str::<HashMap<String, serde_json::Value>>(&body) {
1479 Ok(json_body) => match json_body.get("detail") {
1480 Some(serde_json::Value::String(message)) => {
1481 if message == "Not authenticated" {
1482 Err(ClientError::NotAuthenticated)
1483 } else {
1484 Err(ClientError::Error {
1485 message: message.clone(),
1486 })
1487 }
1488 }
1489 Some(other_error) => Err(ClientError::Error {
1490 message: other_error.to_string(),
1491 }),
1492 _ => Ok(serde_json::from_str(&body)?),
1493 },
1494 Err(_) => Ok(serde_json::from_str(&body)?),
1495 }
1496 }
1497
1498 async fn send<R: DeserializeOwned>(&self, builder: reqwest::RequestBuilder) -> Result<R> {
1499 let builder = if let Some(token) = &self.authentication_token {
1500 builder.header("Authorization", token)
1501 } else {
1502 builder
1503 };
1504 Self::check_for_authentication(builder.send().await?.text().await?)
1505 }
1506 async fn get<R: DeserializeOwned>(&self, endpoint: &str) -> Result<R> {
1507 self.send(self.client.get(self.url(endpoint))).await
1508 }
1509 async fn post<R: DeserializeOwned, T: Serialize>(&self, endpoint: &str, body: &T) -> Result<R> {
1510 self.send(self.client.post(self.url(endpoint)).json(body))
1511 .await
1512 }
1513 fn post_raw(&self, endpoint: &str) -> reqwest::RequestBuilder {
1514 self.client.post(self.url(endpoint))
1515 }
1516}
1517
1518#[derive(Serialize, Deserialize, Default)]
1519struct OverrideSettings {
1520 #[serde(skip_serializing_if = "Option::is_none")]
1521 sd_model_checkpoint: Option<String>,
1522}
1523impl OverrideSettings {
1524 fn from_model(model: Option<&Model>) -> Self {
1525 Self {
1526 sd_model_checkpoint: model.map(|m| m.title.clone()),
1527 }
1528 }
1529}