Skip to main content

hf_fetch_model/
inspect.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2
3//! Safetensors header inspection (local and remote).
4//!
5//! Reads tensor metadata (names, shapes, dtypes, byte offsets) from
6//! `.safetensors` files without downloading full weight data. Supports
7//! cache-first resolution with HTTP Range request fallback.
8//!
9//! The primary types are [`TensorInfo`] (per-tensor metadata),
10//! [`SafetensorsHeaderInfo`] (parsed header), and [`ShardedIndex`]
11//! (shard-to-tensor mapping for sharded models).
12
13use std::collections::HashMap;
14use std::path::{Path, PathBuf};
15
16use serde::Serialize;
17
18use crate::cache;
19use crate::chunked;
20use crate::error::FetchError;
21
22// -----------------------------------------------------------------------
23// Types
24// -----------------------------------------------------------------------
25
26/// Metadata for a single tensor from a `.safetensors` header.
27///
28/// This is hf-fetch-model's own type — lightweight, no quantization logic.
29/// Consumers (e.g., anamnesis) map this into their own richer types.
30#[derive(Debug, Clone, Serialize)]
31pub struct TensorInfo {
32    /// Tensor name (e.g., `"model.layers.0.self_attn.q_proj.weight"`).
33    pub name: String,
34    /// Element dtype string as it appears in the header (e.g., `"F8_E4M3"`, `"BF16"`).
35    pub dtype: String,
36    /// Tensor shape (e.g., `[7168, 7168]`).
37    pub shape: Vec<usize>,
38    /// Byte offset range `[start, end)` within the data section of the file.
39    pub data_offsets: (u64, u64),
40}
41
42impl TensorInfo {
43    /// Total number of elements (product of shape dimensions).
44    ///
45    /// Returns `1` for a scalar (empty shape).
46    #[must_use]
47    pub fn num_elements(&self) -> u64 {
48        self.shape.iter().fold(1u64, |acc, &d| {
49            // CAST: usize → u64, dimension values fit in u64
50            #[allow(clippy::as_conversions)]
51            let dim = d as u64;
52            acc.saturating_mul(dim)
53        })
54    }
55
56    /// Byte length of the tensor data (`end - start`).
57    #[must_use]
58    pub const fn byte_len(&self) -> u64 {
59        self.data_offsets.1.saturating_sub(self.data_offsets.0)
60    }
61
62    /// Bytes per element for the tensor's dtype, if recognized.
63    ///
64    /// Returns `None` for unknown dtype strings. Recognized dtypes:
65    ///
66    /// | Dtype string | Bytes | Notes |
67    /// |-------------|-------|-------|
68    /// | `"BOOL"` | 1 | |
69    /// | `"U8"`, `"I8"` | 1 | |
70    /// | `"F8_E4M3"`, `"F8_E5M2"` | 1 | FP8 variants |
71    /// | `"U16"`, `"I16"`, `"F16"`, `"BF16"` | 2 | |
72    /// | `"U32"`, `"I32"`, `"F32"` | 4 | |
73    /// | `"U64"`, `"I64"`, `"F64"` | 8 | |
74    #[must_use]
75    pub fn dtype_bytes(&self) -> Option<usize> {
76        // BORROW: explicit .as_str() instead of Deref coercion
77        match self.dtype.as_str() {
78            "BOOL" | "U8" | "I8" | "F8_E4M3" | "F8_E5M2" => Some(1),
79            "U16" | "I16" | "F16" | "BF16" => Some(2),
80            "U32" | "I32" | "F32" => Some(4),
81            "U64" | "I64" | "F64" => Some(8),
82            _ => None,
83        }
84    }
85}
86
87/// Parsed `.safetensors` header metadata.
88#[derive(Debug, Clone, Serialize)]
89pub struct SafetensorsHeaderInfo {
90    /// All tensors in the header, in the order they appear in the JSON.
91    pub tensors: Vec<TensorInfo>,
92    /// Raw `__metadata__` entries, if present.
93    ///
94    /// For quantized models, this typically contains entries like
95    /// `quant_method`, `bits`, `group_size` that consumers like anamnesis
96    /// use to distinguish GPTQ from AWQ without downloading weights.
97    pub metadata: Option<HashMap<String, String>>,
98    /// Size of the JSON header in bytes.
99    pub header_size: u64,
100    /// Total file size in bytes (header + data), if known.
101    ///
102    /// **Source:** for local files, from `std::fs::metadata().len()`. For HTTP
103    /// Range requests, extracted from the `Content-Range` response header of
104    /// the first request (`bytes 0-7/TOTAL` → `TOTAL`). This is free — no
105    /// extra request needed.
106    pub file_size: Option<u64>,
107}
108
109impl SafetensorsHeaderInfo {
110    /// Total parameter count across all tensors.
111    #[must_use]
112    pub fn total_params(&self) -> u64 {
113        self.tensors
114            .iter()
115            .map(TensorInfo::num_elements)
116            .fold(0u64, u64::saturating_add)
117    }
118
119    /// Returns tensors matching a dtype string (e.g., `"F8_E4M3"`).
120    #[must_use]
121    pub fn tensors_with_dtype(&self, dtype: &str) -> Vec<&TensorInfo> {
122        self.tensors
123            .iter()
124            // BORROW: explicit .as_str() instead of Deref coercion
125            .filter(|t| t.dtype.as_str() == dtype)
126            .collect()
127    }
128}
129
130/// The source from which a header was read.
131#[derive(Debug, Clone, Copy, PartialEq, Eq)]
132pub enum InspectSource {
133    /// Read from local cache (no network).
134    Cached,
135    /// Fetched via HTTP Range requests.
136    Remote,
137}
138
139/// Parsed `model.safetensors.index.json` for a sharded model.
140#[derive(Debug, Clone, Serialize)]
141pub struct ShardedIndex {
142    /// Mapping from tensor name to shard filename.
143    pub weight_map: HashMap<String, String>,
144    /// Ordered list of unique shard filenames.
145    pub shards: Vec<String>,
146    /// Raw metadata from the index, if present.
147    pub metadata: Option<HashMap<String, serde_json::Value>>,
148}
149
150/// `PEFT` adapter configuration parsed from `adapter_config.json`.
151///
152/// Contains the key fields that identify an adapter: the `PEFT` type,
153/// base model, `LoRA` rank and scaling parameters, and target modules.
154/// All fields are optional because adapter configs vary across `PEFT` methods.
155#[derive(Debug, Clone, Serialize)]
156pub struct AdapterConfig {
157    /// `PEFT` method type (e.g., `"LORA"`, `"ADALORA"`, `"IA3"`).
158    pub peft_type: Option<String>,
159    /// The base model this adapter was trained on.
160    pub base_model_name_or_path: Option<String>,
161    /// `LoRA` rank (the `r` parameter). Only meaningful for `LoRA`-family methods.
162    pub r: Option<u32>,
163    /// `LoRA` alpha scaling factor. Only meaningful for `LoRA`-family methods.
164    pub lora_alpha: Option<f64>,
165    /// List of model modules targeted by the adapter.
166    pub target_modules: Vec<String>,
167    /// Task type the adapter was trained for (e.g., `"CAUSAL_LM"`).
168    pub task_type: Option<String>,
169}
170
171// -----------------------------------------------------------------------
172// JSON parsing
173// -----------------------------------------------------------------------
174
175/// Raw tensor entry as it appears in the safetensors JSON header.
176#[derive(serde::Deserialize)]
177struct RawTensorEntry {
178    dtype: String,
179    shape: Vec<usize>,
180    data_offsets: (u64, u64),
181}
182
183/// Parsed tensor list and optional metadata from a safetensors header.
184type ParsedHeader = (Vec<TensorInfo>, Option<HashMap<String, String>>);
185
186/// Parses the safetensors JSON header bytes into tensor metadata.
187///
188/// Extracts the `__metadata__` key separately (if present).
189fn parse_header_json(json_bytes: &[u8], filename: &str) -> Result<ParsedHeader, FetchError> {
190    let raw: HashMap<String, serde_json::Value> =
191        serde_json::from_slice(json_bytes).map_err(|e| FetchError::SafetensorsHeader {
192            filename: filename.to_owned(),
193            reason: format!("failed to parse header JSON: {e}"),
194        })?;
195
196    let mut metadata: Option<HashMap<String, String>> = None;
197    let mut tensors = Vec::new();
198
199    for (key, value) in &raw {
200        if key == "__metadata__" {
201            if let Some(obj) = value.as_object() {
202                let mut meta_map = HashMap::new();
203                for (mk, mv) in obj {
204                    // BORROW: explicit .to_string() for Value → String (strips quotes from strings)
205                    let v_str = if let Some(s) = mv.as_str() {
206                        s.to_owned()
207                    } else {
208                        mv.to_string()
209                    };
210                    // BORROW: explicit .clone() for owned String
211                    meta_map.insert(mk.clone(), v_str);
212                }
213                metadata = Some(meta_map);
214            }
215            continue;
216        }
217
218        let entry: RawTensorEntry =
219            serde_json::from_value(value.clone()).map_err(|e| FetchError::SafetensorsHeader {
220                filename: filename.to_owned(),
221                reason: format!("failed to parse tensor \"{key}\": {e}"),
222            })?;
223
224        tensors.push(TensorInfo {
225            // BORROW: explicit .clone() for owned String
226            name: key.clone(),
227            dtype: entry.dtype,
228            shape: entry.shape,
229            data_offsets: entry.data_offsets,
230        });
231    }
232
233    // Sort by data offset start to preserve file order.
234    tensors.sort_by_key(|t| t.data_offsets.0);
235
236    Ok((tensors, metadata))
237}
238
239// -----------------------------------------------------------------------
240// Cache resolution
241// -----------------------------------------------------------------------
242
243/// Resolves a cached file path for a given repo, revision, and filename.
244///
245/// Returns `None` if the file is not in the local cache.
246fn resolve_cached_path(repo_id: &str, revision: &str, filename: &str) -> Option<PathBuf> {
247    let cache_dir = cache::hf_cache_dir().ok()?;
248    let repo_folder = chunked::repo_folder_name(repo_id);
249    let repo_dir = cache_dir.join(&repo_folder);
250    let commit_hash = cache::read_ref(&repo_dir, revision)?;
251    let cached_path = repo_dir.join("snapshots").join(commit_hash).join(filename);
252    if cached_path.exists() {
253        Some(cached_path)
254    } else {
255        None
256    }
257}
258
259// -----------------------------------------------------------------------
260// Local file reading
261// -----------------------------------------------------------------------
262
263/// Inspects a single `.safetensors` file's header from a local file path.
264///
265/// Reads the first `8 + header_size` bytes from disk. Does not read tensor data.
266///
267/// # Errors
268///
269/// Returns [`FetchError::Io`] if the file cannot be read.
270/// Returns [`FetchError::SafetensorsHeader`] if the header is malformed.
271pub fn inspect_safetensors_local(path: &Path) -> Result<SafetensorsHeaderInfo, FetchError> {
272    use std::io::Read;
273
274    let file_size = std::fs::metadata(path)
275        .map_err(|e| FetchError::Io {
276            path: path.to_path_buf(),
277            source: e,
278        })?
279        .len();
280
281    // BORROW: explicit .to_string_lossy() for Path → str conversion
282    let filename = path.file_name().map_or_else(
283        || path.display().to_string(),
284        |n| n.to_string_lossy().to_string(),
285    );
286
287    let mut file = std::fs::File::open(path).map_err(|e| FetchError::Io {
288        path: path.to_path_buf(),
289        source: e,
290    })?;
291
292    // Read 8-byte header length prefix (little-endian u64).
293    let mut len_buf = [0u8; 8];
294    file.read_exact(&mut len_buf).map_err(|e| FetchError::Io {
295        path: path.to_path_buf(),
296        source: e,
297    })?;
298    let header_size = u64::from_le_bytes(len_buf);
299
300    // Sanity check: header cannot be larger than the file.
301    if header_size.saturating_add(8) > file_size {
302        return Err(FetchError::SafetensorsHeader {
303            filename,
304            reason: format!("header length {header_size} exceeds file size {file_size}"),
305        });
306    }
307
308    // Read the JSON header.
309    // CAST: u64 → usize, header size bounded by file size (checked above)
310    #[allow(clippy::cast_possible_truncation, clippy::as_conversions)]
311    let json_len = header_size as usize;
312    let mut json_buf = vec![0u8; json_len];
313    file.read_exact(&mut json_buf).map_err(|e| FetchError::Io {
314        path: path.to_path_buf(),
315        source: e,
316    })?;
317
318    // BORROW: explicit .as_str() instead of Deref coercion
319    let (tensors, metadata) = parse_header_json(&json_buf, filename.as_str())?;
320
321    Ok(SafetensorsHeaderInfo {
322        tensors,
323        metadata,
324        header_size,
325        file_size: Some(file_size),
326    })
327}
328
329// -----------------------------------------------------------------------
330// Remote fetching (HTTP Range requests)
331// -----------------------------------------------------------------------
332
333/// Fetches safetensors header bytes via two HTTP Range requests.
334///
335/// 1. `Range: bytes=0-7` → 8-byte header length (little-endian `u64`)
336/// 2. `Range: bytes=8-{8+length-1}` → JSON header
337///
338/// Returns `(json_bytes, total_file_size)`. The file size is extracted from
339/// the `Content-Range` header of the first request.
340async fn fetch_header_bytes(
341    client: &reqwest::Client,
342    url: &str,
343    filename: &str,
344) -> Result<(Vec<u8>, Option<u64>), FetchError> {
345    // Request 1: 8-byte length prefix.
346    let resp1 = client
347        .get(url)
348        .header(reqwest::header::RANGE, "bytes=0-7")
349        .send()
350        .await
351        .map_err(|e| {
352            FetchError::Http(format!("failed to fetch header length for {filename}: {e}"))
353        })?;
354
355    if !resp1.status().is_success() && resp1.status() != reqwest::StatusCode::PARTIAL_CONTENT {
356        return Err(FetchError::Http(format!(
357            "Range request for {filename} returned status {}",
358            resp1.status()
359        )));
360    }
361
362    // Extract total file size from Content-Range: bytes 0-7/{total}
363    let file_size = resp1
364        .headers()
365        .get(reqwest::header::CONTENT_RANGE)
366        .and_then(|v| v.to_str().ok())
367        .and_then(|s| s.split('/').next_back())
368        .and_then(|s| s.parse::<u64>().ok());
369
370    let len_bytes = resp1.bytes().await.map_err(|e| {
371        FetchError::Http(format!("failed to read header length for {filename}: {e}"))
372    })?;
373
374    if len_bytes.len() < 8 {
375        return Err(FetchError::SafetensorsHeader {
376            filename: filename.to_owned(),
377            reason: format!(
378                "expected 8 bytes for length prefix, got {}",
379                len_bytes.len()
380            ),
381        });
382    }
383
384    // INDEX: first 8 bytes guaranteed by length check above
385    #[allow(clippy::indexing_slicing)]
386    let header_size = u64::from_le_bytes([
387        len_bytes[0],
388        len_bytes[1],
389        len_bytes[2],
390        len_bytes[3],
391        len_bytes[4],
392        len_bytes[5],
393        len_bytes[6],
394        len_bytes[7],
395    ]);
396
397    // Request 2: JSON header.
398    let range_end = 8u64.saturating_add(header_size).saturating_sub(1);
399    let range_header = format!("bytes=8-{range_end}");
400    let resp2 = client
401        .get(url)
402        // BORROW: explicit .as_str() instead of Deref coercion
403        .header(reqwest::header::RANGE, range_header.as_str())
404        .send()
405        .await
406        .map_err(|e| {
407            FetchError::Http(format!("failed to fetch header JSON for {filename}: {e}"))
408        })?;
409
410    if !resp2.status().is_success() && resp2.status() != reqwest::StatusCode::PARTIAL_CONTENT {
411        return Err(FetchError::Http(format!(
412            "Range request for {filename} header JSON returned status {}",
413            resp2.status()
414        )));
415    }
416
417    let json_bytes = resp2
418        .bytes()
419        .await
420        .map_err(|e| FetchError::Http(format!("failed to read header JSON for {filename}: {e}")))?;
421
422    Ok((json_bytes.to_vec(), file_size))
423}
424
425// -----------------------------------------------------------------------
426// Public API: single-file inspection
427// -----------------------------------------------------------------------
428
429/// Inspects a single `.safetensors` file's header (cache-first).
430///
431/// Checks the local HF cache first. If the file is cached, reads the header
432/// from disk with zero network requests. Otherwise, falls back to two HTTP
433/// Range requests (8-byte length prefix + JSON header). Does not download
434/// tensor data in either case.
435///
436/// # Errors
437///
438/// Returns [`FetchError::Http`] if the Range requests fail.
439/// Returns [`FetchError::SafetensorsHeader`] if the header is malformed.
440pub async fn inspect_safetensors(
441    repo_id: &str,
442    filename: &str,
443    token: Option<&str>,
444    revision: Option<&str>,
445) -> Result<(SafetensorsHeaderInfo, InspectSource), FetchError> {
446    let rev = revision.unwrap_or("main");
447
448    // Try local cache first.
449    if let Some(cached_path) = resolve_cached_path(repo_id, rev, filename) {
450        let info = inspect_safetensors_local(&cached_path)?;
451        return Ok((info, InspectSource::Cached));
452    }
453
454    // Fall back to HTTP Range requests.
455    let client = chunked::build_client(token)?;
456    let url = chunked::build_download_url(repo_id, rev, filename);
457
458    // BORROW: explicit .as_str() instead of Deref coercion
459    let (json_bytes, file_size) = fetch_header_bytes(&client, url.as_str(), filename).await?;
460
461    // CAST: usize → u64, JSON buffer length is always small
462    #[allow(clippy::as_conversions)]
463    let header_size = json_bytes.len() as u64;
464
465    let (tensors, metadata) = parse_header_json(&json_bytes, filename)?;
466
467    Ok((
468        SafetensorsHeaderInfo {
469            tensors,
470            metadata,
471            header_size,
472            file_size,
473        },
474        InspectSource::Remote,
475    ))
476}
477
478/// Inspects a single `.safetensors` file from cache only.
479///
480/// Resolves the file in the local HF cache using the given `repo_id`,
481/// `revision`, and `filename`. Returns an error if the file is not cached.
482///
483/// # Errors
484///
485/// Returns [`FetchError::SafetensorsHeader`] if the file is not in the cache.
486/// Returns [`FetchError::Io`] if the cached file cannot be read.
487/// Returns [`FetchError::SafetensorsHeader`] if the header is malformed.
488pub fn inspect_safetensors_cached(
489    repo_id: &str,
490    filename: &str,
491    revision: Option<&str>,
492) -> Result<SafetensorsHeaderInfo, FetchError> {
493    let rev = revision.unwrap_or("main");
494
495    let cached_path = resolve_cached_path(repo_id, rev, filename).ok_or_else(|| {
496        FetchError::SafetensorsHeader {
497            filename: filename.to_owned(),
498            reason: format!("file not found in local cache for {repo_id} ({rev})"),
499        }
500    })?;
501
502    inspect_safetensors_local(&cached_path)
503}
504
505// -----------------------------------------------------------------------
506// Public API: multi-file inspection
507// -----------------------------------------------------------------------
508
509/// Inspects all `.safetensors` files in a repository (cache-first per file).
510///
511/// Fetches the file listing via `list_repo_files_with_metadata()`, then
512/// inspects each `.safetensors` file's header via [`inspect_safetensors()`].
513/// For each file, checks the local cache first and only makes HTTP Range
514/// requests on cache miss. Returns full per-shard headers in filename order.
515///
516/// For a lightweight summary of sharded models (tensor counts per shard
517/// without fetching individual headers), use [`fetch_shard_index()`] instead.
518///
519/// # Errors
520///
521/// Returns [`FetchError::Http`] if the metadata or Range requests fail.
522pub async fn inspect_repo_safetensors(
523    repo_id: &str,
524    token: Option<&str>,
525    revision: Option<&str>,
526) -> Result<Vec<(String, SafetensorsHeaderInfo, InspectSource)>, FetchError> {
527    let files = crate::repo::list_repo_files_with_metadata(repo_id, token, revision).await?;
528
529    let safetensors_files: Vec<String> = files
530        .into_iter()
531        .filter(|f| f.filename.ends_with(".safetensors"))
532        .map(|f| f.filename)
533        .collect();
534
535    if safetensors_files.is_empty() {
536        return Ok(Vec::new());
537    }
538
539    let semaphore = std::sync::Arc::new(tokio::sync::Semaphore::new(4));
540    let mut handles = Vec::new();
541
542    for filename in safetensors_files {
543        let sem = semaphore.clone();
544        let repo = repo_id.to_owned();
545        let tok = token.map(str::to_owned);
546        let rev = revision.map(str::to_owned);
547
548        handles.push(tokio::spawn(async move {
549            let _permit = sem
550                .acquire()
551                .await
552                .map_err(|e| FetchError::Http(format!("semaphore error: {e}")))?;
553            // BORROW: explicit .as_deref() for Option<String> → Option<&str>
554            let (info, source) =
555                inspect_safetensors(&repo, &filename, tok.as_deref(), rev.as_deref()).await?;
556            Ok::<_, FetchError>((filename, info, source))
557        }));
558    }
559
560    let mut results = Vec::new();
561    for handle in handles {
562        let result = handle
563            .await
564            .map_err(|e| FetchError::Http(format!("task join error: {e}")))?;
565        results.push(result?);
566    }
567
568    results.sort_by(|a, b| a.0.cmp(&b.0));
569
570    Ok(results)
571}
572
573/// Inspects all `.safetensors` files in a cached repository (no network).
574///
575/// Walks the snapshot directory and inspects each `.safetensors` file's
576/// header from local disk. Returns results in filename order.
577///
578/// # Errors
579///
580/// Returns [`FetchError::Io`] if the cache directory cannot be read.
581/// Returns [`FetchError::SafetensorsHeader`] if any header is malformed.
582pub fn inspect_repo_safetensors_cached(
583    repo_id: &str,
584    revision: Option<&str>,
585) -> Result<Vec<(String, SafetensorsHeaderInfo)>, FetchError> {
586    let rev = revision.unwrap_or("main");
587    let cache_dir = cache::hf_cache_dir()?;
588    let repo_folder = chunked::repo_folder_name(repo_id);
589    let repo_dir = cache_dir.join(&repo_folder);
590
591    let Some(commit_hash) = cache::read_ref(&repo_dir, rev) else {
592        return Ok(Vec::new());
593    };
594
595    let snapshot_dir = repo_dir.join("snapshots").join(commit_hash);
596    if !snapshot_dir.exists() {
597        return Ok(Vec::new());
598    }
599
600    let mut results = Vec::new();
601    collect_safetensors_recursive(&snapshot_dir, "", &mut results)?;
602    results.sort_by(|a, b| a.0.cmp(&b.0));
603
604    Ok(results)
605}
606
607/// Recursively finds and inspects `.safetensors` files in a snapshot directory.
608fn collect_safetensors_recursive(
609    dir: &Path,
610    prefix: &str,
611    results: &mut Vec<(String, SafetensorsHeaderInfo)>,
612) -> Result<(), FetchError> {
613    let entries = std::fs::read_dir(dir).map_err(|e| FetchError::Io {
614        path: dir.to_path_buf(),
615        source: e,
616    })?;
617
618    for entry in entries {
619        let Ok(entry) = entry else { continue };
620        let path = entry.path();
621        // BORROW: explicit .to_string_lossy() for OsString → str conversion
622        let name = entry.file_name().to_string_lossy().to_string();
623
624        if path.is_dir() {
625            let child_prefix = if prefix.is_empty() {
626                name
627            } else {
628                format!("{prefix}/{name}")
629            };
630            collect_safetensors_recursive(&path, &child_prefix, results)?;
631        } else if name.ends_with(".safetensors") {
632            let filename = if prefix.is_empty() {
633                name
634            } else {
635                format!("{prefix}/{name}")
636            };
637            let info = inspect_safetensors_local(&path)?;
638            results.push((filename, info));
639        }
640    }
641
642    Ok(())
643}
644
645// -----------------------------------------------------------------------
646// Shard index
647// -----------------------------------------------------------------------
648
649/// Raw JSON structure of `model.safetensors.index.json`.
650#[derive(serde::Deserialize)]
651struct RawShardIndex {
652    weight_map: HashMap<String, String>,
653    #[serde(default)]
654    metadata: Option<HashMap<String, serde_json::Value>>,
655}
656
657/// Fetches and parses the shard index for a sharded `.safetensors` model (cache-first).
658///
659/// Returns `Ok(None)` if the repo has no `model.safetensors.index.json` (i.e.,
660/// the model is not sharded or uses a single `.safetensors` file).
661///
662/// # Errors
663///
664/// Returns [`FetchError::Http`] if the index fetch fails.
665/// Returns [`FetchError::SafetensorsHeader`] if the index JSON is malformed.
666pub async fn fetch_shard_index(
667    repo_id: &str,
668    token: Option<&str>,
669    revision: Option<&str>,
670) -> Result<Option<ShardedIndex>, FetchError> {
671    let rev = revision.unwrap_or("main");
672    let index_filename = "model.safetensors.index.json";
673
674    // Try local cache first.
675    if let Some(cached_path) = resolve_cached_path(repo_id, rev, index_filename) {
676        let content = std::fs::read_to_string(&cached_path).map_err(|e| FetchError::Io {
677            path: cached_path,
678            source: e,
679        })?;
680        let index = parse_shard_index_json(&content, repo_id)?;
681        return Ok(Some(index));
682    }
683
684    // Fall back to HTTP.
685    let client = chunked::build_client(token)?;
686    let url = chunked::build_download_url(repo_id, rev, index_filename);
687
688    // BORROW: explicit .as_str() instead of Deref coercion
689    let response =
690        client.get(url.as_str()).send().await.map_err(|e| {
691            FetchError::Http(format!("failed to fetch shard index for {repo_id}: {e}"))
692        })?;
693
694    if response.status() == reqwest::StatusCode::NOT_FOUND {
695        return Ok(None);
696    }
697
698    if !response.status().is_success() {
699        return Err(FetchError::Http(format!(
700            "shard index request for {repo_id} returned status {}",
701            response.status()
702        )));
703    }
704
705    let content = response
706        .text()
707        .await
708        .map_err(|e| FetchError::Http(format!("failed to read shard index for {repo_id}: {e}")))?;
709
710    let index = parse_shard_index_json(&content, repo_id)?;
711    Ok(Some(index))
712}
713
714/// Fetches the shard index from cache only (no network).
715///
716/// Returns `Ok(None)` if the index file is not cached.
717///
718/// # Errors
719///
720/// Returns [`FetchError::Io`] if the cached file cannot be read.
721/// Returns [`FetchError::SafetensorsHeader`] if the index JSON is malformed.
722pub fn fetch_shard_index_cached(
723    repo_id: &str,
724    revision: Option<&str>,
725) -> Result<Option<ShardedIndex>, FetchError> {
726    let rev = revision.unwrap_or("main");
727    let index_filename = "model.safetensors.index.json";
728
729    let Some(cached_path) = resolve_cached_path(repo_id, rev, index_filename) else {
730        return Ok(None);
731    };
732
733    let content = std::fs::read_to_string(&cached_path).map_err(|e| FetchError::Io {
734        path: cached_path,
735        source: e,
736    })?;
737
738    let index = parse_shard_index_json(&content, repo_id)?;
739    Ok(Some(index))
740}
741
742/// Parses shard index JSON into a `ShardedIndex`.
743fn parse_shard_index_json(content: &str, repo_id: &str) -> Result<ShardedIndex, FetchError> {
744    let raw: RawShardIndex =
745        serde_json::from_str(content).map_err(|e| FetchError::SafetensorsHeader {
746            filename: "model.safetensors.index.json".to_owned(),
747            reason: format!("failed to parse shard index for {repo_id}: {e}"),
748        })?;
749
750    // Collect unique shard filenames in sorted order.
751    let mut shard_set: Vec<String> = raw.weight_map.values().cloned().collect();
752    shard_set.sort();
753    shard_set.dedup();
754
755    Ok(ShardedIndex {
756        weight_map: raw.weight_map,
757        shards: shard_set,
758        metadata: raw.metadata,
759    })
760}
761
762// -----------------------------------------------------------------------
763// Param formatting helper
764// -----------------------------------------------------------------------
765
766/// Formats a parameter count with a compact suffix (e.g., `927.0M`, `1.02B`).
767#[must_use]
768pub fn format_params(count: u64) -> String {
769    // CAST: u64 → f64, precision loss acceptable; value is a display-only scalar
770    #[allow(clippy::cast_precision_loss, clippy::as_conversions)]
771    let val = count as f64;
772
773    if count >= 1_000_000_000 {
774        format!("{:.2}B", val / 1_000_000_000.0)
775    } else if count >= 1_000_000 {
776        format!("{:.1}M", val / 1_000_000.0)
777    } else if count >= 1_000 {
778        format!("{:.1}K", val / 1_000.0)
779    } else {
780        count.to_string()
781    }
782}
783
784// -----------------------------------------------------------------------
785// Adapter config
786// -----------------------------------------------------------------------
787
788/// Raw JSON structure of `adapter_config.json`.
789#[derive(serde::Deserialize)]
790struct RawAdapterConfig {
791    #[serde(default)]
792    peft_type: Option<String>,
793    #[serde(default)]
794    base_model_name_or_path: Option<String>,
795    #[serde(default)]
796    r: Option<u32>,
797    #[serde(default)]
798    lora_alpha: Option<f64>,
799    #[serde(default)]
800    target_modules: Option<AdapterTargetModules>,
801    #[serde(default)]
802    task_type: Option<String>,
803}
804
805/// `target_modules` in adapter configs can be a list of strings or a single string.
806#[derive(serde::Deserialize)]
807#[serde(untagged)]
808enum AdapterTargetModules {
809    /// A list of module name strings.
810    List(Vec<String>),
811    /// A single module name string.
812    Single(String),
813}
814
815/// Fetches and parses `adapter_config.json` for a `PEFT` adapter repository (cache-first).
816///
817/// Returns `Ok(None)` if the file does not exist (HTTP 404), meaning the
818/// repository is not a `PEFT` adapter.
819///
820/// # Errors
821///
822/// Returns [`FetchError::Http`] if the request fails (other than 404).
823/// Returns [`FetchError::SafetensorsHeader`] if the JSON is malformed.
824pub async fn fetch_adapter_config(
825    repo_id: &str,
826    token: Option<&str>,
827    revision: Option<&str>,
828) -> Result<Option<AdapterConfig>, FetchError> {
829    let rev = revision.unwrap_or("main");
830    let config_filename = "adapter_config.json";
831
832    // Try local cache first.
833    if let Some(cached_path) = resolve_cached_path(repo_id, rev, config_filename) {
834        let content = std::fs::read_to_string(&cached_path).map_err(|e| FetchError::Io {
835            path: cached_path,
836            source: e,
837        })?;
838        let config = parse_adapter_config_json(&content, repo_id)?;
839        return Ok(Some(config));
840    }
841
842    // Fall back to HTTP.
843    let client = chunked::build_client(token)?;
844    let url = chunked::build_download_url(repo_id, rev, config_filename);
845
846    // BORROW: explicit .as_str() instead of Deref coercion
847    let response = client.get(url.as_str()).send().await.map_err(|e| {
848        FetchError::Http(format!("failed to fetch adapter config for {repo_id}: {e}"))
849    })?;
850
851    if response.status() == reqwest::StatusCode::NOT_FOUND {
852        return Ok(None);
853    }
854
855    if !response.status().is_success() {
856        return Err(FetchError::Http(format!(
857            "adapter config request for {repo_id} returned status {}",
858            response.status()
859        )));
860    }
861
862    let content = response.text().await.map_err(|e| {
863        FetchError::Http(format!("failed to read adapter config for {repo_id}: {e}"))
864    })?;
865
866    let config = parse_adapter_config_json(&content, repo_id)?;
867    Ok(Some(config))
868}
869
870/// Fetches the adapter config from cache only (no network).
871///
872/// Returns `Ok(None)` if the file is not cached.
873///
874/// # Errors
875///
876/// Returns [`FetchError::Io`] if the cached file cannot be read.
877/// Returns [`FetchError::SafetensorsHeader`] if the JSON is malformed.
878pub fn fetch_adapter_config_cached(
879    repo_id: &str,
880    revision: Option<&str>,
881) -> Result<Option<AdapterConfig>, FetchError> {
882    let rev = revision.unwrap_or("main");
883    let config_filename = "adapter_config.json";
884
885    let Some(cached_path) = resolve_cached_path(repo_id, rev, config_filename) else {
886        return Ok(None);
887    };
888
889    let content = std::fs::read_to_string(&cached_path).map_err(|e| FetchError::Io {
890        path: cached_path,
891        source: e,
892    })?;
893
894    let config = parse_adapter_config_json(&content, repo_id)?;
895    Ok(Some(config))
896}
897
898/// Parses adapter config JSON into an [`AdapterConfig`].
899fn parse_adapter_config_json(content: &str, repo_id: &str) -> Result<AdapterConfig, FetchError> {
900    let raw: RawAdapterConfig =
901        serde_json::from_str(content).map_err(|e| FetchError::SafetensorsHeader {
902            filename: "adapter_config.json".to_owned(),
903            reason: format!("failed to parse adapter config for {repo_id}: {e}"),
904        })?;
905
906    let target_modules = match raw.target_modules {
907        Some(AdapterTargetModules::List(v)) => v,
908        Some(AdapterTargetModules::Single(s)) => vec![s],
909        None => Vec::new(),
910    };
911
912    Ok(AdapterConfig {
913        peft_type: raw.peft_type,
914        base_model_name_or_path: raw.base_model_name_or_path,
915        r: raw.r,
916        lora_alpha: raw.lora_alpha,
917        target_modules,
918        task_type: raw.task_type,
919    })
920}