#[derive(Debug, Deserialize)]
struct SafetensorsIndex {
weight_map: HashMap<String, String>,
}
#[cfg(not(target_arch = "wasm32"))]
#[derive(Debug)]
pub struct ShardedSafeTensorsModel {
shards: Vec<MappedSafeTensorsModel>,
tensor_to_shard: HashMap<String, usize>,
base_path: std::path::PathBuf,
}
#[cfg(not(target_arch = "wasm32"))]
impl ShardedSafeTensorsModel {
pub fn load_from_index(index_path: &std::path::Path) -> Result<Self> {
let base_path = index_path
.parent()
.ok_or_else(|| RealizarError::IoError {
message: format!(
"Cannot determine parent directory of '{}'",
index_path.display()
),
})?
.to_path_buf();
let index_content =
std::fs::read_to_string(index_path).map_err(|e| RealizarError::IoError {
message: format!(
"Failed to read index file '{}': {}",
index_path.display(),
e
),
})?;
let index: SafetensorsIndex =
serde_json::from_str(&index_content).map_err(|e| RealizarError::FormatError {
reason: format!("Failed to parse index.json: {}", e),
})?;
let mut shard_filenames: Vec<String> = Vec::new();
let mut filename_to_idx: HashMap<String, usize> = HashMap::new();
for shard_file in index.weight_map.values() {
if !filename_to_idx.contains_key(shard_file) {
let idx = shard_filenames.len();
filename_to_idx.insert(shard_file.clone(), idx);
shard_filenames.push(shard_file.clone());
}
}
let mut shards = Vec::with_capacity(shard_filenames.len());
for filename in &shard_filenames {
let shard_path = base_path.join(filename);
let shard = MappedSafeTensorsModel::load(&shard_path)?;
shards.push(shard);
}
let mut tensor_to_shard = HashMap::with_capacity(index.weight_map.len());
for (tensor_name, shard_file) in &index.weight_map {
let shard_idx = filename_to_idx[shard_file];
tensor_to_shard.insert(tensor_name.clone(), shard_idx);
}
Ok(Self {
shards,
tensor_to_shard,
base_path,
})
}
pub fn get_tensor_auto(&self, name: &str) -> Result<Vec<f32>> {
let shard_idx =
self.tensor_to_shard
.get(name)
.ok_or_else(|| RealizarError::UnsupportedOperation {
operation: "get_tensor_auto".to_string(),
reason: format!("Tensor '{}' not found in sharded model", name),
})?;
self.shards[*shard_idx].get_tensor_auto(name)
}
#[must_use]
pub fn tensor_names(&self) -> Vec<&str> {
self.tensor_to_shard.keys().map(String::as_str).collect()
}
#[must_use]
pub fn get_tensor_info(&self, name: &str) -> Option<&SafetensorsTensorInfo> {
let shard_idx = self.tensor_to_shard.get(name)?;
self.shards[*shard_idx].get_tensor_info(name)
}
#[must_use]
pub fn has_tensor(&self, name: &str) -> bool {
self.tensor_to_shard.contains_key(name)
}
#[must_use]
pub fn path(&self) -> &std::path::Path {
&self.base_path
}
#[must_use]
pub fn tensor_count(&self) -> usize {
self.tensor_to_shard.len()
}
#[must_use]
pub fn shard_count(&self) -> usize {
self.shards.len()
}
}
pub mod validation;
pub use validation::{
enforce_embedding_validation,
enforce_weight_validation,
validate_embedding,
validate_weight,
ContractValidationError,
TensorStats,
ValidatedAprTransformer,
ValidatedEmbedding,
ValidatedVector,
ValidatedWeight,
ValidationResult,
};
#[cfg(test)]
mod tests;
#[cfg(test)]
#[path = "tests_find_sibling.rs"]
mod safetensors_tests_part_02;