1use std::collections::HashMap;
14use std::path::{Path, PathBuf};
15
16use serde::Serialize;
17
18use crate::cache;
19use crate::chunked;
20use crate::error::FetchError;
21
22#[derive(Debug, Clone, Serialize)]
31pub struct TensorInfo {
32 pub name: String,
34 pub dtype: String,
36 pub shape: Vec<usize>,
38 pub data_offsets: (u64, u64),
40}
41
42impl TensorInfo {
43 #[must_use]
47 pub fn num_elements(&self) -> u64 {
48 self.shape.iter().fold(1u64, |acc, &d| {
49 #[allow(clippy::as_conversions)]
51 let dim = d as u64;
52 acc.saturating_mul(dim)
53 })
54 }
55
56 #[must_use]
58 pub const fn byte_len(&self) -> u64 {
59 self.data_offsets.1.saturating_sub(self.data_offsets.0)
60 }
61
62 #[must_use]
75 pub fn dtype_bytes(&self) -> Option<usize> {
76 match self.dtype.as_str() {
78 "BOOL" | "U8" | "I8" | "F8_E4M3" | "F8_E5M2" => Some(1),
79 "U16" | "I16" | "F16" | "BF16" => Some(2),
80 "U32" | "I32" | "F32" => Some(4),
81 "U64" | "I64" | "F64" => Some(8),
82 _ => None,
83 }
84 }
85}
86
87#[derive(Debug, Clone, Serialize)]
89pub struct SafetensorsHeaderInfo {
90 pub tensors: Vec<TensorInfo>,
92 pub metadata: Option<HashMap<String, String>>,
98 pub header_size: u64,
100 pub file_size: Option<u64>,
107}
108
109impl SafetensorsHeaderInfo {
110 #[must_use]
112 pub fn total_params(&self) -> u64 {
113 self.tensors
114 .iter()
115 .map(TensorInfo::num_elements)
116 .fold(0u64, u64::saturating_add)
117 }
118
119 #[must_use]
121 pub fn tensors_with_dtype(&self, dtype: &str) -> Vec<&TensorInfo> {
122 self.tensors
123 .iter()
124 .filter(|t| t.dtype.as_str() == dtype)
126 .collect()
127 }
128}
129
130#[derive(Debug, Clone, Copy, PartialEq, Eq)]
132pub enum InspectSource {
133 Cached,
135 Remote,
137}
138
139#[derive(Debug, Clone, Serialize)]
141pub struct ShardedIndex {
142 pub weight_map: HashMap<String, String>,
144 pub shards: Vec<String>,
146 pub metadata: Option<HashMap<String, serde_json::Value>>,
148}
149
150#[derive(Debug, Clone, Serialize)]
156pub struct AdapterConfig {
157 pub peft_type: Option<String>,
159 pub base_model_name_or_path: Option<String>,
161 pub r: Option<u32>,
163 pub lora_alpha: Option<f64>,
165 pub target_modules: Vec<String>,
167 pub task_type: Option<String>,
169}
170
171#[derive(serde::Deserialize)]
177struct RawTensorEntry {
178 dtype: String,
179 shape: Vec<usize>,
180 data_offsets: (u64, u64),
181}
182
183type ParsedHeader = (Vec<TensorInfo>, Option<HashMap<String, String>>);
185
186fn parse_header_json(json_bytes: &[u8], filename: &str) -> Result<ParsedHeader, FetchError> {
190 let raw: HashMap<String, serde_json::Value> =
191 serde_json::from_slice(json_bytes).map_err(|e| FetchError::SafetensorsHeader {
192 filename: filename.to_owned(),
193 reason: format!("failed to parse header JSON: {e}"),
194 })?;
195
196 let mut metadata: Option<HashMap<String, String>> = None;
197 let mut tensors = Vec::new();
198
199 for (key, value) in &raw {
200 if key == "__metadata__" {
201 if let Some(obj) = value.as_object() {
202 let mut meta_map = HashMap::new();
203 for (mk, mv) in obj {
204 let v_str = if let Some(s) = mv.as_str() {
206 s.to_owned()
207 } else {
208 mv.to_string()
209 };
210 meta_map.insert(mk.clone(), v_str);
212 }
213 metadata = Some(meta_map);
214 }
215 continue;
216 }
217
218 let entry: RawTensorEntry =
219 serde_json::from_value(value.clone()).map_err(|e| FetchError::SafetensorsHeader {
220 filename: filename.to_owned(),
221 reason: format!("failed to parse tensor \"{key}\": {e}"),
222 })?;
223
224 tensors.push(TensorInfo {
225 name: key.clone(),
227 dtype: entry.dtype,
228 shape: entry.shape,
229 data_offsets: entry.data_offsets,
230 });
231 }
232
233 tensors.sort_by_key(|t| t.data_offsets.0);
235
236 Ok((tensors, metadata))
237}
238
239fn resolve_cached_path(repo_id: &str, revision: &str, filename: &str) -> Option<PathBuf> {
247 let cache_dir = cache::hf_cache_dir().ok()?;
248 let repo_folder = chunked::repo_folder_name(repo_id);
249 let repo_dir = cache_dir.join(&repo_folder);
250 let commit_hash = cache::read_ref(&repo_dir, revision)?;
251 let cached_path = repo_dir.join("snapshots").join(commit_hash).join(filename);
252 if cached_path.exists() {
253 Some(cached_path)
254 } else {
255 None
256 }
257}
258
259pub fn inspect_safetensors_local(path: &Path) -> Result<SafetensorsHeaderInfo, FetchError> {
272 use std::io::Read;
273
274 let file_size = std::fs::metadata(path)
275 .map_err(|e| FetchError::Io {
276 path: path.to_path_buf(),
277 source: e,
278 })?
279 .len();
280
281 let filename = path.file_name().map_or_else(
283 || path.display().to_string(),
284 |n| n.to_string_lossy().to_string(),
285 );
286
287 let mut file = std::fs::File::open(path).map_err(|e| FetchError::Io {
288 path: path.to_path_buf(),
289 source: e,
290 })?;
291
292 let mut len_buf = [0u8; 8];
294 file.read_exact(&mut len_buf).map_err(|e| FetchError::Io {
295 path: path.to_path_buf(),
296 source: e,
297 })?;
298 let header_size = u64::from_le_bytes(len_buf);
299
300 if header_size.saturating_add(8) > file_size {
302 return Err(FetchError::SafetensorsHeader {
303 filename,
304 reason: format!("header length {header_size} exceeds file size {file_size}"),
305 });
306 }
307
308 #[allow(clippy::cast_possible_truncation, clippy::as_conversions)]
311 let json_len = header_size as usize;
312 let mut json_buf = vec![0u8; json_len];
313 file.read_exact(&mut json_buf).map_err(|e| FetchError::Io {
314 path: path.to_path_buf(),
315 source: e,
316 })?;
317
318 let (tensors, metadata) = parse_header_json(&json_buf, filename.as_str())?;
320
321 Ok(SafetensorsHeaderInfo {
322 tensors,
323 metadata,
324 header_size,
325 file_size: Some(file_size),
326 })
327}
328
329async fn fetch_header_bytes(
341 client: &reqwest::Client,
342 url: &str,
343 filename: &str,
344) -> Result<(Vec<u8>, Option<u64>), FetchError> {
345 let resp1 = client
347 .get(url)
348 .header(reqwest::header::RANGE, "bytes=0-7")
349 .send()
350 .await
351 .map_err(|e| {
352 FetchError::Http(format!("failed to fetch header length for {filename}: {e}"))
353 })?;
354
355 if !resp1.status().is_success() && resp1.status() != reqwest::StatusCode::PARTIAL_CONTENT {
356 return Err(FetchError::Http(format!(
357 "Range request for {filename} returned status {}",
358 resp1.status()
359 )));
360 }
361
362 let file_size = resp1
364 .headers()
365 .get(reqwest::header::CONTENT_RANGE)
366 .and_then(|v| v.to_str().ok())
367 .and_then(|s| s.split('/').next_back())
368 .and_then(|s| s.parse::<u64>().ok());
369
370 let len_bytes = resp1.bytes().await.map_err(|e| {
371 FetchError::Http(format!("failed to read header length for {filename}: {e}"))
372 })?;
373
374 if len_bytes.len() < 8 {
375 return Err(FetchError::SafetensorsHeader {
376 filename: filename.to_owned(),
377 reason: format!(
378 "expected 8 bytes for length prefix, got {}",
379 len_bytes.len()
380 ),
381 });
382 }
383
384 #[allow(clippy::indexing_slicing)]
386 let header_size = u64::from_le_bytes([
387 len_bytes[0],
388 len_bytes[1],
389 len_bytes[2],
390 len_bytes[3],
391 len_bytes[4],
392 len_bytes[5],
393 len_bytes[6],
394 len_bytes[7],
395 ]);
396
397 let range_end = 8u64.saturating_add(header_size).saturating_sub(1);
399 let range_header = format!("bytes=8-{range_end}");
400 let resp2 = client
401 .get(url)
402 .header(reqwest::header::RANGE, range_header.as_str())
404 .send()
405 .await
406 .map_err(|e| {
407 FetchError::Http(format!("failed to fetch header JSON for {filename}: {e}"))
408 })?;
409
410 if !resp2.status().is_success() && resp2.status() != reqwest::StatusCode::PARTIAL_CONTENT {
411 return Err(FetchError::Http(format!(
412 "Range request for {filename} header JSON returned status {}",
413 resp2.status()
414 )));
415 }
416
417 let json_bytes = resp2
418 .bytes()
419 .await
420 .map_err(|e| FetchError::Http(format!("failed to read header JSON for {filename}: {e}")))?;
421
422 Ok((json_bytes.to_vec(), file_size))
423}
424
425pub async fn inspect_safetensors(
441 repo_id: &str,
442 filename: &str,
443 token: Option<&str>,
444 revision: Option<&str>,
445) -> Result<(SafetensorsHeaderInfo, InspectSource), FetchError> {
446 let rev = revision.unwrap_or("main");
447
448 if let Some(cached_path) = resolve_cached_path(repo_id, rev, filename) {
450 let info = inspect_safetensors_local(&cached_path)?;
451 return Ok((info, InspectSource::Cached));
452 }
453
454 let client = chunked::build_client(token)?;
456 let url = chunked::build_download_url(repo_id, rev, filename);
457
458 let (json_bytes, file_size) = fetch_header_bytes(&client, url.as_str(), filename).await?;
460
461 #[allow(clippy::as_conversions)]
463 let header_size = json_bytes.len() as u64;
464
465 let (tensors, metadata) = parse_header_json(&json_bytes, filename)?;
466
467 Ok((
468 SafetensorsHeaderInfo {
469 tensors,
470 metadata,
471 header_size,
472 file_size,
473 },
474 InspectSource::Remote,
475 ))
476}
477
478pub fn inspect_safetensors_cached(
489 repo_id: &str,
490 filename: &str,
491 revision: Option<&str>,
492) -> Result<SafetensorsHeaderInfo, FetchError> {
493 let rev = revision.unwrap_or("main");
494
495 let cached_path = resolve_cached_path(repo_id, rev, filename).ok_or_else(|| {
496 FetchError::SafetensorsHeader {
497 filename: filename.to_owned(),
498 reason: format!("file not found in local cache for {repo_id} ({rev})"),
499 }
500 })?;
501
502 inspect_safetensors_local(&cached_path)
503}
504
505pub async fn inspect_repo_safetensors(
523 repo_id: &str,
524 token: Option<&str>,
525 revision: Option<&str>,
526) -> Result<Vec<(String, SafetensorsHeaderInfo, InspectSource)>, FetchError> {
527 let files = crate::repo::list_repo_files_with_metadata(repo_id, token, revision).await?;
528
529 let safetensors_files: Vec<String> = files
530 .into_iter()
531 .filter(|f| f.filename.ends_with(".safetensors"))
532 .map(|f| f.filename)
533 .collect();
534
535 if safetensors_files.is_empty() {
536 return Ok(Vec::new());
537 }
538
539 let semaphore = std::sync::Arc::new(tokio::sync::Semaphore::new(4));
540 let mut handles = Vec::new();
541
542 for filename in safetensors_files {
543 let sem = semaphore.clone();
544 let repo = repo_id.to_owned();
545 let tok = token.map(str::to_owned);
546 let rev = revision.map(str::to_owned);
547
548 handles.push(tokio::spawn(async move {
549 let _permit = sem
550 .acquire()
551 .await
552 .map_err(|e| FetchError::Http(format!("semaphore error: {e}")))?;
553 let (info, source) =
555 inspect_safetensors(&repo, &filename, tok.as_deref(), rev.as_deref()).await?;
556 Ok::<_, FetchError>((filename, info, source))
557 }));
558 }
559
560 let mut results = Vec::new();
561 for handle in handles {
562 let result = handle
563 .await
564 .map_err(|e| FetchError::Http(format!("task join error: {e}")))?;
565 results.push(result?);
566 }
567
568 results.sort_by(|a, b| a.0.cmp(&b.0));
569
570 Ok(results)
571}
572
573pub fn inspect_repo_safetensors_cached(
583 repo_id: &str,
584 revision: Option<&str>,
585) -> Result<Vec<(String, SafetensorsHeaderInfo)>, FetchError> {
586 let rev = revision.unwrap_or("main");
587 let cache_dir = cache::hf_cache_dir()?;
588 let repo_folder = chunked::repo_folder_name(repo_id);
589 let repo_dir = cache_dir.join(&repo_folder);
590
591 let Some(commit_hash) = cache::read_ref(&repo_dir, rev) else {
592 return Ok(Vec::new());
593 };
594
595 let snapshot_dir = repo_dir.join("snapshots").join(commit_hash);
596 if !snapshot_dir.exists() {
597 return Ok(Vec::new());
598 }
599
600 let mut results = Vec::new();
601 collect_safetensors_recursive(&snapshot_dir, "", &mut results)?;
602 results.sort_by(|a, b| a.0.cmp(&b.0));
603
604 Ok(results)
605}
606
607fn collect_safetensors_recursive(
609 dir: &Path,
610 prefix: &str,
611 results: &mut Vec<(String, SafetensorsHeaderInfo)>,
612) -> Result<(), FetchError> {
613 let entries = std::fs::read_dir(dir).map_err(|e| FetchError::Io {
614 path: dir.to_path_buf(),
615 source: e,
616 })?;
617
618 for entry in entries {
619 let Ok(entry) = entry else { continue };
620 let path = entry.path();
621 let name = entry.file_name().to_string_lossy().to_string();
623
624 if path.is_dir() {
625 let child_prefix = if prefix.is_empty() {
626 name
627 } else {
628 format!("{prefix}/{name}")
629 };
630 collect_safetensors_recursive(&path, &child_prefix, results)?;
631 } else if name.ends_with(".safetensors") {
632 let filename = if prefix.is_empty() {
633 name
634 } else {
635 format!("{prefix}/{name}")
636 };
637 let info = inspect_safetensors_local(&path)?;
638 results.push((filename, info));
639 }
640 }
641
642 Ok(())
643}
644
645#[derive(serde::Deserialize)]
651struct RawShardIndex {
652 weight_map: HashMap<String, String>,
653 #[serde(default)]
654 metadata: Option<HashMap<String, serde_json::Value>>,
655}
656
657pub async fn fetch_shard_index(
667 repo_id: &str,
668 token: Option<&str>,
669 revision: Option<&str>,
670) -> Result<Option<ShardedIndex>, FetchError> {
671 let rev = revision.unwrap_or("main");
672 let index_filename = "model.safetensors.index.json";
673
674 if let Some(cached_path) = resolve_cached_path(repo_id, rev, index_filename) {
676 let content = std::fs::read_to_string(&cached_path).map_err(|e| FetchError::Io {
677 path: cached_path,
678 source: e,
679 })?;
680 let index = parse_shard_index_json(&content, repo_id)?;
681 return Ok(Some(index));
682 }
683
684 let client = chunked::build_client(token)?;
686 let url = chunked::build_download_url(repo_id, rev, index_filename);
687
688 let response =
690 client.get(url.as_str()).send().await.map_err(|e| {
691 FetchError::Http(format!("failed to fetch shard index for {repo_id}: {e}"))
692 })?;
693
694 if response.status() == reqwest::StatusCode::NOT_FOUND {
695 return Ok(None);
696 }
697
698 if !response.status().is_success() {
699 return Err(FetchError::Http(format!(
700 "shard index request for {repo_id} returned status {}",
701 response.status()
702 )));
703 }
704
705 let content = response
706 .text()
707 .await
708 .map_err(|e| FetchError::Http(format!("failed to read shard index for {repo_id}: {e}")))?;
709
710 let index = parse_shard_index_json(&content, repo_id)?;
711 Ok(Some(index))
712}
713
714pub fn fetch_shard_index_cached(
723 repo_id: &str,
724 revision: Option<&str>,
725) -> Result<Option<ShardedIndex>, FetchError> {
726 let rev = revision.unwrap_or("main");
727 let index_filename = "model.safetensors.index.json";
728
729 let Some(cached_path) = resolve_cached_path(repo_id, rev, index_filename) else {
730 return Ok(None);
731 };
732
733 let content = std::fs::read_to_string(&cached_path).map_err(|e| FetchError::Io {
734 path: cached_path,
735 source: e,
736 })?;
737
738 let index = parse_shard_index_json(&content, repo_id)?;
739 Ok(Some(index))
740}
741
742fn parse_shard_index_json(content: &str, repo_id: &str) -> Result<ShardedIndex, FetchError> {
744 let raw: RawShardIndex =
745 serde_json::from_str(content).map_err(|e| FetchError::SafetensorsHeader {
746 filename: "model.safetensors.index.json".to_owned(),
747 reason: format!("failed to parse shard index for {repo_id}: {e}"),
748 })?;
749
750 let mut shard_set: Vec<String> = raw.weight_map.values().cloned().collect();
752 shard_set.sort();
753 shard_set.dedup();
754
755 Ok(ShardedIndex {
756 weight_map: raw.weight_map,
757 shards: shard_set,
758 metadata: raw.metadata,
759 })
760}
761
762#[must_use]
768pub fn format_params(count: u64) -> String {
769 #[allow(clippy::cast_precision_loss, clippy::as_conversions)]
771 let val = count as f64;
772
773 if count >= 1_000_000_000 {
774 format!("{:.2}B", val / 1_000_000_000.0)
775 } else if count >= 1_000_000 {
776 format!("{:.1}M", val / 1_000_000.0)
777 } else if count >= 1_000 {
778 format!("{:.1}K", val / 1_000.0)
779 } else {
780 count.to_string()
781 }
782}
783
784#[derive(serde::Deserialize)]
790struct RawAdapterConfig {
791 #[serde(default)]
792 peft_type: Option<String>,
793 #[serde(default)]
794 base_model_name_or_path: Option<String>,
795 #[serde(default)]
796 r: Option<u32>,
797 #[serde(default)]
798 lora_alpha: Option<f64>,
799 #[serde(default)]
800 target_modules: Option<AdapterTargetModules>,
801 #[serde(default)]
802 task_type: Option<String>,
803}
804
805#[derive(serde::Deserialize)]
807#[serde(untagged)]
808enum AdapterTargetModules {
809 List(Vec<String>),
811 Single(String),
813}
814
815pub async fn fetch_adapter_config(
825 repo_id: &str,
826 token: Option<&str>,
827 revision: Option<&str>,
828) -> Result<Option<AdapterConfig>, FetchError> {
829 let rev = revision.unwrap_or("main");
830 let config_filename = "adapter_config.json";
831
832 if let Some(cached_path) = resolve_cached_path(repo_id, rev, config_filename) {
834 let content = std::fs::read_to_string(&cached_path).map_err(|e| FetchError::Io {
835 path: cached_path,
836 source: e,
837 })?;
838 let config = parse_adapter_config_json(&content, repo_id)?;
839 return Ok(Some(config));
840 }
841
842 let client = chunked::build_client(token)?;
844 let url = chunked::build_download_url(repo_id, rev, config_filename);
845
846 let response = client.get(url.as_str()).send().await.map_err(|e| {
848 FetchError::Http(format!("failed to fetch adapter config for {repo_id}: {e}"))
849 })?;
850
851 if response.status() == reqwest::StatusCode::NOT_FOUND {
852 return Ok(None);
853 }
854
855 if !response.status().is_success() {
856 return Err(FetchError::Http(format!(
857 "adapter config request for {repo_id} returned status {}",
858 response.status()
859 )));
860 }
861
862 let content = response.text().await.map_err(|e| {
863 FetchError::Http(format!("failed to read adapter config for {repo_id}: {e}"))
864 })?;
865
866 let config = parse_adapter_config_json(&content, repo_id)?;
867 Ok(Some(config))
868}
869
870pub fn fetch_adapter_config_cached(
879 repo_id: &str,
880 revision: Option<&str>,
881) -> Result<Option<AdapterConfig>, FetchError> {
882 let rev = revision.unwrap_or("main");
883 let config_filename = "adapter_config.json";
884
885 let Some(cached_path) = resolve_cached_path(repo_id, rev, config_filename) else {
886 return Ok(None);
887 };
888
889 let content = std::fs::read_to_string(&cached_path).map_err(|e| FetchError::Io {
890 path: cached_path,
891 source: e,
892 })?;
893
894 let config = parse_adapter_config_json(&content, repo_id)?;
895 Ok(Some(config))
896}
897
898fn parse_adapter_config_json(content: &str, repo_id: &str) -> Result<AdapterConfig, FetchError> {
900 let raw: RawAdapterConfig =
901 serde_json::from_str(content).map_err(|e| FetchError::SafetensorsHeader {
902 filename: "adapter_config.json".to_owned(),
903 reason: format!("failed to parse adapter config for {repo_id}: {e}"),
904 })?;
905
906 let target_modules = match raw.target_modules {
907 Some(AdapterTargetModules::List(v)) => v,
908 Some(AdapterTargetModules::Single(s)) => vec![s],
909 None => Vec::new(),
910 };
911
912 Ok(AdapterConfig {
913 peft_type: raw.peft_type,
914 base_model_name_or_path: raw.base_model_name_or_path,
915 r: raw.r,
916 lora_alpha: raw.lora_alpha,
917 target_modules,
918 task_type: raw.task_type,
919 })
920}