stable_diffusion_a1111_webui_client/
lib.rs

1#![deny(missing_docs)]
2//! This is a client for the Automatic1111 stable-diffusion web UI.
3
4use std::{collections::HashMap, future::Future};
5
6use data_encoding::BASE64;
7use serde::{de::DeserializeOwned, Deserialize, Serialize};
8use thiserror::Error;
9
10pub use image;
11
12/// All potential errors that the client can produce.
13#[derive(Error, Debug)]
14pub enum ClientError {
15    /// The URL passed to the `Client` was invalid.
16    #[error("invalid url; make sure it starts with http")]
17    InvalidUrl,
18    /// The credentials for the `Client` are wrong or missing.
19    #[error("Not authenticated")]
20    NotAuthenticated,
21    /// There was an arbitrary error while servicing the request.
22    #[error("error: {message}")]
23    Error {
24        /// The message associated with the error.
25        message: String,
26    },
27    /// The response body was missing some data.
28    #[error("invalid response body (expected {expected:?})")]
29    InvalidResponse {
30        /// The data that was expected to be there, but wasn't.
31        expected: String,
32    },
33    /// The UI experienced an internal server error.
34    #[error("internal server error")]
35    InternalServerError,
36    /// The operation requires access to the SDUI's config, which is not
37    /// accessible through UI auth alone
38    #[error("Config not available")]
39    ConfigNotAvailable,
40
41    /// Error returned by `reqwest`.
42    #[error("reqwest error")]
43    ReqwestError(#[from] reqwest::Error),
44    /// Error returned by `serde_json`.
45    #[error("serde json error")]
46    SerdeJsonError(#[from] serde_json::Error),
47    /// Error returned by `base64`.
48    #[error("base64 decode error")]
49    Base64DecodeError(#[from] data_encoding::DecodeError),
50    /// Error returned by `image`.
51    #[error("image error")]
52    ImageError(#[from] image::ImageError),
53    /// Error returned by `tokio` due to a failure to join the task.
54    #[error("tokio join error")]
55    TokioJoinError(#[from] tokio::task::JoinError),
56    /// Error returned by `chrono` when trying to parse a datetime.
57    #[error("chrono parse error")]
58    ChronoParseError {
59        /// The timestamp that failed to parse
60        timestamp: String,
61        /// The error returned by `chrono`
62        error: chrono::format::ParseError,
63    },
64}
65impl ClientError {
66    fn invalid_response(expected: &str) -> Self {
67        Self::InvalidResponse {
68            expected: expected.to_string(),
69        }
70    }
71}
72/// Result type for the `Client`.
73pub type Result<T> = core::result::Result<T, ClientError>;
74
75/// The type of authentication to use with a [Client].
76pub enum Authentication<'a> {
77    /// The server is unauthenticated
78    None,
79    /// The server is using API authentication (Authorization header)
80    ApiAuth(&'a str, &'a str),
81    /// The server is using Gradio authentication (/login)
82    GradioAuth(&'a str, &'a str),
83}
84
85/// Interface to the web UI.
86pub struct Client {
87    client: RequestClient,
88}
89impl Client {
90    /// Creates a new `Client` and authenticates to the web UI.
91    pub async fn new(url: &str, authentication: Authentication<'_>) -> Result<Self> {
92        let mut client = RequestClient::new(url).await?;
93
94        match authentication {
95            Authentication::None => {}
96            Authentication::ApiAuth(username, password) => {
97                client.set_authentication_token(
98                    BASE64.encode(format!("{username}:{password}").as_bytes()),
99                );
100            }
101            Authentication::GradioAuth(username, password) => {
102                client
103                    .post_raw("login")
104                    .form(&HashMap::<&str, &str>::from_iter([
105                        ("username", username),
106                        ("password", password),
107                    ]))
108                    .send()
109                    .await?;
110            }
111        }
112
113        Ok(Self { client })
114    }
115
116    /// Generates an image from the provided `request`, which contains a prompt.
117    pub fn generate_from_text(
118        &self,
119        request: &TextToImageGenerationRequest,
120    ) -> impl Future<Output = Result<GenerationResult>> {
121        let client = self.client.clone();
122
123        #[derive(Serialize)]
124        struct Request {
125            batch_size: i32,
126            cfg_scale: f32,
127            denoising_strength: f32,
128            eta: f32,
129            height: u32,
130            n_iter: u32,
131            negative_prompt: String,
132            prompt: String,
133            restore_faces: bool,
134            s_churn: f32,
135            s_noise: f32,
136            s_tmax: f32,
137            s_tmin: f32,
138            sampler_index: String,
139            seed: i64,
140            seed_resize_from_h: i32,
141            seed_resize_from_w: i32,
142            steps: u32,
143            styles: Vec<String>,
144            subseed: i64,
145            subseed_strength: f32,
146            tiling: bool,
147            width: u32,
148            override_settings: OverrideSettings,
149            override_settings_restore_afterwards: bool,
150
151            enable_hr: bool,
152            firstphase_height: u32,
153            firstphase_width: u32,
154        }
155
156        let json_request = {
157            let d = Request {
158                enable_hr: false,
159                denoising_strength: 0.0,
160                firstphase_width: 0,
161                firstphase_height: 0,
162                prompt: String::new(),
163                styles: vec![],
164                seed: -1,
165                subseed: -1,
166                subseed_strength: 0.0,
167                seed_resize_from_h: -1,
168                seed_resize_from_w: -1,
169                batch_size: 1,
170                n_iter: 1,
171                steps: 20,
172                cfg_scale: 7.0,
173                width: 512,
174                height: 512,
175                restore_faces: false,
176                tiling: false,
177                negative_prompt: String::new(),
178                eta: 0.0,
179                s_churn: 0.0,
180                s_tmax: 0.0,
181                s_tmin: 0.0,
182                s_noise: 1.0,
183                sampler_index: Sampler::EulerA.to_string(),
184                override_settings: OverrideSettings::default(),
185                override_settings_restore_afterwards: false,
186            };
187            let r = &request;
188            let b = &request.base;
189            Request {
190                batch_size: b.batch_size.map(|i| i as i32).unwrap_or(d.batch_size),
191                cfg_scale: b.cfg_scale.unwrap_or(d.cfg_scale),
192                denoising_strength: b.denoising_strength.unwrap_or(d.denoising_strength),
193                enable_hr: r.enable_hr.unwrap_or(d.enable_hr),
194                eta: b.eta.unwrap_or(d.eta),
195                firstphase_height: r.firstphase_height.unwrap_or(d.firstphase_height),
196                firstphase_width: r.firstphase_width.unwrap_or(d.firstphase_width),
197                height: b.height.unwrap_or(d.height),
198                n_iter: b.batch_count.unwrap_or(d.n_iter),
199                negative_prompt: b.negative_prompt.clone().unwrap_or(d.negative_prompt),
200                prompt: b.prompt.to_owned(),
201                restore_faces: b.restore_faces.unwrap_or(d.restore_faces),
202                s_churn: b.s_churn.unwrap_or(d.s_churn),
203                s_noise: b.s_noise.unwrap_or(d.s_noise),
204                s_tmax: b.s_tmax.unwrap_or(d.s_tmax),
205                s_tmin: b.s_tmin.unwrap_or(d.s_tmin),
206                sampler_index: b.sampler.map(|s| s.to_string()).unwrap_or(d.sampler_index),
207                seed: b.seed.unwrap_or(d.seed),
208                seed_resize_from_h: b
209                    .seed_resize_from_h
210                    .map(|i| i as i32)
211                    .unwrap_or(d.seed_resize_from_h),
212                seed_resize_from_w: b
213                    .seed_resize_from_w
214                    .map(|i| i as i32)
215                    .unwrap_or(d.seed_resize_from_w),
216                steps: b.steps.unwrap_or(d.steps),
217                styles: b.styles.clone().unwrap_or(d.styles),
218                subseed: b.subseed.unwrap_or(d.subseed),
219                subseed_strength: b.subseed_strength.unwrap_or(d.subseed_strength),
220                tiling: b.tiling.unwrap_or(d.tiling),
221                width: b.width.unwrap_or(d.width),
222                override_settings: OverrideSettings::from_model(r.base.model.as_ref()),
223                override_settings_restore_afterwards: false,
224            }
225        };
226
227        async move {
228            let tiling = json_request.tiling;
229            Self::issue_generation_task(
230                client,
231                "sdapi/v1/txt2img".to_string(),
232                json_request,
233                tiling,
234            )
235            .await
236        }
237    }
238
239    /// Generates an image from the provided `request`, which contains both an image and a prompt.
240    pub fn generate_from_image_and_text(
241        &self,
242        request: &ImageToImageGenerationRequest,
243    ) -> impl Future<Output = Result<GenerationResult>> {
244        let client = self.client.clone();
245
246        #[derive(Serialize)]
247        struct Request {
248            batch_size: i32,
249            cfg_scale: f32,
250            denoising_strength: f32,
251            eta: f32,
252            height: u32,
253            n_iter: u32,
254            negative_prompt: String,
255            prompt: String,
256            restore_faces: bool,
257            s_churn: f32,
258            s_noise: f32,
259            s_tmax: f32,
260            s_tmin: f32,
261            sampler_index: String,
262            seed: i64,
263            seed_resize_from_h: i32,
264            seed_resize_from_w: i32,
265            steps: u32,
266            styles: Vec<String>,
267            subseed: i64,
268            subseed_strength: f32,
269            tiling: bool,
270            width: u32,
271            override_settings: OverrideSettings,
272            override_settings_restore_afterwards: bool,
273
274            init_images: Vec<String>,
275            resize_mode: u32,
276            mask: Option<String>,
277            mask_blur: u32,
278            inpainting_fill: u32,
279            inpaint_full_res: bool,
280            inpaint_full_res_padding: u32,
281            inpainting_mask_invert: u32,
282            include_init_images: bool,
283        }
284
285        let json_request = (|| {
286            let d = Request {
287                denoising_strength: 0.0,
288                prompt: String::new(),
289                styles: vec![],
290                seed: -1,
291                subseed: -1,
292                subseed_strength: 0.0,
293                seed_resize_from_h: -1,
294                seed_resize_from_w: -1,
295                batch_size: 1,
296                n_iter: 1,
297                steps: 20,
298                cfg_scale: 7.0,
299                width: 512,
300                height: 512,
301                restore_faces: false,
302                tiling: false,
303                negative_prompt: String::new(),
304                eta: 0.0,
305                s_churn: 0.0,
306                s_tmax: 0.0,
307                s_tmin: 0.0,
308                s_noise: 1.0,
309                sampler_index: Sampler::EulerA.to_string(),
310                override_settings: OverrideSettings::default(),
311                override_settings_restore_afterwards: true,
312
313                init_images: vec![],
314                resize_mode: 0,
315                mask: None,
316                mask_blur: 4,
317                inpainting_fill: 0,
318                inpaint_full_res: true,
319                inpaint_full_res_padding: 0,
320                inpainting_mask_invert: 0,
321                include_init_images: false,
322            };
323            let r = &request;
324            let b = &request.base;
325            Ok::<_, ClientError>(Request {
326                batch_size: b.batch_size.map(|i| i as i32).unwrap_or(d.batch_size),
327                cfg_scale: b.cfg_scale.unwrap_or(d.cfg_scale),
328                denoising_strength: b.denoising_strength.unwrap_or(d.denoising_strength),
329                eta: b.eta.unwrap_or(d.eta),
330                height: b.height.unwrap_or(d.height),
331                n_iter: b.batch_count.unwrap_or(d.n_iter),
332                negative_prompt: b.negative_prompt.clone().unwrap_or(d.negative_prompt),
333                prompt: b.prompt.to_owned(),
334                restore_faces: b.restore_faces.unwrap_or(d.restore_faces),
335                s_churn: b.s_churn.unwrap_or(d.s_churn),
336                s_noise: b.s_noise.unwrap_or(d.s_noise),
337                s_tmax: b.s_tmax.unwrap_or(d.s_tmax),
338                s_tmin: b.s_tmin.unwrap_or(d.s_tmin),
339                sampler_index: b.sampler.map(|s| s.to_string()).unwrap_or(d.sampler_index),
340                seed: b.seed.unwrap_or(d.seed),
341                seed_resize_from_h: b
342                    .seed_resize_from_h
343                    .map(|i| i as i32)
344                    .unwrap_or(d.seed_resize_from_h),
345                seed_resize_from_w: b
346                    .seed_resize_from_w
347                    .map(|i| i as i32)
348                    .unwrap_or(d.seed_resize_from_w),
349                steps: b.steps.unwrap_or(d.steps),
350                styles: b.styles.clone().unwrap_or(d.styles),
351                subseed: b.subseed.unwrap_or(d.subseed),
352                subseed_strength: b.subseed_strength.unwrap_or(d.subseed_strength),
353                tiling: b.tiling.unwrap_or(d.tiling),
354                width: b.width.unwrap_or(d.width),
355                override_settings: OverrideSettings::from_model(r.base.model.as_ref()),
356                override_settings_restore_afterwards: false,
357
358                init_images: r
359                    .images
360                    .iter()
361                    .map(encode_image_to_base64)
362                    .collect::<core::result::Result<Vec<_>, _>>()?,
363                resize_mode: r.resize_mode.unwrap_or_default().into(),
364                mask: r.mask.as_ref().map(encode_image_to_base64).transpose()?,
365                mask_blur: r.mask_blur.unwrap_or(d.mask_blur),
366                inpainting_fill: match r.inpainting_fill_mode.unwrap_or_default() {
367                    InpaintingFillMode::Fill => 0,
368                    InpaintingFillMode::Original => 1,
369                    InpaintingFillMode::LatentNoise => 2,
370                    InpaintingFillMode::LatentNothing => 3,
371                },
372                inpaint_full_res: r.inpaint_full_resolution,
373                inpaint_full_res_padding: r
374                    .inpaint_full_resolution_padding
375                    .unwrap_or(d.inpaint_full_res_padding),
376                inpainting_mask_invert: r.inpainting_mask_invert as _,
377                include_init_images: false,
378            })
379        })();
380
381        async move {
382            let json_request = json_request?;
383            let tiling = json_request.tiling;
384            Self::issue_generation_task(
385                client,
386                "sdapi/v1/img2img".to_string(),
387                json_request,
388                tiling,
389            )
390            .await
391        }
392    }
393
394    /// Retrieves the progress of the current generation.
395    ///
396    /// Note that:
397    ///     - this does not disambiguate between generations (the WebUI does not expose details on its queue)
398    ///     - this will return 0% if there is no generation underway, which can be confusing after a
399    ///       generation finishes
400    pub async fn progress(&self) -> Result<GenerationProgress> {
401        #[derive(Deserialize)]
402        struct State {
403            job_timestamp: String,
404        }
405
406        #[derive(Deserialize)]
407        struct Response {
408            eta_relative: f32,
409            progress: f32,
410            current_image: Option<String>,
411            state: Option<State>,
412        }
413
414        let response: Response = self.client.get("sdapi/v1/progress").await?;
415        Ok(GenerationProgress {
416            eta_seconds: response.eta_relative.max(0.0),
417            progress_factor: response.progress.clamp(0.0, 1.0),
418            current_image: response
419                .current_image
420                .map(|i| decode_image_from_base64(&i))
421                .transpose()?,
422            job_timestamp: response
423                .state
424                .map(|s| {
425                    chrono::NaiveDateTime::parse_from_str(&s.job_timestamp, "%Y%m%d%H%M%S").map_err(
426                        |error| ClientError::ChronoParseError {
427                            timestamp: s.job_timestamp.clone(),
428                            error,
429                        },
430                    )
431                })
432                .transpose()?
433                .map(|dt| dt.and_local_timezone(chrono::Local).unwrap()),
434        })
435    }
436
437    /// Upscales the given `image` and applies additional (optional) post-processing.
438    pub async fn postprocess(
439        &self,
440        image: &image::DynamicImage,
441        request: &PostprocessRequest,
442    ) -> Result<image::DynamicImage> {
443        #[derive(Serialize)]
444        struct RequestRaw<'a> {
445            image: &'a str,
446            resize_mode: u32,
447            upscaler_1: &'a str,
448            upscaler_2: &'a str,
449            upscaling_resize: f32,
450
451            #[serde(skip_serializing_if = "Option::is_none")]
452            codeformer_visibility: Option<f32>,
453            #[serde(skip_serializing_if = "Option::is_none")]
454            codeformer_weight: Option<f32>,
455            #[serde(skip_serializing_if = "Option::is_none")]
456            extras_upscaler_2_visibility: Option<f32>,
457            #[serde(skip_serializing_if = "Option::is_none")]
458            gfpgan_visibility: Option<f32>,
459            #[serde(skip_serializing_if = "Option::is_none")]
460            upscale_first: Option<bool>,
461        }
462
463        #[derive(Deserialize)]
464        struct ResponseRaw {
465            image: String,
466        }
467
468        let response: ResponseRaw = self
469            .client
470            .post(
471                "sdapi/v1/extra-single-image",
472                &RequestRaw {
473                    image: &encode_image_to_base64(image)?,
474                    resize_mode: request.resize_mode.into(),
475                    upscaler_1: &request.upscaler_1.to_string(),
476                    upscaler_2: &request.upscaler_2.to_string(),
477                    upscaling_resize: request.scale_factor,
478
479                    codeformer_visibility: request.codeformer_visibility,
480                    codeformer_weight: request.codeformer_weight,
481                    extras_upscaler_2_visibility: request.upscaler_2_visibility,
482                    gfpgan_visibility: request.gfpgan_visibility,
483                    upscale_first: request.upscale_first,
484                },
485            )
486            .await?;
487
488        decode_image_from_base64(&response.image)
489    }
490
491    /// Interrogates the given `image` with the `interrogator` to generate a caption.
492    pub async fn interrogate(
493        &self,
494        image: &image::DynamicImage,
495        interrogator: Interrogator,
496    ) -> Result<String> {
497        #[derive(Serialize)]
498        struct RequestRaw<'a> {
499            image: &'a str,
500            model: &'a str,
501        }
502
503        #[derive(Deserialize)]
504        struct ResponseRaw {
505            caption: String,
506        }
507
508        let response: ResponseRaw = self
509            .client
510            .post(
511                "sdapi/v1/interrogate",
512                &RequestRaw {
513                    image: &encode_image_to_base64(image)?,
514                    model: match interrogator {
515                        Interrogator::Clip => "clip",
516                        Interrogator::DeepDanbooru => "deepdanbooru",
517                    },
518                },
519            )
520            .await?;
521
522        Ok(response.caption)
523    }
524
525    /// Gets the PNG info for the image (assumed to be valid PNG)
526    pub async fn png_info(&self, image_bytes: &[u8]) -> Result<String> {
527        #[derive(Serialize)]
528        struct RequestRaw<'a> {
529            image: &'a str,
530        }
531
532        #[derive(Deserialize)]
533        struct ResponseRaw {
534            info: String,
535        }
536
537        let response: ResponseRaw = self
538            .client
539            .post(
540                "sdapi/v1/png-info",
541                &RequestRaw {
542                    image: &BASE64.encode(image_bytes),
543                },
544            )
545            .await?;
546
547        Ok(response.info)
548    }
549
550    /// Get the embeddings
551    pub async fn embeddings(&self) -> Result<Embeddings> {
552        #[derive(Deserialize)]
553        struct EmbeddingRaw {
554            step: Option<u32>,
555            sd_checkpoint: Option<String>,
556            sd_checkpoint_name: Option<String>,
557            shape: u32,
558            vectors: u32,
559        }
560
561        #[derive(Deserialize)]
562        struct ResponseRaw {
563            loaded: HashMap<String, EmbeddingRaw>,
564            skipped: HashMap<String, EmbeddingRaw>,
565        }
566
567        fn convert_embeddings(hm: HashMap<String, EmbeddingRaw>) -> HashMap<String, Embedding> {
568            hm.into_iter()
569                .map(|(k, v)| {
570                    (
571                        k,
572                        Embedding {
573                            step: v.step,
574                            sd_checkpoint: v.sd_checkpoint,
575                            sd_checkpoint_name: v.sd_checkpoint_name,
576                            shape: v.shape,
577                            vectors: v.vectors,
578                        },
579                    )
580                })
581                .collect()
582        }
583
584        let response: ResponseRaw = self.client.get("sdapi/v1/embeddings").await?;
585
586        Ok(Embeddings {
587            loaded: convert_embeddings(response.loaded),
588            skipped: convert_embeddings(response.skipped),
589        })
590    }
591
592    /// Get the options
593    pub async fn options(&self) -> Result<Options> {
594        #[derive(Deserialize)]
595        struct OptionsRaw {
596            s_churn: f32,
597            s_noise: f32,
598            s_tmin: f32,
599            sd_model_checkpoint: String,
600        }
601
602        self.client
603            .get::<OptionsRaw>("sdapi/v1/options")
604            .await
605            .map(|r| Options {
606                model: r.sd_model_checkpoint,
607                s_churn: r.s_churn,
608                s_noise: r.s_noise,
609                s_tmin: r.s_tmin,
610            })
611    }
612
613    /// Get the samplers
614    pub async fn samplers(&self) -> Result<Vec<Sampler>> {
615        #[derive(Serialize, Deserialize)]
616        struct SamplerRaw {
617            aliases: Vec<String>,
618            name: String,
619        }
620
621        self.client
622            .get::<Vec<SamplerRaw>>("sdapi/v1/samplers")
623            .await
624            .map(|r| {
625                r.into_iter()
626                    .filter_map(|s| Sampler::try_from(s.name.as_str()).ok())
627                    .collect::<Vec<_>>()
628            })
629    }
630
631    /// Get the upscalers
632    pub async fn upscalers(&self) -> Result<Vec<Upscaler>> {
633        #[derive(Serialize, Deserialize)]
634        struct UpscalerRaw {
635            model_name: Option<String>,
636            model_path: Option<String>,
637            model_url: Option<String>,
638            name: String,
639        }
640
641        self.client
642            .get::<Vec<UpscalerRaw>>("sdapi/v1/upscalers")
643            .await
644            .map(|r| {
645                r.into_iter()
646                    .flat_map(|r| {
647                        Upscaler::try_from(r.model_name.as_deref().unwrap_or(r.name.as_str())).ok()
648                    })
649                    .collect()
650            })
651    }
652
653    /// Get the models
654    pub async fn models(&self) -> Result<Vec<Model>> {
655        #[derive(Serialize, Deserialize)]
656        struct ModelRaw {
657            config: Option<String>,
658            filename: String,
659            hash: Option<String>,
660            #[serde(default)]
661            sha256: Option<String>,
662            model_name: String,
663            title: String,
664        }
665
666        self.client
667            .get::<Vec<ModelRaw>>("sdapi/v1/sd-models")
668            .await
669            .map(|r| {
670                r.into_iter()
671                    .map(|r| Model {
672                        title: r.title,
673                        name: r.model_name,
674                        hash_short: r.hash,
675                        hash_sha256: r.sha256,
676                    })
677                    .collect()
678            })
679    }
680
681    /// Get the hypernetworks
682    pub async fn hypernetworks(&self) -> Result<Vec<String>> {
683        #[derive(Serialize, Deserialize)]
684        struct HypernetworkRaw {
685            name: String,
686            path: String,
687        }
688
689        self.client
690            .get::<Vec<HypernetworkRaw>>("sdapi/v1/hypernetworks")
691            .await
692            .map(|r| r.into_iter().map(|s| s.name).collect())
693    }
694
695    /// Get the face restorers
696    pub async fn face_restorers(&self) -> Result<Vec<String>> {
697        #[derive(Serialize, Deserialize)]
698        struct FaceRestorerRaw {
699            cmd_dir: Option<String>,
700            name: String,
701        }
702
703        self.client
704            .get::<Vec<FaceRestorerRaw>>("sdapi/v1/face-restorers")
705            .await
706            .map(|r| r.into_iter().map(|s| s.name).collect())
707    }
708
709    /// Get the real ESRGAN models
710    pub async fn realesrgan_models(&self) -> Result<Vec<String>> {
711        #[derive(Serialize, Deserialize)]
712        struct RealEsrganModelRaw {
713            name: String,
714            path: Option<String>,
715            scale: i64,
716        }
717
718        self.client
719            .get::<Vec<RealEsrganModelRaw>>("sdapi/v1/realesrgan-models")
720            .await
721            .map(|r| r.into_iter().map(|s| s.name).collect())
722    }
723
724    /// Get the prompt styles
725    pub async fn prompt_styles(&self) -> Result<Vec<PromptStyle>> {
726        #[derive(Serialize, Deserialize)]
727        struct PromptStyleRaw {
728            name: String,
729            negative_prompt: Option<String>,
730            prompt: Option<String>,
731        }
732
733        self.client
734            .get::<Vec<PromptStyleRaw>>("sdapi/v1/prompt-styles")
735            .await
736            .map(|r| {
737                r.into_iter()
738                    .map(|r| PromptStyle {
739                        name: r.name,
740                        prompt: r.prompt,
741                        negative_prompt: r.negative_prompt,
742                    })
743                    .collect()
744            })
745    }
746
747    /// Get the artist categories
748    pub async fn artist_categories(&self) -> Result<Vec<String>> {
749        self.client
750            .get::<Vec<String>>("sdapi/v1/artist-categories")
751            .await
752    }
753
754    /// Get the artists
755    pub async fn artists(&self) -> Result<Vec<Artist>> {
756        #[derive(Serialize, Deserialize)]
757        struct ArtistRaw {
758            category: String,
759            name: String,
760            score: f32,
761        }
762
763        self.client
764            .get::<Vec<ArtistRaw>>("sdapi/v1/artists")
765            .await
766            .map(|r| {
767                r.into_iter()
768                    .map(|r| Artist {
769                        name: r.name,
770                        category: r.category,
771                    })
772                    .collect()
773            })
774    }
775}
776impl Client {
777    async fn issue_generation_task<R: Serialize + Send + Sync + 'static>(
778        client: RequestClient,
779        url: String,
780        request: R,
781        tiling: bool,
782    ) -> Result<GenerationResult> {
783        #[derive(Deserialize)]
784        struct Response {
785            images: Vec<String>,
786            info: String,
787        }
788
789        #[derive(Deserialize)]
790        pub struct InfoResponse {
791            all_negative_prompts: Vec<String>,
792            all_prompts: Vec<String>,
793
794            all_seeds: Vec<i64>,
795            seed_resize_from_h: i32,
796            seed_resize_from_w: i32,
797
798            all_subseeds: Vec<i64>,
799            subseed_strength: f32,
800
801            cfg_scale: f32,
802            clip_skip: usize,
803            denoising_strength: f32,
804            face_restoration_model: Option<String>,
805            is_using_inpainting_conditioning: bool,
806            job_timestamp: String,
807            restore_faces: bool,
808            sd_model_hash: String,
809            styles: Vec<String>,
810
811            width: u32,
812            height: u32,
813            sampler_name: String,
814            steps: u32,
815        }
816
817        let response: Response = client.post(&url, &request).await?;
818        let pngs = response
819            .images
820            .iter()
821            .map(|b64| {
822                BASE64
823                    .decode(b64.as_bytes())
824                    .map_err(|e| ClientError::from(e))
825            })
826            .collect::<Result<Vec<_>>>()?;
827        let info = {
828            let raw: InfoResponse = serde_json::from_str(&response.info)?;
829            GenerationInfo {
830                prompts: raw.all_prompts,
831                negative_prompts: raw.all_negative_prompts,
832                seeds: raw.all_seeds,
833                subseeds: raw.all_subseeds,
834                subseed_strength: raw.subseed_strength,
835                width: raw.width,
836                height: raw.height,
837                sampler: Sampler::try_from(raw.sampler_name.as_str()).unwrap(),
838                steps: raw.steps,
839                tiling,
840
841                cfg_scale: raw.cfg_scale,
842                denoising_strength: raw.denoising_strength,
843                restore_faces: raw.restore_faces,
844                seed_resize_from_w: Some(raw.seed_resize_from_w)
845                    .filter(|v| *v > 0)
846                    .map(|v| v as u32),
847                seed_resize_from_h: Some(raw.seed_resize_from_h)
848                    .filter(|v| *v > 0)
849                    .map(|v| v as u32),
850                styles: raw.styles,
851
852                clip_skip: raw.clip_skip,
853                face_restoration_model: raw.face_restoration_model,
854                is_using_inpainting_conditioning: raw.is_using_inpainting_conditioning,
855                job_timestamp: chrono::NaiveDateTime::parse_from_str(
856                    &raw.job_timestamp,
857                    "%Y%m%d%H%M%S",
858                )
859                .map_err(|error| ClientError::ChronoParseError {
860                    timestamp: raw.job_timestamp.clone(),
861                    error,
862                })?
863                .and_local_timezone(chrono::Local)
864                .unwrap(),
865                model_hash: raw.sd_model_hash,
866            }
867        };
868
869        Ok(GenerationResult { pngs, info })
870    }
871}
872
873fn decode_image_from_base64(b64: &str) -> Result<image::DynamicImage> {
874    Ok(image::load_from_memory(&BASE64.decode(b64.as_bytes())?)?)
875}
876
877fn encode_image_to_base64(image: &image::DynamicImage) -> image::ImageResult<String> {
878    let mut bytes: Vec<u8> = Vec::new();
879    let mut cursor = std::io::Cursor::new(&mut bytes);
880    image.write_to(&mut cursor, image::ImageOutputFormat::Png)?;
881    Ok(BASE64.encode(&bytes))
882}
883
884/// How much of the generation is complete.
885///
886/// Note that this can return a zero value on completion as the web UI
887/// can take time to return the result after completion.
888pub struct GenerationProgress {
889    /// Estimated time to completion, in seconds
890    pub eta_seconds: f32,
891    /// How much of the generation is complete, from 0 to 1
892    pub progress_factor: f32,
893    /// The current image being generated, if available.
894    pub current_image: Option<image::DynamicImage>,
895    /// The timestamp that the current job was started. Can be used to disambiguate between jobs.
896    pub job_timestamp: Option<chrono::DateTime<chrono::Local>>,
897}
898impl GenerationProgress {
899    /// Whether or not the generation has completed.
900    pub fn is_finished(&self) -> bool {
901        self.progress_factor >= 1.0
902    }
903}
904
905/// The parameters to the generation that are shared between text-to-image synthesis
906/// and image-to-image synthesis.
907///
908/// Consider using the [Default] trait to fill in the
909/// parameters that you don't need to fill in.
910#[derive(Default, Clone, Debug)]
911pub struct BaseGenerationRequest {
912    /// The prompt
913    pub prompt: String,
914    /// The negative prompt (elements to avoid from the generation)
915    pub negative_prompt: Option<String>,
916
917    /// The number of images in each batch
918    pub batch_size: Option<u32>,
919    /// The number of batches
920    pub batch_count: Option<u32>,
921
922    /// The width of the generated image
923    pub width: Option<u32>,
924    /// The height of the generated image
925    pub height: Option<u32>,
926
927    /// The Classifier-Free Guidance scale; how strongly the prompt is
928    /// applied to the generation
929    pub cfg_scale: Option<f32>,
930    /// The denoising strength
931    pub denoising_strength: Option<f32>,
932    /// The η parameter
933    pub eta: Option<f32>,
934    /// The sampler to use for the generation
935    pub sampler: Option<Sampler>,
936    /// The number of steps
937    pub steps: Option<u32>,
938    /// The model override to use. If not supplied, the currently-set model will be used.
939    pub model: Option<Model>,
940
941    /// Whether or not the image should be tiled at the edges
942    pub tiling: Option<bool>,
943    /// Whether or not to apply the face restoration
944    pub restore_faces: Option<bool>,
945
946    /// s_churn
947    pub s_churn: Option<f32>,
948    /// s_noise
949    pub s_noise: Option<f32>,
950    /// s_tmax
951    pub s_tmax: Option<f32>,
952    /// s_tmin
953    pub s_tmin: Option<f32>,
954
955    /// The seed to use for this generation. This will apply to the first image,
956    /// and the web UI will generate the successive seeds.
957    pub seed: Option<i64>,
958    /// The width to resize the image from if reusing a seed with a different size
959    pub seed_resize_from_w: Option<u32>,
960    /// The height to resize the image from if reusing a seed with a different size
961    pub seed_resize_from_h: Option<u32>,
962    /// The subseed to use for this generation
963    pub subseed: Option<i64>,
964    /// The strength of the subseed
965    pub subseed_strength: Option<f32>,
966
967    /// Any styles to apply to the generation
968    pub styles: Option<Vec<String>>,
969}
970
971/// Parameters for a text-to-image generation.
972///
973/// Consider using the [Default] trait to fill in the
974/// parameters that you don't need to fill in.
975#[derive(Default, Clone, Debug)]
976pub struct TextToImageGenerationRequest {
977    /// The base parameters for this generation request.
978    pub base: BaseGenerationRequest,
979
980    /// The width of the first phase of the generated image
981    pub firstphase_width: Option<u32>,
982    /// The height of the first phase of the generated image
983    pub firstphase_height: Option<u32>,
984
985    /// Unknown
986    pub enable_hr: Option<bool>,
987}
988
989/// Parameters for an image-to-image generation.
990///
991/// Consider using the [Default] trait to fill in the
992/// parameters that you don't need to fill in.
993#[derive(Default, Clone, Debug)]
994pub struct ImageToImageGenerationRequest {
995    /// The base parameters for this generation request.
996    pub base: BaseGenerationRequest,
997
998    /// The images to alter.
999    pub images: Vec<image::DynamicImage>,
1000
1001    /// How the image will be resized to match the generation resolution
1002    pub resize_mode: Option<ResizeMode>,
1003
1004    /// The mask to apply
1005    pub mask: Option<image::DynamicImage>,
1006
1007    /// The amount to blur the mask
1008    pub mask_blur: Option<u32>,
1009
1010    /// How the area to be inpainted will be initialized
1011    pub inpainting_fill_mode: Option<InpaintingFillMode>,
1012
1013    /// Whether or not to inpaint at full resolution
1014    pub inpaint_full_resolution: bool,
1015
1016    /// The amount of padding to apply to the full-resolution padding
1017    pub inpaint_full_resolution_padding: Option<u32>,
1018
1019    /// By default, the masked area is inpainted. If this is turned on, the unmasked area
1020    /// will be inpainted.
1021    pub inpainting_mask_invert: bool,
1022}
1023
1024/// The result of the generation.
1025pub struct GenerationResult {
1026    /// The images produced by the generator, as PNG byte arrays.
1027    ///
1028    /// Note that these contain any information included by the generators,
1029    /// including the PNG info.
1030    pub pngs: Vec<Vec<u8>>,
1031    /// The information associated with this generation.
1032    pub info: GenerationInfo,
1033}
1034impl GenerationResult {
1035    /// Converts [pngs] to [image::DynamicImage]s.
1036    ///
1037    /// Note that this conversion will lose PNG info and any other information embedded in the PNGs.
1038    pub fn images(&self) -> Result<Vec<image::DynamicImage>> {
1039        self.pngs
1040            .iter()
1041            .map(|p| image::load_from_memory(&p).map_err(|e| ClientError::from(e)))
1042            .collect()
1043    }
1044}
1045
1046/// The information associated with a generation.
1047#[derive(Debug, Clone)]
1048pub struct GenerationInfo {
1049    /// The prompts used for each image in the generation.
1050    pub prompts: Vec<String>,
1051    /// The negative prompt for each image in the generation.
1052    pub negative_prompts: Vec<String>,
1053    /// The seeds for the images; each seed corresponds to an image.
1054    pub seeds: Vec<i64>,
1055    /// The subseeds for the images; each seed corresponds to an image.
1056    pub subseeds: Vec<i64>,
1057    /// The strength of the subseed.
1058    pub subseed_strength: f32,
1059    /// The width of the generated images.
1060    pub width: u32,
1061    /// The height of the generated images.
1062    pub height: u32,
1063    /// The sampler that was used for this generation.
1064    pub sampler: Sampler,
1065    /// The number of steps that were used for each generation.
1066    pub steps: u32,
1067    /// Whether or not the image should be tiled at the edges
1068    pub tiling: bool,
1069
1070    /// The Classifier-Free Guidance scale; how strongly the prompt was
1071    /// applied to the generation
1072    pub cfg_scale: f32,
1073    /// The denoising strength
1074    pub denoising_strength: f32,
1075
1076    /// Whether or not the face restoration was applied
1077    pub restore_faces: bool,
1078
1079    /// The width to resize the image from if reusing a seed with a different size
1080    pub seed_resize_from_w: Option<u32>,
1081    /// The height to resize the image from if reusing a seed with a different size
1082    pub seed_resize_from_h: Option<u32>,
1083
1084    /// Any styles applied to the generation
1085    pub styles: Vec<String>,
1086
1087    /// CLIP rounds to skip
1088    pub clip_skip: usize,
1089    /// Face restoration model in use
1090    pub face_restoration_model: Option<String>,
1091    /// Whether or not inpainting conditioning is being used
1092    pub is_using_inpainting_conditioning: bool,
1093    /// When the job was run
1094    pub job_timestamp: chrono::DateTime<chrono::Local>,
1095    /// The hash of the model in use. Note that this is the *short* hash,
1096    /// not the long hash.
1097    pub model_hash: String,
1098}
1099
1100/// A request to post-process an image. See [Client::postprocess].
1101#[derive(Default)]
1102pub struct PostprocessRequest {
1103    /// How the image should be resized to fit the target resolution if the
1104    /// resolution doesn't fit within the frame
1105    pub resize_mode: ResizeMode,
1106    /// The first upscaler to use
1107    pub upscaler_1: Upscaler,
1108    /// The second upscaler to use
1109    pub upscaler_2: Upscaler,
1110    /// The scale factor to use
1111    pub scale_factor: f32,
1112
1113    /// How much of CodeFormer's result is blended into the result? [0-1]
1114    pub codeformer_visibility: Option<f32>,
1115    /// How strong is CodeFormer's effect? [0-1]
1116    pub codeformer_weight: Option<f32>,
1117    /// How much of the second upscaler's result is blended into the result? [0-1]
1118    pub upscaler_2_visibility: Option<f32>,
1119    /// How much of GFPGAN's result is blended into the result? [0-1]
1120    pub gfpgan_visibility: Option<f32>,
1121    /// Should upscaling occur before face restoration?
1122    pub upscale_first: Option<bool>,
1123}
1124
1125macro_rules! define_user_friendly_enum {
1126    ($enum_name:ident, $doc:literal, {$(($name:ident, $friendly_name:literal)),*}) => {
1127        #[doc = $doc]
1128        #[derive(Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
1129        pub enum $enum_name {
1130            $(
1131                #[doc = $friendly_name]
1132                $name
1133            ),*
1134        }
1135        impl std::fmt::Display for $enum_name {
1136            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1137                write!(f, "{}", match self {
1138                    $(
1139                        Self::$name => $friendly_name
1140                    ),*
1141                })
1142            }
1143        }
1144        impl std::fmt::Debug for $enum_name {
1145            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1146                std::fmt::Display::fmt(self, f)
1147            }
1148        }
1149        impl TryFrom<&str> for $enum_name {
1150            type Error = ();
1151
1152            fn try_from(s: &str) -> core::result::Result<Self, ()> {
1153                match s {
1154                    $(
1155                        $friendly_name => Ok(Self::$name),
1156                    )*
1157                    _ => Err(()),
1158                }
1159            }
1160        }
1161        impl $enum_name {
1162            /// All of the possible values.
1163            pub const VALUES: &[Self] = &[
1164                $(Self::$name),*
1165            ];
1166        }
1167    }
1168}
1169
1170define_user_friendly_enum!(
1171    Sampler,
1172    "The sampler to use for the generation.",
1173    {
1174        (EulerA, "Euler a"),
1175        (Euler, "Euler"),
1176        (Lms, "LMS"),
1177        (Heun, "Heun"),
1178        (Dpm2, "DPM2"),
1179        (Dpm2A, "DPM2 a"),
1180        (DpmPP2SA, "DPM++ 2S a"),
1181        (DpmPP2M, "DPM++ 2M"),
1182        (DpmPPSDE, "DPM++ SDE"),
1183        (DpmFast, "DPM fast"),
1184        (DpmAdaptive, "DPM adaptive"),
1185        (LmsKarras, "LMS Karras"),
1186        (Dpm2Karras, "DPM2 Karras"),
1187        (Dpm2AKarras, "DPM2 a Karras"),
1188        (DpmPP2SAKarras, "DPM++ 2S a Karras"),
1189        (DpmPP2MKarras, "DPM++ 2M Karras"),
1190        (DpmPPSDEKarras, "DPM++ SDE Karras"),
1191        (Ddim, "DDIM"),
1192        (Plms, "PLMS")
1193    }
1194);
1195
1196define_user_friendly_enum!(
1197    Interrogator,
1198    "Supported interrogators for [Client::interrogate]",
1199    {
1200        (Clip, "CLIP"),
1201        (DeepDanbooru, "DeepDanbooru")
1202    }
1203);
1204
1205define_user_friendly_enum!(
1206    ResizeMode,
1207    "How to resize the image for image-to-image generation",
1208    {
1209        (Resize, "Just resize"),
1210        (CropAndResize, "Crop and resize"),
1211        (ResizeAndFill, "Resize and fill")
1212    }
1213);
1214impl Default for ResizeMode {
1215    fn default() -> Self {
1216        Self::Resize
1217    }
1218}
1219impl From<ResizeMode> for u32 {
1220    fn from(mode: ResizeMode) -> Self {
1221        match mode {
1222            ResizeMode::Resize => 0,
1223            ResizeMode::CropAndResize => 1,
1224            ResizeMode::ResizeAndFill => 2,
1225        }
1226    }
1227}
1228
1229define_user_friendly_enum!(
1230    InpaintingFillMode,
1231    "How the area to be inpainted will be initialized",
1232    {
1233        (Fill, "Fill"),
1234        (Original, "Original"),
1235        (LatentNoise, "Latent noise"),
1236        (LatentNothing, "Latent nothing")
1237    }
1238);
1239impl Default for InpaintingFillMode {
1240    fn default() -> Self {
1241        Self::Original
1242    }
1243}
1244
1245define_user_friendly_enum!(
1246    Upscaler,
1247    "Upscaler",
1248    {
1249        (None, "None"),
1250        (Lanczos, "Lanczos"),
1251        (Nearest, "Nearest"),
1252        (Ldsr, "LDSR"),
1253        (ScuNetGan, "ScuNET GAN"),
1254        (ScuNetPSNR, "ScuNET PSNR"),
1255        (SwinIR4x, "SwinIR 4x"),
1256        (ESRGAN4x, "ESRGAN_4x")
1257    }
1258);
1259impl Default for Upscaler {
1260    fn default() -> Self {
1261        Self::None
1262    }
1263}
1264
1265/// The currently set options for the UI
1266#[derive(Debug, Clone)]
1267pub struct Options {
1268    /// Current model
1269    pub model: String,
1270
1271    /// s_churn
1272    pub s_churn: f32,
1273    /// s_noise
1274    pub s_noise: f32,
1275    /// s_tmin
1276    pub s_tmin: f32,
1277}
1278
1279/// Model
1280#[derive(Debug, Clone)]
1281pub struct Model {
1282    /// Title of the model
1283    pub title: String,
1284    /// Name of the model
1285    pub name: String,
1286    /// Short hash of the model, if available
1287    pub hash_short: Option<String>,
1288    /// Long SHA256 hash of the model, if available
1289    pub hash_sha256: Option<String>,
1290}
1291
1292/// Prompt style
1293#[derive(Debug, Clone)]
1294pub struct PromptStyle {
1295    /// Name of the style
1296    pub name: String,
1297    /// Prompt of the style
1298    pub prompt: Option<String>,
1299    /// Negative prompt of the style
1300    pub negative_prompt: Option<String>,
1301}
1302
1303/// Artist
1304#[derive(Debug, Clone)]
1305pub struct Artist {
1306    /// Name of the artist
1307    pub name: String,
1308    /// Category the artist belongs to
1309    pub category: String,
1310}
1311
1312/// A textual inversion embedding
1313#[derive(Debug, Clone)]
1314pub struct Embedding {
1315    /// The number of steps that were used to train this embedding, if available
1316    pub step: Option<u32>,
1317    /// The hash of the checkpoint this embedding was trained on, if available
1318    pub sd_checkpoint: Option<String>,
1319    /// The name of the checkpoint this embedding was trained on, if available
1320    ///
1321    /// Note that this is the name that was used by the trainer; for a stable identifier, use [sd_checkpoint] instead
1322    pub sd_checkpoint_name: Option<String>,
1323    /// The length of each individual vector in the embedding
1324    pub shape: u32,
1325    /// The number of vectors in the embedding
1326    pub vectors: u32,
1327}
1328
1329/// All available textual inversion embeddings
1330#[derive(Debug, Clone)]
1331pub struct Embeddings {
1332    /// Embeddings loaded for the current model
1333    pub loaded: HashMap<String, Embedding>,
1334    /// Embeddings skipped for the current model (likely due to architecture incompatibility)
1335    pub skipped: HashMap<String, Embedding>,
1336}
1337impl Embeddings {
1338    /// Iterator over all embeddings available, including currently unloaded ones
1339    pub fn all(&self) -> impl Iterator<Item = (&String, &Embedding)> {
1340        self.loaded.iter().chain(self.skipped.iter())
1341    }
1342}
1343
1344#[allow(dead_code)]
1345mod config {
1346    use crate::{ClientError, RequestClient, Result};
1347    use std::collections::HashMap;
1348
1349    #[derive(Debug)]
1350    enum ConfigComponent {
1351        Dropdown { choices: Vec<String>, id: String },
1352        Radio { choices: Vec<String>, id: String },
1353    }
1354
1355    /// The top-level Gradio configuration for the Web UI.
1356    ///
1357    /// Used to get things that aren't in the API.
1358    #[derive(Debug)]
1359    struct Config(HashMap<u32, ConfigComponent>);
1360    impl Config {
1361        async fn new(client: &RequestClient) -> Result<Self> {
1362            Ok(Self(
1363                client
1364                    .get::<HashMap<String, serde_json::Value>>("config")
1365                    .await?
1366                    .get("components")
1367                    .ok_or_else(|| ClientError::invalid_response("components"))?
1368                    .as_array()
1369                    .ok_or_else(|| ClientError::invalid_response("components to be an array"))?
1370                    .iter()
1371                    .filter_map(|v| v.as_object())
1372                    .filter_map(|o| {
1373                        let id = o.get("id")?.as_u64()? as u32;
1374                        let comp_type = o.get("type")?.as_str()?;
1375                        let props = o.get("props")?.as_object()?;
1376                        match comp_type {
1377                            "dropdown" => Some((
1378                                id,
1379                                ConfigComponent::Dropdown {
1380                                    choices: extract_string_array(props.get("choices")?)?,
1381                                    id: props.get("elem_id")?.as_str()?.to_owned(),
1382                                },
1383                            )),
1384                            "radio" => Some((
1385                                id,
1386                                ConfigComponent::Radio {
1387                                    choices: extract_string_array(props.get("choices")?)?,
1388                                    id: props.get("elem_id")?.as_str()?.to_owned(),
1389                                },
1390                            )),
1391                            _ => None,
1392                        }
1393                    })
1394                    .collect(),
1395            ))
1396        }
1397
1398        /// All of the embeddings available.
1399        fn embeddings(&self) -> Result<Vec<String>> {
1400            self.get_dropdown("train_embedding")
1401        }
1402
1403        fn values(&self) -> impl Iterator<Item = &ConfigComponent> {
1404            self.0.values()
1405        }
1406        fn get_dropdown(&self, target_id: &str) -> Result<Vec<String>> {
1407            self.values()
1408                .find_map(|comp| match comp {
1409                    ConfigComponent::Dropdown { id, choices, .. } if id == target_id => {
1410                        Some(choices.clone())
1411                    }
1412                    _ => None,
1413                })
1414                .ok_or_else(|| {
1415                    ClientError::invalid_response(&format!("no {target_id} dropdown component"))
1416                })
1417        }
1418        fn get_radio(&self, target_id: &str) -> Result<Vec<String>> {
1419            self.values()
1420                .find_map(|comp| match comp {
1421                    ConfigComponent::Radio { id, choices, .. } if id == target_id => {
1422                        Some(choices.clone())
1423                    }
1424                    _ => None,
1425                })
1426                .ok_or_else(|| {
1427                    ClientError::invalid_response(&format!("no {target_id} radio component"))
1428                })
1429        }
1430    }
1431
1432    fn extract_string_array(value: &serde_json::Value) -> Option<Vec<String>> {
1433        Some(
1434            value
1435                .as_array()?
1436                .iter()
1437                .flat_map(|s| Some(s.as_str()?.to_owned()))
1438                .collect(),
1439        )
1440    }
1441}
1442
1443#[derive(Clone)]
1444struct RequestClient {
1445    url: String,
1446    client: reqwest::Client,
1447    authentication_token: Option<String>,
1448}
1449impl RequestClient {
1450    async fn new(url: &str) -> Result<Self> {
1451        if !url.starts_with("http") {
1452            return Err(ClientError::InvalidUrl);
1453        }
1454
1455        let url = url.strip_suffix('/').unwrap_or(url).to_owned();
1456        let client = reqwest::ClientBuilder::new().cookie_store(true).build()?;
1457
1458        Ok(Self {
1459            url,
1460            client,
1461            authentication_token: None,
1462        })
1463    }
1464
1465    fn set_authentication_token(&mut self, token: String) {
1466        self.authentication_token = Some(format!("Basic {token}"));
1467    }
1468
1469    fn url(&self, endpoint: &str) -> String {
1470        format!("{}/{}", self.url, endpoint)
1471    }
1472
1473    fn check_for_authentication<R: DeserializeOwned>(body: String) -> Result<R> {
1474        if body.trim() == "Internal Server Error" {
1475            return Err(ClientError::InternalServerError);
1476        }
1477
1478        match serde_json::from_str::<HashMap<String, serde_json::Value>>(&body) {
1479            Ok(json_body) => match json_body.get("detail") {
1480                Some(serde_json::Value::String(message)) => {
1481                    if message == "Not authenticated" {
1482                        Err(ClientError::NotAuthenticated)
1483                    } else {
1484                        Err(ClientError::Error {
1485                            message: message.clone(),
1486                        })
1487                    }
1488                }
1489                Some(other_error) => Err(ClientError::Error {
1490                    message: other_error.to_string(),
1491                }),
1492                _ => Ok(serde_json::from_str(&body)?),
1493            },
1494            Err(_) => Ok(serde_json::from_str(&body)?),
1495        }
1496    }
1497
1498    async fn send<R: DeserializeOwned>(&self, builder: reqwest::RequestBuilder) -> Result<R> {
1499        let builder = if let Some(token) = &self.authentication_token {
1500            builder.header("Authorization", token)
1501        } else {
1502            builder
1503        };
1504        Self::check_for_authentication(builder.send().await?.text().await?)
1505    }
1506    async fn get<R: DeserializeOwned>(&self, endpoint: &str) -> Result<R> {
1507        self.send(self.client.get(self.url(endpoint))).await
1508    }
1509    async fn post<R: DeserializeOwned, T: Serialize>(&self, endpoint: &str, body: &T) -> Result<R> {
1510        self.send(self.client.post(self.url(endpoint)).json(body))
1511            .await
1512    }
1513    fn post_raw(&self, endpoint: &str) -> reqwest::RequestBuilder {
1514        self.client.post(self.url(endpoint))
1515    }
1516}
1517
1518#[derive(Serialize, Deserialize, Default)]
1519struct OverrideSettings {
1520    #[serde(skip_serializing_if = "Option::is_none")]
1521    sd_model_checkpoint: Option<String>,
1522}
1523impl OverrideSettings {
1524    fn from_model(model: Option<&Model>) -> Self {
1525        Self {
1526            sd_model_checkpoint: model.map(|m| m.title.clone()),
1527        }
1528    }
1529}