entrenar/hf_pipeline/export/
weights.rs1use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5
6#[derive(Debug, Clone)]
8pub struct ModelWeights {
9 pub tensors: HashMap<String, Vec<f32>>,
11 pub shapes: HashMap<String, Vec<usize>>,
13 pub metadata: ModelMetadata,
15}
16
17#[derive(Debug, Clone, Default, Serialize, Deserialize)]
19pub struct ModelMetadata {
20 pub architecture: Option<String>,
22 pub model_name: Option<String>,
24 pub num_params: u64,
26 pub hidden_size: Option<usize>,
28 pub num_layers: Option<usize>,
30 pub vocab_size: Option<usize>,
32 pub training: Option<TrainingMetadata>,
34}
35
36#[derive(Debug, Clone, Default, Serialize, Deserialize)]
38pub struct TrainingMetadata {
39 pub epochs: usize,
41 pub final_loss: Option<f32>,
43 pub final_val_loss: Option<f32>,
45 pub learning_rate: Option<f64>,
47 pub batch_size: Option<usize>,
49 pub temperature: Option<f32>,
51 pub teacher_model: Option<String>,
53}
54
55impl ModelWeights {
56 #[must_use]
58 pub fn new() -> Self {
59 Self { tensors: HashMap::new(), shapes: HashMap::new(), metadata: ModelMetadata::default() }
60 }
61
62 pub fn add_tensor(&mut self, name: impl Into<String>, data: Vec<f32>, shape: Vec<usize>) {
64 let name = name.into();
65 self.tensors.insert(name.clone(), data);
66 self.shapes.insert(name, shape);
67 }
68
69 #[must_use]
71 pub fn get_tensor(&self, name: &str) -> Option<(&Vec<f32>, &Vec<usize>)> {
72 let data = self.tensors.get(name)?;
73 let shape = self.shapes.get(name)?;
74 Some((data, shape))
75 }
76
77 #[must_use]
79 pub fn tensor_names(&self) -> Vec<&str> {
80 self.tensors.keys().map(String::as_str).collect()
81 }
82
83 #[must_use]
85 pub fn param_count(&self) -> u64 {
86 self.tensors.values().map(|t| t.len() as u64).sum()
87 }
88
89 pub fn with_metadata(mut self, metadata: ModelMetadata) -> Self {
91 self.metadata = metadata;
92 self
93 }
94
95 #[must_use]
97 pub fn mock(num_layers: usize, hidden_size: usize) -> Self {
98 let mut weights = Self::new();
99
100 for layer in 0..num_layers {
101 for proj in &["q_proj", "k_proj", "v_proj", "o_proj"] {
103 let name = format!("layer.{layer}.attention.{proj}.weight");
104 let size = hidden_size * hidden_size;
105 let data = vec![0.01; size];
106 weights.add_tensor(name, data, vec![hidden_size, hidden_size]);
107 }
108
109 let mlp_size = hidden_size * 4;
111 weights.add_tensor(
112 format!("layer.{layer}.mlp.up.weight"),
113 vec![0.01; hidden_size * mlp_size],
114 vec![mlp_size, hidden_size],
115 );
116 weights.add_tensor(
117 format!("layer.{layer}.mlp.down.weight"),
118 vec![0.01; mlp_size * hidden_size],
119 vec![hidden_size, mlp_size],
120 );
121 }
122
123 weights.metadata = ModelMetadata {
124 num_params: weights.param_count(),
125 hidden_size: Some(hidden_size),
126 num_layers: Some(num_layers),
127 ..Default::default()
128 };
129
130 weights
131 }
132}
133
134impl Default for ModelWeights {
135 fn default() -> Self {
136 Self::new()
137 }
138}