Skip to main content

hf_fetch_model/
inspect.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2
3//! Safetensors header inspection (local and remote).
4//!
5//! Reads tensor metadata (names, shapes, dtypes, byte offsets) from
6//! `.safetensors` files without downloading full weight data. Supports
7//! cache-first resolution with HTTP Range request fallback.
8//!
9//! The primary types are [`TensorInfo`] (per-tensor metadata),
10//! [`SafetensorsHeaderInfo`] (parsed header), and [`ShardedIndex`]
11//! (shard-to-tensor mapping for sharded models).
12
13use std::collections::HashMap;
14use std::path::{Path, PathBuf};
15
16use serde::Serialize;
17use tokio::task::JoinSet;
18
19use crate::cache;
20use crate::cache_layout;
21use crate::chunked;
22use crate::error::FetchError;
23
24// -----------------------------------------------------------------------
25// Types
26// -----------------------------------------------------------------------
27
28/// Metadata for a single tensor from a `.safetensors` header.
29///
30/// This is hf-fetch-model's own type — lightweight, no quantization logic.
31/// Consumers (e.g., anamnesis) map this into their own richer types.
32#[derive(Debug, Clone, Serialize)]
33pub struct TensorInfo {
34    /// Tensor name (e.g., `"model.layers.0.self_attn.q_proj.weight"`).
35    pub name: String,
36    /// Element dtype string as it appears in the header (e.g., `"F8_E4M3"`, `"BF16"`).
37    pub dtype: String,
38    /// Tensor shape (e.g., `[7168, 7168]`).
39    pub shape: Vec<usize>,
40    /// Byte offset range `[start, end)` within the data section of the file.
41    pub data_offsets: (u64, u64),
42}
43
44impl TensorInfo {
45    /// Total number of elements (product of shape dimensions).
46    ///
47    /// Returns `1` for a scalar (empty shape).
48    #[must_use]
49    pub fn num_elements(&self) -> u64 {
50        self.shape.iter().fold(1u64, |acc, &d| {
51            // CAST: usize → u64, dimension values fit in u64
52            #[allow(clippy::as_conversions)]
53            let dim = d as u64;
54            acc.saturating_mul(dim)
55        })
56    }
57
58    /// Byte length of the tensor data (`end - start`).
59    #[must_use]
60    pub const fn byte_len(&self) -> u64 {
61        self.data_offsets.1.saturating_sub(self.data_offsets.0)
62    }
63
64    /// Bytes per element for the tensor's dtype, if recognized.
65    ///
66    /// Returns `None` for unknown dtype strings. Recognized dtypes:
67    ///
68    /// | Dtype string | Bytes | Notes |
69    /// |-------------|-------|-------|
70    /// | `"BOOL"` | 1 | |
71    /// | `"U8"`, `"I8"` | 1 | |
72    /// | `"F8_E4M3"`, `"F8_E5M2"` | 1 | FP8 variants |
73    /// | `"U16"`, `"I16"`, `"F16"`, `"BF16"` | 2 | |
74    /// | `"U32"`, `"I32"`, `"F32"` | 4 | |
75    /// | `"U64"`, `"I64"`, `"F64"` | 8 | |
76    #[must_use]
77    pub fn dtype_bytes(&self) -> Option<usize> {
78        // BORROW: explicit .as_str() instead of Deref coercion
79        match self.dtype.as_str() {
80            "BOOL" | "U8" | "I8" | "F8_E4M3" | "F8_E5M2" => Some(1),
81            "U16" | "I16" | "F16" | "BF16" => Some(2),
82            "U32" | "I32" | "F32" => Some(4),
83            "U64" | "I64" | "F64" => Some(8),
84            _ => None,
85        }
86    }
87}
88
89/// Parsed `.safetensors` header metadata.
90#[derive(Debug, Clone, Serialize)]
91pub struct SafetensorsHeaderInfo {
92    /// All tensors in the header, in the order they appear in the JSON.
93    pub tensors: Vec<TensorInfo>,
94    /// Raw `__metadata__` entries, if present.
95    ///
96    /// For quantized models, this typically contains entries like
97    /// `quant_method`, `bits`, `group_size` that consumers like anamnesis
98    /// use to distinguish GPTQ from AWQ without downloading weights.
99    pub metadata: Option<HashMap<String, String>>,
100    /// Size of the JSON header in bytes.
101    pub header_size: u64,
102    /// Total file size in bytes (header + data), if known.
103    ///
104    /// **Source:** for local files, from `std::fs::metadata().len()`. For HTTP
105    /// Range requests, extracted from the `Content-Range` response header of
106    /// the first request (`bytes 0-7/TOTAL` → `TOTAL`). This is free — no
107    /// extra request needed.
108    pub file_size: Option<u64>,
109}
110
111impl SafetensorsHeaderInfo {
112    /// Total parameter count across all tensors.
113    #[must_use]
114    pub fn total_params(&self) -> u64 {
115        self.tensors
116            .iter()
117            .map(TensorInfo::num_elements)
118            .fold(0u64, u64::saturating_add)
119    }
120
121    /// Returns tensors matching a dtype string (e.g., `"F8_E4M3"`).
122    #[must_use]
123    pub fn tensors_with_dtype(&self, dtype: &str) -> Vec<&TensorInfo> {
124        self.tensors
125            .iter()
126            // BORROW: explicit .as_str() instead of Deref coercion
127            .filter(|t| t.dtype.as_str() == dtype)
128            .collect()
129    }
130}
131
132/// The source from which a header was read.
133#[derive(Debug, Clone, Copy, PartialEq, Eq)]
134#[non_exhaustive]
135pub enum InspectSource {
136    /// Read from local cache (no network).
137    Cached,
138    /// Fetched via HTTP Range requests.
139    Remote,
140}
141
142/// Parsed `model.safetensors.index.json` for a sharded model.
143#[derive(Debug, Clone, Serialize)]
144pub struct ShardedIndex {
145    /// Mapping from tensor name to shard filename.
146    pub weight_map: HashMap<String, String>,
147    /// Ordered list of unique shard filenames.
148    pub shards: Vec<String>,
149    /// Raw metadata from the index, if present.
150    pub metadata: Option<HashMap<String, serde_json::Value>>,
151}
152
153/// `PEFT` adapter configuration parsed from `adapter_config.json`.
154///
155/// Contains the key fields that identify an adapter: the `PEFT` type,
156/// base model, `LoRA` rank and scaling parameters, and target modules.
157/// All fields are optional because adapter configs vary across `PEFT` methods.
158#[derive(Debug, Clone, Serialize)]
159pub struct AdapterConfig {
160    /// `PEFT` method type (e.g., `"LORA"`, `"ADALORA"`, `"IA3"`).
161    pub peft_type: Option<String>,
162    /// The base model this adapter was trained on.
163    pub base_model_name_or_path: Option<String>,
164    /// `LoRA` rank (the `r` parameter). Only meaningful for `LoRA`-family methods.
165    pub r: Option<u32>,
166    /// `LoRA` alpha scaling factor. Only meaningful for `LoRA`-family methods.
167    pub lora_alpha: Option<f64>,
168    /// List of model modules targeted by the adapter.
169    pub target_modules: Vec<String>,
170    /// Task type the adapter was trained for (e.g., `"CAUSAL_LM"`).
171    pub task_type: Option<String>,
172}
173
174// -----------------------------------------------------------------------
175// JSON parsing
176// -----------------------------------------------------------------------
177
178/// Raw tensor entry as it appears in the safetensors JSON header.
179#[derive(serde::Deserialize)]
180struct RawTensorEntry {
181    dtype: String,
182    shape: Vec<usize>,
183    data_offsets: (u64, u64),
184}
185
186/// Parsed tensor list and optional metadata from a safetensors header.
187type ParsedHeader = (Vec<TensorInfo>, Option<HashMap<String, String>>);
188
189/// Parses the safetensors JSON header bytes into tensor metadata.
190///
191/// Extracts the `__metadata__` key separately (if present).
192fn parse_header_json(json_bytes: &[u8], filename: &str) -> Result<ParsedHeader, FetchError> {
193    let raw: HashMap<String, serde_json::Value> =
194        serde_json::from_slice(json_bytes).map_err(|e| FetchError::SafetensorsHeader {
195            filename: filename.to_owned(),
196            reason: format!("failed to parse header JSON: {e}"),
197        })?;
198
199    let mut metadata: Option<HashMap<String, String>> = None;
200    let mut tensors = Vec::new();
201
202    for (key, value) in raw {
203        if key == "__metadata__" {
204            if let serde_json::Value::Object(obj) = value {
205                let mut meta_map = HashMap::new();
206                for (mk, mv) in obj {
207                    // BORROW: explicit .to_owned()/.to_string() for Value → String conversion
208                    let v_str = if let Some(s) = mv.as_str() {
209                        s.to_owned()
210                    } else {
211                        mv.to_string()
212                    };
213                    meta_map.insert(mk, v_str);
214                }
215                metadata = Some(meta_map);
216            }
217            continue;
218        }
219
220        let entry: RawTensorEntry =
221            serde_json::from_value(value).map_err(|e| FetchError::SafetensorsHeader {
222                filename: filename.to_owned(),
223                reason: format!("failed to parse tensor \"{key}\": {e}"),
224            })?;
225
226        tensors.push(TensorInfo {
227            name: key,
228            dtype: entry.dtype,
229            shape: entry.shape,
230            data_offsets: entry.data_offsets,
231        });
232    }
233
234    // Sort by data offset start to preserve file order.
235    tensors.sort_by_key(|t| t.data_offsets.0);
236
237    Ok((tensors, metadata))
238}
239
240// -----------------------------------------------------------------------
241// Cache resolution
242// -----------------------------------------------------------------------
243
244/// Resolves a cached file path for a given repo, revision, and filename.
245///
246/// Returns `None` if the file is not in the local cache.
247fn resolve_cached_path(repo_id: &str, revision: &str, filename: &str) -> Option<PathBuf> {
248    let cache_dir = cache::hf_cache_dir().ok()?;
249    let repo_dir = cache_layout::repo_dir(&cache_dir, repo_id);
250    let commit_hash = cache::read_ref(&repo_dir, revision)?;
251    let cached_path = cache_layout::pointer_path(&repo_dir, &commit_hash, filename);
252    if cached_path.exists() {
253        Some(cached_path)
254    } else {
255        None
256    }
257}
258
259// -----------------------------------------------------------------------
260// Local file reading
261// -----------------------------------------------------------------------
262
263/// Inspects a single `.safetensors` file's header from a local file path.
264///
265/// Reads the first `8 + header_size` bytes from disk. Does not read tensor data.
266///
267/// # Blocking I/O
268///
269/// This function performs synchronous filesystem I/O. In async contexts, wrap
270/// it in [`tokio::task::spawn_blocking`] so the calling task does not stall
271/// the runtime — particularly important on network-mounted caches (NFS/CIFS)
272/// where `read`/`stat` calls can take tens of milliseconds each.
273///
274/// # Errors
275///
276/// Returns [`FetchError::Io`] if the file cannot be read.
277/// Returns [`FetchError::SafetensorsHeader`] if the header is malformed.
278pub fn inspect_safetensors_local(path: &Path) -> Result<SafetensorsHeaderInfo, FetchError> {
279    use std::io::Read;
280
281    let file_size = std::fs::metadata(path)
282        .map_err(|e| FetchError::Io {
283            path: path.to_path_buf(),
284            source: e,
285        })?
286        .len();
287
288    // BORROW: explicit .to_string_lossy() for Path → str conversion
289    let filename = path.file_name().map_or_else(
290        || path.display().to_string(),
291        |n| n.to_string_lossy().to_string(),
292    );
293
294    let mut file = std::fs::File::open(path).map_err(|e| FetchError::Io {
295        path: path.to_path_buf(),
296        source: e,
297    })?;
298
299    // Read 8-byte header length prefix (little-endian u64).
300    let mut len_buf = [0u8; 8];
301    file.read_exact(&mut len_buf).map_err(|e| FetchError::Io {
302        path: path.to_path_buf(),
303        source: e,
304    })?;
305    let header_size = u64::from_le_bytes(len_buf);
306
307    // Sanity check: header cannot be larger than the file.
308    if header_size.saturating_add(8) > file_size {
309        return Err(FetchError::SafetensorsHeader {
310            filename,
311            reason: format!("header length {header_size} exceeds file size {file_size}"),
312        });
313    }
314
315    // Read the JSON header.
316    // CAST: u64 → usize, header size bounded by file size (checked above)
317    #[allow(clippy::cast_possible_truncation, clippy::as_conversions)]
318    let json_len = header_size as usize;
319    let mut json_buf = vec![0u8; json_len];
320    file.read_exact(&mut json_buf).map_err(|e| FetchError::Io {
321        path: path.to_path_buf(),
322        source: e,
323    })?;
324
325    // BORROW: explicit .as_str() instead of Deref coercion
326    let (tensors, metadata) = parse_header_json(&json_buf, filename.as_str())?;
327
328    Ok(SafetensorsHeaderInfo {
329        tensors,
330        metadata,
331        header_size,
332        file_size: Some(file_size),
333    })
334}
335
336// -----------------------------------------------------------------------
337// Remote fetching (HTTP Range requests)
338// -----------------------------------------------------------------------
339
340/// Fetches safetensors header bytes via two HTTP Range requests.
341///
342/// 1. `Range: bytes=0-7` → 8-byte header length (little-endian `u64`)
343/// 2. `Range: bytes=8-{8+length-1}` → JSON header
344///
345/// Returns `(json_bytes, total_file_size)`. The file size is extracted from
346/// the `Content-Range` header of the first request.
347async fn fetch_header_bytes(
348    client: &reqwest::Client,
349    url: &str,
350    filename: &str,
351) -> Result<(Vec<u8>, Option<u64>), FetchError> {
352    // Request 1: 8-byte length prefix.
353    let resp1 = client
354        .get(url)
355        .header(reqwest::header::RANGE, "bytes=0-7")
356        .send()
357        .await
358        .map_err(|e| {
359            FetchError::Http(format!("failed to fetch header length for {filename}: {e}"))
360        })?;
361
362    if !resp1.status().is_success() && resp1.status() != reqwest::StatusCode::PARTIAL_CONTENT {
363        return Err(FetchError::Http(format!(
364            "Range request for {filename} returned status {}",
365            resp1.status()
366        )));
367    }
368
369    // Extract total file size from Content-Range: bytes 0-7/{total}
370    let file_size = resp1
371        .headers()
372        .get(reqwest::header::CONTENT_RANGE)
373        .and_then(|v| v.to_str().ok())
374        .and_then(|s| s.split('/').next_back())
375        .and_then(|s| s.parse::<u64>().ok());
376
377    let len_bytes = resp1.bytes().await.map_err(|e| {
378        FetchError::Http(format!("failed to read header length for {filename}: {e}"))
379    })?;
380
381    if len_bytes.len() < 8 {
382        return Err(FetchError::SafetensorsHeader {
383            filename: filename.to_owned(),
384            reason: format!(
385                "expected 8 bytes for length prefix, got {}",
386                len_bytes.len()
387            ),
388        });
389    }
390
391    // INDEX: first 8 bytes guaranteed by length check above
392    #[allow(clippy::indexing_slicing)]
393    let header_size = u64::from_le_bytes([
394        len_bytes[0],
395        len_bytes[1],
396        len_bytes[2],
397        len_bytes[3],
398        len_bytes[4],
399        len_bytes[5],
400        len_bytes[6],
401        len_bytes[7],
402    ]);
403
404    // Request 2: JSON header.
405    let range_end = 8u64.saturating_add(header_size).saturating_sub(1);
406    let range_header = format!("bytes=8-{range_end}");
407    let resp2 = client
408        .get(url)
409        // BORROW: explicit .as_str() instead of Deref coercion
410        .header(reqwest::header::RANGE, range_header.as_str())
411        .send()
412        .await
413        .map_err(|e| {
414            FetchError::Http(format!("failed to fetch header JSON for {filename}: {e}"))
415        })?;
416
417    if !resp2.status().is_success() && resp2.status() != reqwest::StatusCode::PARTIAL_CONTENT {
418        return Err(FetchError::Http(format!(
419            "Range request for {filename} header JSON returned status {}",
420            resp2.status()
421        )));
422    }
423
424    let json_bytes = resp2
425        .bytes()
426        .await
427        .map_err(|e| FetchError::Http(format!("failed to read header JSON for {filename}: {e}")))?;
428
429    Ok((json_bytes.to_vec(), file_size))
430}
431
432// -----------------------------------------------------------------------
433// Public API: single-file inspection
434// -----------------------------------------------------------------------
435
436/// Inspects a single `.safetensors` file's header (cache-first).
437///
438/// Checks the local HF cache first. If the file is cached, reads the header
439/// from disk with zero network requests. Otherwise, falls back to two HTTP
440/// Range requests (8-byte length prefix + JSON header). Does not download
441/// tensor data in either case.
442///
443/// # Errors
444///
445/// Returns [`FetchError::Http`] if the Range requests fail.
446/// Returns [`FetchError::SafetensorsHeader`] if the header is malformed.
447pub async fn inspect_safetensors(
448    repo_id: &str,
449    filename: &str,
450    token: Option<&str>,
451    revision: Option<&str>,
452) -> Result<(SafetensorsHeaderInfo, InspectSource), FetchError> {
453    let rev = revision.unwrap_or("main");
454
455    // Try local cache first.
456    if let Some(cached_path) = resolve_cached_path(repo_id, rev, filename) {
457        let info = inspect_safetensors_local(&cached_path)?;
458        return Ok((info, InspectSource::Cached));
459    }
460
461    // Fall back to HTTP Range requests.
462    let client = chunked::build_client(token)?;
463    let url = chunked::build_download_url(repo_id, rev, filename);
464
465    // BORROW: explicit .as_str() instead of Deref coercion
466    let (json_bytes, file_size) = fetch_header_bytes(&client, url.as_str(), filename).await?;
467
468    // CAST: usize → u64, JSON buffer length is always small
469    #[allow(clippy::as_conversions)]
470    let header_size = json_bytes.len() as u64;
471
472    let (tensors, metadata) = parse_header_json(&json_bytes, filename)?;
473
474    Ok((
475        SafetensorsHeaderInfo {
476            tensors,
477            metadata,
478            header_size,
479            file_size,
480        },
481        InspectSource::Remote,
482    ))
483}
484
485/// Inspects a single `.safetensors` file from cache only.
486///
487/// Resolves the file in the local HF cache using the given `repo_id`,
488/// `revision`, and `filename`. Returns an error if the file is not cached.
489///
490/// # Blocking I/O
491///
492/// Performs synchronous filesystem I/O; wrap in [`tokio::task::spawn_blocking`]
493/// from async contexts. See [`inspect_safetensors_local`] for rationale.
494///
495/// # Errors
496///
497/// Returns [`FetchError::SafetensorsHeader`] if the file is not in the cache.
498/// Returns [`FetchError::Io`] if the cached file cannot be read.
499/// Returns [`FetchError::SafetensorsHeader`] if the header is malformed.
500pub fn inspect_safetensors_cached(
501    repo_id: &str,
502    filename: &str,
503    revision: Option<&str>,
504) -> Result<SafetensorsHeaderInfo, FetchError> {
505    let rev = revision.unwrap_or("main");
506
507    let cached_path = resolve_cached_path(repo_id, rev, filename).ok_or_else(|| {
508        FetchError::SafetensorsHeader {
509            filename: filename.to_owned(),
510            reason: format!("file not found in local cache for {repo_id} ({rev})"),
511        }
512    })?;
513
514    inspect_safetensors_local(&cached_path)
515}
516
517// -----------------------------------------------------------------------
518// Public API: multi-file inspection
519// -----------------------------------------------------------------------
520
521/// Inspects all `.safetensors` files in a repository (cache-first per file).
522///
523/// Fetches the file listing via `list_repo_files_with_metadata()`, then
524/// inspects each `.safetensors` file's header via [`inspect_safetensors()`].
525/// For each file, checks the local cache first and only makes HTTP Range
526/// requests on cache miss. Returns full per-shard headers in filename order.
527///
528/// For a lightweight summary of sharded models (tensor counts per shard
529/// without fetching individual headers), use [`fetch_shard_index()`] instead.
530///
531/// # Errors
532///
533/// Returns [`FetchError::Http`] if the metadata or Range requests fail.
534pub async fn inspect_repo_safetensors(
535    repo_id: &str,
536    token: Option<&str>,
537    revision: Option<&str>,
538) -> Result<Vec<(String, SafetensorsHeaderInfo, InspectSource)>, FetchError> {
539    let client = crate::chunked::build_client(token)?;
540    let files =
541        crate::repo::list_repo_files_with_metadata(repo_id, token, revision, &client).await?;
542
543    let safetensors_files: Vec<String> = files
544        .into_iter()
545        .filter(|f| f.filename.ends_with(".safetensors"))
546        .map(|f| f.filename)
547        .collect();
548
549    if safetensors_files.is_empty() {
550        return Ok(Vec::new());
551    }
552
553    let semaphore = std::sync::Arc::new(tokio::sync::Semaphore::new(4));
554    let mut join_set = JoinSet::new();
555
556    for filename in safetensors_files {
557        // BORROW: explicit .clone()/.to_owned() to move into async task
558        let sem = semaphore.clone();
559        let repo = repo_id.to_owned();
560        let tok = token.map(str::to_owned);
561        let rev = revision.map(str::to_owned);
562
563        join_set.spawn(async move {
564            let _permit = sem
565                .acquire()
566                .await
567                .map_err(|e| FetchError::Http(format!("semaphore error: {e}")))?;
568            // BORROW: explicit .as_deref() for Option<String> → Option<&str>
569            let (info, source) =
570                inspect_safetensors(&repo, &filename, tok.as_deref(), rev.as_deref()).await?;
571            Ok::<_, FetchError>((filename, info, source))
572        });
573    }
574
575    let mut results = Vec::new();
576    while let Some(join_result) = join_set.join_next().await {
577        match join_result {
578            Ok(Ok(item)) => results.push(item),
579            Ok(Err(e)) => {
580                join_set.abort_all();
581                return Err(e);
582            }
583            Err(e) => {
584                join_set.abort_all();
585                return Err(FetchError::Http(format!("task join error: {e}")));
586            }
587        }
588    }
589
590    results.sort_by(|a, b| a.0.cmp(&b.0));
591
592    Ok(results)
593}
594
595/// A `(filename, size_bytes)` enumeration of safetensors files in a repo,
596/// paired with the commit SHA of the resolved revision (when known).
597///
598/// The same tuple shape serves both local and remote listings: [`list_cached_safetensors`]
599/// produces it from a cached snapshot; `repo::list_repo_files_with_commit` filtered to
600/// `*.safetensors` produces it from the `HuggingFace` API. Callers that need a uniform
601/// view over "what safetensors can I inspect?" regardless of source use this alias.
602pub type SafetensorsListing = (Vec<(String, u64)>, Option<String>);
603
604/// Lists `.safetensors` files in the cached snapshot for `repo_id`@`revision`.
605///
606/// Returns `(entries, commit_sha)` where `entries` is a sorted list of
607/// `(filename, size_bytes)` tuples, and `commit_sha` is the snapshot's commit
608/// hash (same value stored in `refs/<revision>`). Returns empty lists when the
609/// repo or revision is not cached. Unlike [`inspect_repo_safetensors_cached`],
610/// this does **not** parse any headers — it is a cheap name-and-size enumeration
611/// intended for discovery UI (e.g. `inspect --list --cached`).
612///
613/// # Blocking I/O
614///
615/// Performs a synchronous recursive directory walk with a `stat` call per
616/// `.safetensors` entry. On local SSDs the cost is sub-millisecond; on
617/// networked caches (NFS/CIFS) a large sharded repo can take seconds. Wrap
618/// in [`tokio::task::spawn_blocking`] from async contexts.
619///
620/// # Errors
621///
622/// Returns [`FetchError::Io`] if the snapshot directory cannot be read.
623pub fn list_cached_safetensors(
624    repo_id: &str,
625    revision: Option<&str>,
626) -> Result<SafetensorsListing, FetchError> {
627    let rev = revision.unwrap_or("main");
628    let cache_dir = cache::hf_cache_dir()?;
629    let repo_dir = cache_layout::repo_dir(&cache_dir, repo_id);
630
631    let Some(commit_hash) = cache::read_ref(&repo_dir, rev) else {
632        return Ok((Vec::new(), None));
633    };
634
635    let snapshot_dir = cache_layout::snapshot_dir(&repo_dir, &commit_hash);
636    if !snapshot_dir.exists() {
637        return Ok((Vec::new(), Some(commit_hash)));
638    }
639
640    let mut results = Vec::new();
641    collect_safetensors_names_sizes(&snapshot_dir, "", &mut results)?;
642    results.sort_by(|a, b| a.0.cmp(&b.0));
643    Ok((results, Some(commit_hash)))
644}
645
646/// Recursively collects `(filename, size)` pairs for `.safetensors` files.
647fn collect_safetensors_names_sizes(
648    dir: &Path,
649    prefix: &str,
650    results: &mut Vec<(String, u64)>,
651) -> Result<(), FetchError> {
652    let entries = std::fs::read_dir(dir).map_err(|e| FetchError::Io {
653        path: dir.to_path_buf(),
654        source: e,
655    })?;
656
657    for entry in entries {
658        let Ok(entry) = entry else { continue };
659        let path = entry.path();
660        // BORROW: explicit .to_string_lossy() for OsString → str conversion
661        let name = entry.file_name().to_string_lossy().to_string();
662
663        if path.is_dir() {
664            let child_prefix = if prefix.is_empty() {
665                name
666            } else {
667                format!("{prefix}/{name}")
668            };
669            collect_safetensors_names_sizes(&path, &child_prefix, results)?;
670        } else if name.ends_with(".safetensors") {
671            let filename = if prefix.is_empty() {
672                name
673            } else {
674                format!("{prefix}/{name}")
675            };
676            let size = entry.metadata().map_or(0, |m| m.len());
677            results.push((filename, size));
678        }
679    }
680
681    Ok(())
682}
683
684/// Inspects all `.safetensors` files in a cached repository (no network).
685///
686/// Walks the snapshot directory and inspects each `.safetensors` file's
687/// header from local disk. Returns results in filename order.
688///
689/// # Blocking I/O
690///
691/// Walks the snapshot directory and reads each header synchronously. In async
692/// contexts, wrap in [`tokio::task::spawn_blocking`] to avoid stalling the
693/// runtime — multi-shard repos on network-mounted caches can take seconds.
694///
695/// # Errors
696///
697/// Returns [`FetchError::Io`] if the cache directory cannot be read.
698/// Returns [`FetchError::SafetensorsHeader`] if any header is malformed.
699pub fn inspect_repo_safetensors_cached(
700    repo_id: &str,
701    revision: Option<&str>,
702) -> Result<Vec<(String, SafetensorsHeaderInfo)>, FetchError> {
703    let rev = revision.unwrap_or("main");
704    let cache_dir = cache::hf_cache_dir()?;
705    let repo_dir = cache_layout::repo_dir(&cache_dir, repo_id);
706
707    let Some(commit_hash) = cache::read_ref(&repo_dir, rev) else {
708        return Ok(Vec::new());
709    };
710
711    let snapshot_dir = cache_layout::snapshot_dir(&repo_dir, &commit_hash);
712    if !snapshot_dir.exists() {
713        return Ok(Vec::new());
714    }
715
716    let mut results = Vec::new();
717    collect_safetensors_recursive(&snapshot_dir, "", &mut results)?;
718    results.sort_by(|a, b| a.0.cmp(&b.0));
719
720    Ok(results)
721}
722
723/// Recursively finds and inspects `.safetensors` files in a snapshot directory.
724fn collect_safetensors_recursive(
725    dir: &Path,
726    prefix: &str,
727    results: &mut Vec<(String, SafetensorsHeaderInfo)>,
728) -> Result<(), FetchError> {
729    let entries = std::fs::read_dir(dir).map_err(|e| FetchError::Io {
730        path: dir.to_path_buf(),
731        source: e,
732    })?;
733
734    for entry in entries {
735        let Ok(entry) = entry else { continue };
736        let path = entry.path();
737        // BORROW: explicit .to_string_lossy() for OsString → str conversion
738        let name = entry.file_name().to_string_lossy().to_string();
739
740        if path.is_dir() {
741            let child_prefix = if prefix.is_empty() {
742                name
743            } else {
744                format!("{prefix}/{name}")
745            };
746            collect_safetensors_recursive(&path, &child_prefix, results)?;
747        } else if name.ends_with(".safetensors") {
748            let filename = if prefix.is_empty() {
749                name
750            } else {
751                format!("{prefix}/{name}")
752            };
753            let info = inspect_safetensors_local(&path)?;
754            results.push((filename, info));
755        }
756    }
757
758    Ok(())
759}
760
761// -----------------------------------------------------------------------
762// Shard index
763// -----------------------------------------------------------------------
764
765/// Raw JSON structure of `model.safetensors.index.json`.
766#[derive(serde::Deserialize)]
767struct RawShardIndex {
768    weight_map: HashMap<String, String>,
769    #[serde(default)]
770    metadata: Option<HashMap<String, serde_json::Value>>,
771}
772
773/// Fetches and parses the shard index for a sharded `.safetensors` model (cache-first).
774///
775/// Returns `Ok(None)` if the repo has no `model.safetensors.index.json` (i.e.,
776/// the model is not sharded or uses a single `.safetensors` file).
777///
778/// # Errors
779///
780/// Returns [`FetchError::Http`] if the index fetch fails.
781/// Returns [`FetchError::SafetensorsHeader`] if the index JSON is malformed.
782pub async fn fetch_shard_index(
783    repo_id: &str,
784    token: Option<&str>,
785    revision: Option<&str>,
786) -> Result<Option<ShardedIndex>, FetchError> {
787    let rev = revision.unwrap_or("main");
788    let index_filename = "model.safetensors.index.json";
789
790    // Try local cache first.
791    if let Some(cached_path) = resolve_cached_path(repo_id, rev, index_filename) {
792        let content = std::fs::read_to_string(&cached_path).map_err(|e| FetchError::Io {
793            path: cached_path,
794            source: e,
795        })?;
796        let index = parse_shard_index_json(&content, repo_id)?;
797        return Ok(Some(index));
798    }
799
800    // Fall back to HTTP.
801    let client = chunked::build_client(token)?;
802    let url = chunked::build_download_url(repo_id, rev, index_filename);
803
804    // BORROW: explicit .as_str() instead of Deref coercion
805    let response =
806        client.get(url.as_str()).send().await.map_err(|e| {
807            FetchError::Http(format!("failed to fetch shard index for {repo_id}: {e}"))
808        })?;
809
810    if response.status() == reqwest::StatusCode::NOT_FOUND {
811        return Ok(None);
812    }
813
814    if !response.status().is_success() {
815        return Err(FetchError::Http(format!(
816            "shard index request for {repo_id} returned status {}",
817            response.status()
818        )));
819    }
820
821    let content = response
822        .text()
823        .await
824        .map_err(|e| FetchError::Http(format!("failed to read shard index for {repo_id}: {e}")))?;
825
826    let index = parse_shard_index_json(&content, repo_id)?;
827    Ok(Some(index))
828}
829
830/// Fetches the shard index from cache only (no network).
831///
832/// Returns `Ok(None)` if the index file is not cached.
833///
834/// # Errors
835///
836/// Returns [`FetchError::Io`] if the cached file cannot be read.
837/// Returns [`FetchError::SafetensorsHeader`] if the index JSON is malformed.
838pub fn fetch_shard_index_cached(
839    repo_id: &str,
840    revision: Option<&str>,
841) -> Result<Option<ShardedIndex>, FetchError> {
842    let rev = revision.unwrap_or("main");
843    let index_filename = "model.safetensors.index.json";
844
845    let Some(cached_path) = resolve_cached_path(repo_id, rev, index_filename) else {
846        return Ok(None);
847    };
848
849    let content = std::fs::read_to_string(&cached_path).map_err(|e| FetchError::Io {
850        path: cached_path,
851        source: e,
852    })?;
853
854    let index = parse_shard_index_json(&content, repo_id)?;
855    Ok(Some(index))
856}
857
858/// Parses shard index JSON into a `ShardedIndex`.
859fn parse_shard_index_json(content: &str, repo_id: &str) -> Result<ShardedIndex, FetchError> {
860    let raw: RawShardIndex =
861        serde_json::from_str(content).map_err(|e| FetchError::SafetensorsHeader {
862            filename: "model.safetensors.index.json".to_owned(),
863            reason: format!("failed to parse shard index for {repo_id}: {e}"),
864        })?;
865
866    // Collect unique shard filenames in sorted order.
867    let mut shard_set: Vec<String> = raw.weight_map.values().cloned().collect();
868    shard_set.sort();
869    shard_set.dedup();
870
871    Ok(ShardedIndex {
872        weight_map: raw.weight_map,
873        shards: shard_set,
874        metadata: raw.metadata,
875    })
876}
877
878// -----------------------------------------------------------------------
879// Param formatting helper
880// -----------------------------------------------------------------------
881
882/// Formats a parameter count with a compact suffix (e.g., `927.0M`, `1.02B`).
883#[must_use]
884pub fn format_params(count: u64) -> String {
885    // CAST: u64 → f64, precision loss acceptable; value is a display-only scalar
886    #[allow(clippy::cast_precision_loss, clippy::as_conversions)]
887    let val = count as f64;
888
889    if count >= 1_000_000_000 {
890        format!("{:.2}B", val / 1_000_000_000.0)
891    } else if count >= 1_000_000 {
892        format!("{:.1}M", val / 1_000_000.0)
893    } else if count >= 1_000 {
894        format!("{:.1}K", val / 1_000.0)
895    } else {
896        count.to_string()
897    }
898}
899
900// -----------------------------------------------------------------------
901// Adapter config
902// -----------------------------------------------------------------------
903
904/// Raw JSON structure of `adapter_config.json`.
905#[derive(serde::Deserialize)]
906struct RawAdapterConfig {
907    #[serde(default)]
908    peft_type: Option<String>,
909    #[serde(default)]
910    base_model_name_or_path: Option<String>,
911    #[serde(default)]
912    r: Option<u32>,
913    #[serde(default)]
914    lora_alpha: Option<f64>,
915    #[serde(default)]
916    target_modules: Option<AdapterTargetModules>,
917    #[serde(default)]
918    task_type: Option<String>,
919}
920
921/// `target_modules` in adapter configs can be a list of strings or a single string.
922#[derive(serde::Deserialize)]
923#[serde(untagged)]
924enum AdapterTargetModules {
925    /// A list of module name strings.
926    List(Vec<String>),
927    /// A single module name string.
928    Single(String),
929}
930
931/// Fetches and parses `adapter_config.json` for a `PEFT` adapter repository (cache-first).
932///
933/// Returns `Ok(None)` if the file does not exist (HTTP 404), meaning the
934/// repository is not a `PEFT` adapter.
935///
936/// # Errors
937///
938/// Returns [`FetchError::Http`] if the request fails (other than 404).
939/// Returns [`FetchError::SafetensorsHeader`] if the JSON is malformed.
940pub async fn fetch_adapter_config(
941    repo_id: &str,
942    token: Option<&str>,
943    revision: Option<&str>,
944) -> Result<Option<AdapterConfig>, FetchError> {
945    let rev = revision.unwrap_or("main");
946    let config_filename = "adapter_config.json";
947
948    // Try local cache first.
949    if let Some(cached_path) = resolve_cached_path(repo_id, rev, config_filename) {
950        let content = std::fs::read_to_string(&cached_path).map_err(|e| FetchError::Io {
951            path: cached_path,
952            source: e,
953        })?;
954        let config = parse_adapter_config_json(&content, repo_id)?;
955        return Ok(Some(config));
956    }
957
958    // Fall back to HTTP.
959    let client = chunked::build_client(token)?;
960    let url = chunked::build_download_url(repo_id, rev, config_filename);
961
962    // BORROW: explicit .as_str() instead of Deref coercion
963    let response = client.get(url.as_str()).send().await.map_err(|e| {
964        FetchError::Http(format!("failed to fetch adapter config for {repo_id}: {e}"))
965    })?;
966
967    if response.status() == reqwest::StatusCode::NOT_FOUND {
968        return Ok(None);
969    }
970
971    if !response.status().is_success() {
972        return Err(FetchError::Http(format!(
973            "adapter config request for {repo_id} returned status {}",
974            response.status()
975        )));
976    }
977
978    let content = response.text().await.map_err(|e| {
979        FetchError::Http(format!("failed to read adapter config for {repo_id}: {e}"))
980    })?;
981
982    let config = parse_adapter_config_json(&content, repo_id)?;
983    Ok(Some(config))
984}
985
986/// Fetches the adapter config from cache only (no network).
987///
988/// Returns `Ok(None)` if the file is not cached.
989///
990/// # Errors
991///
992/// Returns [`FetchError::Io`] if the cached file cannot be read.
993/// Returns [`FetchError::SafetensorsHeader`] if the JSON is malformed.
994pub fn fetch_adapter_config_cached(
995    repo_id: &str,
996    revision: Option<&str>,
997) -> Result<Option<AdapterConfig>, FetchError> {
998    let rev = revision.unwrap_or("main");
999    let config_filename = "adapter_config.json";
1000
1001    let Some(cached_path) = resolve_cached_path(repo_id, rev, config_filename) else {
1002        return Ok(None);
1003    };
1004
1005    let content = std::fs::read_to_string(&cached_path).map_err(|e| FetchError::Io {
1006        path: cached_path,
1007        source: e,
1008    })?;
1009
1010    let config = parse_adapter_config_json(&content, repo_id)?;
1011    Ok(Some(config))
1012}
1013
1014/// Parses adapter config JSON into an [`AdapterConfig`].
1015fn parse_adapter_config_json(content: &str, repo_id: &str) -> Result<AdapterConfig, FetchError> {
1016    let raw: RawAdapterConfig =
1017        serde_json::from_str(content).map_err(|e| FetchError::SafetensorsHeader {
1018            filename: "adapter_config.json".to_owned(),
1019            reason: format!("failed to parse adapter config for {repo_id}: {e}"),
1020        })?;
1021
1022    let target_modules = match raw.target_modules {
1023        Some(AdapterTargetModules::List(v)) => v,
1024        Some(AdapterTargetModules::Single(s)) => vec![s],
1025        None => Vec::new(),
1026    };
1027
1028    Ok(AdapterConfig {
1029        peft_type: raw.peft_type,
1030        base_model_name_or_path: raw.base_model_name_or_path,
1031        r: raw.r,
1032        lora_alpha: raw.lora_alpha,
1033        target_modules,
1034        task_type: raw.task_type,
1035    })
1036}