Skip to main content

hf_fetch_model/
inspect.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2
3//! Safetensors header inspection (local and remote).
4//!
5//! Reads tensor metadata (names, shapes, dtypes, byte offsets) from
6//! `.safetensors` files without downloading full weight data. Supports
7//! cache-first resolution with HTTP Range request fallback.
8//!
9//! The primary types are [`TensorInfo`] (per-tensor metadata),
10//! [`SafetensorsHeaderInfo`] (parsed header), and [`ShardedIndex`]
11//! (shard-to-tensor mapping for sharded models).
12
13use std::collections::HashMap;
14use std::path::{Path, PathBuf};
15
16use serde::Serialize;
17
18use crate::cache;
19use crate::chunked;
20use crate::error::FetchError;
21
22// -----------------------------------------------------------------------
23// Types
24// -----------------------------------------------------------------------
25
26/// Metadata for a single tensor from a `.safetensors` header.
27///
28/// This is hf-fetch-model's own type — lightweight, no quantization logic.
29/// Consumers (e.g., anamnesis) map this into their own richer types.
30#[derive(Debug, Clone, Serialize)]
31pub struct TensorInfo {
32    /// Tensor name (e.g., `"model.layers.0.self_attn.q_proj.weight"`).
33    pub name: String,
34    /// Element dtype string as it appears in the header (e.g., `"F8_E4M3"`, `"BF16"`).
35    pub dtype: String,
36    /// Tensor shape (e.g., `[7168, 7168]`).
37    pub shape: Vec<usize>,
38    /// Byte offset range `[start, end)` within the data section of the file.
39    pub data_offsets: (u64, u64),
40}
41
42impl TensorInfo {
43    /// Total number of elements (product of shape dimensions).
44    ///
45    /// Returns `1` for a scalar (empty shape).
46    #[must_use]
47    pub fn num_elements(&self) -> u64 {
48        self.shape.iter().fold(1u64, |acc, &d| {
49            // CAST: usize → u64, dimension values fit in u64
50            #[allow(clippy::as_conversions)]
51            let dim = d as u64;
52            acc.saturating_mul(dim)
53        })
54    }
55
56    /// Byte length of the tensor data (`end - start`).
57    #[must_use]
58    pub const fn byte_len(&self) -> u64 {
59        self.data_offsets.1.saturating_sub(self.data_offsets.0)
60    }
61
62    /// Bytes per element for the tensor's dtype, if recognized.
63    ///
64    /// Returns `None` for unknown dtype strings. Recognized dtypes:
65    ///
66    /// | Dtype string | Bytes | Notes |
67    /// |-------------|-------|-------|
68    /// | `"BOOL"` | 1 | |
69    /// | `"U8"`, `"I8"` | 1 | |
70    /// | `"F8_E4M3"`, `"F8_E5M2"` | 1 | FP8 variants |
71    /// | `"U16"`, `"I16"`, `"F16"`, `"BF16"` | 2 | |
72    /// | `"U32"`, `"I32"`, `"F32"` | 4 | |
73    /// | `"U64"`, `"I64"`, `"F64"` | 8 | |
74    #[must_use]
75    pub fn dtype_bytes(&self) -> Option<usize> {
76        // BORROW: explicit .as_str() instead of Deref coercion
77        match self.dtype.as_str() {
78            "BOOL" | "U8" | "I8" | "F8_E4M3" | "F8_E5M2" => Some(1),
79            "U16" | "I16" | "F16" | "BF16" => Some(2),
80            "U32" | "I32" | "F32" => Some(4),
81            "U64" | "I64" | "F64" => Some(8),
82            _ => None,
83        }
84    }
85}
86
87/// Parsed `.safetensors` header metadata.
88#[derive(Debug, Clone, Serialize)]
89pub struct SafetensorsHeaderInfo {
90    /// All tensors in the header, in the order they appear in the JSON.
91    pub tensors: Vec<TensorInfo>,
92    /// Raw `__metadata__` entries, if present.
93    ///
94    /// For quantized models, this typically contains entries like
95    /// `quant_method`, `bits`, `group_size` that consumers like anamnesis
96    /// use to distinguish GPTQ from AWQ without downloading weights.
97    pub metadata: Option<HashMap<String, String>>,
98    /// Size of the JSON header in bytes.
99    pub header_size: u64,
100    /// Total file size in bytes (header + data), if known.
101    ///
102    /// **Source:** for local files, from `std::fs::metadata().len()`. For HTTP
103    /// Range requests, extracted from the `Content-Range` response header of
104    /// the first request (`bytes 0-7/TOTAL` → `TOTAL`). This is free — no
105    /// extra request needed.
106    pub file_size: Option<u64>,
107}
108
109impl SafetensorsHeaderInfo {
110    /// Total parameter count across all tensors.
111    #[must_use]
112    pub fn total_params(&self) -> u64 {
113        self.tensors
114            .iter()
115            .map(TensorInfo::num_elements)
116            .fold(0u64, u64::saturating_add)
117    }
118
119    /// Returns tensors matching a dtype string (e.g., `"F8_E4M3"`).
120    #[must_use]
121    pub fn tensors_with_dtype(&self, dtype: &str) -> Vec<&TensorInfo> {
122        self.tensors
123            .iter()
124            // BORROW: explicit .as_str() instead of Deref coercion
125            .filter(|t| t.dtype.as_str() == dtype)
126            .collect()
127    }
128}
129
130/// The source from which a header was read.
131#[derive(Debug, Clone, Copy, PartialEq, Eq)]
132#[non_exhaustive]
133pub enum InspectSource {
134    /// Read from local cache (no network).
135    Cached,
136    /// Fetched via HTTP Range requests.
137    Remote,
138}
139
140/// Parsed `model.safetensors.index.json` for a sharded model.
141#[derive(Debug, Clone, Serialize)]
142pub struct ShardedIndex {
143    /// Mapping from tensor name to shard filename.
144    pub weight_map: HashMap<String, String>,
145    /// Ordered list of unique shard filenames.
146    pub shards: Vec<String>,
147    /// Raw metadata from the index, if present.
148    pub metadata: Option<HashMap<String, serde_json::Value>>,
149}
150
151/// `PEFT` adapter configuration parsed from `adapter_config.json`.
152///
153/// Contains the key fields that identify an adapter: the `PEFT` type,
154/// base model, `LoRA` rank and scaling parameters, and target modules.
155/// All fields are optional because adapter configs vary across `PEFT` methods.
156#[derive(Debug, Clone, Serialize)]
157pub struct AdapterConfig {
158    /// `PEFT` method type (e.g., `"LORA"`, `"ADALORA"`, `"IA3"`).
159    pub peft_type: Option<String>,
160    /// The base model this adapter was trained on.
161    pub base_model_name_or_path: Option<String>,
162    /// `LoRA` rank (the `r` parameter). Only meaningful for `LoRA`-family methods.
163    pub r: Option<u32>,
164    /// `LoRA` alpha scaling factor. Only meaningful for `LoRA`-family methods.
165    pub lora_alpha: Option<f64>,
166    /// List of model modules targeted by the adapter.
167    pub target_modules: Vec<String>,
168    /// Task type the adapter was trained for (e.g., `"CAUSAL_LM"`).
169    pub task_type: Option<String>,
170}
171
172// -----------------------------------------------------------------------
173// JSON parsing
174// -----------------------------------------------------------------------
175
176/// Raw tensor entry as it appears in the safetensors JSON header.
177#[derive(serde::Deserialize)]
178struct RawTensorEntry {
179    dtype: String,
180    shape: Vec<usize>,
181    data_offsets: (u64, u64),
182}
183
184/// Parsed tensor list and optional metadata from a safetensors header.
185type ParsedHeader = (Vec<TensorInfo>, Option<HashMap<String, String>>);
186
187/// Parses the safetensors JSON header bytes into tensor metadata.
188///
189/// Extracts the `__metadata__` key separately (if present).
190fn parse_header_json(json_bytes: &[u8], filename: &str) -> Result<ParsedHeader, FetchError> {
191    let raw: HashMap<String, serde_json::Value> =
192        serde_json::from_slice(json_bytes).map_err(|e| FetchError::SafetensorsHeader {
193            filename: filename.to_owned(),
194            reason: format!("failed to parse header JSON: {e}"),
195        })?;
196
197    let mut metadata: Option<HashMap<String, String>> = None;
198    let mut tensors = Vec::new();
199
200    for (key, value) in &raw {
201        if key == "__metadata__" {
202            if let Some(obj) = value.as_object() {
203                let mut meta_map = HashMap::new();
204                for (mk, mv) in obj {
205                    // BORROW: explicit .to_string() for Value → String (strips quotes from strings)
206                    let v_str = if let Some(s) = mv.as_str() {
207                        s.to_owned()
208                    } else {
209                        mv.to_string()
210                    };
211                    // BORROW: explicit .clone() for owned String
212                    meta_map.insert(mk.clone(), v_str);
213                }
214                metadata = Some(meta_map);
215            }
216            continue;
217        }
218
219        let entry: RawTensorEntry =
220            serde_json::from_value(value.clone()).map_err(|e| FetchError::SafetensorsHeader {
221                filename: filename.to_owned(),
222                reason: format!("failed to parse tensor \"{key}\": {e}"),
223            })?;
224
225        tensors.push(TensorInfo {
226            // BORROW: explicit .clone() for owned String
227            name: key.clone(),
228            dtype: entry.dtype,
229            shape: entry.shape,
230            data_offsets: entry.data_offsets,
231        });
232    }
233
234    // Sort by data offset start to preserve file order.
235    tensors.sort_by_key(|t| t.data_offsets.0);
236
237    Ok((tensors, metadata))
238}
239
240// -----------------------------------------------------------------------
241// Cache resolution
242// -----------------------------------------------------------------------
243
244/// Resolves a cached file path for a given repo, revision, and filename.
245///
246/// Returns `None` if the file is not in the local cache.
247fn resolve_cached_path(repo_id: &str, revision: &str, filename: &str) -> Option<PathBuf> {
248    let cache_dir = cache::hf_cache_dir().ok()?;
249    let repo_folder = chunked::repo_folder_name(repo_id);
250    // BORROW: explicit .as_str() instead of Deref coercion
251    let repo_dir = cache_dir.join(repo_folder.as_str());
252    let commit_hash = cache::read_ref(&repo_dir, revision)?;
253    let cached_path = repo_dir.join("snapshots").join(commit_hash).join(filename);
254    if cached_path.exists() {
255        Some(cached_path)
256    } else {
257        None
258    }
259}
260
261// -----------------------------------------------------------------------
262// Local file reading
263// -----------------------------------------------------------------------
264
265/// Inspects a single `.safetensors` file's header from a local file path.
266///
267/// Reads the first `8 + header_size` bytes from disk. Does not read tensor data.
268///
269/// # Errors
270///
271/// Returns [`FetchError::Io`] if the file cannot be read.
272/// Returns [`FetchError::SafetensorsHeader`] if the header is malformed.
273pub fn inspect_safetensors_local(path: &Path) -> Result<SafetensorsHeaderInfo, FetchError> {
274    use std::io::Read;
275
276    let file_size = std::fs::metadata(path)
277        .map_err(|e| FetchError::Io {
278            path: path.to_path_buf(),
279            source: e,
280        })?
281        .len();
282
283    // BORROW: explicit .to_string_lossy() for Path → str conversion
284    let filename = path.file_name().map_or_else(
285        || path.display().to_string(),
286        |n| n.to_string_lossy().to_string(),
287    );
288
289    let mut file = std::fs::File::open(path).map_err(|e| FetchError::Io {
290        path: path.to_path_buf(),
291        source: e,
292    })?;
293
294    // Read 8-byte header length prefix (little-endian u64).
295    let mut len_buf = [0u8; 8];
296    file.read_exact(&mut len_buf).map_err(|e| FetchError::Io {
297        path: path.to_path_buf(),
298        source: e,
299    })?;
300    let header_size = u64::from_le_bytes(len_buf);
301
302    // Sanity check: header cannot be larger than the file.
303    if header_size.saturating_add(8) > file_size {
304        return Err(FetchError::SafetensorsHeader {
305            filename,
306            reason: format!("header length {header_size} exceeds file size {file_size}"),
307        });
308    }
309
310    // Read the JSON header.
311    // CAST: u64 → usize, header size bounded by file size (checked above)
312    #[allow(clippy::cast_possible_truncation, clippy::as_conversions)]
313    let json_len = header_size as usize;
314    let mut json_buf = vec![0u8; json_len];
315    file.read_exact(&mut json_buf).map_err(|e| FetchError::Io {
316        path: path.to_path_buf(),
317        source: e,
318    })?;
319
320    // BORROW: explicit .as_str() instead of Deref coercion
321    let (tensors, metadata) = parse_header_json(&json_buf, filename.as_str())?;
322
323    Ok(SafetensorsHeaderInfo {
324        tensors,
325        metadata,
326        header_size,
327        file_size: Some(file_size),
328    })
329}
330
331// -----------------------------------------------------------------------
332// Remote fetching (HTTP Range requests)
333// -----------------------------------------------------------------------
334
335/// Fetches safetensors header bytes via two HTTP Range requests.
336///
337/// 1. `Range: bytes=0-7` → 8-byte header length (little-endian `u64`)
338/// 2. `Range: bytes=8-{8+length-1}` → JSON header
339///
340/// Returns `(json_bytes, total_file_size)`. The file size is extracted from
341/// the `Content-Range` header of the first request.
342async fn fetch_header_bytes(
343    client: &reqwest::Client,
344    url: &str,
345    filename: &str,
346) -> Result<(Vec<u8>, Option<u64>), FetchError> {
347    // Request 1: 8-byte length prefix.
348    let resp1 = client
349        .get(url)
350        .header(reqwest::header::RANGE, "bytes=0-7")
351        .send()
352        .await
353        .map_err(|e| {
354            FetchError::Http(format!("failed to fetch header length for {filename}: {e}"))
355        })?;
356
357    if !resp1.status().is_success() && resp1.status() != reqwest::StatusCode::PARTIAL_CONTENT {
358        return Err(FetchError::Http(format!(
359            "Range request for {filename} returned status {}",
360            resp1.status()
361        )));
362    }
363
364    // Extract total file size from Content-Range: bytes 0-7/{total}
365    let file_size = resp1
366        .headers()
367        .get(reqwest::header::CONTENT_RANGE)
368        .and_then(|v| v.to_str().ok())
369        .and_then(|s| s.split('/').next_back())
370        .and_then(|s| s.parse::<u64>().ok());
371
372    let len_bytes = resp1.bytes().await.map_err(|e| {
373        FetchError::Http(format!("failed to read header length for {filename}: {e}"))
374    })?;
375
376    if len_bytes.len() < 8 {
377        return Err(FetchError::SafetensorsHeader {
378            filename: filename.to_owned(),
379            reason: format!(
380                "expected 8 bytes for length prefix, got {}",
381                len_bytes.len()
382            ),
383        });
384    }
385
386    // INDEX: first 8 bytes guaranteed by length check above
387    #[allow(clippy::indexing_slicing)]
388    let header_size = u64::from_le_bytes([
389        len_bytes[0],
390        len_bytes[1],
391        len_bytes[2],
392        len_bytes[3],
393        len_bytes[4],
394        len_bytes[5],
395        len_bytes[6],
396        len_bytes[7],
397    ]);
398
399    // Request 2: JSON header.
400    let range_end = 8u64.saturating_add(header_size).saturating_sub(1);
401    let range_header = format!("bytes=8-{range_end}");
402    let resp2 = client
403        .get(url)
404        // BORROW: explicit .as_str() instead of Deref coercion
405        .header(reqwest::header::RANGE, range_header.as_str())
406        .send()
407        .await
408        .map_err(|e| {
409            FetchError::Http(format!("failed to fetch header JSON for {filename}: {e}"))
410        })?;
411
412    if !resp2.status().is_success() && resp2.status() != reqwest::StatusCode::PARTIAL_CONTENT {
413        return Err(FetchError::Http(format!(
414            "Range request for {filename} header JSON returned status {}",
415            resp2.status()
416        )));
417    }
418
419    let json_bytes = resp2
420        .bytes()
421        .await
422        .map_err(|e| FetchError::Http(format!("failed to read header JSON for {filename}: {e}")))?;
423
424    Ok((json_bytes.to_vec(), file_size))
425}
426
427// -----------------------------------------------------------------------
428// Public API: single-file inspection
429// -----------------------------------------------------------------------
430
431/// Inspects a single `.safetensors` file's header (cache-first).
432///
433/// Checks the local HF cache first. If the file is cached, reads the header
434/// from disk with zero network requests. Otherwise, falls back to two HTTP
435/// Range requests (8-byte length prefix + JSON header). Does not download
436/// tensor data in either case.
437///
438/// # Errors
439///
440/// Returns [`FetchError::Http`] if the Range requests fail.
441/// Returns [`FetchError::SafetensorsHeader`] if the header is malformed.
442pub async fn inspect_safetensors(
443    repo_id: &str,
444    filename: &str,
445    token: Option<&str>,
446    revision: Option<&str>,
447) -> Result<(SafetensorsHeaderInfo, InspectSource), FetchError> {
448    let rev = revision.unwrap_or("main");
449
450    // Try local cache first.
451    if let Some(cached_path) = resolve_cached_path(repo_id, rev, filename) {
452        let info = inspect_safetensors_local(&cached_path)?;
453        return Ok((info, InspectSource::Cached));
454    }
455
456    // Fall back to HTTP Range requests.
457    let client = chunked::build_client(token)?;
458    let url = chunked::build_download_url(repo_id, rev, filename);
459
460    // BORROW: explicit .as_str() instead of Deref coercion
461    let (json_bytes, file_size) = fetch_header_bytes(&client, url.as_str(), filename).await?;
462
463    // CAST: usize → u64, JSON buffer length is always small
464    #[allow(clippy::as_conversions)]
465    let header_size = json_bytes.len() as u64;
466
467    let (tensors, metadata) = parse_header_json(&json_bytes, filename)?;
468
469    Ok((
470        SafetensorsHeaderInfo {
471            tensors,
472            metadata,
473            header_size,
474            file_size,
475        },
476        InspectSource::Remote,
477    ))
478}
479
480/// Inspects a single `.safetensors` file from cache only.
481///
482/// Resolves the file in the local HF cache using the given `repo_id`,
483/// `revision`, and `filename`. Returns an error if the file is not cached.
484///
485/// # Errors
486///
487/// Returns [`FetchError::SafetensorsHeader`] if the file is not in the cache.
488/// Returns [`FetchError::Io`] if the cached file cannot be read.
489/// Returns [`FetchError::SafetensorsHeader`] if the header is malformed.
490pub fn inspect_safetensors_cached(
491    repo_id: &str,
492    filename: &str,
493    revision: Option<&str>,
494) -> Result<SafetensorsHeaderInfo, FetchError> {
495    let rev = revision.unwrap_or("main");
496
497    let cached_path = resolve_cached_path(repo_id, rev, filename).ok_or_else(|| {
498        FetchError::SafetensorsHeader {
499            filename: filename.to_owned(),
500            reason: format!("file not found in local cache for {repo_id} ({rev})"),
501        }
502    })?;
503
504    inspect_safetensors_local(&cached_path)
505}
506
507// -----------------------------------------------------------------------
508// Public API: multi-file inspection
509// -----------------------------------------------------------------------
510
511/// Inspects all `.safetensors` files in a repository (cache-first per file).
512///
513/// Fetches the file listing via `list_repo_files_with_metadata()`, then
514/// inspects each `.safetensors` file's header via [`inspect_safetensors()`].
515/// For each file, checks the local cache first and only makes HTTP Range
516/// requests on cache miss. Returns full per-shard headers in filename order.
517///
518/// For a lightweight summary of sharded models (tensor counts per shard
519/// without fetching individual headers), use [`fetch_shard_index()`] instead.
520///
521/// # Errors
522///
523/// Returns [`FetchError::Http`] if the metadata or Range requests fail.
524pub async fn inspect_repo_safetensors(
525    repo_id: &str,
526    token: Option<&str>,
527    revision: Option<&str>,
528) -> Result<Vec<(String, SafetensorsHeaderInfo, InspectSource)>, FetchError> {
529    let files = crate::repo::list_repo_files_with_metadata(repo_id, token, revision).await?;
530
531    let safetensors_files: Vec<String> = files
532        .into_iter()
533        .filter(|f| f.filename.ends_with(".safetensors"))
534        .map(|f| f.filename)
535        .collect();
536
537    if safetensors_files.is_empty() {
538        return Ok(Vec::new());
539    }
540
541    let semaphore = std::sync::Arc::new(tokio::sync::Semaphore::new(4));
542    let mut handles = Vec::new();
543
544    for filename in safetensors_files {
545        // BORROW: explicit .clone()/.to_owned() to move into async task
546        let sem = semaphore.clone();
547        let repo = repo_id.to_owned();
548        let tok = token.map(str::to_owned);
549        let rev = revision.map(str::to_owned);
550
551        handles.push(tokio::spawn(async move {
552            let _permit = sem
553                .acquire()
554                .await
555                .map_err(|e| FetchError::Http(format!("semaphore error: {e}")))?;
556            // BORROW: explicit .as_deref() for Option<String> → Option<&str>
557            let (info, source) =
558                inspect_safetensors(&repo, &filename, tok.as_deref(), rev.as_deref()).await?;
559            Ok::<_, FetchError>((filename, info, source))
560        }));
561    }
562
563    let mut results = Vec::new();
564    for handle in handles {
565        let result = handle
566            .await
567            .map_err(|e| FetchError::Http(format!("task join error: {e}")))?;
568        results.push(result?);
569    }
570
571    results.sort_by(|a, b| a.0.cmp(&b.0));
572
573    Ok(results)
574}
575
576/// Inspects all `.safetensors` files in a cached repository (no network).
577///
578/// Walks the snapshot directory and inspects each `.safetensors` file's
579/// header from local disk. Returns results in filename order.
580///
581/// # Errors
582///
583/// Returns [`FetchError::Io`] if the cache directory cannot be read.
584/// Returns [`FetchError::SafetensorsHeader`] if any header is malformed.
585pub fn inspect_repo_safetensors_cached(
586    repo_id: &str,
587    revision: Option<&str>,
588) -> Result<Vec<(String, SafetensorsHeaderInfo)>, FetchError> {
589    let rev = revision.unwrap_or("main");
590    let cache_dir = cache::hf_cache_dir()?;
591    let repo_folder = chunked::repo_folder_name(repo_id);
592    // BORROW: explicit .as_str() instead of Deref coercion
593    let repo_dir = cache_dir.join(repo_folder.as_str());
594
595    let Some(commit_hash) = cache::read_ref(&repo_dir, rev) else {
596        return Ok(Vec::new());
597    };
598
599    let snapshot_dir = repo_dir.join("snapshots").join(commit_hash);
600    if !snapshot_dir.exists() {
601        return Ok(Vec::new());
602    }
603
604    let mut results = Vec::new();
605    collect_safetensors_recursive(&snapshot_dir, "", &mut results)?;
606    results.sort_by(|a, b| a.0.cmp(&b.0));
607
608    Ok(results)
609}
610
611/// Recursively finds and inspects `.safetensors` files in a snapshot directory.
612fn collect_safetensors_recursive(
613    dir: &Path,
614    prefix: &str,
615    results: &mut Vec<(String, SafetensorsHeaderInfo)>,
616) -> Result<(), FetchError> {
617    let entries = std::fs::read_dir(dir).map_err(|e| FetchError::Io {
618        path: dir.to_path_buf(),
619        source: e,
620    })?;
621
622    for entry in entries {
623        let Ok(entry) = entry else { continue };
624        let path = entry.path();
625        // BORROW: explicit .to_string_lossy() for OsString → str conversion
626        let name = entry.file_name().to_string_lossy().to_string();
627
628        if path.is_dir() {
629            let child_prefix = if prefix.is_empty() {
630                name
631            } else {
632                format!("{prefix}/{name}")
633            };
634            collect_safetensors_recursive(&path, &child_prefix, results)?;
635        } else if name.ends_with(".safetensors") {
636            let filename = if prefix.is_empty() {
637                name
638            } else {
639                format!("{prefix}/{name}")
640            };
641            let info = inspect_safetensors_local(&path)?;
642            results.push((filename, info));
643        }
644    }
645
646    Ok(())
647}
648
649// -----------------------------------------------------------------------
650// Shard index
651// -----------------------------------------------------------------------
652
653/// Raw JSON structure of `model.safetensors.index.json`.
654#[derive(serde::Deserialize)]
655struct RawShardIndex {
656    weight_map: HashMap<String, String>,
657    #[serde(default)]
658    metadata: Option<HashMap<String, serde_json::Value>>,
659}
660
661/// Fetches and parses the shard index for a sharded `.safetensors` model (cache-first).
662///
663/// Returns `Ok(None)` if the repo has no `model.safetensors.index.json` (i.e.,
664/// the model is not sharded or uses a single `.safetensors` file).
665///
666/// # Errors
667///
668/// Returns [`FetchError::Http`] if the index fetch fails.
669/// Returns [`FetchError::SafetensorsHeader`] if the index JSON is malformed.
670pub async fn fetch_shard_index(
671    repo_id: &str,
672    token: Option<&str>,
673    revision: Option<&str>,
674) -> Result<Option<ShardedIndex>, FetchError> {
675    let rev = revision.unwrap_or("main");
676    let index_filename = "model.safetensors.index.json";
677
678    // Try local cache first.
679    if let Some(cached_path) = resolve_cached_path(repo_id, rev, index_filename) {
680        let content = std::fs::read_to_string(&cached_path).map_err(|e| FetchError::Io {
681            path: cached_path,
682            source: e,
683        })?;
684        let index = parse_shard_index_json(&content, repo_id)?;
685        return Ok(Some(index));
686    }
687
688    // Fall back to HTTP.
689    let client = chunked::build_client(token)?;
690    let url = chunked::build_download_url(repo_id, rev, index_filename);
691
692    // BORROW: explicit .as_str() instead of Deref coercion
693    let response =
694        client.get(url.as_str()).send().await.map_err(|e| {
695            FetchError::Http(format!("failed to fetch shard index for {repo_id}: {e}"))
696        })?;
697
698    if response.status() == reqwest::StatusCode::NOT_FOUND {
699        return Ok(None);
700    }
701
702    if !response.status().is_success() {
703        return Err(FetchError::Http(format!(
704            "shard index request for {repo_id} returned status {}",
705            response.status()
706        )));
707    }
708
709    let content = response
710        .text()
711        .await
712        .map_err(|e| FetchError::Http(format!("failed to read shard index for {repo_id}: {e}")))?;
713
714    let index = parse_shard_index_json(&content, repo_id)?;
715    Ok(Some(index))
716}
717
718/// Fetches the shard index from cache only (no network).
719///
720/// Returns `Ok(None)` if the index file is not cached.
721///
722/// # Errors
723///
724/// Returns [`FetchError::Io`] if the cached file cannot be read.
725/// Returns [`FetchError::SafetensorsHeader`] if the index JSON is malformed.
726pub fn fetch_shard_index_cached(
727    repo_id: &str,
728    revision: Option<&str>,
729) -> Result<Option<ShardedIndex>, FetchError> {
730    let rev = revision.unwrap_or("main");
731    let index_filename = "model.safetensors.index.json";
732
733    let Some(cached_path) = resolve_cached_path(repo_id, rev, index_filename) else {
734        return Ok(None);
735    };
736
737    let content = std::fs::read_to_string(&cached_path).map_err(|e| FetchError::Io {
738        path: cached_path,
739        source: e,
740    })?;
741
742    let index = parse_shard_index_json(&content, repo_id)?;
743    Ok(Some(index))
744}
745
746/// Parses shard index JSON into a `ShardedIndex`.
747fn parse_shard_index_json(content: &str, repo_id: &str) -> Result<ShardedIndex, FetchError> {
748    let raw: RawShardIndex =
749        serde_json::from_str(content).map_err(|e| FetchError::SafetensorsHeader {
750            filename: "model.safetensors.index.json".to_owned(),
751            reason: format!("failed to parse shard index for {repo_id}: {e}"),
752        })?;
753
754    // Collect unique shard filenames in sorted order.
755    let mut shard_set: Vec<String> = raw.weight_map.values().cloned().collect();
756    shard_set.sort();
757    shard_set.dedup();
758
759    Ok(ShardedIndex {
760        weight_map: raw.weight_map,
761        shards: shard_set,
762        metadata: raw.metadata,
763    })
764}
765
766// -----------------------------------------------------------------------
767// Param formatting helper
768// -----------------------------------------------------------------------
769
770/// Formats a parameter count with a compact suffix (e.g., `927.0M`, `1.02B`).
771#[must_use]
772pub fn format_params(count: u64) -> String {
773    // CAST: u64 → f64, precision loss acceptable; value is a display-only scalar
774    #[allow(clippy::cast_precision_loss, clippy::as_conversions)]
775    let val = count as f64;
776
777    if count >= 1_000_000_000 {
778        format!("{:.2}B", val / 1_000_000_000.0)
779    } else if count >= 1_000_000 {
780        format!("{:.1}M", val / 1_000_000.0)
781    } else if count >= 1_000 {
782        format!("{:.1}K", val / 1_000.0)
783    } else {
784        count.to_string()
785    }
786}
787
788// -----------------------------------------------------------------------
789// Adapter config
790// -----------------------------------------------------------------------
791
792/// Raw JSON structure of `adapter_config.json`.
793#[derive(serde::Deserialize)]
794struct RawAdapterConfig {
795    #[serde(default)]
796    peft_type: Option<String>,
797    #[serde(default)]
798    base_model_name_or_path: Option<String>,
799    #[serde(default)]
800    r: Option<u32>,
801    #[serde(default)]
802    lora_alpha: Option<f64>,
803    #[serde(default)]
804    target_modules: Option<AdapterTargetModules>,
805    #[serde(default)]
806    task_type: Option<String>,
807}
808
809/// `target_modules` in adapter configs can be a list of strings or a single string.
810#[derive(serde::Deserialize)]
811#[serde(untagged)]
812enum AdapterTargetModules {
813    /// A list of module name strings.
814    List(Vec<String>),
815    /// A single module name string.
816    Single(String),
817}
818
819/// Fetches and parses `adapter_config.json` for a `PEFT` adapter repository (cache-first).
820///
821/// Returns `Ok(None)` if the file does not exist (HTTP 404), meaning the
822/// repository is not a `PEFT` adapter.
823///
824/// # Errors
825///
826/// Returns [`FetchError::Http`] if the request fails (other than 404).
827/// Returns [`FetchError::SafetensorsHeader`] if the JSON is malformed.
828pub async fn fetch_adapter_config(
829    repo_id: &str,
830    token: Option<&str>,
831    revision: Option<&str>,
832) -> Result<Option<AdapterConfig>, FetchError> {
833    let rev = revision.unwrap_or("main");
834    let config_filename = "adapter_config.json";
835
836    // Try local cache first.
837    if let Some(cached_path) = resolve_cached_path(repo_id, rev, config_filename) {
838        let content = std::fs::read_to_string(&cached_path).map_err(|e| FetchError::Io {
839            path: cached_path,
840            source: e,
841        })?;
842        let config = parse_adapter_config_json(&content, repo_id)?;
843        return Ok(Some(config));
844    }
845
846    // Fall back to HTTP.
847    let client = chunked::build_client(token)?;
848    let url = chunked::build_download_url(repo_id, rev, config_filename);
849
850    // BORROW: explicit .as_str() instead of Deref coercion
851    let response = client.get(url.as_str()).send().await.map_err(|e| {
852        FetchError::Http(format!("failed to fetch adapter config for {repo_id}: {e}"))
853    })?;
854
855    if response.status() == reqwest::StatusCode::NOT_FOUND {
856        return Ok(None);
857    }
858
859    if !response.status().is_success() {
860        return Err(FetchError::Http(format!(
861            "adapter config request for {repo_id} returned status {}",
862            response.status()
863        )));
864    }
865
866    let content = response.text().await.map_err(|e| {
867        FetchError::Http(format!("failed to read adapter config for {repo_id}: {e}"))
868    })?;
869
870    let config = parse_adapter_config_json(&content, repo_id)?;
871    Ok(Some(config))
872}
873
874/// Fetches the adapter config from cache only (no network).
875///
876/// Returns `Ok(None)` if the file is not cached.
877///
878/// # Errors
879///
880/// Returns [`FetchError::Io`] if the cached file cannot be read.
881/// Returns [`FetchError::SafetensorsHeader`] if the JSON is malformed.
882pub fn fetch_adapter_config_cached(
883    repo_id: &str,
884    revision: Option<&str>,
885) -> Result<Option<AdapterConfig>, FetchError> {
886    let rev = revision.unwrap_or("main");
887    let config_filename = "adapter_config.json";
888
889    let Some(cached_path) = resolve_cached_path(repo_id, rev, config_filename) else {
890        return Ok(None);
891    };
892
893    let content = std::fs::read_to_string(&cached_path).map_err(|e| FetchError::Io {
894        path: cached_path,
895        source: e,
896    })?;
897
898    let config = parse_adapter_config_json(&content, repo_id)?;
899    Ok(Some(config))
900}
901
902/// Parses adapter config JSON into an [`AdapterConfig`].
903fn parse_adapter_config_json(content: &str, repo_id: &str) -> Result<AdapterConfig, FetchError> {
904    let raw: RawAdapterConfig =
905        serde_json::from_str(content).map_err(|e| FetchError::SafetensorsHeader {
906            filename: "adapter_config.json".to_owned(),
907            reason: format!("failed to parse adapter config for {repo_id}: {e}"),
908        })?;
909
910    let target_modules = match raw.target_modules {
911        Some(AdapterTargetModules::List(v)) => v,
912        Some(AdapterTargetModules::Single(s)) => vec![s],
913        None => Vec::new(),
914    };
915
916    Ok(AdapterConfig {
917        peft_type: raw.peft_type,
918        base_model_name_or_path: raw.base_model_name_or_path,
919        r: raw.r,
920        lora_alpha: raw.lora_alpha,
921        target_modules,
922        task_type: raw.task_type,
923    })
924}