1pub mod fp8;
4pub mod hf;
5pub mod models;
6pub mod split;
7
8use std::path::{Path, PathBuf};
9
10use candle_core::{
11 utils::{cuda_is_available, metal_is_available},
12 DType, Device, Tensor,
13};
14
15use anyhow::{bail, Result};
16
17use candle_nn::VarBuilder;
18
19pub fn get_inference_device(force_cpu: bool, ordinal: usize) -> Result<Device> {
21 if force_cpu {
22 log::debug!("device is forced cpu");
23 Ok(Device::Cpu)
24 } else if cuda_is_available() {
25 log::debug!("device is cuda {ordinal}");
26 Ok(Device::new_cuda(ordinal)?)
27 } else if metal_is_available() {
28 log::debug!("device is metal {ordinal}");
29 Ok(Device::new_metal(ordinal)?)
30 } else {
31 log::debug!("device is cpu");
32 Ok(Device::Cpu)
34 }
35}
36
37pub fn load_safetensors_from_model(path: &Path) -> Result<Vec<std::path::PathBuf>> {
38 log::info!("loading tensors from {} ...", "model.safetensors");
39 let result = vec![path.join("model.safetensors")];
40 Ok(result)
41}
42
43pub fn load_safetensors_paths_from_index(
45 tensors_index_json_filename: PathBuf,
46) -> Result<Vec<std::path::PathBuf>> {
47 log::info!(
48 "loading tensors from {} ...",
49 tensors_index_json_filename.display()
50 );
51
52 let parent_dir = tensors_index_json_filename.parent().unwrap();
53 let json_file = std::fs::File::open(&tensors_index_json_filename).map_err(|e| {
54 anyhow!(
55 "can't open {}: {:?}",
56 tensors_index_json_filename.display(),
57 e
58 )
59 })?;
60 let json: serde_json::Value = serde_json::from_reader(&json_file).map_err(|e| {
61 anyhow!(
62 "can't parse {}: {:?}",
63 tensors_index_json_filename.display(),
64 e
65 )
66 })?;
67 let weight_map = match json.get("weight_map") {
68 None => bail!("no weight map in {json_file:?}"),
69 Some(serde_json::Value::Object(map)) => map,
70 Some(_) => bail!("weight map in {json_file:?} is not a map"),
71 };
72 let mut safetensors_files = std::collections::HashSet::new();
73 for value in weight_map.values() {
74 if let Some(file) = value.as_str() {
75 safetensors_files.insert(file.to_string());
76 }
77 }
78 let safetensors_files = safetensors_files
79 .iter()
80 .map(|v| parent_dir.join(v))
81 .collect::<Vec<std::path::PathBuf>>();
82
83 Ok(safetensors_files)
84}
85
86fn prefetch_safetensors(filenames: &[PathBuf]) -> Result<()> {
90 use std::sync::OnceLock;
91 static DONE: OnceLock<()> = OnceLock::new();
92
93 if DONE.get().is_some() {
94 log::info!("safetensor files already in page cache, skipping prefetch");
95 return Ok(());
96 }
97
98 use std::io::Read;
99 let start = std::time::Instant::now();
100 let mut total_bytes: u64 = 0;
101 let mut buf = Vec::new();
102 for filename in filenames {
103 buf.clear();
104 std::fs::File::open(filename)
105 .map_err(|e| anyhow!("prefetch: can't open {}: {e}", filename.display()))?
106 .read_to_end(&mut buf)
107 .map_err(|e| anyhow!("prefetch: can't read {}: {e}", filename.display()))?;
108 total_bytes += buf.len() as u64;
109 }
110 log::info!(
111 "pre-cached {} in {:.1}s",
112 human_bytes::human_bytes(total_bytes as f64),
113 start.elapsed().as_secs_f64()
114 );
115
116 DONE.set(()).ok();
117 Ok(())
118}
119
120pub fn load_var_builder_from_index<'a>(
122 tensor_index: PathBuf,
123 dtype: DType,
124 device: Device,
125 fp8: bool,
126) -> Result<VarBuilder<'a>> {
127 let filenames: Vec<std::path::PathBuf> = if tensor_index.exists() {
128 load_safetensors_paths_from_index(tensor_index)
129 .map_err(|e| anyhow!("can't load tensors index: {:?}", e))?
130 } else {
131 load_safetensors_from_model(tensor_index.parent().unwrap())
132 .map_err(|e| anyhow!("can't load tensors index: {:?}", e))?
133 };
134
135 prefetch_safetensors(&filenames)?;
136
137 if fp8 {
138 unsafe {
139 fp8::load_fp8_var_builder(&filenames, dtype, &device)
140 .map_err(|e| anyhow!("can't create fp8 varbuilder from tensors: {:?}", e))
141 }
142 } else {
143 unsafe {
144 VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)
145 .map_err(|e| anyhow!("can't create varbuilder from tensors: {:?}", e))
146 }
147 }
148}
149
150pub fn load_var_builder_for_local_layers<'a>(
154 tensor_index: PathBuf,
155 dtype: DType,
156 device: Device,
157 worker_layers: &std::collections::HashSet<String>,
158 fp8: bool,
159) -> Result<VarBuilder<'a>> {
160 if !tensor_index.exists() {
161 return load_var_builder_from_index(tensor_index, dtype, device, fp8);
163 }
164
165 if worker_layers.is_empty() {
166 return load_var_builder_from_index(tensor_index, dtype, device, fp8);
168 }
169
170 let parent_dir = tensor_index.parent().unwrap();
171 let json_data = std::fs::read_to_string(&tensor_index)
172 .map_err(|e| anyhow!("can't read {}: {:?}", tensor_index.display(), e))?;
173 let json: serde_json::Value = serde_json::from_str(&json_data)
174 .map_err(|e| anyhow!("can't parse {}: {:?}", tensor_index.display(), e))?;
175 let weight_map = json
176 .get("weight_map")
177 .and_then(|v| v.as_object())
178 .ok_or_else(|| anyhow!("no weight_map in {}", tensor_index.display()))?;
179
180 let mut needed_shards: std::collections::HashSet<String> = std::collections::HashSet::new();
183 for (tensor_name, shard_file) in weight_map {
184 let is_worker_tensor = worker_layers
185 .iter()
186 .any(|layer| tensor_name.starts_with(&format!("{}.", layer)));
187 if !is_worker_tensor {
188 if let Some(filename) = shard_file.as_str() {
189 needed_shards.insert(filename.to_string());
190 }
191 }
192 }
193
194 let filenames: Vec<PathBuf> = needed_shards
195 .iter()
196 .map(|f| parent_dir.join(f))
197 .collect();
198
199 log::info!(
200 "loading {} of {} shard file(s) for local layers",
201 filenames.len(),
202 weight_map
203 .values()
204 .filter_map(|v| v.as_str())
205 .collect::<std::collections::HashSet<_>>()
206 .len()
207 );
208
209 prefetch_safetensors(&filenames)?;
210
211 if fp8 {
212 unsafe {
213 fp8::load_fp8_var_builder(&filenames, dtype, &device)
214 .map_err(|e| anyhow!("can't create fp8 varbuilder from tensors: {:?}", e))
215 }
216 } else {
217 unsafe {
218 VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)
219 .map_err(|e| anyhow!("can't create varbuilder from tensors: {:?}", e))
220 }
221 }
222}
223
224pub fn load_var_builder_for_specific_layers<'a>(
228 tensor_index: PathBuf,
229 dtype: DType,
230 device: Device,
231 layer_prefixes: &[String],
232 fp8: bool,
233) -> Result<VarBuilder<'a>> {
234 if !tensor_index.exists() || layer_prefixes.is_empty() {
235 return load_var_builder_from_index(tensor_index, dtype, device, fp8);
236 }
237
238 let parent_dir = tensor_index.parent().unwrap();
239 let json_data = std::fs::read_to_string(&tensor_index)
240 .map_err(|e| anyhow!("can't read {}: {:?}", tensor_index.display(), e))?;
241 let json: serde_json::Value = serde_json::from_str(&json_data)
242 .map_err(|e| anyhow!("can't parse {}: {:?}", tensor_index.display(), e))?;
243 let weight_map = json
244 .get("weight_map")
245 .and_then(|v| v.as_object())
246 .ok_or_else(|| anyhow!("no weight_map in {}", tensor_index.display()))?;
247
248 let mut needed_shards: std::collections::HashSet<String> = std::collections::HashSet::new();
249 for (tensor_name, shard_file) in weight_map {
250 let is_needed = layer_prefixes
251 .iter()
252 .any(|prefix| tensor_name.starts_with(&format!("{}.", prefix)));
253 if is_needed {
254 if let Some(filename) = shard_file.as_str() {
255 needed_shards.insert(filename.to_string());
256 }
257 }
258 }
259
260 let total_shards = weight_map
261 .values()
262 .filter_map(|v| v.as_str())
263 .collect::<std::collections::HashSet<_>>()
264 .len();
265
266 let filenames: Vec<PathBuf> = needed_shards.iter().map(|f| parent_dir.join(f)).collect();
267
268 log::info!(
269 "loading {} of {} shard file(s) for {} layers",
270 filenames.len(),
271 total_shards,
272 layer_prefixes.len()
273 );
274
275 prefetch_safetensors(&filenames)?;
276
277 if fp8 {
278 unsafe {
279 fp8::load_fp8_var_builder(&filenames, dtype, &device)
280 .map_err(|e| anyhow!("can't create fp8 varbuilder from tensors: {:?}", e))
281 }
282 } else {
283 unsafe {
284 VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)
285 .map_err(|e| anyhow!("can't create varbuilder from tensors: {:?}", e))
286 }
287 }
288}
289
290#[allow(dead_code)]
292pub(crate) fn panic_on_nan(t: &Tensor, name: &str) {
293 if t.to_string().contains("NaN") {
294 panic!("\ntensor '{name}' contains NaN: \n{t}");
295 }
296}