Skip to main content

candle_mi/
backend.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2
3//! Core backend trait and model wrapper.
4//!
5//! [`MIBackend`] is the trait that every model backend implements.
6//! [`MIModel`] wraps a backend with device metadata and convenience methods.
7
8use candle_core::{DType, Device, Tensor};
9
10use crate::error::{MIError, Result};
11use crate::hooks::{HookCache, HookSpec};
12use crate::tokenizer::MITokenizer;
13
14// ---------------------------------------------------------------------------
15// MIBackend trait
16// ---------------------------------------------------------------------------
17
18/// Unified interface for model backends with hook-aware forward passes.
19///
20/// Implementing this trait is the only requirement for adding a new model
21/// to candle-mi.  The single [`forward`](Self::forward) method replaces
22/// plip-rs's (frozen predecessor project, v1.4.0) proliferation of `forward_with_*` variants: the caller
23/// specifies captures and interventions via [`HookSpec`], and the backend
24/// returns a [`HookCache`] containing the output plus any requested
25/// activations.
26///
27/// Optional capabilities (chat template, embedding access) have default
28/// implementations that return `None` or an error.
29pub trait MIBackend: Send + Sync {
30    // --- Metadata --------------------------------------------------------
31
32    /// Number of layers (transformer blocks or RWKV blocks).
33    fn num_layers(&self) -> usize;
34
35    /// Hidden dimension (`d_model`).
36    fn hidden_size(&self) -> usize;
37
38    /// Vocabulary size.
39    fn vocab_size(&self) -> usize;
40
41    /// Number of attention heads (or RWKV heads).
42    fn num_heads(&self) -> usize;
43
44    // --- Core forward pass -----------------------------------------------
45
46    /// Unified forward pass with optional hook capture and interventions.
47    ///
48    /// When `hooks` is empty, this must be equivalent to a plain forward
49    /// pass with **zero extra allocations** (see `design/hook-overhead.md`).
50    ///
51    /// The returned [`HookCache`] always contains the output tensor
52    /// (logits or hidden states, depending on the backend) and any
53    /// activations requested via [`HookSpec::capture`].
54    ///
55    /// # Shapes
56    /// - `input_ids`: `[batch, seq]` -- token IDs
57    /// - returns: [`HookCache`] containing `logits` at `[batch, seq, vocab_size]`
58    ///
59    /// # Errors
60    ///
61    /// Returns [`MIError::Model`] on tensor operation failures and
62    /// [`MIError::Intervention`] if an intervention is invalid for
63    /// the current model dimensions.
64    fn forward(&self, input_ids: &Tensor, hooks: &HookSpec) -> Result<HookCache>;
65
66    // --- Logit projection ------------------------------------------------
67
68    /// Project a hidden-state tensor to vocabulary logits.
69    ///
70    /// Applies the model's final layer norm before the unembedding projection,
71    /// matching the standard logit lens technique (nostalgebraist, 2020).
72    ///
73    /// # Shapes
74    /// - `hidden`: `[batch, hidden_size]` -- hidden states (pre-norm)
75    /// - returns: `[batch, vocab_size]`
76    ///
77    /// # Errors
78    ///
79    /// Returns [`MIError::Model`] on shape mismatch or tensor operation failure.
80    fn project_to_vocab(&self, hidden: &Tensor) -> Result<Tensor>;
81
82    // --- Optional capabilities -------------------------------------------
83
84    /// Format a prompt with the model's chat template, if any.
85    ///
86    /// Returns `None` for base (non-instruct) models.
87    fn chat_template(&self, _prompt: &str, _system_prompt: Option<&str>) -> Option<String> {
88        None
89    }
90
91    /// Return the raw embedding vector for a single token.
92    ///
93    /// For models with tied embeddings this is also the unembedding direction.
94    ///
95    /// # Shapes
96    /// - returns: `[hidden_size]`
97    ///
98    /// # Errors
99    ///
100    /// Returns [`MIError::Hook`] if the backend does not support this.
101    fn embedding_vector(&self, _token_id: u32) -> Result<Tensor> {
102        Err(MIError::Hook(
103            "embedding_vector not supported for this backend".into(),
104        ))
105    }
106}
107
108// ---------------------------------------------------------------------------
109// MIModel
110// ---------------------------------------------------------------------------
111
112/// High-level model wrapper combining a backend with device metadata.
113///
114/// `MIModel` delegates to the wrapped [`MIBackend`] and adds convenience
115/// methods including [`from_pretrained`](Self::from_pretrained) for
116/// one-line model loading from `HuggingFace`.
117pub struct MIModel {
118    /// The underlying model backend.
119    // TRAIT_OBJECT: heterogeneous model backends require dynamic dispatch
120    backend: Box<dyn MIBackend>,
121    /// The device this model lives on.
122    device: Device,
123    /// Tokenizer loaded alongside the model (present when loaded via `from_pretrained`).
124    tokenizer: Option<MITokenizer>,
125}
126
127impl MIModel {
128    /// Load a model from a `HuggingFace` model ID or local path.
129    ///
130    /// Checks local `HuggingFace` cache first, then downloads if necessary.
131    /// Automatically selects the appropriate backend based on `model_type`
132    /// in the model's `config.json`.
133    ///
134    /// # `DType` selection
135    ///
136    /// Always uses `F32` for research-grade precision — numerically identical
137    /// to Python/PyTorch F32 on both CPU and CUDA.  Models up to ~7B fit in
138    /// 16 GB VRAM at F32.  For larger models or when speed matters more than
139    /// precision, use the backend-specific `load()` API with `DType::BF16`.
140    ///
141    /// # Errors
142    ///
143    /// Returns [`MIError::Config`] if the model type is unsupported, or
144    /// [`MIError::Model`] if weight loading fails.
145    #[cfg(any(feature = "transformer", feature = "rwkv"))]
146    pub fn from_pretrained(model_id: &str) -> Result<Self> {
147        // --- Device and dtype ---
148        let device = Self::select_device()?;
149        // F32 everywhere: research-grade precision, matching Python/PyTorch.
150        let dtype = DType::F32;
151
152        // --- Download / resolve local files ---
153        // hf-fetch-model 0.9.x requires explicit opt-in to HF_TOKEN; go through
154        // the shared builder so gated models (Llama/Mistral/Gemma/Qwen) work.
155        let fetch_config = crate::download::fetch_config_builder()
156            .build()
157            .map_err(|e| MIError::Download(format!("failed to build fetch config: {e}")))?;
158        let files =
159            hf_fetch_model::download_files_with_config_blocking(model_id.to_owned(), &fetch_config)
160                .map(hf_fetch_model::DownloadOutcome::into_inner)
161                .map_err(|e| MIError::Download(e.to_string()))?;
162
163        let config_path = files
164            .get("config.json")
165            .ok_or_else(|| MIError::Config("config.json not found in downloaded files".into()))?;
166        let config_str = std::fs::read_to_string(config_path)
167            .map_err(|e| MIError::Config(format!("read config.json: {e}")))?;
168        let json: serde_json::Value = serde_json::from_str(&config_str)
169            .map_err(|e| MIError::Config(format!("parse config.json: {e}")))?;
170
171        // --- Dispatch on model_type ---
172        let model_type = json
173            .get("model_type")
174            .and_then(serde_json::Value::as_str)
175            .ok_or_else(|| MIError::Config("missing 'model_type' field".into()))?;
176
177        // --- Load tokenizer (best-effort: present for HF models) ---
178        let tokenizer = files
179            .get("tokenizer.json")
180            .and_then(|p| MITokenizer::from_hf_path(p).ok());
181
182        let weights_paths = resolve_safetensors_paths(&files)?;
183        let vb = create_var_builder(&weights_paths, dtype, &device)?;
184
185        match model_type {
186            #[cfg(feature = "transformer")]
187            mt if crate::config::SUPPORTED_MODEL_TYPES.contains(&mt) => {
188                use crate::config::TransformerConfig;
189                use crate::transformer::GenericTransformer;
190
191                let config = TransformerConfig::from_hf_config(&json)?;
192                let transformer = GenericTransformer::load(config, &device, dtype, vb)?;
193                Ok(Self::with_tokenizer(
194                    Box::new(transformer),
195                    device,
196                    tokenizer,
197                ))
198            }
199            #[cfg(feature = "rwkv")]
200            mt if crate::rwkv::SUPPORTED_RWKV_MODEL_TYPES.contains(&mt) => {
201                use crate::rwkv::{GenericRwkv, RwkvConfig};
202
203                let config = RwkvConfig::from_hf_config(&json)?;
204                let rwkv = GenericRwkv::load(config, &device, dtype, vb)?;
205                Ok(Self::with_tokenizer(Box::new(rwkv), device, tokenizer))
206            }
207            #[cfg(feature = "transformer")]
208            _unknown => {
209                use crate::config::TransformerConfig;
210                use crate::transformer::GenericTransformer;
211
212                // Extract tensor names for auto-config inference
213                let tensor_names = extract_tensor_names(&files)?;
214
215                // Preflight: check compatibility before attempting to load
216                TransformerConfig::check_auto_compatibility(&json, &tensor_names).into_result()?;
217
218                let config = TransformerConfig::from_hf_config_auto(&json, &tensor_names)?;
219                let transformer = GenericTransformer::load(config, &device, dtype, vb)?;
220                Ok(Self::with_tokenizer(
221                    Box::new(transformer),
222                    device,
223                    tokenizer,
224                ))
225            }
226            #[cfg(not(feature = "transformer"))]
227            other => Err(MIError::Config(format!(
228                "unsupported model_type: '{other}' (enable the `transformer` feature for auto-config)"
229            ))),
230        }
231    }
232
233    /// Select the best available device (CUDA GPU 0, or CPU fallback).
234    ///
235    /// # Errors
236    ///
237    /// Returns [`MIError::Model`] on device detection failure.
238    #[cfg(any(feature = "transformer", feature = "rwkv"))]
239    fn select_device() -> Result<Device> {
240        match Device::cuda_if_available(0) {
241            Ok(dev) => Ok(dev),
242            Err(e) => Err(MIError::Model(e)),
243        }
244    }
245
246    /// Wrap an existing backend (no tokenizer).
247    // TRAIT_OBJECT: heterogeneous model backends require dynamic dispatch
248    #[must_use]
249    pub fn new(backend: Box<dyn MIBackend>, device: Device) -> Self {
250        Self {
251            backend,
252            device,
253            tokenizer: None,
254        }
255    }
256
257    /// Wrap an existing backend with an optional tokenizer.
258    // TRAIT_OBJECT: heterogeneous model backends require dynamic dispatch
259    #[must_use]
260    pub fn with_tokenizer(
261        backend: Box<dyn MIBackend>,
262        device: Device,
263        tokenizer: Option<MITokenizer>,
264    ) -> Self {
265        Self {
266            backend,
267            device,
268            tokenizer,
269        }
270    }
271
272    /// The device this model lives on.
273    #[must_use]
274    pub const fn device(&self) -> &Device {
275        &self.device
276    }
277
278    /// The tokenizer loaded alongside the model, if available.
279    ///
280    /// Present when the model was loaded via [`from_pretrained`](Self::from_pretrained)
281    /// and a `tokenizer.json` was found in the downloaded files.
282    #[must_use]
283    pub const fn tokenizer(&self) -> Option<&MITokenizer> {
284        self.tokenizer.as_ref()
285    }
286
287    /// Number of layers.
288    #[must_use]
289    pub fn num_layers(&self) -> usize {
290        self.backend.num_layers()
291    }
292
293    /// Hidden dimension.
294    #[must_use]
295    pub fn hidden_size(&self) -> usize {
296        self.backend.hidden_size()
297    }
298
299    /// Vocabulary size.
300    #[must_use]
301    pub fn vocab_size(&self) -> usize {
302        self.backend.vocab_size()
303    }
304
305    /// Number of attention heads.
306    #[must_use]
307    pub fn num_heads(&self) -> usize {
308        self.backend.num_heads()
309    }
310
311    /// Run a forward pass with the given hook specification.
312    ///
313    /// # Shapes
314    /// - `input_ids`: `[batch, seq]` -- token IDs
315    /// - returns: [`HookCache`] containing `logits` at `[batch, seq, vocab_size]`
316    ///
317    /// # Errors
318    ///
319    /// Propagates errors from the underlying backend.
320    pub fn forward(&self, input_ids: &Tensor, hooks: &HookSpec) -> Result<HookCache> {
321        self.backend.forward(input_ids, hooks)
322    }
323
324    /// Project hidden states to vocabulary logits.
325    ///
326    /// Applies the model's final layer norm before the unembedding projection,
327    /// matching the standard logit lens technique (nostalgebraist, 2020).
328    ///
329    /// # Shapes
330    /// - `hidden`: `[batch, hidden_size]` -- hidden states (pre-norm)
331    /// - returns: `[batch, vocab_size]`
332    ///
333    /// # Errors
334    ///
335    /// Propagates errors from the underlying backend.
336    pub fn project_to_vocab(&self, hidden: &Tensor) -> Result<Tensor> {
337        self.backend.project_to_vocab(hidden)
338    }
339
340    /// Access the underlying backend (e.g., for backend-specific methods).
341    // TRAIT_OBJECT: caller needs dynamic dispatch for backend-specific methods
342    #[must_use]
343    pub fn backend(&self) -> &dyn MIBackend {
344        &*self.backend
345    }
346
347    /// Run a forward pass from text, returning both MI outputs and token
348    /// position mapping.
349    ///
350    /// Combines [`MITokenizer::encode_with_offsets`](crate::MITokenizer::encode_with_offsets)
351    /// + tensor creation + [`forward`](Self::forward) in a single call.
352    ///
353    /// The returned [`TextForwardResult`] carries the [`HookCache`]
354    /// alongside the [`EncodingWithOffsets`](crate::EncodingWithOffsets),
355    /// eliminating the need for separate encoding and manual offset
356    /// tracking.
357    ///
358    /// # Shapes
359    /// - input: text string
360    /// - returns: [`TextForwardResult`] with logits at `[1, seq, vocab_size]`
361    ///
362    /// # Errors
363    ///
364    /// Returns [`MIError::Config`] if no tokenizer is available.
365    /// Returns [`MIError::Tokenizer`] on encoding failure.
366    /// Propagates errors from the underlying backend.
367    pub fn forward_text(&self, text: &str, hooks: &HookSpec) -> Result<TextForwardResult> {
368        let tokenizer = self
369            .tokenizer()
370            .ok_or_else(|| MIError::Config("forward_text requires a tokenizer".into()))?;
371        let encoding = tokenizer.encode_with_offsets(text)?;
372        let input = Tensor::new(&encoding.ids[..], &self.device)?.unsqueeze(0)?;
373        let cache = self.forward(&input, hooks)?;
374        Ok(TextForwardResult { cache, encoding })
375    }
376}
377
378// ---------------------------------------------------------------------------
379// TextForwardResult
380// ---------------------------------------------------------------------------
381
382/// Result of a text-based forward pass, bundling MI outputs with token
383/// position mapping.
384///
385/// Returned by [`MIModel::forward_text`]. Provides access to the
386/// [`HookCache`] (logits + captured activations) alongside the
387/// [`EncodingWithOffsets`](crate::EncodingWithOffsets) (token strings,
388/// IDs, and byte offset ranges for mapping between source text positions
389/// and token indices).
390#[derive(Debug)]
391pub struct TextForwardResult {
392    /// Hook cache containing output logits and any captured activations.
393    cache: HookCache,
394    /// Token encoding with character offset mapping.
395    encoding: crate::util::positioning::EncodingWithOffsets,
396}
397
398impl TextForwardResult {
399    /// Access the hook cache (logits + captured activations).
400    #[must_use]
401    pub const fn cache(&self) -> &HookCache {
402        &self.cache
403    }
404
405    /// Consume the result and return the hook cache.
406    #[must_use]
407    pub fn into_cache(self) -> HookCache {
408        self.cache
409    }
410
411    /// Access the token encoding with character offset mapping.
412    #[must_use]
413    pub const fn encoding(&self) -> &crate::util::positioning::EncodingWithOffsets {
414        &self.encoding
415    }
416
417    /// The output tensor from the forward pass (typically logits).
418    ///
419    /// Shortcut for `self.cache().output()`.
420    ///
421    /// # Shapes
422    /// - returns: `[1, seq, vocab_size]`
423    #[must_use]
424    pub const fn output(&self) -> &Tensor {
425        self.cache.output()
426    }
427
428    /// Retrieve a captured tensor by hook point, returning an error if
429    /// not found.
430    ///
431    /// Shortcut for `self.cache().require(hook)`.
432    ///
433    /// # Errors
434    ///
435    /// Returns [`MIError::Hook`] if the hook point was not captured.
436    pub fn require(&self, hook: &crate::hooks::HookPoint) -> Result<&Tensor> {
437        self.cache.require(hook)
438    }
439
440    /// Retrieve a captured tensor by hook point.
441    ///
442    /// Shortcut for `self.cache().get(hook)`.
443    #[must_use]
444    pub fn get(&self, hook: &crate::hooks::HookPoint) -> Option<&Tensor> {
445        self.cache.get(hook)
446    }
447
448    /// The raw BPE token strings (with space-prefix markers like `Ġ`).
449    ///
450    /// Shortcut for `self.encoding().tokens`.
451    #[must_use]
452    pub fn tokens(&self) -> &[String] {
453        &self.encoding.tokens
454    }
455
456    /// Number of tokens in the encoded sequence.
457    #[must_use]
458    pub const fn seq_len(&self) -> usize {
459        self.encoding.len()
460    }
461}
462
463// ---------------------------------------------------------------------------
464// Sampling helpers
465// ---------------------------------------------------------------------------
466
467/// Sample a token from logits using the given temperature.
468///
469/// When `temperature <= 0.0`, performs greedy (argmax) decoding.
470///
471/// # Shapes
472/// - `logits`: `[vocab_size]` -- logit scores for each vocabulary token
473///
474/// # Errors
475///
476/// Returns [`MIError::Model`] if the logits tensor is empty or
477/// cannot be converted to `f32`.
478pub fn sample_token(logits: &Tensor, temperature: f32) -> Result<u32> {
479    if temperature <= 0.0 {
480        argmax(logits)
481    } else {
482        sample_with_temperature(logits, temperature)
483    }
484}
485
486/// Greedy (argmax) sampling.
487fn argmax(logits: &Tensor) -> Result<u32> {
488    let logits_f32 = logits.to_dtype(DType::F32)?;
489    let logits_vec: Vec<f32> = logits_f32.flatten_all()?.to_vec1()?;
490
491    let (max_idx, _) = logits_vec
492        .iter()
493        .enumerate()
494        .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
495        .ok_or_else(|| MIError::Model(candle_core::Error::Msg("empty logits".into())))?;
496
497    // CAST: usize → u32, vocab size fits in u32 (max ~250K tokens)
498    #[allow(clippy::cast_possible_truncation, clippy::as_conversions)]
499    Ok(max_idx as u32)
500}
501
502/// Temperature-scaled softmax sampling.
503fn sample_with_temperature(logits: &Tensor, temperature: f32) -> Result<u32> {
504    use rand::Rng;
505
506    let logits_f32 = logits.to_dtype(DType::F32)?;
507    let logits_vec: Vec<f32> = logits_f32.flatten_all()?.to_vec1()?;
508
509    if logits_vec.is_empty() {
510        return Err(MIError::Model(candle_core::Error::Msg(
511            "empty logits".into(),
512        )));
513    }
514
515    // Scale by temperature.
516    let scaled: Vec<f32> = logits_vec.iter().map(|x| x / temperature).collect();
517
518    // Numerically stable softmax.
519    let max_val = scaled.iter().copied().fold(f32::NEG_INFINITY, f32::max);
520    let exp_vals: Vec<f32> = scaled.iter().map(|x| (x - max_val).exp()).collect();
521    let sum: f32 = exp_vals.iter().sum();
522    let probs: Vec<f32> = exp_vals.iter().map(|x| x / sum).collect();
523
524    // Sample from the categorical distribution.
525    let mut rng = rand::thread_rng();
526    let r: f32 = rng.r#gen();
527    let mut cumsum = 0.0;
528    for (idx, &p) in probs.iter().enumerate() {
529        cumsum += p;
530        if r < cumsum {
531            // CAST: usize → u32, vocab index fits in u32
532            #[allow(clippy::cast_possible_truncation, clippy::as_conversions)]
533            return Ok(idx as u32);
534        }
535    }
536
537    // Fallback to last token (floating-point rounding edge case).
538    // CAST: usize → u32, vocab index fits in u32
539    #[allow(clippy::cast_possible_truncation, clippy::as_conversions)]
540    Ok((probs.len() - 1) as u32)
541}
542
543/// Extract the probability of a specific token from a logit tensor.
544///
545/// Applies softmax to the last sequence position and returns the probability
546/// at `token_id`. Useful for measuring steering effectiveness.
547///
548/// # Shapes
549/// - `logits`: `[1, seq_len, vocab]` or `[seq_len, vocab]` or `[vocab]`
550///
551/// # Errors
552///
553/// Returns [`MIError::Model`] on shape mismatch or tensor operation failure.
554pub fn extract_token_prob(logits: &Tensor, token_id: u32) -> Result<f32> {
555    use candle_core::IndexOp;
556
557    let logits_f32 = logits.to_dtype(DType::F32)?;
558
559    // Get last-position logits as a 1-D vector.
560    let last_logits = match logits_f32.dims().len() {
561        1 => logits_f32,
562        2 => {
563            let seq_len = logits_f32.dim(0)?;
564            logits_f32.i(seq_len - 1)?
565        }
566        3 => {
567            let seq_len = logits_f32.dim(1)?;
568            logits_f32.i((0, seq_len - 1))?
569        }
570        n => {
571            return Err(MIError::Model(candle_core::Error::Msg(format!(
572                "extract_token_prob: expected 1-3 dims, got {n}"
573            ))));
574        }
575    };
576
577    let probs = candle_nn::ops::softmax_last_dim(&last_logits)?;
578    // CAST: u32 → usize, token ID used as tensor index
579    #[allow(clippy::as_conversions)]
580    let prob = probs.i(token_id as usize)?.to_scalar::<f32>()?;
581    Ok(prob)
582}
583
584// ---------------------------------------------------------------------------
585// GenerationResult
586// ---------------------------------------------------------------------------
587
588/// Output of a text generation run with token-level details.
589#[derive(Debug, Clone)]
590pub struct GenerationResult {
591    /// Original prompt text.
592    pub prompt: String,
593    /// Full output (prompt + generated).
594    pub full_text: String,
595    /// Only the generated portion.
596    pub generated_text: String,
597    /// Token IDs from the prompt.
598    pub prompt_tokens: Vec<u32>,
599    /// Token IDs that were generated.
600    pub generated_tokens: Vec<u32>,
601    /// Total token count (prompt + generated).
602    pub total_tokens: usize,
603}
604
605// ---------------------------------------------------------------------------
606// Weight loading helpers (used by from_pretrained)
607// ---------------------------------------------------------------------------
608
609/// Index structure for sharded safetensors models.
610#[cfg(any(feature = "transformer", feature = "rwkv"))]
611#[derive(serde::Deserialize)]
612struct SafetensorsIndex {
613    /// Maps weight name → shard filename.
614    weight_map: std::collections::HashMap<String, String>,
615}
616
617/// Extract tensor names from a downloaded file map for auto-config inference.
618///
619/// Tries `model.safetensors.index.json` first (sharded models), falls back
620/// to reading the header of `model.safetensors` (single-file models).
621#[cfg(feature = "transformer")]
622fn extract_tensor_names(
623    files: &std::collections::HashMap<String, std::path::PathBuf>,
624) -> Result<Vec<String>> {
625    if let Some(index_path) = files.get("model.safetensors.index.json") {
626        return crate::config::tensor_names_from_index(index_path);
627    }
628    if let Some(st_path) = files.get("model.safetensors") {
629        return crate::config::tensor_names_from_safetensors(st_path);
630    }
631    Err(MIError::Config(
632        "no safetensors files found for tensor name extraction".into(),
633    ))
634}
635
636/// Resolve safetensors file paths from a downloaded file map.
637///
638/// Tries `model.safetensors.index.json` first (sharded), falls back to
639/// single `model.safetensors`.
640#[cfg(any(feature = "transformer", feature = "rwkv"))]
641fn resolve_safetensors_paths(
642    files: &std::collections::HashMap<String, std::path::PathBuf>,
643) -> Result<Vec<std::path::PathBuf>> {
644    // Try sharded first
645    if let Some(index_path) = files.get("model.safetensors.index.json") {
646        let index_str = std::fs::read_to_string(index_path)
647            .map_err(|e| MIError::Model(candle_core::Error::Msg(format!("read index: {e}"))))?;
648        let index: SafetensorsIndex = serde_json::from_str(&index_str)
649            .map_err(|e| MIError::Config(format!("parse index: {e}")))?;
650
651        // Collect unique shard filenames
652        let mut shard_names: Vec<String> = index.weight_map.values().cloned().collect();
653        shard_names.sort();
654        shard_names.dedup();
655
656        let mut paths = Vec::with_capacity(shard_names.len());
657        for shard_name in &shard_names {
658            // BORROW: explicit .as_str() — &str from String for HashMap lookup
659            let path = files.get(shard_name.as_str()).ok_or_else(|| {
660                MIError::Model(candle_core::Error::Msg(format!(
661                    "shard {shard_name} not found in downloaded files"
662                )))
663            })?;
664            // BORROW: explicit .clone() — PathBuf from HashMap value
665            paths.push(path.clone());
666        }
667        return Ok(paths);
668    }
669
670    // Single file
671    let path = files.get("model.safetensors").ok_or_else(|| {
672        MIError::Model(candle_core::Error::Msg(
673            "model.safetensors not found in downloaded files".into(),
674        ))
675    })?;
676    // BORROW: explicit .clone() — PathBuf from HashMap value
677    Ok(vec![path.clone()])
678}
679
680/// Create a `VarBuilder` from safetensors file paths.
681///
682/// Uses buffered (safe) loading by default. With the `mmap` feature,
683/// uses memory-mapped loading for reduced memory overhead on large models.
684#[cfg(any(feature = "transformer", feature = "rwkv"))]
685fn create_var_builder(
686    paths: &[std::path::PathBuf],
687    dtype: DType,
688    device: &Device,
689) -> Result<candle_nn::VarBuilder<'static>> {
690    #[cfg(feature = "mmap")]
691    {
692        mmap_var_builder(paths, dtype, device)
693    }
694    #[cfg(not(feature = "mmap"))]
695    {
696        buffered_var_builder(paths, dtype, device)
697    }
698}
699
700/// Load weights via buffered (safe) reading — reads all data into RAM.
701///
702/// Only supports single-file models. For sharded models (7B+), enable
703/// the `mmap` feature.
704#[cfg(all(any(feature = "transformer", feature = "rwkv"), not(feature = "mmap")))]
705fn buffered_var_builder(
706    paths: &[std::path::PathBuf],
707    dtype: DType,
708    device: &Device,
709) -> Result<candle_nn::VarBuilder<'static>> {
710    if paths.len() > 1 {
711        return Err(MIError::Config(format!(
712            "this model is sharded across {} files and requires the `mmap` feature.\n  \
713             Library:  candle-mi = {{ features = [\"mmap\"] }}\n  \
714             Example:  cargo run --features mmap --example <name>",
715            paths.len()
716        )));
717    }
718    let path = paths
719        .first()
720        .ok_or_else(|| MIError::Model(candle_core::Error::Msg("no safetensors files".into())))?;
721    let data = std::fs::read(path).map_err(|e| {
722        MIError::Model(candle_core::Error::Msg(format!(
723            "read {}: {e}",
724            path.display()
725        )))
726    })?;
727    let vb = candle_nn::VarBuilder::from_buffered_safetensors(data, dtype, device)?;
728    Ok(vb)
729}
730
731/// Load weights via memory-mapped files — minimal RAM overhead for large models.
732///
733/// # Safety
734///
735/// The safetensors files must not be modified while the model is loaded.
736/// This is the standard invariant for memory-mapped files.
737#[cfg(all(any(feature = "transformer", feature = "rwkv"), feature = "mmap"))]
738#[allow(unsafe_code)]
739fn mmap_var_builder(
740    paths: &[std::path::PathBuf],
741    dtype: DType,
742    device: &Device,
743) -> Result<candle_nn::VarBuilder<'static>> {
744    // SAFETY: safetensors files must not be modified while loaded.
745    let vb = unsafe { candle_nn::VarBuilder::from_mmaped_safetensors(paths, dtype, device)? };
746    Ok(vb)
747}