sd3/
sd3.rs

1// cargo run --example sd3 -- 'A cute Crab in gradient colors.'
2use gradio::{Client, ClientOptions, PredictionInput, PredictionOutput};
3
4#[tokio::main]
5async fn main() {
6    if std::env::args().len() < 2 {
7        println!("Please provide the prompt as an argument");
8        std::process::exit(1);
9    }
10    let args: Vec<String> = std::env::args().collect();
11    let prompt = &args[1];
12
13    let client = Client::new(
14        "stabilityai/stable-diffusion-3-medium",
15        ClientOptions::default(),
16    )
17    .await
18    .unwrap();
19
20    let mut prediction = client
21        .submit(
22            "/infer",
23            vec![
24                PredictionInput::from_value(prompt),
25                PredictionInput::from_value(""),   // negative_prompt
26                PredictionInput::from_value(0),    // seed
27                PredictionInput::from_value(true), // randomize_seed
28                PredictionInput::from_value(1024), // width
29                PredictionInput::from_value(1024), // height
30                PredictionInput::from_value(5),    // guidance_scale
31                PredictionInput::from_value(28),   // num_inference_steps
32            ],
33        )
34        .await
35        .unwrap();
36
37    while let Some(event) = prediction.next().await {
38        let event = event.unwrap();
39        match event {
40            gradio::structs::QueueDataMessage::Estimation {
41                rank, queue_size, ..
42            } => {
43                println!("Queueing: {}/{}", rank + 1, queue_size);
44            }
45            gradio::structs::QueueDataMessage::Progress { progress_data, .. } => {
46                if progress_data.is_none() {
47                    continue;
48                }
49                let progress_data = progress_data.unwrap();
50                if !progress_data.is_empty() {
51                    let progress_data = &progress_data[0];
52                    println!(
53                        "Processing: {}/{} {}",
54                        progress_data.index + 1,
55                        progress_data.length.unwrap(),
56                        progress_data.unit
57                    );
58                }
59            }
60            gradio::structs::QueueDataMessage::ProcessCompleted { output, .. } => {
61                let output: Vec<PredictionOutput> = output.try_into().unwrap();
62
63                println!(
64                    "Generated Image: {}",
65                    output[0].clone().as_file().unwrap().url.unwrap()
66                );
67                println!(
68                    "Seed: {}",
69                    output[1].clone().as_value().unwrap().as_i64().unwrap()
70                );
71                break;
72            }
73            _ => {}
74        }
75    }
76}