Skip to main content

cake_core/
lib.rs

1//! This is the core library where all Cake logic is implemented.
2#[macro_use]
3extern crate anyhow;
4
5use cake::Mode;
6
7use clap::{Parser, ValueEnum};
8use serde::Deserialize;
9
10pub mod cake;
11pub mod models;
12pub mod utils;
13
14#[derive(Copy, Clone, Parser, Default, Debug, Eq, PartialEq, PartialOrd, Ord, ValueEnum)]
15pub enum ModelType {
16    #[default]
17    TextModel,
18    ImageModel,
19}
20
21/// Supported text model architectures.
22#[derive(Copy, Clone, Parser, Default, Debug, Eq, PartialEq, PartialOrd, Ord, ValueEnum)]
23pub enum TextModelArch {
24    /// Auto-detect from config.json
25    #[default]
26    Auto,
27    /// LLaMA family
28    Llama,
29    /// Qwen2/Qwen2.5 family
30    Qwen2,
31    /// Qwen3.5 hybrid linear/full attention
32    Qwen3_5,
33}
34
35#[derive(Clone, Parser, Default, Debug)]
36#[command(author, version, about, long_about = None)]
37pub struct Args {
38    /// GPU device index.
39    #[arg(long, default_value_t = 0)]
40    pub device: usize,
41    /// Mode (set by subcommand, not directly by user).
42    #[arg(skip)]
43    pub mode: Mode,
44    /// Worker name.
45    #[arg(long)]
46    pub name: Option<String>,
47    /// Binding address and port for workers.
48    #[arg(long, default_value = "0.0.0.0:10128")]
49    pub address: String,
50    /// Enable OpenAI compatible chat completion API.
51    #[arg(long)]
52    pub api: Option<String>,
53    /// Path to model directory, or HuggingFace repo ID (e.g., Qwen/Qwen2.5-Coder-1.5B-Instruct).
54    #[arg(long, default_value = "./cake-data/Meta-Llama-3-8B/")]
55    pub model: String,
56    /// Topology file.
57    #[arg(long)]
58    pub topology: Option<String>,
59    /// The initial prompt.
60    #[arg(long, default_value = "The sky is blue because ")]
61    pub prompt: String,
62    /// The system prompt.
63    #[arg(long, default_value = "You are a helpful AI assistant.")]
64    pub system_prompt: String,
65    /// The seed to use when generating random samples.
66    #[arg(long, default_value_t = 299792458)]
67    pub seed: u64,
68    /// The length of the sample to generate (in tokens).
69    #[arg(short = 'n', long, default_value_t = 2048)]
70    pub sample_len: usize,
71    /// The temperature used to generate samples.
72    #[arg(long, default_value_t = 1.0)]
73    pub temperature: f64,
74    /// Nucleus sampling probability cutoff.
75    #[arg(long)]
76    pub top_p: Option<f64>,
77    /// Only sample among the top K samples.
78    #[arg(long)]
79    pub top_k: Option<usize>,
80    /// Penalty to be applied for repeating tokens, 1. means no penalty.
81    #[arg(long, default_value_t = 1.1)]
82    pub repeat_penalty: f32,
83    /// The context size to consider for the repeat penalty.
84    #[arg(long, default_value_t = 128)]
85    pub repeat_last_n: usize,
86    /// Use different dtype than f16
87    #[arg(long)]
88    pub dtype: Option<String>,
89
90    /// Cluster key for zero-config mDNS discovery and PSK authentication.
91    /// When set on both master and workers, enables automatic discovery,
92    /// layer assignment, and model data push without topology files.
93    #[arg(long, env = "CAKE_CLUSTER_KEY")]
94    pub cluster_key: Option<String>,
95
96    /// How long to wait for worker discovery (seconds). 0 = skip discovery.
97    #[arg(long, default_value_t = 10)]
98    pub discovery_timeout: u64,
99
100    /// Optional basic auth for the web UI (format: "user:pass").
101    #[arg(long)]
102    pub ui_auth: Option<String>,
103
104    /// Topology built during zero-config setup (not a CLI arg).
105    #[arg(skip)]
106    pub topology_override: Option<cake::Topology>,
107
108    /// Run on CPU rather than on GPU.
109    #[arg(long, default_value_t = false)]
110    pub cpu: bool,
111
112    #[arg(long, default_value = "text-model")]
113    pub model_type: ModelType,
114
115    /// Text model architecture (auto-detected from config.json if omitted).
116    #[arg(long, default_value = "auto")]
117    pub text_model_arch: TextModelArch,
118
119    #[clap(flatten)]
120    pub sd_args: SDArgs,
121
122    #[clap(flatten)]
123    pub sd_img_gen_args: ImageGenerationArgs,
124}
125
126#[derive(Clone, Parser, Default, Debug)]
127pub struct SDArgs {
128    #[arg(long = "sd-tokenizer")]
129    pub tokenizer: Option<String>,
130
131    #[arg(long = "sd-tokenizer-2")]
132    pub tokenizer_2: Option<String>,
133
134    #[arg(long = "sd-version", value_enum, default_value = "v1-5")]
135    sd_version: StableDiffusionVersion,
136
137    #[arg(long = "sd-use-f16", default_value_t = true)]
138    use_f16: bool,
139
140    #[arg(long = "sd-width")]
141    width: Option<usize>,
142
143    #[arg(long = "sd-height")]
144    height: Option<usize>,
145
146    #[arg(long = "sd-sliced-attention-size")]
147    sliced_attention_size: Option<usize>,
148
149    #[arg(long = "sd-clip")]
150    clip: Option<String>,
151
152    #[arg(long = "sd-clip2")]
153    clip2: Option<String>,
154
155    #[arg(long = "sd-vae")]
156    vae: Option<String>,
157
158    #[arg(long = "sd-unet")]
159    unet: Option<String>,
160
161    #[arg(long = "sd-use-flash-attention", default_value_t = false)]
162    use_flash_attention: bool,
163}
164
165fn default_prompt() -> String {
166    "A very realistic photo of a rusty robot walking on a sandy beach".to_string()
167}
168
169fn empty_str() -> String {
170    "".to_string()
171}
172
173fn usize_one() -> usize {
174    1
175}
176
177fn default_img2img_strength() -> f64 {
178    0.8
179}
180
181#[derive(Clone, Parser, Default, Debug, Deserialize)]
182pub struct ImageGenerationArgs {
183    /// The prompt to be used for image generation.
184    #[arg(
185        long = "sd-image-prompt",
186        default_value = "A very realistic photo of a rusty robot walking on a sandy beach"
187    )]
188    #[serde(rename(deserialize = "sd-image-prompt"), default = "default_prompt")]
189    image_prompt: String,
190
191    #[arg(long = "sd-uncond-prompt", default_value = "")]
192    #[serde(rename(deserialize = "sd-uncond-prompt"), default = "empty_str")]
193    uncond_prompt: String,
194
195    /// Enable tracing (generates a trace-timestamp.json file).
196    #[arg(long = "sd-tracing", default_value_t = false)]
197    #[serde(rename(deserialize = "sd-tracing"), default)]
198    tracing: bool,
199
200    /// The number of steps to run the diffusion for.
201    #[arg(long = "sd-n-steps")]
202    #[serde(rename(deserialize = "sd-n-steps"))]
203    n_steps: Option<usize>,
204
205    /// The number of samples to generate iteratively.
206    #[arg(long = "sd-num-samples", default_value_t = 1)]
207    #[serde(rename(deserialize = "sd-num-samples"), default = "usize_one")]
208    num_samples: usize,
209
210    /// The numbers of samples to generate simultaneously.
211    #[arg(long = "sd-bsize", default_value_t = 1)]
212    #[serde(rename(deserialize = "sd-bsize"), default = "usize_one")]
213    bsize: usize,
214
215    /// Generate intermediary images every n steps.
216    #[arg(long = "sd-intermediary-images", default_value_t = 0, action)]
217    #[serde(rename(deserialize = "sd-intermediary-images"), default)]
218    intermediary_images: usize,
219
220    #[arg(long = "sd-guidance-scale")]
221    #[serde(rename(deserialize = "sd-guidance-scale"))]
222    guidance_scale: Option<f64>,
223
224    #[arg(long = "sd-img2img", value_name = "FILE")]
225    #[serde(rename(deserialize = "sd-img2img"))]
226    img2img: Option<String>,
227
228    /// The strength, indicates how much to transform the initial image. The
229    /// value must be between 0 and 1, a value of 1 discards the initial image
230    /// information.
231    #[arg(long = "sd-img2img-strength", default_value_t = 0.8)]
232    #[serde(
233        rename(deserialize = "sd-img2img-strength"),
234        default = "default_img2img_strength"
235    )]
236    img2img_strength: f64,
237
238    /// The seed to use when generating random samples.
239    #[arg(long = "sd-seed")]
240    #[serde(rename(deserialize = "sd-seed"))]
241    image_seed: Option<u64>,
242}
243
244#[derive(Debug, Clone, Copy, clap::ValueEnum, PartialEq, Eq, Default)]
245pub enum StableDiffusionVersion {
246    #[default]
247    V1_5,
248    V2_1,
249    Xl,
250    Turbo,
251}
252
253impl StableDiffusionVersion {
254    fn repo(&self) -> &'static str {
255        match self {
256            Self::Xl => "stabilityai/stable-diffusion-xl-base-1.0",
257            Self::V2_1 => "stabilityai/stable-diffusion-2-1",
258            Self::V1_5 => "runwayml/stable-diffusion-v1-5",
259            Self::Turbo => "stabilityai/sdxl-turbo",
260        }
261    }
262
263    fn unet_file(&self, use_f16: bool) -> &'static str {
264        match self {
265            Self::V1_5 | Self::V2_1 | Self::Xl | Self::Turbo => {
266                if use_f16 {
267                    "unet/diffusion_pytorch_model.fp16.safetensors"
268                } else {
269                    "unet/diffusion_pytorch_model.safetensors"
270                }
271            }
272        }
273    }
274
275    fn vae_file(&self, use_f16: bool) -> &'static str {
276        match self {
277            Self::V1_5 | Self::V2_1 | Self::Xl | Self::Turbo => {
278                if use_f16 {
279                    "vae/diffusion_pytorch_model.fp16.safetensors"
280                } else {
281                    "vae/diffusion_pytorch_model.safetensors"
282                }
283            }
284        }
285    }
286
287    fn clip_file(&self, use_f16: bool) -> &'static str {
288        match self {
289            Self::V1_5 | Self::V2_1 | Self::Xl | Self::Turbo => {
290                if use_f16 {
291                    "text_encoder/model.fp16.safetensors"
292                } else {
293                    "text_encoder/model.safetensors"
294                }
295            }
296        }
297    }
298
299    fn clip2_file(&self, use_f16: bool) -> &'static str {
300        match self {
301            Self::V1_5 | Self::V2_1 | Self::Xl | Self::Turbo => {
302                if use_f16 {
303                    "text_encoder_2/model.fp16.safetensors"
304                } else {
305                    "text_encoder_2/model.safetensors"
306                }
307            }
308        }
309    }
310}