use crate::{Error, Result};
use std::path::PathBuf;
pub fn download_files(repo_id: &str, files: &[&str]) -> Result<Vec<PathBuf>> {
let api = hf_hub::api::sync::Api::new()
.map_err(|e| Error::Other(anyhow::anyhow!("hf-hub api init failed: {e}")))?;
let repo = api.model(repo_id.to_string());
files
.iter()
.map(|f| {
repo.get(f)
.map_err(|e| Error::Other(anyhow::anyhow!("hf-hub get {repo_id}/{f}: {e}")))
})
.collect()
}
pub fn download_qwen2_single_shard(repo_id: &str) -> Result<(PathBuf, PathBuf, PathBuf)> {
let files = download_files(
repo_id,
&["config.json", "tokenizer.json", "model.safetensors"],
)?;
Ok((files[0].clone(), files[1].clone(), files[2].clone()))
}
pub fn download_pth_sharded(
repo_id: &str,
) -> Result<(std::path::PathBuf, Vec<std::path::PathBuf>)> {
let api = hf_hub::api::sync::Api::new()
.map_err(|e| Error::Other(anyhow::anyhow!("hf-hub api init failed: {e}")))?;
let repo = api.model(repo_id.to_string());
let index_path = repo
.get("pytorch_model.bin.index.json")
.map_err(|e| Error::Other(anyhow::anyhow!("get pytorch_model.bin.index.json: {e}")))?;
let index_json: serde_json::Value = serde_json::from_reader(std::fs::File::open(&index_path)?)
.map_err(|e| Error::Other(anyhow::anyhow!("parse pth index: {e}")))?;
let weight_map = index_json
.get("weight_map")
.and_then(|v| v.as_object())
.ok_or_else(|| {
Error::Other(anyhow::anyhow!(
"pytorch_model.bin.index.json has no `weight_map`"
))
})?;
let mut shards: Vec<String> = weight_map
.values()
.filter_map(|v| v.as_str().map(String::from))
.collect();
shards.sort();
shards.dedup();
let mut paths = Vec::with_capacity(shards.len());
for s in &shards {
let p = repo
.get(s)
.map_err(|e| Error::Other(anyhow::anyhow!("get pth shard {s}: {e}")))?;
paths.push(p);
}
Ok((index_path, paths))
}
pub struct MultiPthBackend {
shards: Vec<candle_core::pickle::PthTensors>,
tensor_to_shard: std::collections::HashMap<String, usize>,
}
impl MultiPthBackend {
pub fn from_paths(
index_path: impl AsRef<std::path::Path>,
shard_paths: &[impl AsRef<std::path::Path>],
) -> Result<Self> {
let index_json: serde_json::Value =
serde_json::from_reader(std::fs::File::open(index_path.as_ref())?)
.map_err(|e| Error::Other(anyhow::anyhow!("parse pth index: {e}")))?;
let weight_map = index_json
.get("weight_map")
.and_then(|v| v.as_object())
.ok_or_else(|| Error::Other(anyhow::anyhow!("missing weight_map")))?;
let shards: Vec<candle_core::pickle::PthTensors> = shard_paths
.iter()
.map(|p| candle_core::pickle::PthTensors::new(p.as_ref(), None).map_err(Error::Candle))
.collect::<Result<Vec<_>>>()?;
let basename_to_idx: std::collections::HashMap<String, usize> = shard_paths
.iter()
.enumerate()
.filter_map(|(i, p)| {
p.as_ref()
.file_name()
.and_then(|n| n.to_str())
.map(|n| (n.to_string(), i))
})
.collect();
let mut tensor_to_shard = std::collections::HashMap::new();
for (tensor_name, shard_filename) in weight_map.iter() {
if let Some(name_str) = shard_filename.as_str() {
let basename = std::path::Path::new(name_str)
.file_name()
.and_then(|n| n.to_str())
.unwrap_or(name_str);
if let Some(&idx) = basename_to_idx.get(basename) {
tensor_to_shard.insert(tensor_name.clone(), idx);
} else if let Some(&idx) = basename_to_idx.get(name_str) {
tensor_to_shard.insert(tensor_name.clone(), idx);
}
}
}
if tensor_to_shard.is_empty() {
return Err(Error::Other(anyhow::anyhow!(
"no tensors mapped — shard filename mismatch between index and downloaded files"
)));
}
Ok(Self {
shards,
tensor_to_shard,
})
}
}
impl candle_nn::var_builder::SimpleBackend for MultiPthBackend {
fn get(
&self,
s: candle_core::Shape,
name: &str,
_h: candle_nn::Init,
dtype: candle_core::DType,
dev: &candle_core::Device,
) -> candle_core::Result<candle_core::Tensor> {
let idx =
self.tensor_to_shard
.get(name)
.ok_or_else(|| candle_core::Error::CannotFindTensor {
path: name.to_string(),
})?;
let tensor = match self.shards[*idx].get(name)? {
Some(t) => t,
None => {
return Err(candle_core::Error::CannotFindTensor {
path: name.to_string(),
});
}
};
let tensor = tensor.to_device(dev)?.to_dtype(dtype)?;
if tensor.shape() != &s {
return Err(candle_core::Error::UnexpectedShape {
msg: format!("shape mismatch for {name}"),
expected: s,
got: tensor.shape().clone(),
});
}
Ok(tensor)
}
fn contains_tensor(&self, name: &str) -> bool {
self.tensor_to_shard.contains_key(name)
}
}
pub fn download_qwen2(repo_id: &str) -> Result<(PathBuf, PathBuf, Vec<PathBuf>)> {
let api = hf_hub::api::sync::Api::new()
.map_err(|e| Error::Other(anyhow::anyhow!("hf-hub api init failed: {e}")))?;
let repo = api.model(repo_id.to_string());
let config_path = repo
.get("config.json")
.map_err(|e| Error::Other(anyhow::anyhow!("get config.json: {e}")))?;
let tokenizer_path = repo
.get("tokenizer.json")
.map_err(|e| Error::Other(anyhow::anyhow!("get tokenizer.json: {e}")))?;
if let Ok(p) = repo.get("model.safetensors") {
return Ok((config_path, tokenizer_path, vec![p]));
}
let index_path = repo
.get("model.safetensors.index.json")
.map_err(|e| Error::Other(anyhow::anyhow!("get index.json: {e}")))?;
let index_json: serde_json::Value = serde_json::from_reader(std::fs::File::open(&index_path)?)
.map_err(|e| Error::Other(anyhow::anyhow!("parse index.json: {e}")))?;
let weight_map = index_json
.get("weight_map")
.and_then(|v| v.as_object())
.ok_or_else(|| {
Error::Other(anyhow::anyhow!(
"model.safetensors.index.json has no `weight_map`"
))
})?;
let mut shards: Vec<String> = weight_map
.values()
.filter_map(|v| v.as_str().map(String::from))
.collect();
shards.sort();
shards.dedup();
let mut paths = Vec::with_capacity(shards.len());
for s in &shards {
let p = repo
.get(s)
.map_err(|e| Error::Other(anyhow::anyhow!("get shard {s}: {e}")))?;
paths.push(p);
}
Ok((config_path, tokenizer_path, paths))
}