candle_mi/backend.rs
1// SPDX-License-Identifier: MIT OR Apache-2.0
2
3//! Core backend trait and model wrapper.
4//!
5//! [`MIBackend`] is the trait that every model backend implements.
6//! [`MIModel`] wraps a backend with device metadata and convenience methods.
7
8use candle_core::{DType, Device, Tensor};
9
10use crate::error::{MIError, Result};
11use crate::hooks::{HookCache, HookSpec};
12use crate::tokenizer::MITokenizer;
13
14// ---------------------------------------------------------------------------
15// MIBackend trait
16// ---------------------------------------------------------------------------
17
18/// Unified interface for model backends with hook-aware forward passes.
19///
20/// Implementing this trait is the only requirement for adding a new model
21/// to candle-mi. The single [`forward`](Self::forward) method replaces
22/// plip-rs's (frozen predecessor project, v1.4.0) proliferation of `forward_with_*` variants: the caller
23/// specifies captures and interventions via [`HookSpec`], and the backend
24/// returns a [`HookCache`] containing the output plus any requested
25/// activations.
26///
27/// Optional capabilities (chat template, embedding access) have default
28/// implementations that return `None` or an error.
29pub trait MIBackend: Send + Sync {
30 // --- Metadata --------------------------------------------------------
31
32 /// Number of layers (transformer blocks or RWKV blocks).
33 fn num_layers(&self) -> usize;
34
35 /// Hidden dimension (`d_model`).
36 fn hidden_size(&self) -> usize;
37
38 /// Vocabulary size.
39 fn vocab_size(&self) -> usize;
40
41 /// Number of attention heads (or RWKV heads).
42 fn num_heads(&self) -> usize;
43
44 // --- Core forward pass -----------------------------------------------
45
46 /// Unified forward pass with optional hook capture and interventions.
47 ///
48 /// When `hooks` is empty, this must be equivalent to a plain forward
49 /// pass with **zero extra allocations** (see `design/hook-overhead.md`).
50 ///
51 /// The returned [`HookCache`] always contains the output tensor
52 /// (logits or hidden states, depending on the backend) and any
53 /// activations requested via [`HookSpec::capture`].
54 ///
55 /// # Shapes
56 /// - `input_ids`: `[batch, seq]` -- token IDs
57 /// - returns: [`HookCache`] containing `logits` at `[batch, seq, vocab_size]`
58 ///
59 /// # Errors
60 ///
61 /// Returns [`MIError::Model`] on tensor operation failures and
62 /// [`MIError::Intervention`] if an intervention is invalid for
63 /// the current model dimensions.
64 fn forward(&self, input_ids: &Tensor, hooks: &HookSpec) -> Result<HookCache>;
65
66 // --- Logit projection ------------------------------------------------
67
68 /// Project a hidden-state tensor to vocabulary logits.
69 ///
70 /// Applies the model's final layer norm before the unembedding projection,
71 /// matching the standard logit lens technique (nostalgebraist, 2020).
72 ///
73 /// # Shapes
74 /// - `hidden`: `[batch, hidden_size]` -- hidden states (pre-norm)
75 /// - returns: `[batch, vocab_size]`
76 ///
77 /// # Errors
78 ///
79 /// Returns [`MIError::Model`] on shape mismatch or tensor operation failure.
80 fn project_to_vocab(&self, hidden: &Tensor) -> Result<Tensor>;
81
82 // --- Optional capabilities -------------------------------------------
83
84 /// Format a prompt with the model's chat template, if any.
85 ///
86 /// Returns `None` for base (non-instruct) models.
87 fn chat_template(&self, _prompt: &str, _system_prompt: Option<&str>) -> Option<String> {
88 None
89 }
90
91 /// Return the raw embedding vector for a single token.
92 ///
93 /// For models with tied embeddings this is also the unembedding direction.
94 ///
95 /// # Shapes
96 /// - returns: `[hidden_size]`
97 ///
98 /// # Errors
99 ///
100 /// Returns [`MIError::Hook`] if the backend does not support this.
101 fn embedding_vector(&self, _token_id: u32) -> Result<Tensor> {
102 Err(MIError::Hook(
103 "embedding_vector not supported for this backend".into(),
104 ))
105 }
106}
107
108// ---------------------------------------------------------------------------
109// MIModel
110// ---------------------------------------------------------------------------
111
112/// High-level model wrapper combining a backend with device metadata.
113///
114/// `MIModel` delegates to the wrapped [`MIBackend`] and adds convenience
115/// methods including [`from_pretrained`](Self::from_pretrained) for
116/// one-line model loading from `HuggingFace`.
117pub struct MIModel {
118 /// The underlying model backend.
119 // TRAIT_OBJECT: heterogeneous model backends require dynamic dispatch
120 backend: Box<dyn MIBackend>,
121 /// The device this model lives on.
122 device: Device,
123 /// Tokenizer loaded alongside the model (present when loaded via `from_pretrained`).
124 tokenizer: Option<MITokenizer>,
125}
126
127impl MIModel {
128 /// Load a model from a `HuggingFace` model ID or local path.
129 ///
130 /// Checks local `HuggingFace` cache first, then downloads if necessary.
131 /// Automatically selects the appropriate backend based on `model_type`
132 /// in the model's `config.json`.
133 ///
134 /// # `DType` selection
135 ///
136 /// Always uses `F32` for research-grade precision — numerically identical
137 /// to Python/PyTorch F32 on both CPU and CUDA. Models up to ~7B fit in
138 /// 16 GB VRAM at F32. For larger models or when speed matters more than
139 /// precision, use the backend-specific `load()` API with `DType::BF16`.
140 ///
141 /// # Errors
142 ///
143 /// Returns [`MIError::Config`] if the model type is unsupported, or
144 /// [`MIError::Model`] if weight loading fails.
145 #[cfg(any(feature = "transformer", feature = "rwkv"))]
146 pub fn from_pretrained(model_id: &str) -> Result<Self> {
147 // --- Device and dtype ---
148 let device = Self::select_device()?;
149 // F32 everywhere: research-grade precision, matching Python/PyTorch.
150 let dtype = DType::F32;
151
152 // --- Download / resolve local files ---
153 // hf-fetch-model 0.9.x requires explicit opt-in to HF_TOKEN; go through
154 // the shared builder so gated models (Llama/Mistral/Gemma/Qwen) work.
155 let fetch_config = crate::download::fetch_config_builder()
156 .build()
157 .map_err(|e| MIError::Download(format!("failed to build fetch config: {e}")))?;
158 let files =
159 hf_fetch_model::download_files_with_config_blocking(model_id.to_owned(), &fetch_config)
160 .map(hf_fetch_model::DownloadOutcome::into_inner)
161 .map_err(|e| MIError::Download(e.to_string()))?;
162
163 let config_path = files
164 .get("config.json")
165 .ok_or_else(|| MIError::Config("config.json not found in downloaded files".into()))?;
166 let config_str = std::fs::read_to_string(config_path)
167 .map_err(|e| MIError::Config(format!("read config.json: {e}")))?;
168 let json: serde_json::Value = serde_json::from_str(&config_str)
169 .map_err(|e| MIError::Config(format!("parse config.json: {e}")))?;
170
171 // --- Dispatch on model_type ---
172 let model_type = json
173 .get("model_type")
174 .and_then(serde_json::Value::as_str)
175 .ok_or_else(|| MIError::Config("missing 'model_type' field".into()))?;
176
177 // --- Load tokenizer (best-effort: present for HF models) ---
178 let tokenizer = files
179 .get("tokenizer.json")
180 .and_then(|p| MITokenizer::from_hf_path(p).ok());
181
182 let weights_paths = resolve_safetensors_paths(&files)?;
183 let vb = create_var_builder(&weights_paths, dtype, &device)?;
184
185 match model_type {
186 #[cfg(feature = "transformer")]
187 mt if crate::config::SUPPORTED_MODEL_TYPES.contains(&mt) => {
188 use crate::config::TransformerConfig;
189 use crate::transformer::GenericTransformer;
190
191 let config = TransformerConfig::from_hf_config(&json)?;
192 let transformer = GenericTransformer::load(config, &device, dtype, vb)?;
193 Ok(Self::with_tokenizer(
194 Box::new(transformer),
195 device,
196 tokenizer,
197 ))
198 }
199 #[cfg(feature = "rwkv")]
200 mt if crate::rwkv::SUPPORTED_RWKV_MODEL_TYPES.contains(&mt) => {
201 use crate::rwkv::{GenericRwkv, RwkvConfig};
202
203 let config = RwkvConfig::from_hf_config(&json)?;
204 let rwkv = GenericRwkv::load(config, &device, dtype, vb)?;
205 Ok(Self::with_tokenizer(Box::new(rwkv), device, tokenizer))
206 }
207 #[cfg(feature = "transformer")]
208 _unknown => {
209 use crate::config::TransformerConfig;
210 use crate::transformer::GenericTransformer;
211
212 // Extract tensor names for auto-config inference
213 let tensor_names = extract_tensor_names(&files)?;
214
215 // Preflight: check compatibility before attempting to load
216 TransformerConfig::check_auto_compatibility(&json, &tensor_names).into_result()?;
217
218 let config = TransformerConfig::from_hf_config_auto(&json, &tensor_names)?;
219 let transformer = GenericTransformer::load(config, &device, dtype, vb)?;
220 Ok(Self::with_tokenizer(
221 Box::new(transformer),
222 device,
223 tokenizer,
224 ))
225 }
226 #[cfg(not(feature = "transformer"))]
227 other => Err(MIError::Config(format!(
228 "unsupported model_type: '{other}' (enable the `transformer` feature for auto-config)"
229 ))),
230 }
231 }
232
233 /// Select the best available device (CUDA GPU 0, or CPU fallback).
234 ///
235 /// # Errors
236 ///
237 /// Returns [`MIError::Model`] on device detection failure.
238 #[cfg(any(feature = "transformer", feature = "rwkv"))]
239 fn select_device() -> Result<Device> {
240 match Device::cuda_if_available(0) {
241 Ok(dev) => Ok(dev),
242 Err(e) => Err(MIError::Model(e)),
243 }
244 }
245
246 /// Wrap an existing backend (no tokenizer).
247 // TRAIT_OBJECT: heterogeneous model backends require dynamic dispatch
248 #[must_use]
249 pub fn new(backend: Box<dyn MIBackend>, device: Device) -> Self {
250 Self {
251 backend,
252 device,
253 tokenizer: None,
254 }
255 }
256
257 /// Wrap an existing backend with an optional tokenizer.
258 // TRAIT_OBJECT: heterogeneous model backends require dynamic dispatch
259 #[must_use]
260 pub fn with_tokenizer(
261 backend: Box<dyn MIBackend>,
262 device: Device,
263 tokenizer: Option<MITokenizer>,
264 ) -> Self {
265 Self {
266 backend,
267 device,
268 tokenizer,
269 }
270 }
271
272 /// The device this model lives on.
273 #[must_use]
274 pub const fn device(&self) -> &Device {
275 &self.device
276 }
277
278 /// The tokenizer loaded alongside the model, if available.
279 ///
280 /// Present when the model was loaded via [`from_pretrained`](Self::from_pretrained)
281 /// and a `tokenizer.json` was found in the downloaded files.
282 #[must_use]
283 pub const fn tokenizer(&self) -> Option<&MITokenizer> {
284 self.tokenizer.as_ref()
285 }
286
287 /// Number of layers.
288 #[must_use]
289 pub fn num_layers(&self) -> usize {
290 self.backend.num_layers()
291 }
292
293 /// Hidden dimension.
294 #[must_use]
295 pub fn hidden_size(&self) -> usize {
296 self.backend.hidden_size()
297 }
298
299 /// Vocabulary size.
300 #[must_use]
301 pub fn vocab_size(&self) -> usize {
302 self.backend.vocab_size()
303 }
304
305 /// Number of attention heads.
306 #[must_use]
307 pub fn num_heads(&self) -> usize {
308 self.backend.num_heads()
309 }
310
311 /// Run a forward pass with the given hook specification.
312 ///
313 /// # Shapes
314 /// - `input_ids`: `[batch, seq]` -- token IDs
315 /// - returns: [`HookCache`] containing `logits` at `[batch, seq, vocab_size]`
316 ///
317 /// # Errors
318 ///
319 /// Propagates errors from the underlying backend.
320 pub fn forward(&self, input_ids: &Tensor, hooks: &HookSpec) -> Result<HookCache> {
321 self.backend.forward(input_ids, hooks)
322 }
323
324 /// Project hidden states to vocabulary logits.
325 ///
326 /// Applies the model's final layer norm before the unembedding projection,
327 /// matching the standard logit lens technique (nostalgebraist, 2020).
328 ///
329 /// # Shapes
330 /// - `hidden`: `[batch, hidden_size]` -- hidden states (pre-norm)
331 /// - returns: `[batch, vocab_size]`
332 ///
333 /// # Errors
334 ///
335 /// Propagates errors from the underlying backend.
336 pub fn project_to_vocab(&self, hidden: &Tensor) -> Result<Tensor> {
337 self.backend.project_to_vocab(hidden)
338 }
339
340 /// Access the underlying backend (e.g., for backend-specific methods).
341 // TRAIT_OBJECT: caller needs dynamic dispatch for backend-specific methods
342 #[must_use]
343 pub fn backend(&self) -> &dyn MIBackend {
344 &*self.backend
345 }
346
347 /// Run a forward pass from text, returning both MI outputs and token
348 /// position mapping.
349 ///
350 /// Combines [`MITokenizer::encode_with_offsets`](crate::MITokenizer::encode_with_offsets)
351 /// + tensor creation + [`forward`](Self::forward) in a single call.
352 ///
353 /// The returned [`TextForwardResult`] carries the [`HookCache`]
354 /// alongside the [`EncodingWithOffsets`](crate::EncodingWithOffsets),
355 /// eliminating the need for separate encoding and manual offset
356 /// tracking.
357 ///
358 /// # Shapes
359 /// - input: text string
360 /// - returns: [`TextForwardResult`] with logits at `[1, seq, vocab_size]`
361 ///
362 /// # Errors
363 ///
364 /// Returns [`MIError::Config`] if no tokenizer is available.
365 /// Returns [`MIError::Tokenizer`] on encoding failure.
366 /// Propagates errors from the underlying backend.
367 pub fn forward_text(&self, text: &str, hooks: &HookSpec) -> Result<TextForwardResult> {
368 let tokenizer = self
369 .tokenizer()
370 .ok_or_else(|| MIError::Config("forward_text requires a tokenizer".into()))?;
371 let encoding = tokenizer.encode_with_offsets(text)?;
372 let input = Tensor::new(&encoding.ids[..], &self.device)?.unsqueeze(0)?;
373 let cache = self.forward(&input, hooks)?;
374 Ok(TextForwardResult { cache, encoding })
375 }
376}
377
378// ---------------------------------------------------------------------------
379// TextForwardResult
380// ---------------------------------------------------------------------------
381
382/// Result of a text-based forward pass, bundling MI outputs with token
383/// position mapping.
384///
385/// Returned by [`MIModel::forward_text`]. Provides access to the
386/// [`HookCache`] (logits + captured activations) alongside the
387/// [`EncodingWithOffsets`](crate::EncodingWithOffsets) (token strings,
388/// IDs, and byte offset ranges for mapping between source text positions
389/// and token indices).
390#[derive(Debug)]
391pub struct TextForwardResult {
392 /// Hook cache containing output logits and any captured activations.
393 cache: HookCache,
394 /// Token encoding with character offset mapping.
395 encoding: crate::util::positioning::EncodingWithOffsets,
396}
397
398impl TextForwardResult {
399 /// Access the hook cache (logits + captured activations).
400 #[must_use]
401 pub const fn cache(&self) -> &HookCache {
402 &self.cache
403 }
404
405 /// Consume the result and return the hook cache.
406 #[must_use]
407 pub fn into_cache(self) -> HookCache {
408 self.cache
409 }
410
411 /// Access the token encoding with character offset mapping.
412 #[must_use]
413 pub const fn encoding(&self) -> &crate::util::positioning::EncodingWithOffsets {
414 &self.encoding
415 }
416
417 /// The output tensor from the forward pass (typically logits).
418 ///
419 /// Shortcut for `self.cache().output()`.
420 ///
421 /// # Shapes
422 /// - returns: `[1, seq, vocab_size]`
423 #[must_use]
424 pub const fn output(&self) -> &Tensor {
425 self.cache.output()
426 }
427
428 /// Retrieve a captured tensor by hook point, returning an error if
429 /// not found.
430 ///
431 /// Shortcut for `self.cache().require(hook)`.
432 ///
433 /// # Errors
434 ///
435 /// Returns [`MIError::Hook`] if the hook point was not captured.
436 pub fn require(&self, hook: &crate::hooks::HookPoint) -> Result<&Tensor> {
437 self.cache.require(hook)
438 }
439
440 /// Retrieve a captured tensor by hook point.
441 ///
442 /// Shortcut for `self.cache().get(hook)`.
443 #[must_use]
444 pub fn get(&self, hook: &crate::hooks::HookPoint) -> Option<&Tensor> {
445 self.cache.get(hook)
446 }
447
448 /// The raw BPE token strings (with space-prefix markers like `Ġ`).
449 ///
450 /// Shortcut for `self.encoding().tokens`.
451 #[must_use]
452 pub fn tokens(&self) -> &[String] {
453 &self.encoding.tokens
454 }
455
456 /// Number of tokens in the encoded sequence.
457 #[must_use]
458 pub const fn seq_len(&self) -> usize {
459 self.encoding.len()
460 }
461}
462
463// ---------------------------------------------------------------------------
464// Sampling helpers
465// ---------------------------------------------------------------------------
466
467/// Sample a token from logits using the given temperature.
468///
469/// When `temperature <= 0.0`, performs greedy (argmax) decoding.
470///
471/// # Shapes
472/// - `logits`: `[vocab_size]` -- logit scores for each vocabulary token
473///
474/// # Errors
475///
476/// Returns [`MIError::Model`] if the logits tensor is empty or
477/// cannot be converted to `f32`.
478pub fn sample_token(logits: &Tensor, temperature: f32) -> Result<u32> {
479 if temperature <= 0.0 {
480 argmax(logits)
481 } else {
482 sample_with_temperature(logits, temperature)
483 }
484}
485
486/// Greedy (argmax) sampling.
487fn argmax(logits: &Tensor) -> Result<u32> {
488 let logits_f32 = logits.to_dtype(DType::F32)?;
489 let logits_vec: Vec<f32> = logits_f32.flatten_all()?.to_vec1()?;
490
491 let (max_idx, _) = logits_vec
492 .iter()
493 .enumerate()
494 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
495 .ok_or_else(|| MIError::Model(candle_core::Error::Msg("empty logits".into())))?;
496
497 // CAST: usize → u32, vocab size fits in u32 (max ~250K tokens)
498 #[allow(clippy::cast_possible_truncation, clippy::as_conversions)]
499 Ok(max_idx as u32)
500}
501
502/// Temperature-scaled softmax sampling.
503fn sample_with_temperature(logits: &Tensor, temperature: f32) -> Result<u32> {
504 use rand::Rng;
505
506 let logits_f32 = logits.to_dtype(DType::F32)?;
507 let logits_vec: Vec<f32> = logits_f32.flatten_all()?.to_vec1()?;
508
509 if logits_vec.is_empty() {
510 return Err(MIError::Model(candle_core::Error::Msg(
511 "empty logits".into(),
512 )));
513 }
514
515 // Scale by temperature.
516 let scaled: Vec<f32> = logits_vec.iter().map(|x| x / temperature).collect();
517
518 // Numerically stable softmax.
519 let max_val = scaled.iter().copied().fold(f32::NEG_INFINITY, f32::max);
520 let exp_vals: Vec<f32> = scaled.iter().map(|x| (x - max_val).exp()).collect();
521 let sum: f32 = exp_vals.iter().sum();
522 let probs: Vec<f32> = exp_vals.iter().map(|x| x / sum).collect();
523
524 // Sample from the categorical distribution.
525 let mut rng = rand::thread_rng();
526 let r: f32 = rng.r#gen();
527 let mut cumsum = 0.0;
528 for (idx, &p) in probs.iter().enumerate() {
529 cumsum += p;
530 if r < cumsum {
531 // CAST: usize → u32, vocab index fits in u32
532 #[allow(clippy::cast_possible_truncation, clippy::as_conversions)]
533 return Ok(idx as u32);
534 }
535 }
536
537 // Fallback to last token (floating-point rounding edge case).
538 // CAST: usize → u32, vocab index fits in u32
539 #[allow(clippy::cast_possible_truncation, clippy::as_conversions)]
540 Ok((probs.len() - 1) as u32)
541}
542
543/// Extract the probability of a specific token from a logit tensor.
544///
545/// Applies softmax to the last sequence position and returns the probability
546/// at `token_id`. Useful for measuring steering effectiveness.
547///
548/// # Shapes
549/// - `logits`: `[1, seq_len, vocab]` or `[seq_len, vocab]` or `[vocab]`
550///
551/// # Errors
552///
553/// Returns [`MIError::Model`] on shape mismatch or tensor operation failure.
554pub fn extract_token_prob(logits: &Tensor, token_id: u32) -> Result<f32> {
555 use candle_core::IndexOp;
556
557 let logits_f32 = logits.to_dtype(DType::F32)?;
558
559 // Get last-position logits as a 1-D vector.
560 let last_logits = match logits_f32.dims().len() {
561 1 => logits_f32,
562 2 => {
563 let seq_len = logits_f32.dim(0)?;
564 logits_f32.i(seq_len - 1)?
565 }
566 3 => {
567 let seq_len = logits_f32.dim(1)?;
568 logits_f32.i((0, seq_len - 1))?
569 }
570 n => {
571 return Err(MIError::Model(candle_core::Error::Msg(format!(
572 "extract_token_prob: expected 1-3 dims, got {n}"
573 ))));
574 }
575 };
576
577 let probs = candle_nn::ops::softmax_last_dim(&last_logits)?;
578 // CAST: u32 → usize, token ID used as tensor index
579 #[allow(clippy::as_conversions)]
580 let prob = probs.i(token_id as usize)?.to_scalar::<f32>()?;
581 Ok(prob)
582}
583
584// ---------------------------------------------------------------------------
585// GenerationResult
586// ---------------------------------------------------------------------------
587
588/// Output of a text generation run with token-level details.
589#[derive(Debug, Clone)]
590pub struct GenerationResult {
591 /// Original prompt text.
592 pub prompt: String,
593 /// Full output (prompt + generated).
594 pub full_text: String,
595 /// Only the generated portion.
596 pub generated_text: String,
597 /// Token IDs from the prompt.
598 pub prompt_tokens: Vec<u32>,
599 /// Token IDs that were generated.
600 pub generated_tokens: Vec<u32>,
601 /// Total token count (prompt + generated).
602 pub total_tokens: usize,
603}
604
605// ---------------------------------------------------------------------------
606// Weight loading helpers (used by from_pretrained)
607// ---------------------------------------------------------------------------
608
609/// Index structure for sharded safetensors models.
610#[cfg(any(feature = "transformer", feature = "rwkv"))]
611#[derive(serde::Deserialize)]
612struct SafetensorsIndex {
613 /// Maps weight name → shard filename.
614 weight_map: std::collections::HashMap<String, String>,
615}
616
617/// Extract tensor names from a downloaded file map for auto-config inference.
618///
619/// Tries `model.safetensors.index.json` first (sharded models), falls back
620/// to reading the header of `model.safetensors` (single-file models).
621#[cfg(feature = "transformer")]
622fn extract_tensor_names(
623 files: &std::collections::HashMap<String, std::path::PathBuf>,
624) -> Result<Vec<String>> {
625 if let Some(index_path) = files.get("model.safetensors.index.json") {
626 return crate::config::tensor_names_from_index(index_path);
627 }
628 if let Some(st_path) = files.get("model.safetensors") {
629 return crate::config::tensor_names_from_safetensors(st_path);
630 }
631 Err(MIError::Config(
632 "no safetensors files found for tensor name extraction".into(),
633 ))
634}
635
636/// Resolve safetensors file paths from a downloaded file map.
637///
638/// Tries `model.safetensors.index.json` first (sharded), falls back to
639/// single `model.safetensors`.
640#[cfg(any(feature = "transformer", feature = "rwkv"))]
641fn resolve_safetensors_paths(
642 files: &std::collections::HashMap<String, std::path::PathBuf>,
643) -> Result<Vec<std::path::PathBuf>> {
644 // Try sharded first
645 if let Some(index_path) = files.get("model.safetensors.index.json") {
646 let index_str = std::fs::read_to_string(index_path)
647 .map_err(|e| MIError::Model(candle_core::Error::Msg(format!("read index: {e}"))))?;
648 let index: SafetensorsIndex = serde_json::from_str(&index_str)
649 .map_err(|e| MIError::Config(format!("parse index: {e}")))?;
650
651 // Collect unique shard filenames
652 let mut shard_names: Vec<String> = index.weight_map.values().cloned().collect();
653 shard_names.sort();
654 shard_names.dedup();
655
656 let mut paths = Vec::with_capacity(shard_names.len());
657 for shard_name in &shard_names {
658 // BORROW: explicit .as_str() — &str from String for HashMap lookup
659 let path = files.get(shard_name.as_str()).ok_or_else(|| {
660 MIError::Model(candle_core::Error::Msg(format!(
661 "shard {shard_name} not found in downloaded files"
662 )))
663 })?;
664 // BORROW: explicit .clone() — PathBuf from HashMap value
665 paths.push(path.clone());
666 }
667 return Ok(paths);
668 }
669
670 // Single file
671 let path = files.get("model.safetensors").ok_or_else(|| {
672 MIError::Model(candle_core::Error::Msg(
673 "model.safetensors not found in downloaded files".into(),
674 ))
675 })?;
676 // BORROW: explicit .clone() — PathBuf from HashMap value
677 Ok(vec![path.clone()])
678}
679
680/// Create a `VarBuilder` from safetensors file paths.
681///
682/// Uses buffered (safe) loading by default. With the `mmap` feature,
683/// uses memory-mapped loading for reduced memory overhead on large models.
684#[cfg(any(feature = "transformer", feature = "rwkv"))]
685fn create_var_builder(
686 paths: &[std::path::PathBuf],
687 dtype: DType,
688 device: &Device,
689) -> Result<candle_nn::VarBuilder<'static>> {
690 #[cfg(feature = "mmap")]
691 {
692 mmap_var_builder(paths, dtype, device)
693 }
694 #[cfg(not(feature = "mmap"))]
695 {
696 buffered_var_builder(paths, dtype, device)
697 }
698}
699
700/// Load weights via buffered (safe) reading — reads all data into RAM.
701///
702/// Only supports single-file models. For sharded models (7B+), enable
703/// the `mmap` feature.
704#[cfg(all(any(feature = "transformer", feature = "rwkv"), not(feature = "mmap")))]
705fn buffered_var_builder(
706 paths: &[std::path::PathBuf],
707 dtype: DType,
708 device: &Device,
709) -> Result<candle_nn::VarBuilder<'static>> {
710 if paths.len() > 1 {
711 return Err(MIError::Config(format!(
712 "this model is sharded across {} files and requires the `mmap` feature.\n \
713 Library: candle-mi = {{ features = [\"mmap\"] }}\n \
714 Example: cargo run --features mmap --example <name>",
715 paths.len()
716 )));
717 }
718 let path = paths
719 .first()
720 .ok_or_else(|| MIError::Model(candle_core::Error::Msg("no safetensors files".into())))?;
721 let data = std::fs::read(path).map_err(|e| {
722 MIError::Model(candle_core::Error::Msg(format!(
723 "read {}: {e}",
724 path.display()
725 )))
726 })?;
727 let vb = candle_nn::VarBuilder::from_buffered_safetensors(data, dtype, device)?;
728 Ok(vb)
729}
730
731/// Load weights via memory-mapped files — minimal RAM overhead for large models.
732///
733/// # Safety
734///
735/// The safetensors files must not be modified while the model is loaded.
736/// This is the standard invariant for memory-mapped files.
737#[cfg(all(any(feature = "transformer", feature = "rwkv"), feature = "mmap"))]
738#[allow(unsafe_code)]
739fn mmap_var_builder(
740 paths: &[std::path::PathBuf],
741 dtype: DType,
742 device: &Device,
743) -> Result<candle_nn::VarBuilder<'static>> {
744 // SAFETY: safetensors files must not be modified while loaded.
745 let vb = unsafe { candle_nn::VarBuilder::from_mmaped_safetensors(paths, dtype, device)? };
746 Ok(vb)
747}