1use std::collections::HashMap;
14use std::path::{Path, PathBuf};
15
16use serde::Serialize;
17use tokio::task::JoinSet;
18
19use crate::cache;
20use crate::cache_layout;
21use crate::chunked;
22use crate::error::FetchError;
23
24#[derive(Debug, Clone, Serialize)]
33pub struct TensorInfo {
34 pub name: String,
36 pub dtype: String,
38 pub shape: Vec<usize>,
40 pub data_offsets: (u64, u64),
42}
43
44impl TensorInfo {
45 #[must_use]
49 pub fn num_elements(&self) -> u64 {
50 self.shape.iter().fold(1u64, |acc, &d| {
51 #[allow(clippy::as_conversions)]
53 let dim = d as u64;
54 acc.saturating_mul(dim)
55 })
56 }
57
58 #[must_use]
60 pub const fn byte_len(&self) -> u64 {
61 self.data_offsets.1.saturating_sub(self.data_offsets.0)
62 }
63
64 #[must_use]
77 pub fn dtype_bytes(&self) -> Option<usize> {
78 match self.dtype.as_str() {
80 "BOOL" | "U8" | "I8" | "F8_E4M3" | "F8_E5M2" => Some(1),
81 "U16" | "I16" | "F16" | "BF16" => Some(2),
82 "U32" | "I32" | "F32" => Some(4),
83 "U64" | "I64" | "F64" => Some(8),
84 _ => None,
85 }
86 }
87}
88
89#[derive(Debug, Clone, Serialize)]
91pub struct SafetensorsHeaderInfo {
92 pub tensors: Vec<TensorInfo>,
94 pub metadata: Option<HashMap<String, String>>,
100 pub header_size: u64,
102 pub file_size: Option<u64>,
109}
110
111impl SafetensorsHeaderInfo {
112 #[must_use]
114 pub fn total_params(&self) -> u64 {
115 self.tensors
116 .iter()
117 .map(TensorInfo::num_elements)
118 .fold(0u64, u64::saturating_add)
119 }
120
121 #[must_use]
123 pub fn tensors_with_dtype(&self, dtype: &str) -> Vec<&TensorInfo> {
124 self.tensors
125 .iter()
126 .filter(|t| t.dtype.as_str() == dtype)
128 .collect()
129 }
130}
131
132#[derive(Debug, Clone, Copy, PartialEq, Eq)]
134#[non_exhaustive]
135pub enum InspectSource {
136 Cached,
138 Remote,
140}
141
142#[derive(Debug, Clone, Serialize)]
144pub struct ShardedIndex {
145 pub weight_map: HashMap<String, String>,
147 pub shards: Vec<String>,
149 pub metadata: Option<HashMap<String, serde_json::Value>>,
151}
152
153#[derive(Debug, Clone, Serialize)]
159pub struct AdapterConfig {
160 pub peft_type: Option<String>,
162 pub base_model_name_or_path: Option<String>,
164 pub r: Option<u32>,
166 pub lora_alpha: Option<f64>,
168 pub target_modules: Vec<String>,
170 pub task_type: Option<String>,
172}
173
174#[derive(serde::Deserialize)]
180struct RawTensorEntry {
181 dtype: String,
182 shape: Vec<usize>,
183 data_offsets: (u64, u64),
184}
185
186type ParsedHeader = (Vec<TensorInfo>, Option<HashMap<String, String>>);
188
189fn parse_header_json(json_bytes: &[u8], filename: &str) -> Result<ParsedHeader, FetchError> {
193 let raw: HashMap<String, serde_json::Value> =
194 serde_json::from_slice(json_bytes).map_err(|e| FetchError::SafetensorsHeader {
195 filename: filename.to_owned(),
196 reason: format!("failed to parse header JSON: {e}"),
197 })?;
198
199 let mut metadata: Option<HashMap<String, String>> = None;
200 let mut tensors = Vec::new();
201
202 for (key, value) in raw {
203 if key == "__metadata__" {
204 if let serde_json::Value::Object(obj) = value {
205 let mut meta_map = HashMap::new();
206 for (mk, mv) in obj {
207 let v_str = if let Some(s) = mv.as_str() {
209 s.to_owned()
210 } else {
211 mv.to_string()
212 };
213 meta_map.insert(mk, v_str);
214 }
215 metadata = Some(meta_map);
216 }
217 continue;
218 }
219
220 let entry: RawTensorEntry =
221 serde_json::from_value(value).map_err(|e| FetchError::SafetensorsHeader {
222 filename: filename.to_owned(),
223 reason: format!("failed to parse tensor \"{key}\": {e}"),
224 })?;
225
226 tensors.push(TensorInfo {
227 name: key,
228 dtype: entry.dtype,
229 shape: entry.shape,
230 data_offsets: entry.data_offsets,
231 });
232 }
233
234 tensors.sort_by_key(|t| t.data_offsets.0);
236
237 Ok((tensors, metadata))
238}
239
240fn resolve_cached_path(repo_id: &str, revision: &str, filename: &str) -> Option<PathBuf> {
248 let cache_dir = cache::hf_cache_dir().ok()?;
249 let repo_dir = cache_layout::repo_dir(&cache_dir, repo_id);
250 let commit_hash = cache::read_ref(&repo_dir, revision)?;
251 let cached_path = cache_layout::pointer_path(&repo_dir, &commit_hash, filename);
252 if cached_path.exists() {
253 Some(cached_path)
254 } else {
255 None
256 }
257}
258
259pub fn inspect_safetensors_local(path: &Path) -> Result<SafetensorsHeaderInfo, FetchError> {
279 use std::io::Read;
280
281 let file_size = std::fs::metadata(path)
282 .map_err(|e| FetchError::Io {
283 path: path.to_path_buf(),
284 source: e,
285 })?
286 .len();
287
288 let filename = path.file_name().map_or_else(
290 || path.display().to_string(),
291 |n| n.to_string_lossy().to_string(),
292 );
293
294 let mut file = std::fs::File::open(path).map_err(|e| FetchError::Io {
295 path: path.to_path_buf(),
296 source: e,
297 })?;
298
299 let mut len_buf = [0u8; 8];
301 file.read_exact(&mut len_buf).map_err(|e| FetchError::Io {
302 path: path.to_path_buf(),
303 source: e,
304 })?;
305 let header_size = u64::from_le_bytes(len_buf);
306
307 if header_size.saturating_add(8) > file_size {
309 return Err(FetchError::SafetensorsHeader {
310 filename,
311 reason: format!("header length {header_size} exceeds file size {file_size}"),
312 });
313 }
314
315 #[allow(clippy::cast_possible_truncation, clippy::as_conversions)]
318 let json_len = header_size as usize;
319 let mut json_buf = vec![0u8; json_len];
320 file.read_exact(&mut json_buf).map_err(|e| FetchError::Io {
321 path: path.to_path_buf(),
322 source: e,
323 })?;
324
325 let (tensors, metadata) = parse_header_json(&json_buf, filename.as_str())?;
327
328 Ok(SafetensorsHeaderInfo {
329 tensors,
330 metadata,
331 header_size,
332 file_size: Some(file_size),
333 })
334}
335
336async fn fetch_header_bytes(
348 client: &reqwest::Client,
349 url: &str,
350 filename: &str,
351) -> Result<(Vec<u8>, Option<u64>), FetchError> {
352 let resp1 = client
354 .get(url)
355 .header(reqwest::header::RANGE, "bytes=0-7")
356 .send()
357 .await
358 .map_err(|e| {
359 FetchError::Http(format!("failed to fetch header length for {filename}: {e}"))
360 })?;
361
362 if !resp1.status().is_success() && resp1.status() != reqwest::StatusCode::PARTIAL_CONTENT {
363 return Err(FetchError::Http(format!(
364 "Range request for {filename} returned status {}",
365 resp1.status()
366 )));
367 }
368
369 let file_size = resp1
371 .headers()
372 .get(reqwest::header::CONTENT_RANGE)
373 .and_then(|v| v.to_str().ok())
374 .and_then(|s| s.split('/').next_back())
375 .and_then(|s| s.parse::<u64>().ok());
376
377 let len_bytes = resp1.bytes().await.map_err(|e| {
378 FetchError::Http(format!("failed to read header length for {filename}: {e}"))
379 })?;
380
381 if len_bytes.len() < 8 {
382 return Err(FetchError::SafetensorsHeader {
383 filename: filename.to_owned(),
384 reason: format!(
385 "expected 8 bytes for length prefix, got {}",
386 len_bytes.len()
387 ),
388 });
389 }
390
391 #[allow(clippy::indexing_slicing)]
393 let header_size = u64::from_le_bytes([
394 len_bytes[0],
395 len_bytes[1],
396 len_bytes[2],
397 len_bytes[3],
398 len_bytes[4],
399 len_bytes[5],
400 len_bytes[6],
401 len_bytes[7],
402 ]);
403
404 let range_end = 8u64.saturating_add(header_size).saturating_sub(1);
406 let range_header = format!("bytes=8-{range_end}");
407 let resp2 = client
408 .get(url)
409 .header(reqwest::header::RANGE, range_header.as_str())
411 .send()
412 .await
413 .map_err(|e| {
414 FetchError::Http(format!("failed to fetch header JSON for {filename}: {e}"))
415 })?;
416
417 if !resp2.status().is_success() && resp2.status() != reqwest::StatusCode::PARTIAL_CONTENT {
418 return Err(FetchError::Http(format!(
419 "Range request for {filename} header JSON returned status {}",
420 resp2.status()
421 )));
422 }
423
424 let json_bytes = resp2
425 .bytes()
426 .await
427 .map_err(|e| FetchError::Http(format!("failed to read header JSON for {filename}: {e}")))?;
428
429 Ok((json_bytes.to_vec(), file_size))
430}
431
432pub async fn inspect_safetensors(
448 repo_id: &str,
449 filename: &str,
450 token: Option<&str>,
451 revision: Option<&str>,
452) -> Result<(SafetensorsHeaderInfo, InspectSource), FetchError> {
453 let rev = revision.unwrap_or("main");
454
455 if let Some(cached_path) = resolve_cached_path(repo_id, rev, filename) {
457 let info = inspect_safetensors_local(&cached_path)?;
458 return Ok((info, InspectSource::Cached));
459 }
460
461 let client = chunked::build_client(token)?;
463 let url = chunked::build_download_url(repo_id, rev, filename);
464
465 let (json_bytes, file_size) = fetch_header_bytes(&client, url.as_str(), filename).await?;
467
468 #[allow(clippy::as_conversions)]
470 let header_size = json_bytes.len() as u64;
471
472 let (tensors, metadata) = parse_header_json(&json_bytes, filename)?;
473
474 Ok((
475 SafetensorsHeaderInfo {
476 tensors,
477 metadata,
478 header_size,
479 file_size,
480 },
481 InspectSource::Remote,
482 ))
483}
484
485pub fn inspect_safetensors_cached(
501 repo_id: &str,
502 filename: &str,
503 revision: Option<&str>,
504) -> Result<SafetensorsHeaderInfo, FetchError> {
505 let rev = revision.unwrap_or("main");
506
507 let cached_path = resolve_cached_path(repo_id, rev, filename).ok_or_else(|| {
508 FetchError::SafetensorsHeader {
509 filename: filename.to_owned(),
510 reason: format!("file not found in local cache for {repo_id} ({rev})"),
511 }
512 })?;
513
514 inspect_safetensors_local(&cached_path)
515}
516
517pub fn inspect_gguf_cached(
551 repo_id: &str,
552 filename: &str,
553 revision: Option<&str>,
554) -> Result<SafetensorsHeaderInfo, FetchError> {
555 let rev = revision.unwrap_or("main");
556
557 let cached_path = resolve_cached_path(repo_id, rev, filename).ok_or_else(|| {
558 FetchError::SafetensorsHeader {
559 filename: filename.to_owned(),
561 reason: format!("file not found in local cache for {repo_id} ({rev})"),
562 }
563 })?;
564
565 let file_size = std::fs::metadata(&cached_path).ok().map(|m| m.len());
566
567 let parsed =
568 anamnesis::parse_gguf(&cached_path).map_err(|e| FetchError::SafetensorsHeader {
569 filename: filename.to_owned(),
571 reason: format!("failed to parse GGUF: {e}"),
572 })?;
573
574 let tensors: Vec<TensorInfo> = parsed
575 .tensor_info()
576 .iter()
577 .map(|info| {
578 let start = info.data_offset;
579 let end = info.byte_len.map_or(start, |b| start.saturating_add(b));
580 TensorInfo {
581 name: info.name.clone(),
584 dtype: info.dtype.to_string(),
585 shape: info.shape.clone(),
586 data_offsets: (start, end),
587 }
588 })
589 .collect();
590
591 let mut metadata: HashMap<String, String> = parsed
595 .metadata()
596 .iter()
597 .filter_map(|(k, v)| stringify_gguf_scalar(v).map(|s| (k.clone(), s)))
600 .collect();
601 metadata.insert("gguf.version".to_owned(), parsed.version().to_string());
603 metadata.insert("gguf.alignment".to_owned(), parsed.alignment().to_string());
604
605 Ok(SafetensorsHeaderInfo {
606 tensors,
607 metadata: Some(metadata),
608 header_size: 0,
614 file_size,
615 })
616}
617
618#[allow(clippy::match_same_arms)]
631fn stringify_gguf_scalar(value: &anamnesis::parse::gguf::GgufMetadataValue) -> Option<String> {
632 use anamnesis::parse::gguf::GgufMetadataValue as V;
633 match value {
634 V::String(s) => Some(s.clone()),
635 V::Bool(b) => Some(b.to_string()),
636 V::U8(n) => Some(n.to_string()),
637 V::I8(n) => Some(n.to_string()),
638 V::U16(n) => Some(n.to_string()),
639 V::I16(n) => Some(n.to_string()),
640 V::U32(n) => Some(n.to_string()),
641 V::I32(n) => Some(n.to_string()),
642 V::U64(n) => Some(n.to_string()),
643 V::I64(n) => Some(n.to_string()),
644 V::F32(n) => Some(format!("{n}")),
645 V::F64(n) => Some(format!("{n}")),
646 V::Array(_) => None,
647 _ => None,
648 }
649}
650
651pub async fn inspect_repo_safetensors(
669 repo_id: &str,
670 token: Option<&str>,
671 revision: Option<&str>,
672) -> Result<Vec<(String, SafetensorsHeaderInfo, InspectSource)>, FetchError> {
673 let client = crate::chunked::build_client(token)?;
674 let files =
675 crate::repo::list_repo_files_with_metadata(repo_id, token, revision, &client).await?;
676
677 let safetensors_files: Vec<String> = files
678 .into_iter()
679 .filter(|f| f.filename.ends_with(".safetensors"))
680 .map(|f| f.filename)
681 .collect();
682
683 if safetensors_files.is_empty() {
684 return Ok(Vec::new());
685 }
686
687 let semaphore = std::sync::Arc::new(tokio::sync::Semaphore::new(4));
688 let mut join_set = JoinSet::new();
689
690 for filename in safetensors_files {
691 let sem = semaphore.clone();
693 let repo = repo_id.to_owned();
694 let tok = token.map(str::to_owned);
695 let rev = revision.map(str::to_owned);
696
697 join_set.spawn(async move {
698 let _permit = sem
699 .acquire()
700 .await
701 .map_err(|e| FetchError::Http(format!("semaphore error: {e}")))?;
702 let (info, source) =
704 inspect_safetensors(&repo, &filename, tok.as_deref(), rev.as_deref()).await?;
705 Ok::<_, FetchError>((filename, info, source))
706 });
707 }
708
709 let mut results = Vec::new();
710 while let Some(join_result) = join_set.join_next().await {
711 match join_result {
712 Ok(Ok(item)) => results.push(item),
713 Ok(Err(e)) => {
714 join_set.abort_all();
715 return Err(e);
716 }
717 Err(e) => {
718 join_set.abort_all();
719 return Err(FetchError::Http(format!("task join error: {e}")));
720 }
721 }
722 }
723
724 results.sort_by(|a, b| a.0.cmp(&b.0));
725
726 Ok(results)
727}
728
729pub type SafetensorsListing = (Vec<(String, u64)>, Option<String>);
737
738pub fn list_cached_safetensors(
758 repo_id: &str,
759 revision: Option<&str>,
760) -> Result<SafetensorsListing, FetchError> {
761 let rev = revision.unwrap_or("main");
762 let cache_dir = cache::hf_cache_dir()?;
763 let repo_dir = cache_layout::repo_dir(&cache_dir, repo_id);
764
765 let Some(commit_hash) = cache::read_ref(&repo_dir, rev) else {
766 return Ok((Vec::new(), None));
767 };
768
769 let snapshot_dir = cache_layout::snapshot_dir(&repo_dir, &commit_hash);
770 if !snapshot_dir.exists() {
771 return Ok((Vec::new(), Some(commit_hash)));
772 }
773
774 let mut results = Vec::new();
775 collect_safetensors_names_sizes(&snapshot_dir, "", &mut results)?;
776 results.sort_by(|a, b| a.0.cmp(&b.0));
777 Ok((results, Some(commit_hash)))
778}
779
780fn collect_safetensors_names_sizes(
782 dir: &Path,
783 prefix: &str,
784 results: &mut Vec<(String, u64)>,
785) -> Result<(), FetchError> {
786 let entries = std::fs::read_dir(dir).map_err(|e| FetchError::Io {
787 path: dir.to_path_buf(),
788 source: e,
789 })?;
790
791 for entry in entries {
792 let Ok(entry) = entry else { continue };
793 let path = entry.path();
794 let name = entry.file_name().to_string_lossy().to_string();
796
797 if path.is_dir() {
798 let child_prefix = if prefix.is_empty() {
799 name
800 } else {
801 format!("{prefix}/{name}")
802 };
803 collect_safetensors_names_sizes(&path, &child_prefix, results)?;
804 } else if name.ends_with(".safetensors") {
805 let filename = if prefix.is_empty() {
806 name
807 } else {
808 format!("{prefix}/{name}")
809 };
810 let size = entry.metadata().map_or(0, |m| m.len());
811 results.push((filename, size));
812 }
813 }
814
815 Ok(())
816}
817
818pub fn inspect_repo_safetensors_cached(
834 repo_id: &str,
835 revision: Option<&str>,
836) -> Result<Vec<(String, SafetensorsHeaderInfo)>, FetchError> {
837 let rev = revision.unwrap_or("main");
838 let cache_dir = cache::hf_cache_dir()?;
839 let repo_dir = cache_layout::repo_dir(&cache_dir, repo_id);
840
841 let Some(commit_hash) = cache::read_ref(&repo_dir, rev) else {
842 return Ok(Vec::new());
843 };
844
845 let snapshot_dir = cache_layout::snapshot_dir(&repo_dir, &commit_hash);
846 if !snapshot_dir.exists() {
847 return Ok(Vec::new());
848 }
849
850 let mut results = Vec::new();
851 collect_safetensors_recursive(&snapshot_dir, "", &mut results)?;
852 results.sort_by(|a, b| a.0.cmp(&b.0));
853
854 Ok(results)
855}
856
857fn collect_safetensors_recursive(
859 dir: &Path,
860 prefix: &str,
861 results: &mut Vec<(String, SafetensorsHeaderInfo)>,
862) -> Result<(), FetchError> {
863 let entries = std::fs::read_dir(dir).map_err(|e| FetchError::Io {
864 path: dir.to_path_buf(),
865 source: e,
866 })?;
867
868 for entry in entries {
869 let Ok(entry) = entry else { continue };
870 let path = entry.path();
871 let name = entry.file_name().to_string_lossy().to_string();
873
874 if path.is_dir() {
875 let child_prefix = if prefix.is_empty() {
876 name
877 } else {
878 format!("{prefix}/{name}")
879 };
880 collect_safetensors_recursive(&path, &child_prefix, results)?;
881 } else if name.ends_with(".safetensors") {
882 let filename = if prefix.is_empty() {
883 name
884 } else {
885 format!("{prefix}/{name}")
886 };
887 let info = inspect_safetensors_local(&path)?;
888 results.push((filename, info));
889 }
890 }
891
892 Ok(())
893}
894
895#[derive(serde::Deserialize)]
901struct RawShardIndex {
902 weight_map: HashMap<String, String>,
903 #[serde(default)]
904 metadata: Option<HashMap<String, serde_json::Value>>,
905}
906
907pub async fn fetch_shard_index(
917 repo_id: &str,
918 token: Option<&str>,
919 revision: Option<&str>,
920) -> Result<Option<ShardedIndex>, FetchError> {
921 let rev = revision.unwrap_or("main");
922 let index_filename = "model.safetensors.index.json";
923
924 if let Some(cached_path) = resolve_cached_path(repo_id, rev, index_filename) {
926 let content = std::fs::read_to_string(&cached_path).map_err(|e| FetchError::Io {
927 path: cached_path,
928 source: e,
929 })?;
930 let index = parse_shard_index_json(&content, repo_id)?;
931 return Ok(Some(index));
932 }
933
934 let client = chunked::build_client(token)?;
936 let url = chunked::build_download_url(repo_id, rev, index_filename);
937
938 let response =
940 client.get(url.as_str()).send().await.map_err(|e| {
941 FetchError::Http(format!("failed to fetch shard index for {repo_id}: {e}"))
942 })?;
943
944 if response.status() == reqwest::StatusCode::NOT_FOUND {
945 return Ok(None);
946 }
947
948 if !response.status().is_success() {
949 return Err(FetchError::Http(format!(
950 "shard index request for {repo_id} returned status {}",
951 response.status()
952 )));
953 }
954
955 let content = response
956 .text()
957 .await
958 .map_err(|e| FetchError::Http(format!("failed to read shard index for {repo_id}: {e}")))?;
959
960 let index = parse_shard_index_json(&content, repo_id)?;
961 Ok(Some(index))
962}
963
964pub fn fetch_shard_index_cached(
973 repo_id: &str,
974 revision: Option<&str>,
975) -> Result<Option<ShardedIndex>, FetchError> {
976 let rev = revision.unwrap_or("main");
977 let index_filename = "model.safetensors.index.json";
978
979 let Some(cached_path) = resolve_cached_path(repo_id, rev, index_filename) else {
980 return Ok(None);
981 };
982
983 let content = std::fs::read_to_string(&cached_path).map_err(|e| FetchError::Io {
984 path: cached_path,
985 source: e,
986 })?;
987
988 let index = parse_shard_index_json(&content, repo_id)?;
989 Ok(Some(index))
990}
991
992fn parse_shard_index_json(content: &str, repo_id: &str) -> Result<ShardedIndex, FetchError> {
994 let raw: RawShardIndex =
995 serde_json::from_str(content).map_err(|e| FetchError::SafetensorsHeader {
996 filename: "model.safetensors.index.json".to_owned(),
997 reason: format!("failed to parse shard index for {repo_id}: {e}"),
998 })?;
999
1000 let mut shard_set: Vec<String> = raw.weight_map.values().cloned().collect();
1002 shard_set.sort();
1003 shard_set.dedup();
1004
1005 Ok(ShardedIndex {
1006 weight_map: raw.weight_map,
1007 shards: shard_set,
1008 metadata: raw.metadata,
1009 })
1010}
1011
1012#[must_use]
1018pub fn format_params(count: u64) -> String {
1019 #[allow(clippy::cast_precision_loss, clippy::as_conversions)]
1021 let val = count as f64;
1022
1023 if count >= 1_000_000_000 {
1024 format!("{:.2}B", val / 1_000_000_000.0)
1025 } else if count >= 1_000_000 {
1026 format!("{:.1}M", val / 1_000_000.0)
1027 } else if count >= 1_000 {
1028 format!("{:.1}K", val / 1_000.0)
1029 } else {
1030 count.to_string()
1031 }
1032}
1033
1034#[derive(serde::Deserialize)]
1040struct RawAdapterConfig {
1041 #[serde(default)]
1042 peft_type: Option<String>,
1043 #[serde(default)]
1044 base_model_name_or_path: Option<String>,
1045 #[serde(default)]
1046 r: Option<u32>,
1047 #[serde(default)]
1048 lora_alpha: Option<f64>,
1049 #[serde(default)]
1050 target_modules: Option<AdapterTargetModules>,
1051 #[serde(default)]
1052 task_type: Option<String>,
1053}
1054
1055#[derive(serde::Deserialize)]
1057#[serde(untagged)]
1058enum AdapterTargetModules {
1059 List(Vec<String>),
1061 Single(String),
1063}
1064
1065pub async fn fetch_adapter_config(
1075 repo_id: &str,
1076 token: Option<&str>,
1077 revision: Option<&str>,
1078) -> Result<Option<AdapterConfig>, FetchError> {
1079 let rev = revision.unwrap_or("main");
1080 let config_filename = "adapter_config.json";
1081
1082 if let Some(cached_path) = resolve_cached_path(repo_id, rev, config_filename) {
1084 let content = std::fs::read_to_string(&cached_path).map_err(|e| FetchError::Io {
1085 path: cached_path,
1086 source: e,
1087 })?;
1088 let config = parse_adapter_config_json(&content, repo_id)?;
1089 return Ok(Some(config));
1090 }
1091
1092 let client = chunked::build_client(token)?;
1094 let url = chunked::build_download_url(repo_id, rev, config_filename);
1095
1096 let response = client.get(url.as_str()).send().await.map_err(|e| {
1098 FetchError::Http(format!("failed to fetch adapter config for {repo_id}: {e}"))
1099 })?;
1100
1101 if response.status() == reqwest::StatusCode::NOT_FOUND {
1102 return Ok(None);
1103 }
1104
1105 if !response.status().is_success() {
1106 return Err(FetchError::Http(format!(
1107 "adapter config request for {repo_id} returned status {}",
1108 response.status()
1109 )));
1110 }
1111
1112 let content = response.text().await.map_err(|e| {
1113 FetchError::Http(format!("failed to read adapter config for {repo_id}: {e}"))
1114 })?;
1115
1116 let config = parse_adapter_config_json(&content, repo_id)?;
1117 Ok(Some(config))
1118}
1119
1120pub fn fetch_adapter_config_cached(
1129 repo_id: &str,
1130 revision: Option<&str>,
1131) -> Result<Option<AdapterConfig>, FetchError> {
1132 let rev = revision.unwrap_or("main");
1133 let config_filename = "adapter_config.json";
1134
1135 let Some(cached_path) = resolve_cached_path(repo_id, rev, config_filename) else {
1136 return Ok(None);
1137 };
1138
1139 let content = std::fs::read_to_string(&cached_path).map_err(|e| FetchError::Io {
1140 path: cached_path,
1141 source: e,
1142 })?;
1143
1144 let config = parse_adapter_config_json(&content, repo_id)?;
1145 Ok(Some(config))
1146}
1147
1148fn parse_adapter_config_json(content: &str, repo_id: &str) -> Result<AdapterConfig, FetchError> {
1150 let raw: RawAdapterConfig =
1151 serde_json::from_str(content).map_err(|e| FetchError::SafetensorsHeader {
1152 filename: "adapter_config.json".to_owned(),
1153 reason: format!("failed to parse adapter config for {repo_id}: {e}"),
1154 })?;
1155
1156 let target_modules = match raw.target_modules {
1157 Some(AdapterTargetModules::List(v)) => v,
1158 Some(AdapterTargetModules::Single(s)) => vec![s],
1159 None => Vec::new(),
1160 };
1161
1162 Ok(AdapterConfig {
1163 peft_type: raw.peft_type,
1164 base_model_name_or_path: raw.base_model_name_or_path,
1165 r: raw.r,
1166 lora_alpha: raw.lora_alpha,
1167 target_modules,
1168 task_type: raw.task_type,
1169 })
1170}