#[derive(Debug, Clone)]
pub struct ShardedImportConfig {
pub max_cached_shards: usize,
pub max_cache_bytes: usize,
pub sort_tensors: bool,
pub verify_checksums: bool,
pub buffer_size: usize,
}
impl Default for ShardedImportConfig {
fn default() -> Self {
#[cfg(target_arch = "wasm32")]
let max_cache_bytes = 256 * 1024 * 1024; #[cfg(not(target_arch = "wasm32"))]
let max_cache_bytes = 4_usize * 1024 * 1024 * 1024;
Self {
max_cached_shards: 2,
max_cache_bytes,
sort_tensors: true,
verify_checksums: true,
buffer_size: 8 * 1024 * 1024, }
}
}
impl ShardedImportConfig {
#[must_use]
pub fn low_memory() -> Self {
Self {
max_cached_shards: 1,
max_cache_bytes: 1024 * 1024 * 1024, buffer_size: 4 * 1024 * 1024, ..Self::default()
}
}
#[must_use]
pub fn high_memory() -> Self {
#[cfg(target_arch = "wasm32")]
let max_cache_bytes = 512 * 1024 * 1024; #[cfg(not(target_arch = "wasm32"))]
let max_cache_bytes = 8_usize * 1024 * 1024 * 1024;
Self {
max_cached_shards: 4,
max_cache_bytes,
buffer_size: 16 * 1024 * 1024, ..Self::default()
}
}
pub fn validate(&self) -> Result<()> {
if self.max_cached_shards == 0 {
return Err(AprenderError::FormatError {
message: "max_cached_shards must be > 0".to_string(),
});
}
if self.buffer_size == 0 {
return Err(AprenderError::FormatError {
message: "buffer_size must be > 0".to_string(),
});
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct ImportReport {
pub tensor_count: usize,
pub shard_count: usize,
pub bytes_written: u64,
pub peak_memory_bytes: u64,
pub cache_hit_rate: f32,
pub duration_ms: u64,
pub warnings: Vec<String>,
}
#[derive(Debug)]
pub struct ShardedImporter {
config: ShardedImportConfig,
cache: ShardCache,
base_dir: PathBuf,
}
impl ShardedImporter {
#[must_use]
pub fn new(config: ShardedImportConfig, base_dir: PathBuf) -> Self {
let cache = ShardCache::new(config.max_cached_shards, config.max_cache_bytes);
Self {
config,
cache,
base_dir,
}
}
#[must_use]
pub fn with_defaults(base_dir: PathBuf) -> Self {
Self::new(ShardedImportConfig::default(), base_dir)
}
pub fn parse_index(&self, index_path: &Path) -> Result<ShardIndex> {
let content = std::fs::read_to_string(index_path).map_err(AprenderError::Io)?;
ShardIndex::from_json(&content)
}
pub fn load_shard(&mut self, shard_file: &str) -> Result<&CachedShard> {
if self.cache.get(shard_file).is_some() {
return self
.cache
.get(shard_file)
.ok_or_else(|| AprenderError::FormatError {
message: "Cache inconsistency".to_string(),
});
}
let shard_path = self.base_dir.join(shard_file);
let mut shard = CachedShard::new(shard_file.to_string());
if shard_path.exists() {
let metadata = std::fs::metadata(&shard_path).map_err(AprenderError::Io)?;
shard.size = metadata.len() as usize;
}
self.cache.insert(shard);
self.cache
.get(shard_file)
.ok_or_else(|| AprenderError::FormatError {
message: "Failed to retrieve cached shard".to_string(),
})
}
pub fn stream_merge(
&mut self,
index: &ShardIndex,
_output_path: &Path,
) -> Result<ImportReport> {
self.config.validate()?;
let start_time = std::time::Instant::now();
let mut warnings = Vec::new();
let tensor_names = if self.config.sort_tensors {
index.tensor_names()
} else {
index.weight_map.keys().map(String::as_str).collect()
};
let mut bytes_written = 0u64;
let mut peak_memory = 0u64;
for tensor_name in &tensor_names {
if let Some(shard_file) = index.shard_for_tensor(tensor_name) {
match self.load_shard(shard_file) {
Ok(_shard) => {
bytes_written += 1024; }
Err(e) => {
warnings.push(format!("Failed to load shard {shard_file}: {e}"));
}
}
}
let current_memory = self.cache.stats().cached_bytes as u64;
peak_memory = peak_memory.max(current_memory);
}
let duration_ms = start_time.elapsed().as_millis() as u64;
Ok(ImportReport {
tensor_count: tensor_names.len(),
shard_count: index.shard_count(),
bytes_written,
peak_memory_bytes: peak_memory,
cache_hit_rate: self.cache.hit_rate(),
duration_ms,
warnings,
})
}
#[must_use]
pub fn cache_stats(&self) -> CacheStats {
self.cache.stats()
}
pub fn clear_cache(&mut self) {
self.cache.clear();
}
#[must_use]
pub fn config(&self) -> &ShardedImportConfig {
&self.config
}
}
impl Default for ShardedImporter {
fn default() -> Self {
Self::with_defaults(PathBuf::from("."))
}
}
#[must_use]
pub fn is_sharded_model(dir: &Path) -> bool {
if dir.join("model.safetensors.index.json").exists() {
return true;
}
let safetensors_count = std::fs::read_dir(dir).map_or(0, |entries| {
entries
.filter_map(std::result::Result::ok)
.filter(|e| e.path().extension().is_some_and(|ext| ext == "safetensors"))
.count()
});
safetensors_count > 1
}
pub fn get_shard_files(dir: &Path) -> Result<Vec<PathBuf>> {
let entries = std::fs::read_dir(dir).map_err(AprenderError::Io)?;
let mut files: Vec<PathBuf> = entries
.filter_map(std::result::Result::ok)
.filter(|e| e.path().extension().is_some_and(|ext| ext == "safetensors"))
.map(|e| e.path())
.collect();
files.sort();
Ok(files)
}
#[must_use]
pub fn estimate_shard_memory(file_size: u64) -> u64 {
file_size + (file_size / 100) }
#[cfg(test)]
mod tests;