1use crate::error::{ModelError, ModelResult};
36use crate::{AutoregressiveModel, ModelType};
37use kizzasi_core::{silu, CoreResult, HiddenState, LayerNorm, NormType, SignalPredictor};
38use scirs2_core::ndarray::{Array1, Array2};
39use scirs2_core::random::{rng, Rng};
40use std::collections::VecDeque;
41
42#[allow(unused_imports)]
43use tracing::{debug, instrument, trace};
44
45#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
47pub struct H3Config {
48 pub input_dim: usize,
50 pub hidden_dim: usize,
52 pub ssm_dim: usize,
54 pub num_layers: usize,
56 pub shift_distance: usize,
58 pub num_heads: usize,
60}
61
62impl H3Config {
63 pub fn new(input_dim: usize, hidden_dim: usize, num_layers: usize) -> Self {
65 Self {
66 input_dim,
67 hidden_dim,
68 ssm_dim: 64,
69 num_layers,
70 shift_distance: 4,
71 num_heads: 4,
72 }
73 }
74
75 pub fn validate(&self) -> ModelResult<()> {
77 if self.hidden_dim == 0 {
78 return Err(ModelError::invalid_config("hidden_dim must be > 0"));
79 }
80 if self.ssm_dim == 0 {
81 return Err(ModelError::invalid_config("ssm_dim must be > 0"));
82 }
83 if self.num_layers == 0 {
84 return Err(ModelError::invalid_config("num_layers must be > 0"));
85 }
86 if self.shift_distance == 0 {
87 return Err(ModelError::invalid_config("shift_distance must be > 0"));
88 }
89 if self.num_heads == 0 {
90 return Err(ModelError::invalid_config("num_heads must be > 0"));
91 }
92 if !self.hidden_dim.is_multiple_of(self.num_heads) {
93 return Err(ModelError::invalid_config(
94 "hidden_dim must be divisible by num_heads",
95 ));
96 }
97 Ok(())
98 }
99}
100
101struct ShiftSSM {
103 head_dim: usize,
105 shift_distance: usize,
107 shift_weights: Array2<f32>,
109 history: VecDeque<Array1<f32>>,
111}
112
113impl ShiftSSM {
114 fn new(head_dim: usize, shift_distance: usize) -> Self {
116 let mut rng = rng();
117
118 let scale = (1.0 / shift_distance as f32).sqrt();
120 let shift_weights = Array2::from_shape_fn((shift_distance, head_dim), |_| {
121 (rng.random::<f32>() - 0.5) * 2.0 * scale
122 });
123
124 let history = VecDeque::with_capacity(shift_distance);
126
127 Self {
128 head_dim,
129 shift_distance,
130 shift_weights,
131 history,
132 }
133 }
134
135 fn forward(&mut self, x: &Array1<f32>) -> Array1<f32> {
137 self.history.push_back(x.clone());
139
140 while self.history.len() > self.shift_distance {
142 self.history.pop_front();
143 }
144
145 let mut output = Array1::zeros(self.head_dim);
147 for (i, hist_x) in self.history.iter().enumerate() {
148 let weight_row = self.shift_weights.row(i);
149 output = output + hist_x * &weight_row;
150 }
151
152 output
153 }
154
155 fn reset(&mut self) {
157 self.history.clear();
158 }
159}
160
161struct H3Layer {
163 num_heads: usize,
165 head_dim: usize,
167 input_proj: Array2<f32>,
169 shift_ssms: Vec<ShiftSSM>,
171 gate_proj: Array2<f32>,
173 output_proj: Array2<f32>,
175 layer_norm: LayerNorm,
177}
178
179impl H3Layer {
180 fn new(config: &H3Config) -> Self {
182 let mut rng = rng();
183 let num_heads = config.num_heads;
184 let head_dim = config.hidden_dim / num_heads;
185
186 let scale = (2.0 / (config.input_dim + config.hidden_dim) as f32).sqrt();
188 let input_proj = Array2::from_shape_fn((config.input_dim, config.hidden_dim), |_| {
189 (rng.random::<f32>() - 0.5) * 2.0 * scale
190 });
191
192 let shift_ssms = (0..num_heads)
194 .map(|_| ShiftSSM::new(head_dim, config.shift_distance))
195 .collect();
196
197 let scale = (2.0 / (config.hidden_dim + config.hidden_dim) as f32).sqrt();
199 let gate_proj = Array2::from_shape_fn((config.hidden_dim, config.hidden_dim), |_| {
200 (rng.random::<f32>() - 0.5) * 2.0 * scale
201 });
202
203 let scale = (2.0 / (config.hidden_dim + config.input_dim) as f32).sqrt();
205 let output_proj = Array2::from_shape_fn((config.hidden_dim, config.input_dim), |_| {
206 (rng.random::<f32>() - 0.5) * 2.0 * scale
207 });
208
209 let layer_norm = LayerNorm::new(config.hidden_dim, NormType::RMSNorm);
211
212 Self {
213 num_heads,
214 head_dim,
215 input_proj,
216 shift_ssms,
217 gate_proj,
218 output_proj,
219 layer_norm,
220 }
221 }
222
223 fn forward(&mut self, x: &Array1<f32>) -> CoreResult<Array1<f32>> {
225 let hidden = x.dot(&self.input_proj);
227
228 let mut ssm_outputs = Vec::with_capacity(self.num_heads);
230 for (head_idx, ssm) in self.shift_ssms.iter_mut().enumerate() {
231 let start = head_idx * self.head_dim;
232 let end = start + self.head_dim;
233 let head_input = hidden.slice(s![start..end]).to_owned();
234 ssm_outputs.push(ssm.forward(&head_input));
235 }
236
237 let mut ssm_output = Array1::zeros(self.num_heads * self.head_dim);
239 for (head_idx, head_out) in ssm_outputs.iter().enumerate() {
240 let start = head_idx * self.head_dim;
241 let end = start + self.head_dim;
242 ssm_output.slice_mut(s![start..end]).assign(head_out);
243 }
244
245 let gate = hidden.dot(&self.gate_proj);
247 let gate_activated = silu(&gate);
248 let gated = &ssm_output * &gate_activated;
249
250 let normed = self.layer_norm.forward(&gated);
252
253 let output = normed.dot(&self.output_proj) + x;
255
256 Ok(output)
257 }
258
259 fn reset(&mut self) {
261 for ssm in &mut self.shift_ssms {
262 ssm.reset();
263 }
264 }
265}
266
267pub struct H3 {
269 config: H3Config,
270 layers: Vec<H3Layer>,
271}
272
273impl H3 {
274 #[instrument(skip(config), fields(input_dim = config.input_dim, hidden_dim = config.hidden_dim, num_layers = config.num_layers))]
276 pub fn new(config: H3Config) -> ModelResult<Self> {
277 debug!("Creating new H3 model");
278 config.validate()?;
279
280 let mut layers = Vec::with_capacity(config.num_layers);
281 for layer_idx in 0..config.num_layers {
282 trace!("Initializing H3 layer {}", layer_idx);
283 layers.push(H3Layer::new(&config));
284 }
285 debug!("Initialized {} H3 layers", layers.len());
286
287 debug!("H3 model created successfully");
288 Ok(Self { config, layers })
289 }
290
291 pub fn config(&self) -> &H3Config {
293 &self.config
294 }
295}
296
297impl SignalPredictor for H3 {
298 #[instrument(skip(self, input))]
299 fn step(&mut self, input: &Array1<f32>) -> CoreResult<Array1<f32>> {
300 let mut x = input.clone();
301
302 for layer in &mut self.layers {
303 x = layer.forward(&x)?;
304 }
305
306 Ok(x)
307 }
308
309 #[instrument(skip(self))]
310 fn reset(&mut self) {
311 debug!("Resetting H3 model state");
312 for layer in &mut self.layers {
313 layer.reset();
314 }
315 }
316
317 fn context_window(&self) -> usize {
318 self.config.shift_distance * self.config.num_layers
320 }
321}
322
323impl AutoregressiveModel for H3 {
324 fn hidden_dim(&self) -> usize {
325 self.config.hidden_dim
326 }
327
328 fn state_dim(&self) -> usize {
329 self.config.ssm_dim
330 }
331
332 fn num_layers(&self) -> usize {
333 self.config.num_layers
334 }
335
336 fn model_type(&self) -> ModelType {
337 ModelType::S4 }
339
340 fn get_states(&self) -> Vec<HiddenState> {
341 self.layers
342 .iter()
343 .map(|layer| {
344 let total_size =
346 layer.shift_ssms.len() * layer.head_dim * self.config.shift_distance;
347 let mut state_vec = vec![0.0; total_size];
348
349 let mut offset = 0;
350 for ssm in &layer.shift_ssms {
351 for hist in &ssm.history {
352 if let Some(hist_slice) = hist.as_slice() {
353 state_vec[offset..offset + hist.len()].copy_from_slice(hist_slice);
354 } else {
355 for (i, &val) in hist.iter().enumerate() {
356 state_vec[offset + i] = val;
357 }
358 }
359 offset += hist.len();
360 }
361 offset += (self.config.shift_distance - ssm.history.len()) * layer.head_dim;
363 }
364
365 let state_1d = Array1::from_vec(state_vec);
366 let state_2d = state_1d.insert_axis(scirs2_core::ndarray::Axis(0));
367 let mut hidden_state = HiddenState::new(
368 self.config.hidden_dim,
369 state_2d.len_of(scirs2_core::ndarray::Axis(1)),
370 );
371 hidden_state.update(state_2d);
372 hidden_state
373 })
374 .collect()
375 }
376
377 fn set_states(&mut self, states: Vec<HiddenState>) -> ModelResult<()> {
378 if states.len() != self.config.num_layers {
379 return Err(ModelError::state_count_mismatch(
380 "H3",
381 self.config.num_layers,
382 states.len(),
383 ));
384 }
385
386 for (layer, state) in self.layers.iter_mut().zip(states.iter()) {
387 let state_2d = state.state();
388 if state_2d.nrows() > 0 {
389 let state_1d = state_2d.row(0).to_owned();
390 let mut offset = 0;
391
392 for ssm in &mut layer.shift_ssms {
393 ssm.history.clear();
394 for _ in 0..self
395 .config
396 .shift_distance
397 .min(state_1d.len() / layer.head_dim)
398 {
399 if offset + layer.head_dim <= state_1d.len() {
400 let hist_vec: Vec<f32> =
401 state_1d.slice(s![offset..offset + layer.head_dim]).to_vec();
402 ssm.history.push_back(Array1::from_vec(hist_vec));
403 offset += layer.head_dim;
404 }
405 }
406 }
407 }
408 }
409
410 Ok(())
411 }
412}
413
414use scirs2_core::ndarray::s;
416
417#[cfg(test)]
418mod tests {
419 use super::*;
420
421 #[test]
422 fn test_h3_creation() {
423 let config = H3Config::new(32, 64, 2);
424 let model = H3::new(config);
425 assert!(model.is_ok());
426 }
427
428 #[test]
429 fn test_h3_forward() {
430 let config = H3Config::new(32, 64, 2);
431 let mut model = H3::new(config).expect("Failed to create H3 model");
432
433 let input = Array1::from_vec(vec![1.0; 32]);
434 let output = model.step(&input);
435 assert!(output.is_ok());
436 assert_eq!(output.expect("Failed to get output").len(), 32);
437 }
438
439 #[test]
440 fn test_h3_reset() {
441 let config = H3Config::new(32, 64, 2);
442 let mut model = H3::new(config).expect("Failed to create H3 model");
443
444 let input = Array1::from_vec(vec![1.0; 32]);
445 let _output1 = model.step(&input).expect("Failed to get output1");
446
447 model.reset();
448
449 let output2 = model.step(&input).expect("Failed to get output2");
450 assert_eq!(output2.len(), 32);
451 }
452
453 #[test]
454 fn test_invalid_config() {
455 let mut config = H3Config::new(32, 64, 2);
456 config.num_heads = 0;
457 assert!(config.validate().is_err());
458 }
459
460 #[test]
461 fn test_h3_context_window() {
462 let config = H3Config::new(32, 64, 3);
463 let model = H3::new(config.clone()).expect("Failed to create H3 model");
464 assert_eq!(
465 model.context_window(),
466 config.shift_distance * config.num_layers
467 );
468 }
469
470 #[test]
471 fn test_h3_state_management() {
472 let config = H3Config::new(32, 64, 2);
473 let mut model = H3::new(config).expect("Failed to create H3 model");
474
475 let input = Array1::from_vec(vec![0.5; 32]);
477 for _ in 0..5 {
478 let _ = model.step(&input).expect("Failed to step H3 model");
479 }
480
481 let states = model.get_states();
483 assert_eq!(states.len(), 2);
484
485 model.reset();
487 let result = model.set_states(states);
488 assert!(result.is_ok());
489 }
490}