use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct ModelWeights {
pub tensors: HashMap<String, Vec<f32>>,
pub shapes: HashMap<String, Vec<usize>>,
pub metadata: ModelMetadata,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ModelMetadata {
pub architecture: Option<String>,
pub model_name: Option<String>,
pub num_params: u64,
pub hidden_size: Option<usize>,
pub num_layers: Option<usize>,
pub vocab_size: Option<usize>,
pub training: Option<TrainingMetadata>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct TrainingMetadata {
pub epochs: usize,
pub final_loss: Option<f32>,
pub final_val_loss: Option<f32>,
pub learning_rate: Option<f64>,
pub batch_size: Option<usize>,
pub temperature: Option<f32>,
pub teacher_model: Option<String>,
}
impl ModelWeights {
#[must_use]
pub fn new() -> Self {
Self { tensors: HashMap::new(), shapes: HashMap::new(), metadata: ModelMetadata::default() }
}
pub fn add_tensor(&mut self, name: impl Into<String>, data: Vec<f32>, shape: Vec<usize>) {
let name = name.into();
self.tensors.insert(name.clone(), data);
self.shapes.insert(name, shape);
}
#[must_use]
pub fn get_tensor(&self, name: &str) -> Option<(&Vec<f32>, &Vec<usize>)> {
let data = self.tensors.get(name)?;
let shape = self.shapes.get(name)?;
Some((data, shape))
}
#[must_use]
pub fn tensor_names(&self) -> Vec<&str> {
self.tensors.keys().map(String::as_str).collect()
}
#[must_use]
pub fn param_count(&self) -> u64 {
self.tensors.values().map(|t| t.len() as u64).sum()
}
pub fn with_metadata(mut self, metadata: ModelMetadata) -> Self {
self.metadata = metadata;
self
}
#[must_use]
pub fn mock(num_layers: usize, hidden_size: usize) -> Self {
let mut weights = Self::new();
for layer in 0..num_layers {
for proj in &["q_proj", "k_proj", "v_proj", "o_proj"] {
let name = format!("layer.{layer}.attention.{proj}.weight");
let size = hidden_size * hidden_size;
let data = vec![0.01; size];
weights.add_tensor(name, data, vec![hidden_size, hidden_size]);
}
let mlp_size = hidden_size * 4;
weights.add_tensor(
format!("layer.{layer}.mlp.up.weight"),
vec![0.01; hidden_size * mlp_size],
vec![mlp_size, hidden_size],
);
weights.add_tensor(
format!("layer.{layer}.mlp.down.weight"),
vec![0.01; mlp_size * hidden_size],
vec![hidden_size, mlp_size],
);
}
weights.metadata = ModelMetadata {
num_params: weights.param_count(),
hidden_size: Some(hidden_size),
num_layers: Some(num_layers),
..Default::default()
};
weights
}
}
impl Default for ModelWeights {
fn default() -> Self {
Self::new()
}
}