1use std::collections::HashMap;
14use std::path::{Path, PathBuf};
15
16use serde::Serialize;
17use tokio::task::JoinSet;
18
19use crate::cache;
20use crate::cache_layout;
21use crate::chunked;
22use crate::error::FetchError;
23
24#[derive(Debug, Clone, Serialize)]
33pub struct TensorInfo {
34 pub name: String,
36 pub dtype: String,
38 pub shape: Vec<usize>,
40 pub data_offsets: (u64, u64),
42}
43
44impl TensorInfo {
45 #[must_use]
49 pub fn num_elements(&self) -> u64 {
50 self.shape.iter().fold(1u64, |acc, &d| {
51 #[allow(clippy::as_conversions)]
53 let dim = d as u64;
54 acc.saturating_mul(dim)
55 })
56 }
57
58 #[must_use]
60 pub const fn byte_len(&self) -> u64 {
61 self.data_offsets.1.saturating_sub(self.data_offsets.0)
62 }
63
64 #[must_use]
77 pub fn dtype_bytes(&self) -> Option<usize> {
78 match self.dtype.as_str() {
80 "BOOL" | "U8" | "I8" | "F8_E4M3" | "F8_E5M2" => Some(1),
81 "U16" | "I16" | "F16" | "BF16" => Some(2),
82 "U32" | "I32" | "F32" => Some(4),
83 "U64" | "I64" | "F64" => Some(8),
84 _ => None,
85 }
86 }
87}
88
89#[derive(Debug, Clone, Serialize)]
91pub struct SafetensorsHeaderInfo {
92 pub tensors: Vec<TensorInfo>,
94 pub metadata: Option<HashMap<String, String>>,
100 pub header_size: u64,
102 pub file_size: Option<u64>,
109}
110
111impl SafetensorsHeaderInfo {
112 #[must_use]
114 pub fn total_params(&self) -> u64 {
115 self.tensors
116 .iter()
117 .map(TensorInfo::num_elements)
118 .fold(0u64, u64::saturating_add)
119 }
120
121 #[must_use]
123 pub fn tensors_with_dtype(&self, dtype: &str) -> Vec<&TensorInfo> {
124 self.tensors
125 .iter()
126 .filter(|t| t.dtype.as_str() == dtype)
128 .collect()
129 }
130}
131
132#[derive(Debug, Clone, Copy, PartialEq, Eq)]
134#[non_exhaustive]
135pub enum InspectSource {
136 Cached,
138 Remote,
140}
141
142#[derive(Debug, Clone, Serialize)]
144pub struct ShardedIndex {
145 pub weight_map: HashMap<String, String>,
147 pub shards: Vec<String>,
149 pub metadata: Option<HashMap<String, serde_json::Value>>,
151}
152
153#[derive(Debug, Clone, Serialize)]
159pub struct AdapterConfig {
160 pub peft_type: Option<String>,
162 pub base_model_name_or_path: Option<String>,
164 pub r: Option<u32>,
166 pub lora_alpha: Option<f64>,
168 pub target_modules: Vec<String>,
170 pub task_type: Option<String>,
172}
173
174#[derive(serde::Deserialize)]
180struct RawTensorEntry {
181 dtype: String,
182 shape: Vec<usize>,
183 data_offsets: (u64, u64),
184}
185
186type ParsedHeader = (Vec<TensorInfo>, Option<HashMap<String, String>>);
188
189fn parse_header_json(json_bytes: &[u8], filename: &str) -> Result<ParsedHeader, FetchError> {
193 let raw: HashMap<String, serde_json::Value> =
194 serde_json::from_slice(json_bytes).map_err(|e| FetchError::SafetensorsHeader {
195 filename: filename.to_owned(),
196 reason: format!("failed to parse header JSON: {e}"),
197 })?;
198
199 let mut metadata: Option<HashMap<String, String>> = None;
200 let mut tensors = Vec::new();
201
202 for (key, value) in raw {
203 if key == "__metadata__" {
204 if let serde_json::Value::Object(obj) = value {
205 let mut meta_map = HashMap::new();
206 for (mk, mv) in obj {
207 let v_str = if let Some(s) = mv.as_str() {
209 s.to_owned()
210 } else {
211 mv.to_string()
212 };
213 meta_map.insert(mk, v_str);
214 }
215 metadata = Some(meta_map);
216 }
217 continue;
218 }
219
220 let entry: RawTensorEntry =
221 serde_json::from_value(value).map_err(|e| FetchError::SafetensorsHeader {
222 filename: filename.to_owned(),
223 reason: format!("failed to parse tensor \"{key}\": {e}"),
224 })?;
225
226 tensors.push(TensorInfo {
227 name: key,
228 dtype: entry.dtype,
229 shape: entry.shape,
230 data_offsets: entry.data_offsets,
231 });
232 }
233
234 tensors.sort_by_key(|t| t.data_offsets.0);
236
237 Ok((tensors, metadata))
238}
239
240fn resolve_cached_path(repo_id: &str, revision: &str, filename: &str) -> Option<PathBuf> {
248 let cache_dir = cache::hf_cache_dir().ok()?;
249 let repo_dir = cache_layout::repo_dir(&cache_dir, repo_id);
250 let commit_hash = cache::read_ref(&repo_dir, revision)?;
251 let cached_path = cache_layout::pointer_path(&repo_dir, &commit_hash, filename);
252 if cached_path.exists() {
253 Some(cached_path)
254 } else {
255 None
256 }
257}
258
259pub fn inspect_safetensors_local(path: &Path) -> Result<SafetensorsHeaderInfo, FetchError> {
279 use std::io::Read;
280
281 let file_size = std::fs::metadata(path)
282 .map_err(|e| FetchError::Io {
283 path: path.to_path_buf(),
284 source: e,
285 })?
286 .len();
287
288 let filename = path.file_name().map_or_else(
290 || path.display().to_string(),
291 |n| n.to_string_lossy().to_string(),
292 );
293
294 let mut file = std::fs::File::open(path).map_err(|e| FetchError::Io {
295 path: path.to_path_buf(),
296 source: e,
297 })?;
298
299 let mut len_buf = [0u8; 8];
301 file.read_exact(&mut len_buf).map_err(|e| FetchError::Io {
302 path: path.to_path_buf(),
303 source: e,
304 })?;
305 let header_size = u64::from_le_bytes(len_buf);
306
307 if header_size.saturating_add(8) > file_size {
309 return Err(FetchError::SafetensorsHeader {
310 filename,
311 reason: format!("header length {header_size} exceeds file size {file_size}"),
312 });
313 }
314
315 #[allow(clippy::cast_possible_truncation, clippy::as_conversions)]
318 let json_len = header_size as usize;
319 let mut json_buf = vec![0u8; json_len];
320 file.read_exact(&mut json_buf).map_err(|e| FetchError::Io {
321 path: path.to_path_buf(),
322 source: e,
323 })?;
324
325 let (tensors, metadata) = parse_header_json(&json_buf, filename.as_str())?;
327
328 Ok(SafetensorsHeaderInfo {
329 tensors,
330 metadata,
331 header_size,
332 file_size: Some(file_size),
333 })
334}
335
336async fn fetch_header_bytes(
348 client: &reqwest::Client,
349 url: &str,
350 filename: &str,
351) -> Result<(Vec<u8>, Option<u64>), FetchError> {
352 let resp1 = client
354 .get(url)
355 .header(reqwest::header::RANGE, "bytes=0-7")
356 .send()
357 .await
358 .map_err(|e| {
359 FetchError::Http(format!("failed to fetch header length for {filename}: {e}"))
360 })?;
361
362 if !resp1.status().is_success() && resp1.status() != reqwest::StatusCode::PARTIAL_CONTENT {
363 return Err(FetchError::Http(format!(
364 "Range request for {filename} returned status {}",
365 resp1.status()
366 )));
367 }
368
369 let file_size = resp1
371 .headers()
372 .get(reqwest::header::CONTENT_RANGE)
373 .and_then(|v| v.to_str().ok())
374 .and_then(|s| s.split('/').next_back())
375 .and_then(|s| s.parse::<u64>().ok());
376
377 let len_bytes = resp1.bytes().await.map_err(|e| {
378 FetchError::Http(format!("failed to read header length for {filename}: {e}"))
379 })?;
380
381 if len_bytes.len() < 8 {
382 return Err(FetchError::SafetensorsHeader {
383 filename: filename.to_owned(),
384 reason: format!(
385 "expected 8 bytes for length prefix, got {}",
386 len_bytes.len()
387 ),
388 });
389 }
390
391 #[allow(clippy::indexing_slicing)]
393 let header_size = u64::from_le_bytes([
394 len_bytes[0],
395 len_bytes[1],
396 len_bytes[2],
397 len_bytes[3],
398 len_bytes[4],
399 len_bytes[5],
400 len_bytes[6],
401 len_bytes[7],
402 ]);
403
404 let range_end = 8u64.saturating_add(header_size).saturating_sub(1);
406 let range_header = format!("bytes=8-{range_end}");
407 let resp2 = client
408 .get(url)
409 .header(reqwest::header::RANGE, range_header.as_str())
411 .send()
412 .await
413 .map_err(|e| {
414 FetchError::Http(format!("failed to fetch header JSON for {filename}: {e}"))
415 })?;
416
417 if !resp2.status().is_success() && resp2.status() != reqwest::StatusCode::PARTIAL_CONTENT {
418 return Err(FetchError::Http(format!(
419 "Range request for {filename} header JSON returned status {}",
420 resp2.status()
421 )));
422 }
423
424 let json_bytes = resp2
425 .bytes()
426 .await
427 .map_err(|e| FetchError::Http(format!("failed to read header JSON for {filename}: {e}")))?;
428
429 Ok((json_bytes.to_vec(), file_size))
430}
431
432pub async fn inspect_safetensors(
448 repo_id: &str,
449 filename: &str,
450 token: Option<&str>,
451 revision: Option<&str>,
452) -> Result<(SafetensorsHeaderInfo, InspectSource), FetchError> {
453 let rev = revision.unwrap_or("main");
454
455 if let Some(cached_path) = resolve_cached_path(repo_id, rev, filename) {
457 let info = inspect_safetensors_local(&cached_path)?;
458 return Ok((info, InspectSource::Cached));
459 }
460
461 let client = chunked::build_client(token)?;
463 let url = chunked::build_download_url(repo_id, rev, filename);
464
465 let (json_bytes, file_size) = fetch_header_bytes(&client, url.as_str(), filename).await?;
467
468 #[allow(clippy::as_conversions)]
470 let header_size = json_bytes.len() as u64;
471
472 let (tensors, metadata) = parse_header_json(&json_bytes, filename)?;
473
474 Ok((
475 SafetensorsHeaderInfo {
476 tensors,
477 metadata,
478 header_size,
479 file_size,
480 },
481 InspectSource::Remote,
482 ))
483}
484
485pub fn inspect_safetensors_cached(
501 repo_id: &str,
502 filename: &str,
503 revision: Option<&str>,
504) -> Result<SafetensorsHeaderInfo, FetchError> {
505 let rev = revision.unwrap_or("main");
506
507 let cached_path = resolve_cached_path(repo_id, rev, filename).ok_or_else(|| {
508 FetchError::SafetensorsHeader {
509 filename: filename.to_owned(),
510 reason: format!("file not found in local cache for {repo_id} ({rev})"),
511 }
512 })?;
513
514 inspect_safetensors_local(&cached_path)
515}
516
517pub async fn inspect_repo_safetensors(
535 repo_id: &str,
536 token: Option<&str>,
537 revision: Option<&str>,
538) -> Result<Vec<(String, SafetensorsHeaderInfo, InspectSource)>, FetchError> {
539 let client = crate::chunked::build_client(token)?;
540 let files =
541 crate::repo::list_repo_files_with_metadata(repo_id, token, revision, &client).await?;
542
543 let safetensors_files: Vec<String> = files
544 .into_iter()
545 .filter(|f| f.filename.ends_with(".safetensors"))
546 .map(|f| f.filename)
547 .collect();
548
549 if safetensors_files.is_empty() {
550 return Ok(Vec::new());
551 }
552
553 let semaphore = std::sync::Arc::new(tokio::sync::Semaphore::new(4));
554 let mut join_set = JoinSet::new();
555
556 for filename in safetensors_files {
557 let sem = semaphore.clone();
559 let repo = repo_id.to_owned();
560 let tok = token.map(str::to_owned);
561 let rev = revision.map(str::to_owned);
562
563 join_set.spawn(async move {
564 let _permit = sem
565 .acquire()
566 .await
567 .map_err(|e| FetchError::Http(format!("semaphore error: {e}")))?;
568 let (info, source) =
570 inspect_safetensors(&repo, &filename, tok.as_deref(), rev.as_deref()).await?;
571 Ok::<_, FetchError>((filename, info, source))
572 });
573 }
574
575 let mut results = Vec::new();
576 while let Some(join_result) = join_set.join_next().await {
577 match join_result {
578 Ok(Ok(item)) => results.push(item),
579 Ok(Err(e)) => {
580 join_set.abort_all();
581 return Err(e);
582 }
583 Err(e) => {
584 join_set.abort_all();
585 return Err(FetchError::Http(format!("task join error: {e}")));
586 }
587 }
588 }
589
590 results.sort_by(|a, b| a.0.cmp(&b.0));
591
592 Ok(results)
593}
594
595pub type SafetensorsListing = (Vec<(String, u64)>, Option<String>);
603
604pub fn list_cached_safetensors(
624 repo_id: &str,
625 revision: Option<&str>,
626) -> Result<SafetensorsListing, FetchError> {
627 let rev = revision.unwrap_or("main");
628 let cache_dir = cache::hf_cache_dir()?;
629 let repo_dir = cache_layout::repo_dir(&cache_dir, repo_id);
630
631 let Some(commit_hash) = cache::read_ref(&repo_dir, rev) else {
632 return Ok((Vec::new(), None));
633 };
634
635 let snapshot_dir = cache_layout::snapshot_dir(&repo_dir, &commit_hash);
636 if !snapshot_dir.exists() {
637 return Ok((Vec::new(), Some(commit_hash)));
638 }
639
640 let mut results = Vec::new();
641 collect_safetensors_names_sizes(&snapshot_dir, "", &mut results)?;
642 results.sort_by(|a, b| a.0.cmp(&b.0));
643 Ok((results, Some(commit_hash)))
644}
645
646fn collect_safetensors_names_sizes(
648 dir: &Path,
649 prefix: &str,
650 results: &mut Vec<(String, u64)>,
651) -> Result<(), FetchError> {
652 let entries = std::fs::read_dir(dir).map_err(|e| FetchError::Io {
653 path: dir.to_path_buf(),
654 source: e,
655 })?;
656
657 for entry in entries {
658 let Ok(entry) = entry else { continue };
659 let path = entry.path();
660 let name = entry.file_name().to_string_lossy().to_string();
662
663 if path.is_dir() {
664 let child_prefix = if prefix.is_empty() {
665 name
666 } else {
667 format!("{prefix}/{name}")
668 };
669 collect_safetensors_names_sizes(&path, &child_prefix, results)?;
670 } else if name.ends_with(".safetensors") {
671 let filename = if prefix.is_empty() {
672 name
673 } else {
674 format!("{prefix}/{name}")
675 };
676 let size = entry.metadata().map_or(0, |m| m.len());
677 results.push((filename, size));
678 }
679 }
680
681 Ok(())
682}
683
684pub fn inspect_repo_safetensors_cached(
700 repo_id: &str,
701 revision: Option<&str>,
702) -> Result<Vec<(String, SafetensorsHeaderInfo)>, FetchError> {
703 let rev = revision.unwrap_or("main");
704 let cache_dir = cache::hf_cache_dir()?;
705 let repo_dir = cache_layout::repo_dir(&cache_dir, repo_id);
706
707 let Some(commit_hash) = cache::read_ref(&repo_dir, rev) else {
708 return Ok(Vec::new());
709 };
710
711 let snapshot_dir = cache_layout::snapshot_dir(&repo_dir, &commit_hash);
712 if !snapshot_dir.exists() {
713 return Ok(Vec::new());
714 }
715
716 let mut results = Vec::new();
717 collect_safetensors_recursive(&snapshot_dir, "", &mut results)?;
718 results.sort_by(|a, b| a.0.cmp(&b.0));
719
720 Ok(results)
721}
722
723fn collect_safetensors_recursive(
725 dir: &Path,
726 prefix: &str,
727 results: &mut Vec<(String, SafetensorsHeaderInfo)>,
728) -> Result<(), FetchError> {
729 let entries = std::fs::read_dir(dir).map_err(|e| FetchError::Io {
730 path: dir.to_path_buf(),
731 source: e,
732 })?;
733
734 for entry in entries {
735 let Ok(entry) = entry else { continue };
736 let path = entry.path();
737 let name = entry.file_name().to_string_lossy().to_string();
739
740 if path.is_dir() {
741 let child_prefix = if prefix.is_empty() {
742 name
743 } else {
744 format!("{prefix}/{name}")
745 };
746 collect_safetensors_recursive(&path, &child_prefix, results)?;
747 } else if name.ends_with(".safetensors") {
748 let filename = if prefix.is_empty() {
749 name
750 } else {
751 format!("{prefix}/{name}")
752 };
753 let info = inspect_safetensors_local(&path)?;
754 results.push((filename, info));
755 }
756 }
757
758 Ok(())
759}
760
761#[derive(serde::Deserialize)]
767struct RawShardIndex {
768 weight_map: HashMap<String, String>,
769 #[serde(default)]
770 metadata: Option<HashMap<String, serde_json::Value>>,
771}
772
773pub async fn fetch_shard_index(
783 repo_id: &str,
784 token: Option<&str>,
785 revision: Option<&str>,
786) -> Result<Option<ShardedIndex>, FetchError> {
787 let rev = revision.unwrap_or("main");
788 let index_filename = "model.safetensors.index.json";
789
790 if let Some(cached_path) = resolve_cached_path(repo_id, rev, index_filename) {
792 let content = std::fs::read_to_string(&cached_path).map_err(|e| FetchError::Io {
793 path: cached_path,
794 source: e,
795 })?;
796 let index = parse_shard_index_json(&content, repo_id)?;
797 return Ok(Some(index));
798 }
799
800 let client = chunked::build_client(token)?;
802 let url = chunked::build_download_url(repo_id, rev, index_filename);
803
804 let response =
806 client.get(url.as_str()).send().await.map_err(|e| {
807 FetchError::Http(format!("failed to fetch shard index for {repo_id}: {e}"))
808 })?;
809
810 if response.status() == reqwest::StatusCode::NOT_FOUND {
811 return Ok(None);
812 }
813
814 if !response.status().is_success() {
815 return Err(FetchError::Http(format!(
816 "shard index request for {repo_id} returned status {}",
817 response.status()
818 )));
819 }
820
821 let content = response
822 .text()
823 .await
824 .map_err(|e| FetchError::Http(format!("failed to read shard index for {repo_id}: {e}")))?;
825
826 let index = parse_shard_index_json(&content, repo_id)?;
827 Ok(Some(index))
828}
829
830pub fn fetch_shard_index_cached(
839 repo_id: &str,
840 revision: Option<&str>,
841) -> Result<Option<ShardedIndex>, FetchError> {
842 let rev = revision.unwrap_or("main");
843 let index_filename = "model.safetensors.index.json";
844
845 let Some(cached_path) = resolve_cached_path(repo_id, rev, index_filename) else {
846 return Ok(None);
847 };
848
849 let content = std::fs::read_to_string(&cached_path).map_err(|e| FetchError::Io {
850 path: cached_path,
851 source: e,
852 })?;
853
854 let index = parse_shard_index_json(&content, repo_id)?;
855 Ok(Some(index))
856}
857
858fn parse_shard_index_json(content: &str, repo_id: &str) -> Result<ShardedIndex, FetchError> {
860 let raw: RawShardIndex =
861 serde_json::from_str(content).map_err(|e| FetchError::SafetensorsHeader {
862 filename: "model.safetensors.index.json".to_owned(),
863 reason: format!("failed to parse shard index for {repo_id}: {e}"),
864 })?;
865
866 let mut shard_set: Vec<String> = raw.weight_map.values().cloned().collect();
868 shard_set.sort();
869 shard_set.dedup();
870
871 Ok(ShardedIndex {
872 weight_map: raw.weight_map,
873 shards: shard_set,
874 metadata: raw.metadata,
875 })
876}
877
878#[must_use]
884pub fn format_params(count: u64) -> String {
885 #[allow(clippy::cast_precision_loss, clippy::as_conversions)]
887 let val = count as f64;
888
889 if count >= 1_000_000_000 {
890 format!("{:.2}B", val / 1_000_000_000.0)
891 } else if count >= 1_000_000 {
892 format!("{:.1}M", val / 1_000_000.0)
893 } else if count >= 1_000 {
894 format!("{:.1}K", val / 1_000.0)
895 } else {
896 count.to_string()
897 }
898}
899
900#[derive(serde::Deserialize)]
906struct RawAdapterConfig {
907 #[serde(default)]
908 peft_type: Option<String>,
909 #[serde(default)]
910 base_model_name_or_path: Option<String>,
911 #[serde(default)]
912 r: Option<u32>,
913 #[serde(default)]
914 lora_alpha: Option<f64>,
915 #[serde(default)]
916 target_modules: Option<AdapterTargetModules>,
917 #[serde(default)]
918 task_type: Option<String>,
919}
920
921#[derive(serde::Deserialize)]
923#[serde(untagged)]
924enum AdapterTargetModules {
925 List(Vec<String>),
927 Single(String),
929}
930
931pub async fn fetch_adapter_config(
941 repo_id: &str,
942 token: Option<&str>,
943 revision: Option<&str>,
944) -> Result<Option<AdapterConfig>, FetchError> {
945 let rev = revision.unwrap_or("main");
946 let config_filename = "adapter_config.json";
947
948 if let Some(cached_path) = resolve_cached_path(repo_id, rev, config_filename) {
950 let content = std::fs::read_to_string(&cached_path).map_err(|e| FetchError::Io {
951 path: cached_path,
952 source: e,
953 })?;
954 let config = parse_adapter_config_json(&content, repo_id)?;
955 return Ok(Some(config));
956 }
957
958 let client = chunked::build_client(token)?;
960 let url = chunked::build_download_url(repo_id, rev, config_filename);
961
962 let response = client.get(url.as_str()).send().await.map_err(|e| {
964 FetchError::Http(format!("failed to fetch adapter config for {repo_id}: {e}"))
965 })?;
966
967 if response.status() == reqwest::StatusCode::NOT_FOUND {
968 return Ok(None);
969 }
970
971 if !response.status().is_success() {
972 return Err(FetchError::Http(format!(
973 "adapter config request for {repo_id} returned status {}",
974 response.status()
975 )));
976 }
977
978 let content = response.text().await.map_err(|e| {
979 FetchError::Http(format!("failed to read adapter config for {repo_id}: {e}"))
980 })?;
981
982 let config = parse_adapter_config_json(&content, repo_id)?;
983 Ok(Some(config))
984}
985
986pub fn fetch_adapter_config_cached(
995 repo_id: &str,
996 revision: Option<&str>,
997) -> Result<Option<AdapterConfig>, FetchError> {
998 let rev = revision.unwrap_or("main");
999 let config_filename = "adapter_config.json";
1000
1001 let Some(cached_path) = resolve_cached_path(repo_id, rev, config_filename) else {
1002 return Ok(None);
1003 };
1004
1005 let content = std::fs::read_to_string(&cached_path).map_err(|e| FetchError::Io {
1006 path: cached_path,
1007 source: e,
1008 })?;
1009
1010 let config = parse_adapter_config_json(&content, repo_id)?;
1011 Ok(Some(config))
1012}
1013
1014fn parse_adapter_config_json(content: &str, repo_id: &str) -> Result<AdapterConfig, FetchError> {
1016 let raw: RawAdapterConfig =
1017 serde_json::from_str(content).map_err(|e| FetchError::SafetensorsHeader {
1018 filename: "adapter_config.json".to_owned(),
1019 reason: format!("failed to parse adapter config for {repo_id}: {e}"),
1020 })?;
1021
1022 let target_modules = match raw.target_modules {
1023 Some(AdapterTargetModules::List(v)) => v,
1024 Some(AdapterTargetModules::Single(s)) => vec![s],
1025 None => Vec::new(),
1026 };
1027
1028 Ok(AdapterConfig {
1029 peft_type: raw.peft_type,
1030 base_model_name_or_path: raw.base_model_name_or_path,
1031 r: raw.r,
1032 lora_alpha: raw.lora_alpha,
1033 target_modules,
1034 task_type: raw.task_type,
1035 })
1036}