Skip to main content

cake_core/cake/
mod.rs

1use std::{
2    fmt::{Debug, Display},
3    path::PathBuf,
4    sync::{Arc, Mutex},
5};
6use tokio::net::TcpListener;
7
8use crate::{
9    models::common::{detect_text_model_arch, Cache, Config},
10    utils, Args, ModelType, TextModelArch,
11};
12use anyhow::Result;
13use async_trait::async_trait;
14use candle_core::{DType, Device, Tensor};
15use candle_nn::VarBuilder;
16
17#[cfg(feature = "master")]
18pub mod api;
19#[cfg(feature = "master")]
20mod master;
21
22pub mod auth;
23mod client;
24pub mod discovery;
25mod proto;
26pub mod setup;
27mod topology;
28mod worker;
29
30#[cfg(feature = "master")]
31pub use master::*;
32
33pub use client::*;
34pub use proto::*;
35pub use topology::*;
36pub use worker::*;
37
38/// Determines if we run in master or worker mode.
39#[derive(clap::ValueEnum, Clone, Debug, Default)]
40pub enum Mode {
41    #[default]
42    Master,
43    Worker,
44}
45
46/// Main contect object used as a shared state.
47#[derive(Clone)]
48pub struct Context {
49    pub args: Args,
50    pub dtype: DType,
51    pub topology: Topology,
52    pub data_path: PathBuf,
53    pub device: Device,
54    pub config: Option<Config>,
55    pub cache: Option<Cache>,
56    pub var_builder: Option<VarBuilder<'static>>,
57    /// Resolved text model architecture.
58    pub text_model_arch: TextModelArch,
59    /// True if the model uses FP8 block-wise quantization.
60    pub fp8: bool,
61    /// Pre-bound TCP listener from setup phase (taken once by Worker::new).
62    pub listener_override: Arc<Mutex<Option<TcpListener>>>,
63}
64
65impl Context {
66    /// Create the context from the parsed command line arguments.
67    pub fn from_args(mut args: Args) -> Result<Self> {
68        let dtype: DType = match args.dtype.as_deref() {
69            Some("f16") => DType::F16,
70            Some("bf16") => DType::BF16,
71            Some("f32") => DType::F32,
72            Some(dtype) => bail!("unsupported dtype {dtype}"),
73            None => DType::F16,
74        };
75
76        let device = utils::get_inference_device(args.cpu, args.device)
77            .map_err(|e| anyhow!("can't attach to device: {:?}", e))?;
78
79        // Disable cudarc event tracking for CUDA devices: cudarc's CudaStream::wait()
80        // rejects events from a different CudaContext, which breaks cross-device tensor
81        // transfers in multi-GPU setups. Safe since we use a single stream per device.
82        #[cfg(feature = "cuda")]
83        if let Device::Cuda(cuda_dev) = &device {
84            unsafe { cuda_dev.disable_event_tracking(); }
85        }
86
87        log::info!(
88            "[{:?}] dtype={:?} device={:?} mem={}",
89            args.mode,
90            &dtype,
91            &device,
92            human_bytes::human_bytes(memory_stats::memory_stats().map(|m| m.physical_mem).unwrap_or(0) as f64)
93        );
94
95        let data_path = PathBuf::from(&args.model);
96        let data_path = if !data_path.exists() {
97            if utils::hf::looks_like_hf_repo(&args.model) {
98                utils::hf::ensure_model_downloaded(&args.model)?
99            } else {
100                bail!("model path does not exist: {}", data_path.display());
101            }
102        } else {
103            data_path
104        };
105
106        let topology = if let Some(topo) = args.topology_override.take() {
107            // Zero-config setup already built the topology
108            topo
109        } else if let Some(path) = &args.topology {
110            Topology::from_path(path, &args.model_type)?
111        } else {
112            log::warn!("no topology file specified, the entire model will be loaded");
113            Topology::new()
114        };
115
116        let mut config: Option<Config> = None;
117        let mut cache: Option<Cache> = None;
118        let mut var_builder: Option<VarBuilder> = None;
119        let mut text_model_arch = args.text_model_arch;
120        let mut fp8 = false;
121
122        if args.model_type == ModelType::TextModel {
123            let config_filename = data_path.join("config.json");
124
125            // Auto-detect architecture if needed
126            if text_model_arch == TextModelArch::Auto {
127                let arch_str = detect_text_model_arch(&config_filename).unwrap_or_default();
128                text_model_arch = match arch_str.as_str() {
129                    #[cfg(feature = "qwen2")]
130                    "Qwen2ForCausalLM" => TextModelArch::Qwen2,
131                    #[cfg(feature = "qwen3_5")]
132                    "Qwen3_5ForConditionalGeneration" => TextModelArch::Qwen3_5,
133                    _ => TextModelArch::Llama,
134                };
135            }
136
137            log::info!("text model architecture: {:?}", text_model_arch);
138
139            let config_internal = match text_model_arch {
140                #[cfg(feature = "qwen2")]
141                TextModelArch::Qwen2 => {
142                    crate::models::qwen2::QwenConfig::from_path(&config_filename)?.into_config()
143                }
144                #[cfg(feature = "qwen3_5")]
145                TextModelArch::Qwen3_5 => {
146                    crate::models::qwen3_5::Qwen3_5Config::from_path(&config_filename)?.into_config()
147                }
148                #[cfg(feature = "llama")]
149                TextModelArch::Llama => {
150                    crate::models::llama3::LlamaConfig::from_path(&config_filename)?.into_config()
151                }
152                _ => {
153                    // Fallback: use a generic config parser approach
154                    // Parse the raw JSON and construct Config directly
155                    bail!("no text model feature enabled for architecture {:?}", text_model_arch)
156                }
157            };
158
159            let model_tensors_index: PathBuf = data_path.join("model.safetensors.index.json");
160            fp8 = utils::fp8::is_fp8_quantized(&config_filename);
161            if fp8 {
162                log::info!("model uses FP8 quantization — weights will be dequantized at load time");
163            }
164            let is_master = matches!(args.mode, Mode::Master);
165            let my_layers: Vec<String> = if !is_master {
166                topology.all_worker_layers().into_iter().collect()
167            } else {
168                vec![]
169            };
170
171            var_builder = Some(if is_master {
172                // Master: exclude shards that only contain remote-worker tensors
173                let worker_layers = topology.all_worker_layers();
174                if worker_layers.is_empty() {
175                    utils::load_var_builder_from_index(
176                        model_tensors_index,
177                        dtype,
178                        device.clone(),
179                        fp8,
180                    )?
181                } else {
182                    utils::load_var_builder_for_local_layers(
183                        model_tensors_index,
184                        dtype,
185                        device.clone(),
186                        &worker_layers,
187                        fp8,
188                    )?
189                }
190            } else if !my_layers.is_empty() {
191                // Worker with known layers: only load shards containing our layers
192                utils::load_var_builder_for_specific_layers(
193                    model_tensors_index,
194                    dtype,
195                    device.clone(),
196                    &my_layers,
197                    fp8,
198                )?
199            } else {
200                // Worker without known layers: load everything
201                utils::load_var_builder_from_index(
202                    model_tensors_index,
203                    dtype,
204                    device.clone(),
205                    fp8,
206                )?
207            });
208            cache = Some(Cache::new(true, dtype, &config_internal, &device)?);
209            config = Some(config_internal);
210        }
211
212        Ok(Context {
213            args,
214            dtype,
215            topology,
216            data_path,
217            device,
218            config,
219            cache,
220            var_builder,
221            text_model_arch,
222            fp8,
223            listener_override: Arc::new(Mutex::new(None)),
224        })
225    }
226}
227
228/// This is the trait that a shardable object must implement.
229#[async_trait]
230pub trait Forwarder: Debug + Send + Sync + Display {
231    /// Create an instance of this object loading the specified layer(s) from a VarBuilder.
232    fn load(name: String, ctx: &Context) -> Result<Box<Self>>
233    where
234        Self: Sized;
235
236    /// Applies a forward operation to the input tensor, does not require mutability.
237    async fn forward(
238        &self,
239        x: &Tensor,
240        index_pos: usize,
241        block_idx: usize,
242        ctx: &mut Context,
243    ) -> Result<Tensor>;
244
245    /// Applies a forward operation to the input tensor, requires mutability.
246    async fn forward_mut(
247        &mut self,
248        x: &Tensor,
249        index_pos: usize,
250        block_idx: usize,
251        ctx: &mut Context,
252    ) -> Result<Tensor>;
253
254    /// Applies a batch of forward operations to the input tensor.
255    async fn forward_batch(
256        &mut self,
257        _x: &Tensor,
258        _batch: Vec<(String, usize, usize)>,
259        _ctx: &mut Context,
260    ) -> Result<Tensor> {
261        unimplemented!()
262    }
263
264    async fn goodbye(&mut self) -> Result<()> {
265        Ok(())
266    }
267
268    /// Return the layer name.
269    fn layer_name(&self) -> &str;
270
271    /// Return the unique identity or local.
272    fn ident(&self) -> &str {
273        "local"
274    }
275}