Skip to main content

cake_core/utils/
mod.rs

1//! Utility functions and abstractions.
2
3pub mod fp8;
4pub mod hf;
5pub mod models;
6pub mod split;
7
8use std::path::{Path, PathBuf};
9
10use candle_core::{
11    utils::{cuda_is_available, metal_is_available},
12    DType, Device, Tensor,
13};
14
15use anyhow::{bail, Result};
16
17use candle_nn::VarBuilder;
18
19/// Returns the best available device at `ordinal` index (in case of multiple GPUs), or CPU if `force_cpu` is true.
20pub fn get_inference_device(force_cpu: bool, ordinal: usize) -> Result<Device> {
21    if force_cpu {
22        log::debug!("device is forced cpu");
23        Ok(Device::Cpu)
24    } else if cuda_is_available() {
25        log::debug!("device is cuda {ordinal}");
26        Ok(Device::new_cuda(ordinal)?)
27    } else if metal_is_available() {
28        log::debug!("device is metal {ordinal}");
29        Ok(Device::new_metal(ordinal)?)
30    } else {
31        log::debug!("device is cpu");
32        // fallback to cpu if nothing else available
33        Ok(Device::Cpu)
34    }
35}
36
37pub fn load_safetensors_from_model(path: &Path) -> Result<Vec<std::path::PathBuf>> {
38    log::info!("loading tensors from {} ...", "model.safetensors");
39    let result = vec![path.join("model.safetensors")];
40    Ok(result)
41}
42
43/// Load the safetensors files for a model from the hub based on a json index file.
44pub fn load_safetensors_paths_from_index(
45    tensors_index_json_filename: PathBuf,
46) -> Result<Vec<std::path::PathBuf>> {
47    log::info!(
48        "loading tensors from {} ...",
49        tensors_index_json_filename.display()
50    );
51
52    let parent_dir = tensors_index_json_filename.parent().unwrap();
53    let json_file = std::fs::File::open(&tensors_index_json_filename).map_err(|e| {
54        anyhow!(
55            "can't open {}: {:?}",
56            tensors_index_json_filename.display(),
57            e
58        )
59    })?;
60    let json: serde_json::Value = serde_json::from_reader(&json_file).map_err(|e| {
61        anyhow!(
62            "can't parse {}: {:?}",
63            tensors_index_json_filename.display(),
64            e
65        )
66    })?;
67    let weight_map = match json.get("weight_map") {
68        None => bail!("no weight map in {json_file:?}"),
69        Some(serde_json::Value::Object(map)) => map,
70        Some(_) => bail!("weight map in {json_file:?} is not a map"),
71    };
72    let mut safetensors_files = std::collections::HashSet::new();
73    for value in weight_map.values() {
74        if let Some(file) = value.as_str() {
75            safetensors_files.insert(file.to_string());
76        }
77    }
78    let safetensors_files = safetensors_files
79        .iter()
80        .map(|v| parent_dir.join(v))
81        .collect::<Vec<std::path::PathBuf>>();
82
83    Ok(safetensors_files)
84}
85
86/// Pre-read safetensor files into the OS page cache so that subsequent
87/// mmap access doesn't trigger per-tensor page faults during layer loading.
88/// Uses OnceLock to skip redundant calls (e.g. multi-GPU VarBuilder creation).
89fn prefetch_safetensors(filenames: &[PathBuf]) -> Result<()> {
90    use std::sync::OnceLock;
91    static DONE: OnceLock<()> = OnceLock::new();
92
93    if DONE.get().is_some() {
94        log::info!("safetensor files already in page cache, skipping prefetch");
95        return Ok(());
96    }
97
98    use std::io::Read;
99    let start = std::time::Instant::now();
100    let mut total_bytes: u64 = 0;
101    let mut buf = Vec::new();
102    for filename in filenames {
103        buf.clear();
104        std::fs::File::open(filename)
105            .map_err(|e| anyhow!("prefetch: can't open {}: {e}", filename.display()))?
106            .read_to_end(&mut buf)
107            .map_err(|e| anyhow!("prefetch: can't read {}: {e}", filename.display()))?;
108        total_bytes += buf.len() as u64;
109    }
110    log::info!(
111        "pre-cached {} in {:.1}s",
112        human_bytes::human_bytes(total_bytes as f64),
113        start.elapsed().as_secs_f64()
114    );
115
116    DONE.set(()).ok();
117    Ok(())
118}
119
120/// Create a VarBuilder with the tensors loaded from the index.
121pub fn load_var_builder_from_index<'a>(
122    tensor_index: PathBuf,
123    dtype: DType,
124    device: Device,
125    fp8: bool,
126) -> Result<VarBuilder<'a>> {
127    let filenames: Vec<std::path::PathBuf> = if tensor_index.exists() {
128        load_safetensors_paths_from_index(tensor_index)
129            .map_err(|e| anyhow!("can't load tensors index: {:?}", e))?
130    } else {
131        load_safetensors_from_model(tensor_index.parent().unwrap())
132            .map_err(|e| anyhow!("can't load tensors index: {:?}", e))?
133    };
134
135    prefetch_safetensors(&filenames)?;
136
137    if fp8 {
138        unsafe {
139            fp8::load_fp8_var_builder(&filenames, dtype, &device)
140                .map_err(|e| anyhow!("can't create fp8 varbuilder from tensors: {:?}", e))
141        }
142    } else {
143        unsafe {
144            VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)
145                .map_err(|e| anyhow!("can't create varbuilder from tensors: {:?}", e))
146        }
147    }
148}
149
150/// Create a VarBuilder that only loads safetensors shards needed for the given
151/// local layers. Shards containing only remote-worker tensors are excluded,
152/// reducing GPU memory usage on the master.
153pub fn load_var_builder_for_local_layers<'a>(
154    tensor_index: PathBuf,
155    dtype: DType,
156    device: Device,
157    worker_layers: &std::collections::HashSet<String>,
158    fp8: bool,
159) -> Result<VarBuilder<'a>> {
160    if !tensor_index.exists() {
161        // Single safetensors file — can't filter, load all
162        return load_var_builder_from_index(tensor_index, dtype, device, fp8);
163    }
164
165    if worker_layers.is_empty() {
166        // No workers — load everything
167        return load_var_builder_from_index(tensor_index, dtype, device, fp8);
168    }
169
170    let parent_dir = tensor_index.parent().unwrap();
171    let json_data = std::fs::read_to_string(&tensor_index)
172        .map_err(|e| anyhow!("can't read {}: {:?}", tensor_index.display(), e))?;
173    let json: serde_json::Value = serde_json::from_str(&json_data)
174        .map_err(|e| anyhow!("can't parse {}: {:?}", tensor_index.display(), e))?;
175    let weight_map = json
176        .get("weight_map")
177        .and_then(|v| v.as_object())
178        .ok_or_else(|| anyhow!("no weight_map in {}", tensor_index.display()))?;
179
180    // Find shard files that contain at least one tensor NOT belonging to a worker layer.
181    // A tensor belongs to a worker layer if its name starts with "<layer_name>."
182    let mut needed_shards: std::collections::HashSet<String> = std::collections::HashSet::new();
183    for (tensor_name, shard_file) in weight_map {
184        let is_worker_tensor = worker_layers
185            .iter()
186            .any(|layer| tensor_name.starts_with(&format!("{}.", layer)));
187        if !is_worker_tensor {
188            if let Some(filename) = shard_file.as_str() {
189                needed_shards.insert(filename.to_string());
190            }
191        }
192    }
193
194    let filenames: Vec<PathBuf> = needed_shards
195        .iter()
196        .map(|f| parent_dir.join(f))
197        .collect();
198
199    log::info!(
200        "loading {} of {} shard file(s) for local layers",
201        filenames.len(),
202        weight_map
203            .values()
204            .filter_map(|v| v.as_str())
205            .collect::<std::collections::HashSet<_>>()
206            .len()
207    );
208
209    prefetch_safetensors(&filenames)?;
210
211    if fp8 {
212        unsafe {
213            fp8::load_fp8_var_builder(&filenames, dtype, &device)
214                .map_err(|e| anyhow!("can't create fp8 varbuilder from tensors: {:?}", e))
215        }
216    } else {
217        unsafe {
218            VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)
219                .map_err(|e| anyhow!("can't create varbuilder from tensors: {:?}", e))
220        }
221    }
222}
223
224/// Create a VarBuilder that only loads safetensors shards containing tensors
225/// for the given layer prefixes. Workers use this to skip shards that only
226/// contain layers assigned to other nodes.
227pub fn load_var_builder_for_specific_layers<'a>(
228    tensor_index: PathBuf,
229    dtype: DType,
230    device: Device,
231    layer_prefixes: &[String],
232    fp8: bool,
233) -> Result<VarBuilder<'a>> {
234    if !tensor_index.exists() || layer_prefixes.is_empty() {
235        return load_var_builder_from_index(tensor_index, dtype, device, fp8);
236    }
237
238    let parent_dir = tensor_index.parent().unwrap();
239    let json_data = std::fs::read_to_string(&tensor_index)
240        .map_err(|e| anyhow!("can't read {}: {:?}", tensor_index.display(), e))?;
241    let json: serde_json::Value = serde_json::from_str(&json_data)
242        .map_err(|e| anyhow!("can't parse {}: {:?}", tensor_index.display(), e))?;
243    let weight_map = json
244        .get("weight_map")
245        .and_then(|v| v.as_object())
246        .ok_or_else(|| anyhow!("no weight_map in {}", tensor_index.display()))?;
247
248    let mut needed_shards: std::collections::HashSet<String> = std::collections::HashSet::new();
249    for (tensor_name, shard_file) in weight_map {
250        let is_needed = layer_prefixes
251            .iter()
252            .any(|prefix| tensor_name.starts_with(&format!("{}.", prefix)));
253        if is_needed {
254            if let Some(filename) = shard_file.as_str() {
255                needed_shards.insert(filename.to_string());
256            }
257        }
258    }
259
260    let total_shards = weight_map
261        .values()
262        .filter_map(|v| v.as_str())
263        .collect::<std::collections::HashSet<_>>()
264        .len();
265
266    let filenames: Vec<PathBuf> = needed_shards.iter().map(|f| parent_dir.join(f)).collect();
267
268    log::info!(
269        "loading {} of {} shard file(s) for {} layers",
270        filenames.len(),
271        total_shards,
272        layer_prefixes.len()
273    );
274
275    prefetch_safetensors(&filenames)?;
276
277    if fp8 {
278        unsafe {
279            fp8::load_fp8_var_builder(&filenames, dtype, &device)
280                .map_err(|e| anyhow!("can't create fp8 varbuilder from tensors: {:?}", e))
281        }
282    } else {
283        unsafe {
284            VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)
285                .map_err(|e| anyhow!("can't create varbuilder from tensors: {:?}", e))
286        }
287    }
288}
289
290/// Nasty hack to debug NaN in tensors.
291#[allow(dead_code)]
292pub(crate) fn panic_on_nan(t: &Tensor, name: &str) {
293    if t.to_string().contains("NaN") {
294        panic!("\ntensor '{name}' contains NaN: \n{t}");
295    }
296}