Skip to main content

entrenar/hf_pipeline/export/
weights.rs

1//! Model weights and metadata containers.
2
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5
6/// Model weights container for export
7#[derive(Debug, Clone)]
8pub struct ModelWeights {
9    /// Tensor data by name
10    pub tensors: HashMap<String, Vec<f32>>,
11    /// Tensor shapes by name
12    pub shapes: HashMap<String, Vec<usize>>,
13    /// Model metadata
14    pub metadata: ModelMetadata,
15}
16
17/// Model metadata
18#[derive(Debug, Clone, Default, Serialize, Deserialize)]
19pub struct ModelMetadata {
20    /// Model architecture
21    pub architecture: Option<String>,
22    /// Model name
23    pub model_name: Option<String>,
24    /// Number of parameters
25    pub num_params: u64,
26    /// Hidden size
27    pub hidden_size: Option<usize>,
28    /// Number of layers
29    pub num_layers: Option<usize>,
30    /// Vocabulary size
31    pub vocab_size: Option<usize>,
32    /// Training info
33    pub training: Option<TrainingMetadata>,
34}
35
36/// Training metadata
37#[derive(Debug, Clone, Default, Serialize, Deserialize)]
38pub struct TrainingMetadata {
39    /// Training epochs completed
40    pub epochs: usize,
41    /// Final training loss
42    pub final_loss: Option<f32>,
43    /// Final validation loss
44    pub final_val_loss: Option<f32>,
45    /// Learning rate used
46    pub learning_rate: Option<f64>,
47    /// Batch size used
48    pub batch_size: Option<usize>,
49    /// Distillation temperature (if applicable)
50    pub temperature: Option<f32>,
51    /// Teacher model (if distilled)
52    pub teacher_model: Option<String>,
53}
54
55impl ModelWeights {
56    /// Create new empty weights container
57    #[must_use]
58    pub fn new() -> Self {
59        Self { tensors: HashMap::new(), shapes: HashMap::new(), metadata: ModelMetadata::default() }
60    }
61
62    /// Add a tensor
63    pub fn add_tensor(&mut self, name: impl Into<String>, data: Vec<f32>, shape: Vec<usize>) {
64        let name = name.into();
65        self.tensors.insert(name.clone(), data);
66        self.shapes.insert(name, shape);
67    }
68
69    /// Get tensor by name
70    #[must_use]
71    pub fn get_tensor(&self, name: &str) -> Option<(&Vec<f32>, &Vec<usize>)> {
72        let data = self.tensors.get(name)?;
73        let shape = self.shapes.get(name)?;
74        Some((data, shape))
75    }
76
77    /// Get all tensor names
78    #[must_use]
79    pub fn tensor_names(&self) -> Vec<&str> {
80        self.tensors.keys().map(String::as_str).collect()
81    }
82
83    /// Count total parameters
84    #[must_use]
85    pub fn param_count(&self) -> u64 {
86        self.tensors.values().map(|t| t.len() as u64).sum()
87    }
88
89    /// Set metadata
90    pub fn with_metadata(mut self, metadata: ModelMetadata) -> Self {
91        self.metadata = metadata;
92        self
93    }
94
95    /// Create mock weights for testing
96    #[must_use]
97    pub fn mock(num_layers: usize, hidden_size: usize) -> Self {
98        let mut weights = Self::new();
99
100        for layer in 0..num_layers {
101            // Q, K, V, O projections
102            for proj in &["q_proj", "k_proj", "v_proj", "o_proj"] {
103                let name = format!("layer.{layer}.attention.{proj}.weight");
104                let size = hidden_size * hidden_size;
105                let data = vec![0.01; size];
106                weights.add_tensor(name, data, vec![hidden_size, hidden_size]);
107            }
108
109            // MLP layers
110            let mlp_size = hidden_size * 4;
111            weights.add_tensor(
112                format!("layer.{layer}.mlp.up.weight"),
113                vec![0.01; hidden_size * mlp_size],
114                vec![mlp_size, hidden_size],
115            );
116            weights.add_tensor(
117                format!("layer.{layer}.mlp.down.weight"),
118                vec![0.01; mlp_size * hidden_size],
119                vec![hidden_size, mlp_size],
120            );
121        }
122
123        weights.metadata = ModelMetadata {
124            num_params: weights.param_count(),
125            hidden_size: Some(hidden_size),
126            num_layers: Some(num_layers),
127            ..Default::default()
128        };
129
130        weights
131    }
132}
133
134impl Default for ModelWeights {
135    fn default() -> Self {
136        Self::new()
137    }
138}