use std::collections::{HashMap, HashSet};
use std::fs::File;
use std::io;
use std::io::BufReader;
use std::path::{Path, PathBuf};
use candle_core::safetensors::MmapedSafetensors;
use candle_nn::var_builder::SimpleBackend;
use serde::Deserialize;
use snafu::{ensure, ResultExt, Snafu};
use crate::error::BoxedError;
use crate::repository::repo::Repo;
static SAFETENSORS_INDEX: &str = "model.safetensors.index.json";
static SAFETENSORS_SINGLE: &str = "model.safetensors";
pub trait LoadHFCheckpoint {
fn load_hf_checkpoint(&self) -> Result<Box<dyn SimpleBackend>, BoxedError>;
}
#[derive(Debug, Snafu)]
pub enum HFCheckpointError {
#[snafu(display("Cannot download checkpoint: {name}"))]
Download { source: BoxedError, name: String },
#[snafu(display("Cannot open or load checkpoint"))]
LoadCheckpoint { source: candle_core::Error },
#[snafu(display("Checkpoint does not exist: {}", name))]
NonExistentCheckpoint { name: String },
#[snafu(display("Shard does not exist: {}", name))]
NonExistentShard { name: String },
#[snafu(display("Cannot open SafeTensors index file: {}", path.to_string_lossy()))]
OpenSafeTensorsIndex { source: io::Error, path: PathBuf },
#[snafu(display("Cannot parse SafeTensors index file: {}", path.to_string_lossy()))]
ParseSafeTensorsIndex {
source: serde_json::Error,
path: PathBuf,
},
}
impl<R> LoadHFCheckpoint for R
where
R: Repo,
{
fn load_hf_checkpoint(&self) -> Result<Box<dyn SimpleBackend>, BoxedError> {
self.load_safetensors()
}
}
trait LoadHFSafeTensors {
fn load_safetensors(&self) -> Result<Box<dyn SimpleBackend>, BoxedError>;
fn load_safetensors_multi(&self, index_path: &Path) -> Result<Vec<PathBuf>, BoxedError>;
fn load_safetensors_single(&self) -> Result<Vec<PathBuf>, BoxedError>;
}
impl<R> LoadHFSafeTensors for R
where
R: Repo,
{
fn load_safetensors(&self) -> Result<Box<dyn SimpleBackend>, BoxedError> {
let file = self.file(SAFETENSORS_INDEX).context(DownloadSnafu {
name: SAFETENSORS_INDEX,
})?;
let paths = match file {
Some(index_path) => self.load_safetensors_multi(&index_path),
None => self.load_safetensors_single(),
}?;
Ok(Box::new(unsafe {
MmapedSafetensors::multi(&paths).context(LoadCheckpointSnafu)?
}))
}
fn load_safetensors_multi(&self, index_path: &Path) -> Result<Vec<PathBuf>, BoxedError> {
let index_file = BufReader::new(
File::open(index_path).context(OpenSafeTensorsIndexSnafu { path: index_path })?,
);
let index: SafeTensorsIndex = serde_json::from_reader(index_file)
.context(ParseSafeTensorsIndexSnafu { path: index_path })?;
let shard_names = index.shards();
let mut shards = Vec::with_capacity(shard_names.len());
for shard_name in shard_names {
let path = self.file(&shard_name).context(DownloadSnafu {
name: shard_name.clone(),
})?;
ensure!(path.is_some(), NonExistentShardSnafu { name: shard_name });
shards.push(path.unwrap());
}
Ok(shards)
}
fn load_safetensors_single(&self) -> Result<Vec<PathBuf>, BoxedError> {
let path = self.file(SAFETENSORS_SINGLE).context(DownloadSnafu {
name: SAFETENSORS_SINGLE,
})?;
ensure!(
path.is_some(),
NonExistentCheckpointSnafu {
name: SAFETENSORS_SINGLE.to_string(),
}
);
Ok(vec![path.unwrap()])
}
}
#[derive(Debug, Deserialize)]
struct SafeTensorsIndex {
weight_map: HashMap<String, String>,
}
impl SafeTensorsIndex {
fn shards(&self) -> HashSet<String> {
self.weight_map.values().cloned().collect()
}
}