use std::collections::HashMap;
use std::path::{Path, PathBuf};
use serde::Serialize;
use crate::cache;
use crate::chunked;
use crate::error::FetchError;
#[derive(Debug, Clone, Serialize)]
pub struct TensorInfo {
pub name: String,
pub dtype: String,
pub shape: Vec<usize>,
pub data_offsets: (u64, u64),
}
impl TensorInfo {
#[must_use]
pub fn num_elements(&self) -> u64 {
self.shape.iter().fold(1u64, |acc, &d| {
#[allow(clippy::as_conversions)]
let dim = d as u64;
acc.saturating_mul(dim)
})
}
#[must_use]
pub const fn byte_len(&self) -> u64 {
self.data_offsets.1.saturating_sub(self.data_offsets.0)
}
#[must_use]
pub fn dtype_bytes(&self) -> Option<usize> {
match self.dtype.as_str() {
"BOOL" | "U8" | "I8" | "F8_E4M3" | "F8_E5M2" => Some(1),
"U16" | "I16" | "F16" | "BF16" => Some(2),
"U32" | "I32" | "F32" => Some(4),
"U64" | "I64" | "F64" => Some(8),
_ => None,
}
}
}
#[derive(Debug, Clone, Serialize)]
pub struct SafetensorsHeaderInfo {
pub tensors: Vec<TensorInfo>,
pub metadata: Option<HashMap<String, String>>,
pub header_size: u64,
pub file_size: Option<u64>,
}
impl SafetensorsHeaderInfo {
#[must_use]
pub fn total_params(&self) -> u64 {
self.tensors
.iter()
.map(TensorInfo::num_elements)
.fold(0u64, u64::saturating_add)
}
#[must_use]
pub fn tensors_with_dtype(&self, dtype: &str) -> Vec<&TensorInfo> {
self.tensors
.iter()
.filter(|t| t.dtype.as_str() == dtype)
.collect()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum InspectSource {
Cached,
Remote,
}
#[derive(Debug, Clone, Serialize)]
pub struct ShardedIndex {
pub weight_map: HashMap<String, String>,
pub shards: Vec<String>,
pub metadata: Option<HashMap<String, serde_json::Value>>,
}
#[derive(Debug, Clone, Serialize)]
pub struct AdapterConfig {
pub peft_type: Option<String>,
pub base_model_name_or_path: Option<String>,
pub r: Option<u32>,
pub lora_alpha: Option<f64>,
pub target_modules: Vec<String>,
pub task_type: Option<String>,
}
#[derive(serde::Deserialize)]
struct RawTensorEntry {
dtype: String,
shape: Vec<usize>,
data_offsets: (u64, u64),
}
type ParsedHeader = (Vec<TensorInfo>, Option<HashMap<String, String>>);
fn parse_header_json(json_bytes: &[u8], filename: &str) -> Result<ParsedHeader, FetchError> {
let raw: HashMap<String, serde_json::Value> =
serde_json::from_slice(json_bytes).map_err(|e| FetchError::SafetensorsHeader {
filename: filename.to_owned(),
reason: format!("failed to parse header JSON: {e}"),
})?;
let mut metadata: Option<HashMap<String, String>> = None;
let mut tensors = Vec::new();
for (key, value) in &raw {
if key == "__metadata__" {
if let Some(obj) = value.as_object() {
let mut meta_map = HashMap::new();
for (mk, mv) in obj {
let v_str = if let Some(s) = mv.as_str() {
s.to_owned()
} else {
mv.to_string()
};
meta_map.insert(mk.clone(), v_str);
}
metadata = Some(meta_map);
}
continue;
}
let entry: RawTensorEntry =
serde_json::from_value(value.clone()).map_err(|e| FetchError::SafetensorsHeader {
filename: filename.to_owned(),
reason: format!("failed to parse tensor \"{key}\": {e}"),
})?;
tensors.push(TensorInfo {
name: key.clone(),
dtype: entry.dtype,
shape: entry.shape,
data_offsets: entry.data_offsets,
});
}
tensors.sort_by_key(|t| t.data_offsets.0);
Ok((tensors, metadata))
}
fn resolve_cached_path(repo_id: &str, revision: &str, filename: &str) -> Option<PathBuf> {
let cache_dir = cache::hf_cache_dir().ok()?;
let repo_folder = chunked::repo_folder_name(repo_id);
let repo_dir = cache_dir.join(&repo_folder);
let commit_hash = cache::read_ref(&repo_dir, revision)?;
let cached_path = repo_dir.join("snapshots").join(commit_hash).join(filename);
if cached_path.exists() {
Some(cached_path)
} else {
None
}
}
pub fn inspect_safetensors_local(path: &Path) -> Result<SafetensorsHeaderInfo, FetchError> {
use std::io::Read;
let file_size = std::fs::metadata(path)
.map_err(|e| FetchError::Io {
path: path.to_path_buf(),
source: e,
})?
.len();
let filename = path.file_name().map_or_else(
|| path.display().to_string(),
|n| n.to_string_lossy().to_string(),
);
let mut file = std::fs::File::open(path).map_err(|e| FetchError::Io {
path: path.to_path_buf(),
source: e,
})?;
let mut len_buf = [0u8; 8];
file.read_exact(&mut len_buf).map_err(|e| FetchError::Io {
path: path.to_path_buf(),
source: e,
})?;
let header_size = u64::from_le_bytes(len_buf);
if header_size.saturating_add(8) > file_size {
return Err(FetchError::SafetensorsHeader {
filename,
reason: format!("header length {header_size} exceeds file size {file_size}"),
});
}
#[allow(clippy::cast_possible_truncation, clippy::as_conversions)]
let json_len = header_size as usize;
let mut json_buf = vec![0u8; json_len];
file.read_exact(&mut json_buf).map_err(|e| FetchError::Io {
path: path.to_path_buf(),
source: e,
})?;
let (tensors, metadata) = parse_header_json(&json_buf, filename.as_str())?;
Ok(SafetensorsHeaderInfo {
tensors,
metadata,
header_size,
file_size: Some(file_size),
})
}
async fn fetch_header_bytes(
client: &reqwest::Client,
url: &str,
filename: &str,
) -> Result<(Vec<u8>, Option<u64>), FetchError> {
let resp1 = client
.get(url)
.header(reqwest::header::RANGE, "bytes=0-7")
.send()
.await
.map_err(|e| {
FetchError::Http(format!("failed to fetch header length for {filename}: {e}"))
})?;
if !resp1.status().is_success() && resp1.status() != reqwest::StatusCode::PARTIAL_CONTENT {
return Err(FetchError::Http(format!(
"Range request for {filename} returned status {}",
resp1.status()
)));
}
let file_size = resp1
.headers()
.get(reqwest::header::CONTENT_RANGE)
.and_then(|v| v.to_str().ok())
.and_then(|s| s.split('/').next_back())
.and_then(|s| s.parse::<u64>().ok());
let len_bytes = resp1.bytes().await.map_err(|e| {
FetchError::Http(format!("failed to read header length for {filename}: {e}"))
})?;
if len_bytes.len() < 8 {
return Err(FetchError::SafetensorsHeader {
filename: filename.to_owned(),
reason: format!(
"expected 8 bytes for length prefix, got {}",
len_bytes.len()
),
});
}
#[allow(clippy::indexing_slicing)]
let header_size = u64::from_le_bytes([
len_bytes[0],
len_bytes[1],
len_bytes[2],
len_bytes[3],
len_bytes[4],
len_bytes[5],
len_bytes[6],
len_bytes[7],
]);
let range_end = 8u64.saturating_add(header_size).saturating_sub(1);
let range_header = format!("bytes=8-{range_end}");
let resp2 = client
.get(url)
.header(reqwest::header::RANGE, range_header.as_str())
.send()
.await
.map_err(|e| {
FetchError::Http(format!("failed to fetch header JSON for {filename}: {e}"))
})?;
if !resp2.status().is_success() && resp2.status() != reqwest::StatusCode::PARTIAL_CONTENT {
return Err(FetchError::Http(format!(
"Range request for {filename} header JSON returned status {}",
resp2.status()
)));
}
let json_bytes = resp2
.bytes()
.await
.map_err(|e| FetchError::Http(format!("failed to read header JSON for {filename}: {e}")))?;
Ok((json_bytes.to_vec(), file_size))
}
pub async fn inspect_safetensors(
repo_id: &str,
filename: &str,
token: Option<&str>,
revision: Option<&str>,
) -> Result<(SafetensorsHeaderInfo, InspectSource), FetchError> {
let rev = revision.unwrap_or("main");
if let Some(cached_path) = resolve_cached_path(repo_id, rev, filename) {
let info = inspect_safetensors_local(&cached_path)?;
return Ok((info, InspectSource::Cached));
}
let client = chunked::build_client(token)?;
let url = chunked::build_download_url(repo_id, rev, filename);
let (json_bytes, file_size) = fetch_header_bytes(&client, url.as_str(), filename).await?;
#[allow(clippy::as_conversions)]
let header_size = json_bytes.len() as u64;
let (tensors, metadata) = parse_header_json(&json_bytes, filename)?;
Ok((
SafetensorsHeaderInfo {
tensors,
metadata,
header_size,
file_size,
},
InspectSource::Remote,
))
}
pub fn inspect_safetensors_cached(
repo_id: &str,
filename: &str,
revision: Option<&str>,
) -> Result<SafetensorsHeaderInfo, FetchError> {
let rev = revision.unwrap_or("main");
let cached_path = resolve_cached_path(repo_id, rev, filename).ok_or_else(|| {
FetchError::SafetensorsHeader {
filename: filename.to_owned(),
reason: format!("file not found in local cache for {repo_id} ({rev})"),
}
})?;
inspect_safetensors_local(&cached_path)
}
pub async fn inspect_repo_safetensors(
repo_id: &str,
token: Option<&str>,
revision: Option<&str>,
) -> Result<Vec<(String, SafetensorsHeaderInfo, InspectSource)>, FetchError> {
let files = crate::repo::list_repo_files_with_metadata(repo_id, token, revision).await?;
let safetensors_files: Vec<String> = files
.into_iter()
.filter(|f| f.filename.ends_with(".safetensors"))
.map(|f| f.filename)
.collect();
if safetensors_files.is_empty() {
return Ok(Vec::new());
}
let semaphore = std::sync::Arc::new(tokio::sync::Semaphore::new(4));
let mut handles = Vec::new();
for filename in safetensors_files {
let sem = semaphore.clone();
let repo = repo_id.to_owned();
let tok = token.map(str::to_owned);
let rev = revision.map(str::to_owned);
handles.push(tokio::spawn(async move {
let _permit = sem
.acquire()
.await
.map_err(|e| FetchError::Http(format!("semaphore error: {e}")))?;
let (info, source) =
inspect_safetensors(&repo, &filename, tok.as_deref(), rev.as_deref()).await?;
Ok::<_, FetchError>((filename, info, source))
}));
}
let mut results = Vec::new();
for handle in handles {
let result = handle
.await
.map_err(|e| FetchError::Http(format!("task join error: {e}")))?;
results.push(result?);
}
results.sort_by(|a, b| a.0.cmp(&b.0));
Ok(results)
}
pub fn inspect_repo_safetensors_cached(
repo_id: &str,
revision: Option<&str>,
) -> Result<Vec<(String, SafetensorsHeaderInfo)>, FetchError> {
let rev = revision.unwrap_or("main");
let cache_dir = cache::hf_cache_dir()?;
let repo_folder = chunked::repo_folder_name(repo_id);
let repo_dir = cache_dir.join(&repo_folder);
let Some(commit_hash) = cache::read_ref(&repo_dir, rev) else {
return Ok(Vec::new());
};
let snapshot_dir = repo_dir.join("snapshots").join(commit_hash);
if !snapshot_dir.exists() {
return Ok(Vec::new());
}
let mut results = Vec::new();
collect_safetensors_recursive(&snapshot_dir, "", &mut results)?;
results.sort_by(|a, b| a.0.cmp(&b.0));
Ok(results)
}
fn collect_safetensors_recursive(
dir: &Path,
prefix: &str,
results: &mut Vec<(String, SafetensorsHeaderInfo)>,
) -> Result<(), FetchError> {
let entries = std::fs::read_dir(dir).map_err(|e| FetchError::Io {
path: dir.to_path_buf(),
source: e,
})?;
for entry in entries {
let Ok(entry) = entry else { continue };
let path = entry.path();
let name = entry.file_name().to_string_lossy().to_string();
if path.is_dir() {
let child_prefix = if prefix.is_empty() {
name
} else {
format!("{prefix}/{name}")
};
collect_safetensors_recursive(&path, &child_prefix, results)?;
} else if name.ends_with(".safetensors") {
let filename = if prefix.is_empty() {
name
} else {
format!("{prefix}/{name}")
};
let info = inspect_safetensors_local(&path)?;
results.push((filename, info));
}
}
Ok(())
}
#[derive(serde::Deserialize)]
struct RawShardIndex {
weight_map: HashMap<String, String>,
#[serde(default)]
metadata: Option<HashMap<String, serde_json::Value>>,
}
pub async fn fetch_shard_index(
repo_id: &str,
token: Option<&str>,
revision: Option<&str>,
) -> Result<Option<ShardedIndex>, FetchError> {
let rev = revision.unwrap_or("main");
let index_filename = "model.safetensors.index.json";
if let Some(cached_path) = resolve_cached_path(repo_id, rev, index_filename) {
let content = std::fs::read_to_string(&cached_path).map_err(|e| FetchError::Io {
path: cached_path,
source: e,
})?;
let index = parse_shard_index_json(&content, repo_id)?;
return Ok(Some(index));
}
let client = chunked::build_client(token)?;
let url = chunked::build_download_url(repo_id, rev, index_filename);
let response =
client.get(url.as_str()).send().await.map_err(|e| {
FetchError::Http(format!("failed to fetch shard index for {repo_id}: {e}"))
})?;
if response.status() == reqwest::StatusCode::NOT_FOUND {
return Ok(None);
}
if !response.status().is_success() {
return Err(FetchError::Http(format!(
"shard index request for {repo_id} returned status {}",
response.status()
)));
}
let content = response
.text()
.await
.map_err(|e| FetchError::Http(format!("failed to read shard index for {repo_id}: {e}")))?;
let index = parse_shard_index_json(&content, repo_id)?;
Ok(Some(index))
}
pub fn fetch_shard_index_cached(
repo_id: &str,
revision: Option<&str>,
) -> Result<Option<ShardedIndex>, FetchError> {
let rev = revision.unwrap_or("main");
let index_filename = "model.safetensors.index.json";
let Some(cached_path) = resolve_cached_path(repo_id, rev, index_filename) else {
return Ok(None);
};
let content = std::fs::read_to_string(&cached_path).map_err(|e| FetchError::Io {
path: cached_path,
source: e,
})?;
let index = parse_shard_index_json(&content, repo_id)?;
Ok(Some(index))
}
fn parse_shard_index_json(content: &str, repo_id: &str) -> Result<ShardedIndex, FetchError> {
let raw: RawShardIndex =
serde_json::from_str(content).map_err(|e| FetchError::SafetensorsHeader {
filename: "model.safetensors.index.json".to_owned(),
reason: format!("failed to parse shard index for {repo_id}: {e}"),
})?;
let mut shard_set: Vec<String> = raw.weight_map.values().cloned().collect();
shard_set.sort();
shard_set.dedup();
Ok(ShardedIndex {
weight_map: raw.weight_map,
shards: shard_set,
metadata: raw.metadata,
})
}
#[must_use]
pub fn format_params(count: u64) -> String {
#[allow(clippy::cast_precision_loss, clippy::as_conversions)]
let val = count as f64;
if count >= 1_000_000_000 {
format!("{:.2}B", val / 1_000_000_000.0)
} else if count >= 1_000_000 {
format!("{:.1}M", val / 1_000_000.0)
} else if count >= 1_000 {
format!("{:.1}K", val / 1_000.0)
} else {
count.to_string()
}
}
#[derive(serde::Deserialize)]
struct RawAdapterConfig {
#[serde(default)]
peft_type: Option<String>,
#[serde(default)]
base_model_name_or_path: Option<String>,
#[serde(default)]
r: Option<u32>,
#[serde(default)]
lora_alpha: Option<f64>,
#[serde(default)]
target_modules: Option<AdapterTargetModules>,
#[serde(default)]
task_type: Option<String>,
}
#[derive(serde::Deserialize)]
#[serde(untagged)]
enum AdapterTargetModules {
List(Vec<String>),
Single(String),
}
pub async fn fetch_adapter_config(
repo_id: &str,
token: Option<&str>,
revision: Option<&str>,
) -> Result<Option<AdapterConfig>, FetchError> {
let rev = revision.unwrap_or("main");
let config_filename = "adapter_config.json";
if let Some(cached_path) = resolve_cached_path(repo_id, rev, config_filename) {
let content = std::fs::read_to_string(&cached_path).map_err(|e| FetchError::Io {
path: cached_path,
source: e,
})?;
let config = parse_adapter_config_json(&content, repo_id)?;
return Ok(Some(config));
}
let client = chunked::build_client(token)?;
let url = chunked::build_download_url(repo_id, rev, config_filename);
let response = client.get(url.as_str()).send().await.map_err(|e| {
FetchError::Http(format!("failed to fetch adapter config for {repo_id}: {e}"))
})?;
if response.status() == reqwest::StatusCode::NOT_FOUND {
return Ok(None);
}
if !response.status().is_success() {
return Err(FetchError::Http(format!(
"adapter config request for {repo_id} returned status {}",
response.status()
)));
}
let content = response.text().await.map_err(|e| {
FetchError::Http(format!("failed to read adapter config for {repo_id}: {e}"))
})?;
let config = parse_adapter_config_json(&content, repo_id)?;
Ok(Some(config))
}
pub fn fetch_adapter_config_cached(
repo_id: &str,
revision: Option<&str>,
) -> Result<Option<AdapterConfig>, FetchError> {
let rev = revision.unwrap_or("main");
let config_filename = "adapter_config.json";
let Some(cached_path) = resolve_cached_path(repo_id, rev, config_filename) else {
return Ok(None);
};
let content = std::fs::read_to_string(&cached_path).map_err(|e| FetchError::Io {
path: cached_path,
source: e,
})?;
let config = parse_adapter_config_json(&content, repo_id)?;
Ok(Some(config))
}
fn parse_adapter_config_json(content: &str, repo_id: &str) -> Result<AdapterConfig, FetchError> {
let raw: RawAdapterConfig =
serde_json::from_str(content).map_err(|e| FetchError::SafetensorsHeader {
filename: "adapter_config.json".to_owned(),
reason: format!("failed to parse adapter config for {repo_id}: {e}"),
})?;
let target_modules = match raw.target_modules {
Some(AdapterTargetModules::List(v)) => v,
Some(AdapterTargetModules::Single(s)) => vec![s],
None => Vec::new(),
};
Ok(AdapterConfig {
peft_type: raw.peft_type,
base_model_name_or_path: raw.base_model_name_or_path,
r: raw.r,
lora_alpha: raw.lora_alpha,
target_modules,
task_type: raw.task_type,
})
}