use crate::error::{AprenderError, Result};
use std::collections::{HashMap, VecDeque};
use std::path::{Path, PathBuf};
#[derive(Debug, Clone, Default)]
pub struct ShardIndex {
weight_map: HashMap<String, String>,
shard_files: Vec<String>,
shard_indices: HashMap<String, usize>,
metadata: HashMap<String, String>,
}
impl ShardIndex {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn from_json(json: &str) -> Result<Self> {
if json.is_empty() {
return Ok(Self::default());
}
let mut index = Self::new();
if let Some(weight_map_start) = json.find("\"weight_map\"") {
let after_key = &json[weight_map_start..];
if let Some(brace_start) = after_key.find('{') {
let content = &after_key[brace_start + 1..];
if let Some(brace_end) = content.find('}') {
let entries = &content[..brace_end];
for entry in entries.split(',') {
let parts: Vec<&str> = entry.split(':').collect();
if parts.len() >= 2 {
let tensor_name = parts[0]
.trim()
.trim_matches('"')
.trim_matches('\\')
.to_string();
let shard_file = parts[1]
.trim()
.trim_matches('"')
.trim_matches('\\')
.to_string();
if !tensor_name.is_empty() && !shard_file.is_empty() {
index.add_mapping(&tensor_name, &shard_file);
}
}
}
}
}
}
Ok(index)
}
pub fn add_mapping(&mut self, tensor_name: &str, shard_file: &str) {
self.weight_map
.insert(tensor_name.to_string(), shard_file.to_string());
if !self.shard_indices.contains_key(shard_file) {
let idx = self.shard_files.len();
self.shard_files.push(shard_file.to_string());
self.shard_indices.insert(shard_file.to_string(), idx);
}
}
#[must_use]
pub fn shard_for_tensor(&self, tensor_name: &str) -> Option<&str> {
self.weight_map.get(tensor_name).map(String::as_str)
}
#[must_use]
pub fn shard_index(&self, shard_file: &str) -> Option<usize> {
self.shard_indices.get(shard_file).copied()
}
#[must_use]
pub fn shard_count(&self) -> usize {
self.shard_files.len()
}
#[must_use]
pub fn tensor_count(&self) -> usize {
self.weight_map.len()
}
#[must_use]
pub fn shard_files(&self) -> &[String] {
&self.shard_files
}
#[must_use]
pub fn tensor_names(&self) -> Vec<&str> {
let mut names: Vec<&str> = self.weight_map.keys().map(String::as_str).collect();
names.sort_unstable();
names
}
#[must_use]
pub fn tensors_by_shard(&self) -> HashMap<&str, Vec<&str>> {
let mut by_shard: HashMap<&str, Vec<&str>> = HashMap::new();
for (tensor, shard) in &self.weight_map {
by_shard
.entry(shard.as_str())
.or_default()
.push(tensor.as_str());
}
for tensors in by_shard.values_mut() {
tensors.sort_unstable();
}
by_shard
}
#[must_use]
pub fn is_valid(&self) -> bool {
!self.weight_map.is_empty() && !self.shard_files.is_empty()
}
pub fn set_metadata(&mut self, key: &str, value: &str) {
self.metadata.insert(key.to_string(), value.to_string());
}
#[must_use]
pub fn get_metadata(&self, key: &str) -> Option<&str> {
self.metadata.get(key).map(String::as_str)
}
}
#[derive(Debug, Clone)]
pub struct CachedShard {
pub filename: String,
pub tensors: HashMap<String, Vec<u8>>,
pub size: usize,
}
impl CachedShard {
#[must_use]
pub fn new(filename: String) -> Self {
Self {
filename,
tensors: HashMap::new(),
size: 0,
}
}
pub fn add_tensor(&mut self, name: String, data: Vec<u8>) {
self.size += data.len();
self.tensors.insert(name, data);
}
#[must_use]
pub fn get_tensor(&self, name: &str) -> Option<&[u8]> {
self.tensors.get(name).map(Vec::as_slice)
}
#[must_use]
pub fn has_tensor(&self, name: &str) -> bool {
self.tensors.contains_key(name)
}
}
#[derive(Debug)]
pub struct ShardCache {
max_shards: usize,
max_bytes: usize,
cache: VecDeque<CachedShard>,
current_size: usize,
hits: usize,
misses: usize,
}
impl ShardCache {
#[must_use]
pub fn new(max_shards: usize, max_bytes: usize) -> Self {
Self {
max_shards: max_shards.max(1),
max_bytes,
cache: VecDeque::new(),
current_size: 0,
hits: 0,
misses: 0,
}
}
#[must_use]
pub fn default_for_import() -> Self {
#[cfg(target_arch = "wasm32")]
let max_bytes = 256 * 1024 * 1024; #[cfg(not(target_arch = "wasm32"))]
let max_bytes = 4_usize * 1024 * 1024 * 1024; Self::new(2, max_bytes)
}
#[must_use]
pub fn get(&mut self, filename: &str) -> Option<&CachedShard> {
let pos = self.cache.iter().position(|s| s.filename == filename);
if let Some(idx) = pos {
if idx < self.cache.len() - 1 {
let shard = self.cache.remove(idx);
if let Some(s) = shard {
self.cache.push_back(s);
}
}
self.hits += 1;
self.cache.back()
} else {
self.misses += 1;
None
}
}
pub fn insert(&mut self, shard: CachedShard) {
while self.cache.len() >= self.max_shards
|| (self.current_size + shard.size > self.max_bytes && !self.cache.is_empty())
{
if let Some(evicted) = self.cache.pop_front() {
self.current_size = self.current_size.saturating_sub(evicted.size);
}
}
self.current_size += shard.size;
self.cache.push_back(shard);
}
pub fn clear(&mut self) {
self.cache.clear();
self.current_size = 0;
}
#[must_use]
pub fn stats(&self) -> CacheStats {
CacheStats {
cached_shards: self.cache.len(),
cached_bytes: self.current_size,
hits: self.hits,
misses: self.misses,
}
}
#[must_use]
pub fn hit_rate(&self) -> f32 {
let total = self.hits + self.misses;
if total > 0 {
self.hits as f32 / total as f32
} else {
0.0
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct CacheStats {
pub cached_shards: usize,
pub cached_bytes: usize,
pub hits: usize,
pub misses: usize,
}
impl Default for ShardCache {
fn default() -> Self {
Self::default_for_import()
}
}
pub type ProgressCallback = Box<dyn Fn(ImportProgress) + Send + Sync>;
#[derive(Debug, Clone)]
pub struct ImportProgress {
pub phase: ImportPhase,
pub tensors_processed: usize,
pub total_tensors: usize,
pub shards_loaded: usize,
pub total_shards: usize,
pub bytes_written: u64,
pub progress: f32,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ImportPhase {
Parsing,
Loading,
Merging,
Finalizing,
Complete,
}
include!("config.rs");