1#[macro_use]
3extern crate anyhow;
4
5use cake::Mode;
6
7use clap::{Parser, ValueEnum};
8use serde::Deserialize;
9
10pub mod cake;
11pub mod models;
12pub mod utils;
13
14#[derive(Copy, Clone, Parser, Default, Debug, Eq, PartialEq, PartialOrd, Ord, ValueEnum)]
15pub enum ModelType {
16 #[default]
17 TextModel,
18 ImageModel,
19}
20
21#[derive(Copy, Clone, Parser, Default, Debug, Eq, PartialEq, PartialOrd, Ord, ValueEnum)]
23pub enum TextModelArch {
24 #[default]
26 Auto,
27 Llama,
29 Qwen2,
31 Qwen3_5,
33}
34
35#[derive(Clone, Parser, Default, Debug)]
36#[command(author, version, about, long_about = None)]
37pub struct Args {
38 #[arg(long, default_value_t = 0)]
40 pub device: usize,
41 #[arg(skip)]
43 pub mode: Mode,
44 #[arg(long)]
46 pub name: Option<String>,
47 #[arg(long, default_value = "0.0.0.0:10128")]
49 pub address: String,
50 #[arg(long)]
52 pub api: Option<String>,
53 #[arg(long, default_value = "./cake-data/Meta-Llama-3-8B/")]
55 pub model: String,
56 #[arg(long)]
58 pub topology: Option<String>,
59 #[arg(long, default_value = "The sky is blue because ")]
61 pub prompt: String,
62 #[arg(long, default_value = "You are a helpful AI assistant.")]
64 pub system_prompt: String,
65 #[arg(long, default_value_t = 299792458)]
67 pub seed: u64,
68 #[arg(short = 'n', long, default_value_t = 2048)]
70 pub sample_len: usize,
71 #[arg(long, default_value_t = 1.0)]
73 pub temperature: f64,
74 #[arg(long)]
76 pub top_p: Option<f64>,
77 #[arg(long)]
79 pub top_k: Option<usize>,
80 #[arg(long, default_value_t = 1.1)]
82 pub repeat_penalty: f32,
83 #[arg(long, default_value_t = 128)]
85 pub repeat_last_n: usize,
86 #[arg(long)]
88 pub dtype: Option<String>,
89
90 #[arg(long, env = "CAKE_CLUSTER_KEY")]
94 pub cluster_key: Option<String>,
95
96 #[arg(long, default_value_t = 10)]
98 pub discovery_timeout: u64,
99
100 #[arg(long)]
102 pub ui_auth: Option<String>,
103
104 #[arg(skip)]
106 pub topology_override: Option<cake::Topology>,
107
108 #[arg(long, default_value_t = false)]
110 pub cpu: bool,
111
112 #[arg(long, default_value = "text-model")]
113 pub model_type: ModelType,
114
115 #[arg(long, default_value = "auto")]
117 pub text_model_arch: TextModelArch,
118
119 #[clap(flatten)]
120 pub sd_args: SDArgs,
121
122 #[clap(flatten)]
123 pub sd_img_gen_args: ImageGenerationArgs,
124}
125
126#[derive(Clone, Parser, Default, Debug)]
127pub struct SDArgs {
128 #[arg(long = "sd-tokenizer")]
129 pub tokenizer: Option<String>,
130
131 #[arg(long = "sd-tokenizer-2")]
132 pub tokenizer_2: Option<String>,
133
134 #[arg(long = "sd-version", value_enum, default_value = "v1-5")]
135 sd_version: StableDiffusionVersion,
136
137 #[arg(long = "sd-use-f16", default_value_t = true)]
138 use_f16: bool,
139
140 #[arg(long = "sd-width")]
141 width: Option<usize>,
142
143 #[arg(long = "sd-height")]
144 height: Option<usize>,
145
146 #[arg(long = "sd-sliced-attention-size")]
147 sliced_attention_size: Option<usize>,
148
149 #[arg(long = "sd-clip")]
150 clip: Option<String>,
151
152 #[arg(long = "sd-clip2")]
153 clip2: Option<String>,
154
155 #[arg(long = "sd-vae")]
156 vae: Option<String>,
157
158 #[arg(long = "sd-unet")]
159 unet: Option<String>,
160
161 #[arg(long = "sd-use-flash-attention", default_value_t = false)]
162 use_flash_attention: bool,
163}
164
165fn default_prompt() -> String {
166 "A very realistic photo of a rusty robot walking on a sandy beach".to_string()
167}
168
169fn empty_str() -> String {
170 "".to_string()
171}
172
173fn usize_one() -> usize {
174 1
175}
176
177fn default_img2img_strength() -> f64 {
178 0.8
179}
180
181#[derive(Clone, Parser, Default, Debug, Deserialize)]
182pub struct ImageGenerationArgs {
183 #[arg(
185 long = "sd-image-prompt",
186 default_value = "A very realistic photo of a rusty robot walking on a sandy beach"
187 )]
188 #[serde(rename(deserialize = "sd-image-prompt"), default = "default_prompt")]
189 image_prompt: String,
190
191 #[arg(long = "sd-uncond-prompt", default_value = "")]
192 #[serde(rename(deserialize = "sd-uncond-prompt"), default = "empty_str")]
193 uncond_prompt: String,
194
195 #[arg(long = "sd-tracing", default_value_t = false)]
197 #[serde(rename(deserialize = "sd-tracing"), default)]
198 tracing: bool,
199
200 #[arg(long = "sd-n-steps")]
202 #[serde(rename(deserialize = "sd-n-steps"))]
203 n_steps: Option<usize>,
204
205 #[arg(long = "sd-num-samples", default_value_t = 1)]
207 #[serde(rename(deserialize = "sd-num-samples"), default = "usize_one")]
208 num_samples: usize,
209
210 #[arg(long = "sd-bsize", default_value_t = 1)]
212 #[serde(rename(deserialize = "sd-bsize"), default = "usize_one")]
213 bsize: usize,
214
215 #[arg(long = "sd-intermediary-images", default_value_t = 0, action)]
217 #[serde(rename(deserialize = "sd-intermediary-images"), default)]
218 intermediary_images: usize,
219
220 #[arg(long = "sd-guidance-scale")]
221 #[serde(rename(deserialize = "sd-guidance-scale"))]
222 guidance_scale: Option<f64>,
223
224 #[arg(long = "sd-img2img", value_name = "FILE")]
225 #[serde(rename(deserialize = "sd-img2img"))]
226 img2img: Option<String>,
227
228 #[arg(long = "sd-img2img-strength", default_value_t = 0.8)]
232 #[serde(
233 rename(deserialize = "sd-img2img-strength"),
234 default = "default_img2img_strength"
235 )]
236 img2img_strength: f64,
237
238 #[arg(long = "sd-seed")]
240 #[serde(rename(deserialize = "sd-seed"))]
241 image_seed: Option<u64>,
242}
243
244#[derive(Debug, Clone, Copy, clap::ValueEnum, PartialEq, Eq, Default)]
245pub enum StableDiffusionVersion {
246 #[default]
247 V1_5,
248 V2_1,
249 Xl,
250 Turbo,
251}
252
253impl StableDiffusionVersion {
254 fn repo(&self) -> &'static str {
255 match self {
256 Self::Xl => "stabilityai/stable-diffusion-xl-base-1.0",
257 Self::V2_1 => "stabilityai/stable-diffusion-2-1",
258 Self::V1_5 => "runwayml/stable-diffusion-v1-5",
259 Self::Turbo => "stabilityai/sdxl-turbo",
260 }
261 }
262
263 fn unet_file(&self, use_f16: bool) -> &'static str {
264 match self {
265 Self::V1_5 | Self::V2_1 | Self::Xl | Self::Turbo => {
266 if use_f16 {
267 "unet/diffusion_pytorch_model.fp16.safetensors"
268 } else {
269 "unet/diffusion_pytorch_model.safetensors"
270 }
271 }
272 }
273 }
274
275 fn vae_file(&self, use_f16: bool) -> &'static str {
276 match self {
277 Self::V1_5 | Self::V2_1 | Self::Xl | Self::Turbo => {
278 if use_f16 {
279 "vae/diffusion_pytorch_model.fp16.safetensors"
280 } else {
281 "vae/diffusion_pytorch_model.safetensors"
282 }
283 }
284 }
285 }
286
287 fn clip_file(&self, use_f16: bool) -> &'static str {
288 match self {
289 Self::V1_5 | Self::V2_1 | Self::Xl | Self::Turbo => {
290 if use_f16 {
291 "text_encoder/model.fp16.safetensors"
292 } else {
293 "text_encoder/model.safetensors"
294 }
295 }
296 }
297 }
298
299 fn clip2_file(&self, use_f16: bool) -> &'static str {
300 match self {
301 Self::V1_5 | Self::V2_1 | Self::Xl | Self::Turbo => {
302 if use_f16 {
303 "text_encoder_2/model.fp16.safetensors"
304 } else {
305 "text_encoder_2/model.safetensors"
306 }
307 }
308 }
309 }
310}