1use std::collections::HashMap;
14use std::path::{Path, PathBuf};
15
16use serde::Serialize;
17
18use crate::cache;
19use crate::chunked;
20use crate::error::FetchError;
21
22#[derive(Debug, Clone, Serialize)]
31pub struct TensorInfo {
32 pub name: String,
34 pub dtype: String,
36 pub shape: Vec<usize>,
38 pub data_offsets: (u64, u64),
40}
41
42impl TensorInfo {
43 #[must_use]
47 pub fn num_elements(&self) -> u64 {
48 self.shape.iter().fold(1u64, |acc, &d| {
49 #[allow(clippy::as_conversions)]
51 let dim = d as u64;
52 acc.saturating_mul(dim)
53 })
54 }
55
56 #[must_use]
58 pub const fn byte_len(&self) -> u64 {
59 self.data_offsets.1.saturating_sub(self.data_offsets.0)
60 }
61
62 #[must_use]
75 pub fn dtype_bytes(&self) -> Option<usize> {
76 match self.dtype.as_str() {
78 "BOOL" | "U8" | "I8" | "F8_E4M3" | "F8_E5M2" => Some(1),
79 "U16" | "I16" | "F16" | "BF16" => Some(2),
80 "U32" | "I32" | "F32" => Some(4),
81 "U64" | "I64" | "F64" => Some(8),
82 _ => None,
83 }
84 }
85}
86
87#[derive(Debug, Clone, Serialize)]
89pub struct SafetensorsHeaderInfo {
90 pub tensors: Vec<TensorInfo>,
92 pub metadata: Option<HashMap<String, String>>,
98 pub header_size: u64,
100 pub file_size: Option<u64>,
107}
108
109impl SafetensorsHeaderInfo {
110 #[must_use]
112 pub fn total_params(&self) -> u64 {
113 self.tensors
114 .iter()
115 .map(TensorInfo::num_elements)
116 .fold(0u64, u64::saturating_add)
117 }
118
119 #[must_use]
121 pub fn tensors_with_dtype(&self, dtype: &str) -> Vec<&TensorInfo> {
122 self.tensors
123 .iter()
124 .filter(|t| t.dtype.as_str() == dtype)
126 .collect()
127 }
128}
129
130#[derive(Debug, Clone, Copy, PartialEq, Eq)]
132#[non_exhaustive]
133pub enum InspectSource {
134 Cached,
136 Remote,
138}
139
140#[derive(Debug, Clone, Serialize)]
142pub struct ShardedIndex {
143 pub weight_map: HashMap<String, String>,
145 pub shards: Vec<String>,
147 pub metadata: Option<HashMap<String, serde_json::Value>>,
149}
150
151#[derive(Debug, Clone, Serialize)]
157pub struct AdapterConfig {
158 pub peft_type: Option<String>,
160 pub base_model_name_or_path: Option<String>,
162 pub r: Option<u32>,
164 pub lora_alpha: Option<f64>,
166 pub target_modules: Vec<String>,
168 pub task_type: Option<String>,
170}
171
172#[derive(serde::Deserialize)]
178struct RawTensorEntry {
179 dtype: String,
180 shape: Vec<usize>,
181 data_offsets: (u64, u64),
182}
183
184type ParsedHeader = (Vec<TensorInfo>, Option<HashMap<String, String>>);
186
187fn parse_header_json(json_bytes: &[u8], filename: &str) -> Result<ParsedHeader, FetchError> {
191 let raw: HashMap<String, serde_json::Value> =
192 serde_json::from_slice(json_bytes).map_err(|e| FetchError::SafetensorsHeader {
193 filename: filename.to_owned(),
194 reason: format!("failed to parse header JSON: {e}"),
195 })?;
196
197 let mut metadata: Option<HashMap<String, String>> = None;
198 let mut tensors = Vec::new();
199
200 for (key, value) in &raw {
201 if key == "__metadata__" {
202 if let Some(obj) = value.as_object() {
203 let mut meta_map = HashMap::new();
204 for (mk, mv) in obj {
205 let v_str = if let Some(s) = mv.as_str() {
207 s.to_owned()
208 } else {
209 mv.to_string()
210 };
211 meta_map.insert(mk.clone(), v_str);
213 }
214 metadata = Some(meta_map);
215 }
216 continue;
217 }
218
219 let entry: RawTensorEntry =
220 serde_json::from_value(value.clone()).map_err(|e| FetchError::SafetensorsHeader {
221 filename: filename.to_owned(),
222 reason: format!("failed to parse tensor \"{key}\": {e}"),
223 })?;
224
225 tensors.push(TensorInfo {
226 name: key.clone(),
228 dtype: entry.dtype,
229 shape: entry.shape,
230 data_offsets: entry.data_offsets,
231 });
232 }
233
234 tensors.sort_by_key(|t| t.data_offsets.0);
236
237 Ok((tensors, metadata))
238}
239
240fn resolve_cached_path(repo_id: &str, revision: &str, filename: &str) -> Option<PathBuf> {
248 let cache_dir = cache::hf_cache_dir().ok()?;
249 let repo_folder = chunked::repo_folder_name(repo_id);
250 let repo_dir = cache_dir.join(repo_folder.as_str());
252 let commit_hash = cache::read_ref(&repo_dir, revision)?;
253 let cached_path = repo_dir.join("snapshots").join(commit_hash).join(filename);
254 if cached_path.exists() {
255 Some(cached_path)
256 } else {
257 None
258 }
259}
260
261pub fn inspect_safetensors_local(path: &Path) -> Result<SafetensorsHeaderInfo, FetchError> {
274 use std::io::Read;
275
276 let file_size = std::fs::metadata(path)
277 .map_err(|e| FetchError::Io {
278 path: path.to_path_buf(),
279 source: e,
280 })?
281 .len();
282
283 let filename = path.file_name().map_or_else(
285 || path.display().to_string(),
286 |n| n.to_string_lossy().to_string(),
287 );
288
289 let mut file = std::fs::File::open(path).map_err(|e| FetchError::Io {
290 path: path.to_path_buf(),
291 source: e,
292 })?;
293
294 let mut len_buf = [0u8; 8];
296 file.read_exact(&mut len_buf).map_err(|e| FetchError::Io {
297 path: path.to_path_buf(),
298 source: e,
299 })?;
300 let header_size = u64::from_le_bytes(len_buf);
301
302 if header_size.saturating_add(8) > file_size {
304 return Err(FetchError::SafetensorsHeader {
305 filename,
306 reason: format!("header length {header_size} exceeds file size {file_size}"),
307 });
308 }
309
310 #[allow(clippy::cast_possible_truncation, clippy::as_conversions)]
313 let json_len = header_size as usize;
314 let mut json_buf = vec![0u8; json_len];
315 file.read_exact(&mut json_buf).map_err(|e| FetchError::Io {
316 path: path.to_path_buf(),
317 source: e,
318 })?;
319
320 let (tensors, metadata) = parse_header_json(&json_buf, filename.as_str())?;
322
323 Ok(SafetensorsHeaderInfo {
324 tensors,
325 metadata,
326 header_size,
327 file_size: Some(file_size),
328 })
329}
330
331async fn fetch_header_bytes(
343 client: &reqwest::Client,
344 url: &str,
345 filename: &str,
346) -> Result<(Vec<u8>, Option<u64>), FetchError> {
347 let resp1 = client
349 .get(url)
350 .header(reqwest::header::RANGE, "bytes=0-7")
351 .send()
352 .await
353 .map_err(|e| {
354 FetchError::Http(format!("failed to fetch header length for {filename}: {e}"))
355 })?;
356
357 if !resp1.status().is_success() && resp1.status() != reqwest::StatusCode::PARTIAL_CONTENT {
358 return Err(FetchError::Http(format!(
359 "Range request for {filename} returned status {}",
360 resp1.status()
361 )));
362 }
363
364 let file_size = resp1
366 .headers()
367 .get(reqwest::header::CONTENT_RANGE)
368 .and_then(|v| v.to_str().ok())
369 .and_then(|s| s.split('/').next_back())
370 .and_then(|s| s.parse::<u64>().ok());
371
372 let len_bytes = resp1.bytes().await.map_err(|e| {
373 FetchError::Http(format!("failed to read header length for {filename}: {e}"))
374 })?;
375
376 if len_bytes.len() < 8 {
377 return Err(FetchError::SafetensorsHeader {
378 filename: filename.to_owned(),
379 reason: format!(
380 "expected 8 bytes for length prefix, got {}",
381 len_bytes.len()
382 ),
383 });
384 }
385
386 #[allow(clippy::indexing_slicing)]
388 let header_size = u64::from_le_bytes([
389 len_bytes[0],
390 len_bytes[1],
391 len_bytes[2],
392 len_bytes[3],
393 len_bytes[4],
394 len_bytes[5],
395 len_bytes[6],
396 len_bytes[7],
397 ]);
398
399 let range_end = 8u64.saturating_add(header_size).saturating_sub(1);
401 let range_header = format!("bytes=8-{range_end}");
402 let resp2 = client
403 .get(url)
404 .header(reqwest::header::RANGE, range_header.as_str())
406 .send()
407 .await
408 .map_err(|e| {
409 FetchError::Http(format!("failed to fetch header JSON for {filename}: {e}"))
410 })?;
411
412 if !resp2.status().is_success() && resp2.status() != reqwest::StatusCode::PARTIAL_CONTENT {
413 return Err(FetchError::Http(format!(
414 "Range request for {filename} header JSON returned status {}",
415 resp2.status()
416 )));
417 }
418
419 let json_bytes = resp2
420 .bytes()
421 .await
422 .map_err(|e| FetchError::Http(format!("failed to read header JSON for {filename}: {e}")))?;
423
424 Ok((json_bytes.to_vec(), file_size))
425}
426
427pub async fn inspect_safetensors(
443 repo_id: &str,
444 filename: &str,
445 token: Option<&str>,
446 revision: Option<&str>,
447) -> Result<(SafetensorsHeaderInfo, InspectSource), FetchError> {
448 let rev = revision.unwrap_or("main");
449
450 if let Some(cached_path) = resolve_cached_path(repo_id, rev, filename) {
452 let info = inspect_safetensors_local(&cached_path)?;
453 return Ok((info, InspectSource::Cached));
454 }
455
456 let client = chunked::build_client(token)?;
458 let url = chunked::build_download_url(repo_id, rev, filename);
459
460 let (json_bytes, file_size) = fetch_header_bytes(&client, url.as_str(), filename).await?;
462
463 #[allow(clippy::as_conversions)]
465 let header_size = json_bytes.len() as u64;
466
467 let (tensors, metadata) = parse_header_json(&json_bytes, filename)?;
468
469 Ok((
470 SafetensorsHeaderInfo {
471 tensors,
472 metadata,
473 header_size,
474 file_size,
475 },
476 InspectSource::Remote,
477 ))
478}
479
480pub fn inspect_safetensors_cached(
491 repo_id: &str,
492 filename: &str,
493 revision: Option<&str>,
494) -> Result<SafetensorsHeaderInfo, FetchError> {
495 let rev = revision.unwrap_or("main");
496
497 let cached_path = resolve_cached_path(repo_id, rev, filename).ok_or_else(|| {
498 FetchError::SafetensorsHeader {
499 filename: filename.to_owned(),
500 reason: format!("file not found in local cache for {repo_id} ({rev})"),
501 }
502 })?;
503
504 inspect_safetensors_local(&cached_path)
505}
506
507pub async fn inspect_repo_safetensors(
525 repo_id: &str,
526 token: Option<&str>,
527 revision: Option<&str>,
528) -> Result<Vec<(String, SafetensorsHeaderInfo, InspectSource)>, FetchError> {
529 let files = crate::repo::list_repo_files_with_metadata(repo_id, token, revision).await?;
530
531 let safetensors_files: Vec<String> = files
532 .into_iter()
533 .filter(|f| f.filename.ends_with(".safetensors"))
534 .map(|f| f.filename)
535 .collect();
536
537 if safetensors_files.is_empty() {
538 return Ok(Vec::new());
539 }
540
541 let semaphore = std::sync::Arc::new(tokio::sync::Semaphore::new(4));
542 let mut handles = Vec::new();
543
544 for filename in safetensors_files {
545 let sem = semaphore.clone();
547 let repo = repo_id.to_owned();
548 let tok = token.map(str::to_owned);
549 let rev = revision.map(str::to_owned);
550
551 handles.push(tokio::spawn(async move {
552 let _permit = sem
553 .acquire()
554 .await
555 .map_err(|e| FetchError::Http(format!("semaphore error: {e}")))?;
556 let (info, source) =
558 inspect_safetensors(&repo, &filename, tok.as_deref(), rev.as_deref()).await?;
559 Ok::<_, FetchError>((filename, info, source))
560 }));
561 }
562
563 let mut results = Vec::new();
564 for handle in handles {
565 let result = handle
566 .await
567 .map_err(|e| FetchError::Http(format!("task join error: {e}")))?;
568 results.push(result?);
569 }
570
571 results.sort_by(|a, b| a.0.cmp(&b.0));
572
573 Ok(results)
574}
575
576pub fn inspect_repo_safetensors_cached(
586 repo_id: &str,
587 revision: Option<&str>,
588) -> Result<Vec<(String, SafetensorsHeaderInfo)>, FetchError> {
589 let rev = revision.unwrap_or("main");
590 let cache_dir = cache::hf_cache_dir()?;
591 let repo_folder = chunked::repo_folder_name(repo_id);
592 let repo_dir = cache_dir.join(repo_folder.as_str());
594
595 let Some(commit_hash) = cache::read_ref(&repo_dir, rev) else {
596 return Ok(Vec::new());
597 };
598
599 let snapshot_dir = repo_dir.join("snapshots").join(commit_hash);
600 if !snapshot_dir.exists() {
601 return Ok(Vec::new());
602 }
603
604 let mut results = Vec::new();
605 collect_safetensors_recursive(&snapshot_dir, "", &mut results)?;
606 results.sort_by(|a, b| a.0.cmp(&b.0));
607
608 Ok(results)
609}
610
611fn collect_safetensors_recursive(
613 dir: &Path,
614 prefix: &str,
615 results: &mut Vec<(String, SafetensorsHeaderInfo)>,
616) -> Result<(), FetchError> {
617 let entries = std::fs::read_dir(dir).map_err(|e| FetchError::Io {
618 path: dir.to_path_buf(),
619 source: e,
620 })?;
621
622 for entry in entries {
623 let Ok(entry) = entry else { continue };
624 let path = entry.path();
625 let name = entry.file_name().to_string_lossy().to_string();
627
628 if path.is_dir() {
629 let child_prefix = if prefix.is_empty() {
630 name
631 } else {
632 format!("{prefix}/{name}")
633 };
634 collect_safetensors_recursive(&path, &child_prefix, results)?;
635 } else if name.ends_with(".safetensors") {
636 let filename = if prefix.is_empty() {
637 name
638 } else {
639 format!("{prefix}/{name}")
640 };
641 let info = inspect_safetensors_local(&path)?;
642 results.push((filename, info));
643 }
644 }
645
646 Ok(())
647}
648
649#[derive(serde::Deserialize)]
655struct RawShardIndex {
656 weight_map: HashMap<String, String>,
657 #[serde(default)]
658 metadata: Option<HashMap<String, serde_json::Value>>,
659}
660
661pub async fn fetch_shard_index(
671 repo_id: &str,
672 token: Option<&str>,
673 revision: Option<&str>,
674) -> Result<Option<ShardedIndex>, FetchError> {
675 let rev = revision.unwrap_or("main");
676 let index_filename = "model.safetensors.index.json";
677
678 if let Some(cached_path) = resolve_cached_path(repo_id, rev, index_filename) {
680 let content = std::fs::read_to_string(&cached_path).map_err(|e| FetchError::Io {
681 path: cached_path,
682 source: e,
683 })?;
684 let index = parse_shard_index_json(&content, repo_id)?;
685 return Ok(Some(index));
686 }
687
688 let client = chunked::build_client(token)?;
690 let url = chunked::build_download_url(repo_id, rev, index_filename);
691
692 let response =
694 client.get(url.as_str()).send().await.map_err(|e| {
695 FetchError::Http(format!("failed to fetch shard index for {repo_id}: {e}"))
696 })?;
697
698 if response.status() == reqwest::StatusCode::NOT_FOUND {
699 return Ok(None);
700 }
701
702 if !response.status().is_success() {
703 return Err(FetchError::Http(format!(
704 "shard index request for {repo_id} returned status {}",
705 response.status()
706 )));
707 }
708
709 let content = response
710 .text()
711 .await
712 .map_err(|e| FetchError::Http(format!("failed to read shard index for {repo_id}: {e}")))?;
713
714 let index = parse_shard_index_json(&content, repo_id)?;
715 Ok(Some(index))
716}
717
718pub fn fetch_shard_index_cached(
727 repo_id: &str,
728 revision: Option<&str>,
729) -> Result<Option<ShardedIndex>, FetchError> {
730 let rev = revision.unwrap_or("main");
731 let index_filename = "model.safetensors.index.json";
732
733 let Some(cached_path) = resolve_cached_path(repo_id, rev, index_filename) else {
734 return Ok(None);
735 };
736
737 let content = std::fs::read_to_string(&cached_path).map_err(|e| FetchError::Io {
738 path: cached_path,
739 source: e,
740 })?;
741
742 let index = parse_shard_index_json(&content, repo_id)?;
743 Ok(Some(index))
744}
745
746fn parse_shard_index_json(content: &str, repo_id: &str) -> Result<ShardedIndex, FetchError> {
748 let raw: RawShardIndex =
749 serde_json::from_str(content).map_err(|e| FetchError::SafetensorsHeader {
750 filename: "model.safetensors.index.json".to_owned(),
751 reason: format!("failed to parse shard index for {repo_id}: {e}"),
752 })?;
753
754 let mut shard_set: Vec<String> = raw.weight_map.values().cloned().collect();
756 shard_set.sort();
757 shard_set.dedup();
758
759 Ok(ShardedIndex {
760 weight_map: raw.weight_map,
761 shards: shard_set,
762 metadata: raw.metadata,
763 })
764}
765
766#[must_use]
772pub fn format_params(count: u64) -> String {
773 #[allow(clippy::cast_precision_loss, clippy::as_conversions)]
775 let val = count as f64;
776
777 if count >= 1_000_000_000 {
778 format!("{:.2}B", val / 1_000_000_000.0)
779 } else if count >= 1_000_000 {
780 format!("{:.1}M", val / 1_000_000.0)
781 } else if count >= 1_000 {
782 format!("{:.1}K", val / 1_000.0)
783 } else {
784 count.to_string()
785 }
786}
787
788#[derive(serde::Deserialize)]
794struct RawAdapterConfig {
795 #[serde(default)]
796 peft_type: Option<String>,
797 #[serde(default)]
798 base_model_name_or_path: Option<String>,
799 #[serde(default)]
800 r: Option<u32>,
801 #[serde(default)]
802 lora_alpha: Option<f64>,
803 #[serde(default)]
804 target_modules: Option<AdapterTargetModules>,
805 #[serde(default)]
806 task_type: Option<String>,
807}
808
809#[derive(serde::Deserialize)]
811#[serde(untagged)]
812enum AdapterTargetModules {
813 List(Vec<String>),
815 Single(String),
817}
818
819pub async fn fetch_adapter_config(
829 repo_id: &str,
830 token: Option<&str>,
831 revision: Option<&str>,
832) -> Result<Option<AdapterConfig>, FetchError> {
833 let rev = revision.unwrap_or("main");
834 let config_filename = "adapter_config.json";
835
836 if let Some(cached_path) = resolve_cached_path(repo_id, rev, config_filename) {
838 let content = std::fs::read_to_string(&cached_path).map_err(|e| FetchError::Io {
839 path: cached_path,
840 source: e,
841 })?;
842 let config = parse_adapter_config_json(&content, repo_id)?;
843 return Ok(Some(config));
844 }
845
846 let client = chunked::build_client(token)?;
848 let url = chunked::build_download_url(repo_id, rev, config_filename);
849
850 let response = client.get(url.as_str()).send().await.map_err(|e| {
852 FetchError::Http(format!("failed to fetch adapter config for {repo_id}: {e}"))
853 })?;
854
855 if response.status() == reqwest::StatusCode::NOT_FOUND {
856 return Ok(None);
857 }
858
859 if !response.status().is_success() {
860 return Err(FetchError::Http(format!(
861 "adapter config request for {repo_id} returned status {}",
862 response.status()
863 )));
864 }
865
866 let content = response.text().await.map_err(|e| {
867 FetchError::Http(format!("failed to read adapter config for {repo_id}: {e}"))
868 })?;
869
870 let config = parse_adapter_config_json(&content, repo_id)?;
871 Ok(Some(config))
872}
873
874pub fn fetch_adapter_config_cached(
883 repo_id: &str,
884 revision: Option<&str>,
885) -> Result<Option<AdapterConfig>, FetchError> {
886 let rev = revision.unwrap_or("main");
887 let config_filename = "adapter_config.json";
888
889 let Some(cached_path) = resolve_cached_path(repo_id, rev, config_filename) else {
890 return Ok(None);
891 };
892
893 let content = std::fs::read_to_string(&cached_path).map_err(|e| FetchError::Io {
894 path: cached_path,
895 source: e,
896 })?;
897
898 let config = parse_adapter_config_json(&content, repo_id)?;
899 Ok(Some(config))
900}
901
902fn parse_adapter_config_json(content: &str, repo_id: &str) -> Result<AdapterConfig, FetchError> {
904 let raw: RawAdapterConfig =
905 serde_json::from_str(content).map_err(|e| FetchError::SafetensorsHeader {
906 filename: "adapter_config.json".to_owned(),
907 reason: format!("failed to parse adapter config for {repo_id}: {e}"),
908 })?;
909
910 let target_modules = match raw.target_modules {
911 Some(AdapterTargetModules::List(v)) => v,
912 Some(AdapterTargetModules::Single(s)) => vec![s],
913 None => Vec::new(),
914 };
915
916 Ok(AdapterConfig {
917 peft_type: raw.peft_type,
918 base_model_name_or_path: raw.base_model_name_or_path,
919 r: raw.r,
920 lora_alpha: raw.lora_alpha,
921 target_modules,
922 task_type: raw.task_type,
923 })
924}