Skip to main content

hf_fetch_model/
inspect.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2
3//! Safetensors header inspection (local and remote).
4//!
5//! Reads tensor metadata (names, shapes, dtypes, byte offsets) from
6//! `.safetensors` files without downloading full weight data. Supports
7//! cache-first resolution with HTTP Range request fallback.
8//!
9//! The primary types are [`TensorInfo`] (per-tensor metadata),
10//! [`SafetensorsHeaderInfo`] (parsed header), and [`ShardedIndex`]
11//! (shard-to-tensor mapping for sharded models).
12
13use std::collections::HashMap;
14use std::path::{Path, PathBuf};
15
16use serde::Serialize;
17use tokio::task::JoinSet;
18
19use crate::cache;
20use crate::cache_layout;
21use crate::chunked;
22use crate::error::FetchError;
23
24// -----------------------------------------------------------------------
25// Types
26// -----------------------------------------------------------------------
27
28/// Metadata for a single tensor from a `.safetensors` header.
29///
30/// This is hf-fetch-model's own type — lightweight, no quantization logic.
31/// Consumers (e.g., anamnesis) map this into their own richer types.
32#[derive(Debug, Clone, Serialize)]
33pub struct TensorInfo {
34    /// Tensor name (e.g., `"model.layers.0.self_attn.q_proj.weight"`).
35    pub name: String,
36    /// Element dtype string as it appears in the header (e.g., `"F8_E4M3"`, `"BF16"`).
37    pub dtype: String,
38    /// Tensor shape (e.g., `[7168, 7168]`).
39    pub shape: Vec<usize>,
40    /// Byte offset range `[start, end)` within the data section of the file.
41    pub data_offsets: (u64, u64),
42}
43
44impl TensorInfo {
45    /// Total number of elements (product of shape dimensions).
46    ///
47    /// Returns `1` for a scalar (empty shape).
48    #[must_use]
49    pub fn num_elements(&self) -> u64 {
50        self.shape.iter().fold(1u64, |acc, &d| {
51            // CAST: usize → u64, dimension values fit in u64
52            #[allow(clippy::as_conversions)]
53            let dim = d as u64;
54            acc.saturating_mul(dim)
55        })
56    }
57
58    /// Byte length of the tensor data (`end - start`).
59    #[must_use]
60    pub const fn byte_len(&self) -> u64 {
61        self.data_offsets.1.saturating_sub(self.data_offsets.0)
62    }
63
64    /// Bytes per element for the tensor's dtype, if recognized.
65    ///
66    /// Returns `None` for unknown dtype strings. Recognized dtypes:
67    ///
68    /// | Dtype string | Bytes | Notes |
69    /// |-------------|-------|-------|
70    /// | `"BOOL"` | 1 | |
71    /// | `"U8"`, `"I8"` | 1 | |
72    /// | `"F8_E4M3"`, `"F8_E5M2"` | 1 | FP8 variants |
73    /// | `"U16"`, `"I16"`, `"F16"`, `"BF16"` | 2 | |
74    /// | `"U32"`, `"I32"`, `"F32"` | 4 | |
75    /// | `"U64"`, `"I64"`, `"F64"` | 8 | |
76    #[must_use]
77    pub fn dtype_bytes(&self) -> Option<usize> {
78        // BORROW: explicit .as_str() instead of Deref coercion
79        match self.dtype.as_str() {
80            "BOOL" | "U8" | "I8" | "F8_E4M3" | "F8_E5M2" => Some(1),
81            "U16" | "I16" | "F16" | "BF16" => Some(2),
82            "U32" | "I32" | "F32" => Some(4),
83            "U64" | "I64" | "F64" => Some(8),
84            _ => None,
85        }
86    }
87}
88
89/// Parsed `.safetensors` header metadata.
90#[derive(Debug, Clone, Serialize)]
91pub struct SafetensorsHeaderInfo {
92    /// All tensors in the header, in the order they appear in the JSON.
93    pub tensors: Vec<TensorInfo>,
94    /// Raw `__metadata__` entries, if present.
95    ///
96    /// For quantized models, this typically contains entries like
97    /// `quant_method`, `bits`, `group_size` that consumers like anamnesis
98    /// use to distinguish GPTQ from AWQ without downloading weights.
99    pub metadata: Option<HashMap<String, String>>,
100    /// Size of the JSON header in bytes.
101    pub header_size: u64,
102    /// Total file size in bytes (header + data), if known.
103    ///
104    /// **Source:** for local files, from `std::fs::metadata().len()`. For HTTP
105    /// Range requests, extracted from the `Content-Range` response header of
106    /// the first request (`bytes 0-7/TOTAL` → `TOTAL`). This is free — no
107    /// extra request needed.
108    pub file_size: Option<u64>,
109}
110
111impl SafetensorsHeaderInfo {
112    /// Total parameter count across all tensors.
113    #[must_use]
114    pub fn total_params(&self) -> u64 {
115        self.tensors
116            .iter()
117            .map(TensorInfo::num_elements)
118            .fold(0u64, u64::saturating_add)
119    }
120
121    /// Returns tensors matching a dtype string (e.g., `"F8_E4M3"`).
122    #[must_use]
123    pub fn tensors_with_dtype(&self, dtype: &str) -> Vec<&TensorInfo> {
124        self.tensors
125            .iter()
126            // BORROW: explicit .as_str() instead of Deref coercion
127            .filter(|t| t.dtype.as_str() == dtype)
128            .collect()
129    }
130}
131
132/// The source from which a header was read.
133#[derive(Debug, Clone, Copy, PartialEq, Eq)]
134#[non_exhaustive]
135pub enum InspectSource {
136    /// Read from local cache (no network).
137    Cached,
138    /// Fetched via HTTP Range requests.
139    Remote,
140}
141
142/// Parsed `model.safetensors.index.json` for a sharded model.
143#[derive(Debug, Clone, Serialize)]
144pub struct ShardedIndex {
145    /// Mapping from tensor name to shard filename.
146    pub weight_map: HashMap<String, String>,
147    /// Ordered list of unique shard filenames.
148    pub shards: Vec<String>,
149    /// Raw metadata from the index, if present.
150    pub metadata: Option<HashMap<String, serde_json::Value>>,
151}
152
153/// `PEFT` adapter configuration parsed from `adapter_config.json`.
154///
155/// Contains the key fields that identify an adapter: the `PEFT` type,
156/// base model, `LoRA` rank and scaling parameters, and target modules.
157/// All fields are optional because adapter configs vary across `PEFT` methods.
158#[derive(Debug, Clone, Serialize)]
159pub struct AdapterConfig {
160    /// `PEFT` method type (e.g., `"LORA"`, `"ADALORA"`, `"IA3"`).
161    pub peft_type: Option<String>,
162    /// The base model this adapter was trained on.
163    pub base_model_name_or_path: Option<String>,
164    /// `LoRA` rank (the `r` parameter). Only meaningful for `LoRA`-family methods.
165    pub r: Option<u32>,
166    /// `LoRA` alpha scaling factor. Only meaningful for `LoRA`-family methods.
167    pub lora_alpha: Option<f64>,
168    /// List of model modules targeted by the adapter.
169    pub target_modules: Vec<String>,
170    /// Task type the adapter was trained for (e.g., `"CAUSAL_LM"`).
171    pub task_type: Option<String>,
172}
173
174// -----------------------------------------------------------------------
175// JSON parsing
176// -----------------------------------------------------------------------
177
178/// Raw tensor entry as it appears in the safetensors JSON header.
179#[derive(serde::Deserialize)]
180struct RawTensorEntry {
181    dtype: String,
182    shape: Vec<usize>,
183    data_offsets: (u64, u64),
184}
185
186/// Parsed tensor list and optional metadata from a safetensors header.
187type ParsedHeader = (Vec<TensorInfo>, Option<HashMap<String, String>>);
188
189/// Parses the safetensors JSON header bytes into tensor metadata.
190///
191/// Extracts the `__metadata__` key separately (if present).
192fn parse_header_json(json_bytes: &[u8], filename: &str) -> Result<ParsedHeader, FetchError> {
193    let raw: HashMap<String, serde_json::Value> =
194        serde_json::from_slice(json_bytes).map_err(|e| FetchError::SafetensorsHeader {
195            filename: filename.to_owned(),
196            reason: format!("failed to parse header JSON: {e}"),
197        })?;
198
199    let mut metadata: Option<HashMap<String, String>> = None;
200    let mut tensors = Vec::new();
201
202    for (key, value) in raw {
203        if key == "__metadata__" {
204            if let serde_json::Value::Object(obj) = value {
205                let mut meta_map = HashMap::new();
206                for (mk, mv) in obj {
207                    // BORROW: explicit .to_owned()/.to_string() for Value → String conversion
208                    let v_str = if let Some(s) = mv.as_str() {
209                        s.to_owned()
210                    } else {
211                        mv.to_string()
212                    };
213                    meta_map.insert(mk, v_str);
214                }
215                metadata = Some(meta_map);
216            }
217            continue;
218        }
219
220        let entry: RawTensorEntry =
221            serde_json::from_value(value).map_err(|e| FetchError::SafetensorsHeader {
222                filename: filename.to_owned(),
223                reason: format!("failed to parse tensor \"{key}\": {e}"),
224            })?;
225
226        tensors.push(TensorInfo {
227            name: key,
228            dtype: entry.dtype,
229            shape: entry.shape,
230            data_offsets: entry.data_offsets,
231        });
232    }
233
234    // Sort by data offset start to preserve file order.
235    tensors.sort_by_key(|t| t.data_offsets.0);
236
237    Ok((tensors, metadata))
238}
239
240// -----------------------------------------------------------------------
241// Cache resolution
242// -----------------------------------------------------------------------
243
244/// Resolves a cached file path for a given repo, revision, and filename.
245///
246/// Returns `None` if the file is not in the local cache.
247fn resolve_cached_path(repo_id: &str, revision: &str, filename: &str) -> Option<PathBuf> {
248    let cache_dir = cache::hf_cache_dir().ok()?;
249    let repo_dir = cache_layout::repo_dir(&cache_dir, repo_id);
250    let commit_hash = cache::read_ref(&repo_dir, revision)?;
251    let cached_path = cache_layout::pointer_path(&repo_dir, &commit_hash, filename);
252    if cached_path.exists() {
253        Some(cached_path)
254    } else {
255        None
256    }
257}
258
259// -----------------------------------------------------------------------
260// Local file reading
261// -----------------------------------------------------------------------
262
263/// Inspects a single `.safetensors` file's header from a local file path.
264///
265/// Reads the first `8 + header_size` bytes from disk. Does not read tensor data.
266///
267/// # Blocking I/O
268///
269/// This function performs synchronous filesystem I/O. In async contexts, wrap
270/// it in [`tokio::task::spawn_blocking`] so the calling task does not stall
271/// the runtime — particularly important on network-mounted caches (NFS/CIFS)
272/// where `read`/`stat` calls can take tens of milliseconds each.
273///
274/// # Errors
275///
276/// Returns [`FetchError::Io`] if the file cannot be read.
277/// Returns [`FetchError::SafetensorsHeader`] if the header is malformed.
278pub fn inspect_safetensors_local(path: &Path) -> Result<SafetensorsHeaderInfo, FetchError> {
279    use std::io::Read;
280
281    let file_size = std::fs::metadata(path)
282        .map_err(|e| FetchError::Io {
283            path: path.to_path_buf(),
284            source: e,
285        })?
286        .len();
287
288    // BORROW: explicit .to_string_lossy() for Path → str conversion
289    let filename = path.file_name().map_or_else(
290        || path.display().to_string(),
291        |n| n.to_string_lossy().to_string(),
292    );
293
294    let mut file = std::fs::File::open(path).map_err(|e| FetchError::Io {
295        path: path.to_path_buf(),
296        source: e,
297    })?;
298
299    // Read 8-byte header length prefix (little-endian u64).
300    let mut len_buf = [0u8; 8];
301    file.read_exact(&mut len_buf).map_err(|e| FetchError::Io {
302        path: path.to_path_buf(),
303        source: e,
304    })?;
305    let header_size = u64::from_le_bytes(len_buf);
306
307    // Sanity check: header cannot be larger than the file.
308    if header_size.saturating_add(8) > file_size {
309        return Err(FetchError::SafetensorsHeader {
310            filename,
311            reason: format!("header length {header_size} exceeds file size {file_size}"),
312        });
313    }
314
315    // Read the JSON header.
316    // CAST: u64 → usize, header size bounded by file size (checked above)
317    #[allow(clippy::cast_possible_truncation, clippy::as_conversions)]
318    let json_len = header_size as usize;
319    let mut json_buf = vec![0u8; json_len];
320    file.read_exact(&mut json_buf).map_err(|e| FetchError::Io {
321        path: path.to_path_buf(),
322        source: e,
323    })?;
324
325    // BORROW: explicit .as_str() instead of Deref coercion
326    let (tensors, metadata) = parse_header_json(&json_buf, filename.as_str())?;
327
328    Ok(SafetensorsHeaderInfo {
329        tensors,
330        metadata,
331        header_size,
332        file_size: Some(file_size),
333    })
334}
335
336// -----------------------------------------------------------------------
337// Remote fetching (HTTP Range requests)
338// -----------------------------------------------------------------------
339
340/// Fetches safetensors header bytes via two HTTP Range requests.
341///
342/// 1. `Range: bytes=0-7` → 8-byte header length (little-endian `u64`)
343/// 2. `Range: bytes=8-{8+length-1}` → JSON header
344///
345/// Returns `(json_bytes, total_file_size)`. The file size is extracted from
346/// the `Content-Range` header of the first request.
347async fn fetch_header_bytes(
348    client: &reqwest::Client,
349    url: &str,
350    filename: &str,
351) -> Result<(Vec<u8>, Option<u64>), FetchError> {
352    // Request 1: 8-byte length prefix.
353    let resp1 = client
354        .get(url)
355        .header(reqwest::header::RANGE, "bytes=0-7")
356        .send()
357        .await
358        .map_err(|e| {
359            FetchError::Http(format!("failed to fetch header length for {filename}: {e}"))
360        })?;
361
362    if !resp1.status().is_success() && resp1.status() != reqwest::StatusCode::PARTIAL_CONTENT {
363        return Err(FetchError::Http(format!(
364            "Range request for {filename} returned status {}",
365            resp1.status()
366        )));
367    }
368
369    // Extract total file size from Content-Range: bytes 0-7/{total}
370    let file_size = resp1
371        .headers()
372        .get(reqwest::header::CONTENT_RANGE)
373        .and_then(|v| v.to_str().ok())
374        .and_then(|s| s.split('/').next_back())
375        .and_then(|s| s.parse::<u64>().ok());
376
377    let len_bytes = resp1.bytes().await.map_err(|e| {
378        FetchError::Http(format!("failed to read header length for {filename}: {e}"))
379    })?;
380
381    if len_bytes.len() < 8 {
382        return Err(FetchError::SafetensorsHeader {
383            filename: filename.to_owned(),
384            reason: format!(
385                "expected 8 bytes for length prefix, got {}",
386                len_bytes.len()
387            ),
388        });
389    }
390
391    // INDEX: first 8 bytes guaranteed by length check above
392    #[allow(clippy::indexing_slicing)]
393    let header_size = u64::from_le_bytes([
394        len_bytes[0],
395        len_bytes[1],
396        len_bytes[2],
397        len_bytes[3],
398        len_bytes[4],
399        len_bytes[5],
400        len_bytes[6],
401        len_bytes[7],
402    ]);
403
404    // Request 2: JSON header.
405    let range_end = 8u64.saturating_add(header_size).saturating_sub(1);
406    let range_header = format!("bytes=8-{range_end}");
407    let resp2 = client
408        .get(url)
409        // BORROW: explicit .as_str() instead of Deref coercion
410        .header(reqwest::header::RANGE, range_header.as_str())
411        .send()
412        .await
413        .map_err(|e| {
414            FetchError::Http(format!("failed to fetch header JSON for {filename}: {e}"))
415        })?;
416
417    if !resp2.status().is_success() && resp2.status() != reqwest::StatusCode::PARTIAL_CONTENT {
418        return Err(FetchError::Http(format!(
419            "Range request for {filename} header JSON returned status {}",
420            resp2.status()
421        )));
422    }
423
424    let json_bytes = resp2
425        .bytes()
426        .await
427        .map_err(|e| FetchError::Http(format!("failed to read header JSON for {filename}: {e}")))?;
428
429    Ok((json_bytes.to_vec(), file_size))
430}
431
432// -----------------------------------------------------------------------
433// Public API: single-file inspection
434// -----------------------------------------------------------------------
435
436/// Inspects a single `.safetensors` file's header (cache-first).
437///
438/// Checks the local HF cache first. If the file is cached, reads the header
439/// from disk with zero network requests. Otherwise, falls back to two HTTP
440/// Range requests (8-byte length prefix + JSON header). Does not download
441/// tensor data in either case.
442///
443/// # Errors
444///
445/// Returns [`FetchError::Http`] if the Range requests fail.
446/// Returns [`FetchError::SafetensorsHeader`] if the header is malformed.
447pub async fn inspect_safetensors(
448    repo_id: &str,
449    filename: &str,
450    token: Option<&str>,
451    revision: Option<&str>,
452) -> Result<(SafetensorsHeaderInfo, InspectSource), FetchError> {
453    let rev = revision.unwrap_or("main");
454
455    // Try local cache first.
456    if let Some(cached_path) = resolve_cached_path(repo_id, rev, filename) {
457        let info = inspect_safetensors_local(&cached_path)?;
458        return Ok((info, InspectSource::Cached));
459    }
460
461    // Fall back to HTTP Range requests.
462    let client = chunked::build_client(token)?;
463    let url = chunked::build_download_url(repo_id, rev, filename);
464
465    // BORROW: explicit .as_str() instead of Deref coercion
466    let (json_bytes, file_size) = fetch_header_bytes(&client, url.as_str(), filename).await?;
467
468    // CAST: usize → u64, JSON buffer length is always small
469    #[allow(clippy::as_conversions)]
470    let header_size = json_bytes.len() as u64;
471
472    let (tensors, metadata) = parse_header_json(&json_bytes, filename)?;
473
474    Ok((
475        SafetensorsHeaderInfo {
476            tensors,
477            metadata,
478            header_size,
479            file_size,
480        },
481        InspectSource::Remote,
482    ))
483}
484
485/// Inspects a single `.safetensors` file from cache only.
486///
487/// Resolves the file in the local HF cache using the given `repo_id`,
488/// `revision`, and `filename`. Returns an error if the file is not cached.
489///
490/// # Blocking I/O
491///
492/// Performs synchronous filesystem I/O; wrap in [`tokio::task::spawn_blocking`]
493/// from async contexts. See [`inspect_safetensors_local`] for rationale.
494///
495/// # Errors
496///
497/// Returns [`FetchError::SafetensorsHeader`] if the file is not in the cache.
498/// Returns [`FetchError::Io`] if the cached file cannot be read.
499/// Returns [`FetchError::SafetensorsHeader`] if the header is malformed.
500pub fn inspect_safetensors_cached(
501    repo_id: &str,
502    filename: &str,
503    revision: Option<&str>,
504) -> Result<SafetensorsHeaderInfo, FetchError> {
505    let rev = revision.unwrap_or("main");
506
507    let cached_path = resolve_cached_path(repo_id, rev, filename).ok_or_else(|| {
508        FetchError::SafetensorsHeader {
509            filename: filename.to_owned(),
510            reason: format!("file not found in local cache for {repo_id} ({rev})"),
511        }
512    })?;
513
514    inspect_safetensors_local(&cached_path)
515}
516
517/// Inspects a `.gguf` file's metadata from the local `HuggingFace` cache.
518///
519/// Delegates to [`anamnesis::parse_gguf`] for the on-disk parse, then maps the
520/// result into the format-agnostic [`SafetensorsHeaderInfo`] shape used by
521/// hf-fm's existing render path. Tensor names, GGUF-native shape order, and
522/// dtype name strings carry over directly; per-tensor `data_offsets` are
523/// `(data_offset, data_offset + byte_len)` (with `byte_len = 0` for tensors
524/// whose dtype has no known byte size in anamnesis yet).
525///
526/// **Naming note:** the returned type is still called [`SafetensorsHeaderInfo`]
527/// in v0.10.x because renaming a public type is a breaking change; the
528/// uniform-dispatch rename to a format-agnostic name is scheduled for v0.10.3
529/// when the dispatcher extends across `.npz` / `.pth` (see the cache-management
530/// roadmap). For now, treat the type name as "header / file-level inspect
531/// info" regardless of format.
532///
533/// **Metadata surfacing:** the GGUF metadata table can contain very large
534/// arrays (e.g. tokenizer.ggml.tokens with 50K+ entries). To keep `Metadata:`
535/// rendering useful, this function surfaces *scalar* metadata values only —
536/// strings, booleans, integers, floats — and skips arrays. The GGUF format
537/// version is surfaced under the synthetic key `gguf.version`, the effective
538/// alignment under `gguf.alignment`. The original `general.architecture`,
539/// `general.name`, and friends pass through unchanged.
540///
541/// **Blocking I/O:** anamnesis's GGUF parser mmaps the file; this function is
542/// synchronous and should be wrapped in [`tokio::task::spawn_blocking`] from
543/// async contexts.
544///
545/// # Errors
546///
547/// Returns [`FetchError::SafetensorsHeader`] if the file is not in the cache.
548/// Returns [`FetchError::SafetensorsHeader`] if anamnesis rejects the GGUF
549/// file (malformed header, truncated tensor table, etc.).
550pub fn inspect_gguf_cached(
551    repo_id: &str,
552    filename: &str,
553    revision: Option<&str>,
554) -> Result<SafetensorsHeaderInfo, FetchError> {
555    let rev = revision.unwrap_or("main");
556
557    let cached_path = resolve_cached_path(repo_id, rev, filename).ok_or_else(|| {
558        FetchError::SafetensorsHeader {
559            // BORROW: explicit .to_owned() for owned String in the error variant
560            filename: filename.to_owned(),
561            reason: format!("file not found in local cache for {repo_id} ({rev})"),
562        }
563    })?;
564
565    let file_size = std::fs::metadata(&cached_path).ok().map(|m| m.len());
566
567    let parsed =
568        anamnesis::parse_gguf(&cached_path).map_err(|e| FetchError::SafetensorsHeader {
569            // BORROW: explicit .to_owned() for owned String in the error variant
570            filename: filename.to_owned(),
571            reason: format!("failed to parse GGUF: {e}"),
572        })?;
573
574    let tensors: Vec<TensorInfo> = parsed
575        .tensor_info()
576        .iter()
577        .map(|info| {
578            let start = info.data_offset;
579            let end = info.byte_len.map_or(start, |b| start.saturating_add(b));
580            TensorInfo {
581                // BORROW: explicit .clone() / .to_string() to materialise owned
582                // String + Vec<usize> from anamnesis's borrowed metadata
583                name: info.name.clone(),
584                dtype: info.dtype.to_string(),
585                shape: info.shape.clone(),
586                data_offsets: (start, end),
587            }
588        })
589        .collect();
590
591    // Stringify scalar metadata only; skip arrays (potentially huge — e.g.
592    // tokenizer.ggml.tokens). Add synthetic keys for the format version and
593    // alignment so they appear in the `Metadata:` block.
594    let mut metadata: HashMap<String, String> = parsed
595        .metadata()
596        .iter()
597        // BORROW: explicit .clone() to materialise an owned String key from
598        // the borrowed HashMap iteration
599        .filter_map(|(k, v)| stringify_gguf_scalar(v).map(|s| (k.clone(), s)))
600        .collect();
601    // BORROW: explicit .to_owned() for owned String keys
602    metadata.insert("gguf.version".to_owned(), parsed.version().to_string());
603    metadata.insert("gguf.alignment".to_owned(), parsed.alignment().to_string());
604
605    Ok(SafetensorsHeaderInfo {
606        tensors,
607        metadata: Some(metadata),
608        // GGUF has no discrete "header size" like safetensors's
609        // u64-length-prefix + JSON. The value is left at 0 here; consumers
610        // that care can derive an approximation from `file_size` minus the
611        // tensor byte sum. The `Metadata:` block's `gguf.version` /
612        // `gguf.alignment` keys surface the equivalent format-level info.
613        header_size: 0,
614        file_size,
615    })
616}
617
618/// Stringifies a scalar `GgufMetadataValue` from anamnesis.
619///
620/// Returns `None` for array variants (potentially huge — vocab tables, merges
621/// lists) and for any future `#[non_exhaustive]` variants we don't yet
622/// recognise. Surfaced through the `Metadata:` block in `inspect` output by
623/// [`inspect_gguf_cached`].
624//
625// `GgufMetadataValue` is `#[non_exhaustive]`. The explicit `V::Array(_)` arm
626// and the `_ =>` catch-all both return `None`, but they document different
627// intents — "array variants are deliberately skipped" vs "future unknown
628// variants fall through". Clippy's `match_same_arms` flags the bodies as
629// identical; the duplication is intentional.
630#[allow(clippy::match_same_arms)]
631fn stringify_gguf_scalar(value: &anamnesis::parse::gguf::GgufMetadataValue) -> Option<String> {
632    use anamnesis::parse::gguf::GgufMetadataValue as V;
633    match value {
634        V::String(s) => Some(s.clone()),
635        V::Bool(b) => Some(b.to_string()),
636        V::U8(n) => Some(n.to_string()),
637        V::I8(n) => Some(n.to_string()),
638        V::U16(n) => Some(n.to_string()),
639        V::I16(n) => Some(n.to_string()),
640        V::U32(n) => Some(n.to_string()),
641        V::I32(n) => Some(n.to_string()),
642        V::U64(n) => Some(n.to_string()),
643        V::I64(n) => Some(n.to_string()),
644        V::F32(n) => Some(format!("{n}")),
645        V::F64(n) => Some(format!("{n}")),
646        V::Array(_) => None,
647        _ => None,
648    }
649}
650
651// -----------------------------------------------------------------------
652// Public API: multi-file inspection
653// -----------------------------------------------------------------------
654
655/// Inspects all `.safetensors` files in a repository (cache-first per file).
656///
657/// Fetches the file listing via `list_repo_files_with_metadata()`, then
658/// inspects each `.safetensors` file's header via [`inspect_safetensors()`].
659/// For each file, checks the local cache first and only makes HTTP Range
660/// requests on cache miss. Returns full per-shard headers in filename order.
661///
662/// For a lightweight summary of sharded models (tensor counts per shard
663/// without fetching individual headers), use [`fetch_shard_index()`] instead.
664///
665/// # Errors
666///
667/// Returns [`FetchError::Http`] if the metadata or Range requests fail.
668pub async fn inspect_repo_safetensors(
669    repo_id: &str,
670    token: Option<&str>,
671    revision: Option<&str>,
672) -> Result<Vec<(String, SafetensorsHeaderInfo, InspectSource)>, FetchError> {
673    let client = crate::chunked::build_client(token)?;
674    let files =
675        crate::repo::list_repo_files_with_metadata(repo_id, token, revision, &client).await?;
676
677    let safetensors_files: Vec<String> = files
678        .into_iter()
679        .filter(|f| f.filename.ends_with(".safetensors"))
680        .map(|f| f.filename)
681        .collect();
682
683    if safetensors_files.is_empty() {
684        return Ok(Vec::new());
685    }
686
687    let semaphore = std::sync::Arc::new(tokio::sync::Semaphore::new(4));
688    let mut join_set = JoinSet::new();
689
690    for filename in safetensors_files {
691        // BORROW: explicit .clone()/.to_owned() to move into async task
692        let sem = semaphore.clone();
693        let repo = repo_id.to_owned();
694        let tok = token.map(str::to_owned);
695        let rev = revision.map(str::to_owned);
696
697        join_set.spawn(async move {
698            let _permit = sem
699                .acquire()
700                .await
701                .map_err(|e| FetchError::Http(format!("semaphore error: {e}")))?;
702            // BORROW: explicit .as_deref() for Option<String> → Option<&str>
703            let (info, source) =
704                inspect_safetensors(&repo, &filename, tok.as_deref(), rev.as_deref()).await?;
705            Ok::<_, FetchError>((filename, info, source))
706        });
707    }
708
709    let mut results = Vec::new();
710    while let Some(join_result) = join_set.join_next().await {
711        match join_result {
712            Ok(Ok(item)) => results.push(item),
713            Ok(Err(e)) => {
714                join_set.abort_all();
715                return Err(e);
716            }
717            Err(e) => {
718                join_set.abort_all();
719                return Err(FetchError::Http(format!("task join error: {e}")));
720            }
721        }
722    }
723
724    results.sort_by(|a, b| a.0.cmp(&b.0));
725
726    Ok(results)
727}
728
729/// A `(filename, size_bytes)` enumeration of safetensors files in a repo,
730/// paired with the commit SHA of the resolved revision (when known).
731///
732/// The same tuple shape serves both local and remote listings: [`list_cached_safetensors`]
733/// produces it from a cached snapshot; `repo::list_repo_files_with_commit` filtered to
734/// `*.safetensors` produces it from the `HuggingFace` API. Callers that need a uniform
735/// view over "what safetensors can I inspect?" regardless of source use this alias.
736pub type SafetensorsListing = (Vec<(String, u64)>, Option<String>);
737
738/// Lists `.safetensors` files in the cached snapshot for `repo_id`@`revision`.
739///
740/// Returns `(entries, commit_sha)` where `entries` is a sorted list of
741/// `(filename, size_bytes)` tuples, and `commit_sha` is the snapshot's commit
742/// hash (same value stored in `refs/<revision>`). Returns empty lists when the
743/// repo or revision is not cached. Unlike [`inspect_repo_safetensors_cached`],
744/// this does **not** parse any headers — it is a cheap name-and-size enumeration
745/// intended for discovery UI (e.g. `inspect --list --cached`).
746///
747/// # Blocking I/O
748///
749/// Performs a synchronous recursive directory walk with a `stat` call per
750/// `.safetensors` entry. On local SSDs the cost is sub-millisecond; on
751/// networked caches (NFS/CIFS) a large sharded repo can take seconds. Wrap
752/// in [`tokio::task::spawn_blocking`] from async contexts.
753///
754/// # Errors
755///
756/// Returns [`FetchError::Io`] if the snapshot directory cannot be read.
757pub fn list_cached_safetensors(
758    repo_id: &str,
759    revision: Option<&str>,
760) -> Result<SafetensorsListing, FetchError> {
761    let rev = revision.unwrap_or("main");
762    let cache_dir = cache::hf_cache_dir()?;
763    let repo_dir = cache_layout::repo_dir(&cache_dir, repo_id);
764
765    let Some(commit_hash) = cache::read_ref(&repo_dir, rev) else {
766        return Ok((Vec::new(), None));
767    };
768
769    let snapshot_dir = cache_layout::snapshot_dir(&repo_dir, &commit_hash);
770    if !snapshot_dir.exists() {
771        return Ok((Vec::new(), Some(commit_hash)));
772    }
773
774    let mut results = Vec::new();
775    collect_safetensors_names_sizes(&snapshot_dir, "", &mut results)?;
776    results.sort_by(|a, b| a.0.cmp(&b.0));
777    Ok((results, Some(commit_hash)))
778}
779
780/// Recursively collects `(filename, size)` pairs for `.safetensors` files.
781fn collect_safetensors_names_sizes(
782    dir: &Path,
783    prefix: &str,
784    results: &mut Vec<(String, u64)>,
785) -> Result<(), FetchError> {
786    let entries = std::fs::read_dir(dir).map_err(|e| FetchError::Io {
787        path: dir.to_path_buf(),
788        source: e,
789    })?;
790
791    for entry in entries {
792        let Ok(entry) = entry else { continue };
793        let path = entry.path();
794        // BORROW: explicit .to_string_lossy() for OsString → str conversion
795        let name = entry.file_name().to_string_lossy().to_string();
796
797        if path.is_dir() {
798            let child_prefix = if prefix.is_empty() {
799                name
800            } else {
801                format!("{prefix}/{name}")
802            };
803            collect_safetensors_names_sizes(&path, &child_prefix, results)?;
804        } else if name.ends_with(".safetensors") {
805            let filename = if prefix.is_empty() {
806                name
807            } else {
808                format!("{prefix}/{name}")
809            };
810            let size = entry.metadata().map_or(0, |m| m.len());
811            results.push((filename, size));
812        }
813    }
814
815    Ok(())
816}
817
818/// Inspects all `.safetensors` files in a cached repository (no network).
819///
820/// Walks the snapshot directory and inspects each `.safetensors` file's
821/// header from local disk. Returns results in filename order.
822///
823/// # Blocking I/O
824///
825/// Walks the snapshot directory and reads each header synchronously. In async
826/// contexts, wrap in [`tokio::task::spawn_blocking`] to avoid stalling the
827/// runtime — multi-shard repos on network-mounted caches can take seconds.
828///
829/// # Errors
830///
831/// Returns [`FetchError::Io`] if the cache directory cannot be read.
832/// Returns [`FetchError::SafetensorsHeader`] if any header is malformed.
833pub fn inspect_repo_safetensors_cached(
834    repo_id: &str,
835    revision: Option<&str>,
836) -> Result<Vec<(String, SafetensorsHeaderInfo)>, FetchError> {
837    let rev = revision.unwrap_or("main");
838    let cache_dir = cache::hf_cache_dir()?;
839    let repo_dir = cache_layout::repo_dir(&cache_dir, repo_id);
840
841    let Some(commit_hash) = cache::read_ref(&repo_dir, rev) else {
842        return Ok(Vec::new());
843    };
844
845    let snapshot_dir = cache_layout::snapshot_dir(&repo_dir, &commit_hash);
846    if !snapshot_dir.exists() {
847        return Ok(Vec::new());
848    }
849
850    let mut results = Vec::new();
851    collect_safetensors_recursive(&snapshot_dir, "", &mut results)?;
852    results.sort_by(|a, b| a.0.cmp(&b.0));
853
854    Ok(results)
855}
856
857/// Recursively finds and inspects `.safetensors` files in a snapshot directory.
858fn collect_safetensors_recursive(
859    dir: &Path,
860    prefix: &str,
861    results: &mut Vec<(String, SafetensorsHeaderInfo)>,
862) -> Result<(), FetchError> {
863    let entries = std::fs::read_dir(dir).map_err(|e| FetchError::Io {
864        path: dir.to_path_buf(),
865        source: e,
866    })?;
867
868    for entry in entries {
869        let Ok(entry) = entry else { continue };
870        let path = entry.path();
871        // BORROW: explicit .to_string_lossy() for OsString → str conversion
872        let name = entry.file_name().to_string_lossy().to_string();
873
874        if path.is_dir() {
875            let child_prefix = if prefix.is_empty() {
876                name
877            } else {
878                format!("{prefix}/{name}")
879            };
880            collect_safetensors_recursive(&path, &child_prefix, results)?;
881        } else if name.ends_with(".safetensors") {
882            let filename = if prefix.is_empty() {
883                name
884            } else {
885                format!("{prefix}/{name}")
886            };
887            let info = inspect_safetensors_local(&path)?;
888            results.push((filename, info));
889        }
890    }
891
892    Ok(())
893}
894
895// -----------------------------------------------------------------------
896// Shard index
897// -----------------------------------------------------------------------
898
899/// Raw JSON structure of `model.safetensors.index.json`.
900#[derive(serde::Deserialize)]
901struct RawShardIndex {
902    weight_map: HashMap<String, String>,
903    #[serde(default)]
904    metadata: Option<HashMap<String, serde_json::Value>>,
905}
906
907/// Fetches and parses the shard index for a sharded `.safetensors` model (cache-first).
908///
909/// Returns `Ok(None)` if the repo has no `model.safetensors.index.json` (i.e.,
910/// the model is not sharded or uses a single `.safetensors` file).
911///
912/// # Errors
913///
914/// Returns [`FetchError::Http`] if the index fetch fails.
915/// Returns [`FetchError::SafetensorsHeader`] if the index JSON is malformed.
916pub async fn fetch_shard_index(
917    repo_id: &str,
918    token: Option<&str>,
919    revision: Option<&str>,
920) -> Result<Option<ShardedIndex>, FetchError> {
921    let rev = revision.unwrap_or("main");
922    let index_filename = "model.safetensors.index.json";
923
924    // Try local cache first.
925    if let Some(cached_path) = resolve_cached_path(repo_id, rev, index_filename) {
926        let content = std::fs::read_to_string(&cached_path).map_err(|e| FetchError::Io {
927            path: cached_path,
928            source: e,
929        })?;
930        let index = parse_shard_index_json(&content, repo_id)?;
931        return Ok(Some(index));
932    }
933
934    // Fall back to HTTP.
935    let client = chunked::build_client(token)?;
936    let url = chunked::build_download_url(repo_id, rev, index_filename);
937
938    // BORROW: explicit .as_str() instead of Deref coercion
939    let response =
940        client.get(url.as_str()).send().await.map_err(|e| {
941            FetchError::Http(format!("failed to fetch shard index for {repo_id}: {e}"))
942        })?;
943
944    if response.status() == reqwest::StatusCode::NOT_FOUND {
945        return Ok(None);
946    }
947
948    if !response.status().is_success() {
949        return Err(FetchError::Http(format!(
950            "shard index request for {repo_id} returned status {}",
951            response.status()
952        )));
953    }
954
955    let content = response
956        .text()
957        .await
958        .map_err(|e| FetchError::Http(format!("failed to read shard index for {repo_id}: {e}")))?;
959
960    let index = parse_shard_index_json(&content, repo_id)?;
961    Ok(Some(index))
962}
963
964/// Fetches the shard index from cache only (no network).
965///
966/// Returns `Ok(None)` if the index file is not cached.
967///
968/// # Errors
969///
970/// Returns [`FetchError::Io`] if the cached file cannot be read.
971/// Returns [`FetchError::SafetensorsHeader`] if the index JSON is malformed.
972pub fn fetch_shard_index_cached(
973    repo_id: &str,
974    revision: Option<&str>,
975) -> Result<Option<ShardedIndex>, FetchError> {
976    let rev = revision.unwrap_or("main");
977    let index_filename = "model.safetensors.index.json";
978
979    let Some(cached_path) = resolve_cached_path(repo_id, rev, index_filename) else {
980        return Ok(None);
981    };
982
983    let content = std::fs::read_to_string(&cached_path).map_err(|e| FetchError::Io {
984        path: cached_path,
985        source: e,
986    })?;
987
988    let index = parse_shard_index_json(&content, repo_id)?;
989    Ok(Some(index))
990}
991
992/// Parses shard index JSON into a `ShardedIndex`.
993fn parse_shard_index_json(content: &str, repo_id: &str) -> Result<ShardedIndex, FetchError> {
994    let raw: RawShardIndex =
995        serde_json::from_str(content).map_err(|e| FetchError::SafetensorsHeader {
996            filename: "model.safetensors.index.json".to_owned(),
997            reason: format!("failed to parse shard index for {repo_id}: {e}"),
998        })?;
999
1000    // Collect unique shard filenames in sorted order.
1001    let mut shard_set: Vec<String> = raw.weight_map.values().cloned().collect();
1002    shard_set.sort();
1003    shard_set.dedup();
1004
1005    Ok(ShardedIndex {
1006        weight_map: raw.weight_map,
1007        shards: shard_set,
1008        metadata: raw.metadata,
1009    })
1010}
1011
1012// -----------------------------------------------------------------------
1013// Param formatting helper
1014// -----------------------------------------------------------------------
1015
1016/// Formats a parameter count with a compact suffix (e.g., `927.0M`, `1.02B`).
1017#[must_use]
1018pub fn format_params(count: u64) -> String {
1019    // CAST: u64 → f64, precision loss acceptable; value is a display-only scalar
1020    #[allow(clippy::cast_precision_loss, clippy::as_conversions)]
1021    let val = count as f64;
1022
1023    if count >= 1_000_000_000 {
1024        format!("{:.2}B", val / 1_000_000_000.0)
1025    } else if count >= 1_000_000 {
1026        format!("{:.1}M", val / 1_000_000.0)
1027    } else if count >= 1_000 {
1028        format!("{:.1}K", val / 1_000.0)
1029    } else {
1030        count.to_string()
1031    }
1032}
1033
1034// -----------------------------------------------------------------------
1035// Adapter config
1036// -----------------------------------------------------------------------
1037
1038/// Raw JSON structure of `adapter_config.json`.
1039#[derive(serde::Deserialize)]
1040struct RawAdapterConfig {
1041    #[serde(default)]
1042    peft_type: Option<String>,
1043    #[serde(default)]
1044    base_model_name_or_path: Option<String>,
1045    #[serde(default)]
1046    r: Option<u32>,
1047    #[serde(default)]
1048    lora_alpha: Option<f64>,
1049    #[serde(default)]
1050    target_modules: Option<AdapterTargetModules>,
1051    #[serde(default)]
1052    task_type: Option<String>,
1053}
1054
1055/// `target_modules` in adapter configs can be a list of strings or a single string.
1056#[derive(serde::Deserialize)]
1057#[serde(untagged)]
1058enum AdapterTargetModules {
1059    /// A list of module name strings.
1060    List(Vec<String>),
1061    /// A single module name string.
1062    Single(String),
1063}
1064
1065/// Fetches and parses `adapter_config.json` for a `PEFT` adapter repository (cache-first).
1066///
1067/// Returns `Ok(None)` if the file does not exist (HTTP 404), meaning the
1068/// repository is not a `PEFT` adapter.
1069///
1070/// # Errors
1071///
1072/// Returns [`FetchError::Http`] if the request fails (other than 404).
1073/// Returns [`FetchError::SafetensorsHeader`] if the JSON is malformed.
1074pub async fn fetch_adapter_config(
1075    repo_id: &str,
1076    token: Option<&str>,
1077    revision: Option<&str>,
1078) -> Result<Option<AdapterConfig>, FetchError> {
1079    let rev = revision.unwrap_or("main");
1080    let config_filename = "adapter_config.json";
1081
1082    // Try local cache first.
1083    if let Some(cached_path) = resolve_cached_path(repo_id, rev, config_filename) {
1084        let content = std::fs::read_to_string(&cached_path).map_err(|e| FetchError::Io {
1085            path: cached_path,
1086            source: e,
1087        })?;
1088        let config = parse_adapter_config_json(&content, repo_id)?;
1089        return Ok(Some(config));
1090    }
1091
1092    // Fall back to HTTP.
1093    let client = chunked::build_client(token)?;
1094    let url = chunked::build_download_url(repo_id, rev, config_filename);
1095
1096    // BORROW: explicit .as_str() instead of Deref coercion
1097    let response = client.get(url.as_str()).send().await.map_err(|e| {
1098        FetchError::Http(format!("failed to fetch adapter config for {repo_id}: {e}"))
1099    })?;
1100
1101    if response.status() == reqwest::StatusCode::NOT_FOUND {
1102        return Ok(None);
1103    }
1104
1105    if !response.status().is_success() {
1106        return Err(FetchError::Http(format!(
1107            "adapter config request for {repo_id} returned status {}",
1108            response.status()
1109        )));
1110    }
1111
1112    let content = response.text().await.map_err(|e| {
1113        FetchError::Http(format!("failed to read adapter config for {repo_id}: {e}"))
1114    })?;
1115
1116    let config = parse_adapter_config_json(&content, repo_id)?;
1117    Ok(Some(config))
1118}
1119
1120/// Fetches the adapter config from cache only (no network).
1121///
1122/// Returns `Ok(None)` if the file is not cached.
1123///
1124/// # Errors
1125///
1126/// Returns [`FetchError::Io`] if the cached file cannot be read.
1127/// Returns [`FetchError::SafetensorsHeader`] if the JSON is malformed.
1128pub fn fetch_adapter_config_cached(
1129    repo_id: &str,
1130    revision: Option<&str>,
1131) -> Result<Option<AdapterConfig>, FetchError> {
1132    let rev = revision.unwrap_or("main");
1133    let config_filename = "adapter_config.json";
1134
1135    let Some(cached_path) = resolve_cached_path(repo_id, rev, config_filename) else {
1136        return Ok(None);
1137    };
1138
1139    let content = std::fs::read_to_string(&cached_path).map_err(|e| FetchError::Io {
1140        path: cached_path,
1141        source: e,
1142    })?;
1143
1144    let config = parse_adapter_config_json(&content, repo_id)?;
1145    Ok(Some(config))
1146}
1147
1148/// Parses adapter config JSON into an [`AdapterConfig`].
1149fn parse_adapter_config_json(content: &str, repo_id: &str) -> Result<AdapterConfig, FetchError> {
1150    let raw: RawAdapterConfig =
1151        serde_json::from_str(content).map_err(|e| FetchError::SafetensorsHeader {
1152            filename: "adapter_config.json".to_owned(),
1153            reason: format!("failed to parse adapter config for {repo_id}: {e}"),
1154        })?;
1155
1156    let target_modules = match raw.target_modules {
1157        Some(AdapterTargetModules::List(v)) => v,
1158        Some(AdapterTargetModules::Single(s)) => vec![s],
1159        None => Vec::new(),
1160    };
1161
1162    Ok(AdapterConfig {
1163        peft_type: raw.peft_type,
1164        base_model_name_or_path: raw.base_model_name_or_path,
1165        r: raw.r,
1166        lora_alpha: raw.lora_alpha,
1167        target_modules,
1168        task_type: raw.task_type,
1169    })
1170}