kizzasi_model/
h3.rs

1//! H3: Hungry Hungry Hippos
2//!
3//! H3 is a state space model that uses shift SSMs with multiplicative interactions,
4//! achieving strong performance on language modeling while maintaining linear complexity.
5//!
6//! # Key Features
7//!
8//! - **Shift SSM**: Simple shift operation instead of complex SSM computations
9//! - **Multiplicative interactions**: Gating mechanisms for improved expressiveness
10//! - **Linear complexity**: O(L) for sequence length L
11//! - **Hardware-efficient**: Optimized for modern accelerators
12//!
13//! # Architecture
14//!
15//! ```text
16//! Input → [Linear] → [ShiftSSM] → [Mult Gate] → [Linear] → Output
17//!                        ↓
18//!                    [Shift Buffer]
19//! ```
20//!
21//! # Shift SSM Formulation
22//!
23//! Instead of complex state space dynamics, H3 uses:
24//! ```text
25//! y[t] = shift(x[t-k..t]) ⊙ gate(x[t])
26//! ```
27//!
28//! Where shift is a learned linear combination of shifted inputs.
29//!
30//! # References
31//!
32//! - H3 paper: "Hungry Hungry Hippos: Towards Language Modeling with State Space Models"
33//! - https://arxiv.org/abs/2212.14052
34
35use crate::error::{ModelError, ModelResult};
36use crate::{AutoregressiveModel, ModelType};
37use kizzasi_core::{silu, CoreResult, HiddenState, LayerNorm, NormType, SignalPredictor};
38use scirs2_core::ndarray::{Array1, Array2};
39use scirs2_core::random::{rng, Rng};
40use std::collections::VecDeque;
41
42#[allow(unused_imports)]
43use tracing::{debug, instrument, trace};
44
45/// Configuration for H3 model
46#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
47pub struct H3Config {
48    /// Input dimension
49    pub input_dim: usize,
50    /// Hidden dimension
51    pub hidden_dim: usize,
52    /// SSM state dimension
53    pub ssm_dim: usize,
54    /// Number of layers
55    pub num_layers: usize,
56    /// Shift distance (how far back to look)
57    pub shift_distance: usize,
58    /// Number of heads for multi-head shift SSM
59    pub num_heads: usize,
60}
61
62impl H3Config {
63    /// Create default H3 configuration
64    pub fn new(input_dim: usize, hidden_dim: usize, num_layers: usize) -> Self {
65        Self {
66            input_dim,
67            hidden_dim,
68            ssm_dim: 64,
69            num_layers,
70            shift_distance: 4,
71            num_heads: 4,
72        }
73    }
74
75    /// Validate configuration
76    pub fn validate(&self) -> ModelResult<()> {
77        if self.hidden_dim == 0 {
78            return Err(ModelError::invalid_config("hidden_dim must be > 0"));
79        }
80        if self.ssm_dim == 0 {
81            return Err(ModelError::invalid_config("ssm_dim must be > 0"));
82        }
83        if self.num_layers == 0 {
84            return Err(ModelError::invalid_config("num_layers must be > 0"));
85        }
86        if self.shift_distance == 0 {
87            return Err(ModelError::invalid_config("shift_distance must be > 0"));
88        }
89        if self.num_heads == 0 {
90            return Err(ModelError::invalid_config("num_heads must be > 0"));
91        }
92        if !self.hidden_dim.is_multiple_of(self.num_heads) {
93            return Err(ModelError::invalid_config(
94                "hidden_dim must be divisible by num_heads",
95            ));
96        }
97        Ok(())
98    }
99}
100
101/// Shift SSM block - core of H3
102struct ShiftSSM {
103    /// Head dimension
104    head_dim: usize,
105    /// Shift distance
106    shift_distance: usize,
107    /// Shift weights for each position [shift_distance, head_dim]
108    shift_weights: Array2<f32>,
109    /// History buffer for shifts
110    history: VecDeque<Array1<f32>>,
111}
112
113impl ShiftSSM {
114    /// Create new Shift SSM
115    fn new(head_dim: usize, shift_distance: usize) -> Self {
116        let mut rng = rng();
117
118        // Initialize shift weights
119        let scale = (1.0 / shift_distance as f32).sqrt();
120        let shift_weights = Array2::from_shape_fn((shift_distance, head_dim), |_| {
121            (rng.random::<f32>() - 0.5) * 2.0 * scale
122        });
123
124        // Initialize history buffer
125        let history = VecDeque::with_capacity(shift_distance);
126
127        Self {
128            head_dim,
129            shift_distance,
130            shift_weights,
131            history,
132        }
133    }
134
135    /// Forward pass through shift SSM
136    fn forward(&mut self, x: &Array1<f32>) -> Array1<f32> {
137        // Add current input to history
138        self.history.push_back(x.clone());
139
140        // Keep only the last shift_distance elements
141        while self.history.len() > self.shift_distance {
142            self.history.pop_front();
143        }
144
145        // Compute weighted sum of shifted inputs
146        let mut output = Array1::zeros(self.head_dim);
147        for (i, hist_x) in self.history.iter().enumerate() {
148            let weight_row = self.shift_weights.row(i);
149            output = output + hist_x * &weight_row;
150        }
151
152        output
153    }
154
155    /// Reset history
156    fn reset(&mut self) {
157        self.history.clear();
158    }
159}
160
161/// H3 layer with multi-head shift SSM
162struct H3Layer {
163    /// Number of heads
164    num_heads: usize,
165    /// Head dimension
166    head_dim: usize,
167    /// Input projection
168    input_proj: Array2<f32>,
169    /// Shift SSMs (one per head)
170    shift_ssms: Vec<ShiftSSM>,
171    /// Gate projection for multiplicative interaction
172    gate_proj: Array2<f32>,
173    /// Output projection
174    output_proj: Array2<f32>,
175    /// Layer normalization
176    layer_norm: LayerNorm,
177}
178
179impl H3Layer {
180    /// Create a new H3 layer
181    fn new(config: &H3Config) -> Self {
182        let mut rng = rng();
183        let num_heads = config.num_heads;
184        let head_dim = config.hidden_dim / num_heads;
185
186        // Input projection
187        let scale = (2.0 / (config.input_dim + config.hidden_dim) as f32).sqrt();
188        let input_proj = Array2::from_shape_fn((config.input_dim, config.hidden_dim), |_| {
189            (rng.random::<f32>() - 0.5) * 2.0 * scale
190        });
191
192        // Create shift SSMs for each head
193        let shift_ssms = (0..num_heads)
194            .map(|_| ShiftSSM::new(head_dim, config.shift_distance))
195            .collect();
196
197        // Gate projection
198        let scale = (2.0 / (config.hidden_dim + config.hidden_dim) as f32).sqrt();
199        let gate_proj = Array2::from_shape_fn((config.hidden_dim, config.hidden_dim), |_| {
200            (rng.random::<f32>() - 0.5) * 2.0 * scale
201        });
202
203        // Output projection
204        let scale = (2.0 / (config.hidden_dim + config.input_dim) as f32).sqrt();
205        let output_proj = Array2::from_shape_fn((config.hidden_dim, config.input_dim), |_| {
206            (rng.random::<f32>() - 0.5) * 2.0 * scale
207        });
208
209        // Layer normalization
210        let layer_norm = LayerNorm::new(config.hidden_dim, NormType::RMSNorm);
211
212        Self {
213            num_heads,
214            head_dim,
215            input_proj,
216            shift_ssms,
217            gate_proj,
218            output_proj,
219            layer_norm,
220        }
221    }
222
223    /// Forward pass
224    fn forward(&mut self, x: &Array1<f32>) -> CoreResult<Array1<f32>> {
225        // Project input to hidden dimension
226        let hidden = x.dot(&self.input_proj);
227
228        // Split into heads and process through shift SSMs
229        let mut ssm_outputs = Vec::with_capacity(self.num_heads);
230        for (head_idx, ssm) in self.shift_ssms.iter_mut().enumerate() {
231            let start = head_idx * self.head_dim;
232            let end = start + self.head_dim;
233            let head_input = hidden.slice(s![start..end]).to_owned();
234            ssm_outputs.push(ssm.forward(&head_input));
235        }
236
237        // Concatenate head outputs
238        let mut ssm_output = Array1::zeros(self.num_heads * self.head_dim);
239        for (head_idx, head_out) in ssm_outputs.iter().enumerate() {
240            let start = head_idx * self.head_dim;
241            let end = start + self.head_dim;
242            ssm_output.slice_mut(s![start..end]).assign(head_out);
243        }
244
245        // Multiplicative gating
246        let gate = hidden.dot(&self.gate_proj);
247        let gate_activated = silu(&gate);
248        let gated = &ssm_output * &gate_activated;
249
250        // Layer normalization
251        let normed = self.layer_norm.forward(&gated);
252
253        // Output projection with residual
254        let output = normed.dot(&self.output_proj) + x;
255
256        Ok(output)
257    }
258
259    /// Reset layer state
260    fn reset(&mut self) {
261        for ssm in &mut self.shift_ssms {
262            ssm.reset();
263        }
264    }
265}
266
267/// H3 model with multiple layers
268pub struct H3 {
269    config: H3Config,
270    layers: Vec<H3Layer>,
271}
272
273impl H3 {
274    /// Create a new H3 model
275    #[instrument(skip(config), fields(input_dim = config.input_dim, hidden_dim = config.hidden_dim, num_layers = config.num_layers))]
276    pub fn new(config: H3Config) -> ModelResult<Self> {
277        debug!("Creating new H3 model");
278        config.validate()?;
279
280        let mut layers = Vec::with_capacity(config.num_layers);
281        for layer_idx in 0..config.num_layers {
282            trace!("Initializing H3 layer {}", layer_idx);
283            layers.push(H3Layer::new(&config));
284        }
285        debug!("Initialized {} H3 layers", layers.len());
286
287        debug!("H3 model created successfully");
288        Ok(Self { config, layers })
289    }
290
291    /// Get configuration
292    pub fn config(&self) -> &H3Config {
293        &self.config
294    }
295}
296
297impl SignalPredictor for H3 {
298    #[instrument(skip(self, input))]
299    fn step(&mut self, input: &Array1<f32>) -> CoreResult<Array1<f32>> {
300        let mut x = input.clone();
301
302        for layer in &mut self.layers {
303            x = layer.forward(&x)?;
304        }
305
306        Ok(x)
307    }
308
309    #[instrument(skip(self))]
310    fn reset(&mut self) {
311        debug!("Resetting H3 model state");
312        for layer in &mut self.layers {
313            layer.reset();
314        }
315    }
316
317    fn context_window(&self) -> usize {
318        // H3 has limited context via shift buffer
319        self.config.shift_distance * self.config.num_layers
320    }
321}
322
323impl AutoregressiveModel for H3 {
324    fn hidden_dim(&self) -> usize {
325        self.config.hidden_dim
326    }
327
328    fn state_dim(&self) -> usize {
329        self.config.ssm_dim
330    }
331
332    fn num_layers(&self) -> usize {
333        self.config.num_layers
334    }
335
336    fn model_type(&self) -> ModelType {
337        ModelType::S4 // H3 is an SSM variant, closest to S4
338    }
339
340    fn get_states(&self) -> Vec<HiddenState> {
341        self.layers
342            .iter()
343            .map(|layer| {
344                // Collect shift buffer histories into a state
345                let total_size =
346                    layer.shift_ssms.len() * layer.head_dim * self.config.shift_distance;
347                let mut state_vec = vec![0.0; total_size];
348
349                let mut offset = 0;
350                for ssm in &layer.shift_ssms {
351                    for hist in &ssm.history {
352                        if let Some(hist_slice) = hist.as_slice() {
353                            state_vec[offset..offset + hist.len()].copy_from_slice(hist_slice);
354                        } else {
355                            for (i, &val) in hist.iter().enumerate() {
356                                state_vec[offset + i] = val;
357                            }
358                        }
359                        offset += hist.len();
360                    }
361                    // Pad if history is shorter than shift_distance
362                    offset += (self.config.shift_distance - ssm.history.len()) * layer.head_dim;
363                }
364
365                let state_1d = Array1::from_vec(state_vec);
366                let state_2d = state_1d.insert_axis(scirs2_core::ndarray::Axis(0));
367                let mut hidden_state = HiddenState::new(
368                    self.config.hidden_dim,
369                    state_2d.len_of(scirs2_core::ndarray::Axis(1)),
370                );
371                hidden_state.update(state_2d);
372                hidden_state
373            })
374            .collect()
375    }
376
377    fn set_states(&mut self, states: Vec<HiddenState>) -> ModelResult<()> {
378        if states.len() != self.config.num_layers {
379            return Err(ModelError::state_count_mismatch(
380                "H3",
381                self.config.num_layers,
382                states.len(),
383            ));
384        }
385
386        for (layer, state) in self.layers.iter_mut().zip(states.iter()) {
387            let state_2d = state.state();
388            if state_2d.nrows() > 0 {
389                let state_1d = state_2d.row(0).to_owned();
390                let mut offset = 0;
391
392                for ssm in &mut layer.shift_ssms {
393                    ssm.history.clear();
394                    for _ in 0..self
395                        .config
396                        .shift_distance
397                        .min(state_1d.len() / layer.head_dim)
398                    {
399                        if offset + layer.head_dim <= state_1d.len() {
400                            let hist_vec: Vec<f32> =
401                                state_1d.slice(s![offset..offset + layer.head_dim]).to_vec();
402                            ssm.history.push_back(Array1::from_vec(hist_vec));
403                            offset += layer.head_dim;
404                        }
405                    }
406                }
407            }
408        }
409
410        Ok(())
411    }
412}
413
414// Import ndarray slice macro
415use scirs2_core::ndarray::s;
416
417#[cfg(test)]
418mod tests {
419    use super::*;
420
421    #[test]
422    fn test_h3_creation() {
423        let config = H3Config::new(32, 64, 2);
424        let model = H3::new(config);
425        assert!(model.is_ok());
426    }
427
428    #[test]
429    fn test_h3_forward() {
430        let config = H3Config::new(32, 64, 2);
431        let mut model = H3::new(config).expect("Failed to create H3 model");
432
433        let input = Array1::from_vec(vec![1.0; 32]);
434        let output = model.step(&input);
435        assert!(output.is_ok());
436        assert_eq!(output.expect("Failed to get output").len(), 32);
437    }
438
439    #[test]
440    fn test_h3_reset() {
441        let config = H3Config::new(32, 64, 2);
442        let mut model = H3::new(config).expect("Failed to create H3 model");
443
444        let input = Array1::from_vec(vec![1.0; 32]);
445        let _output1 = model.step(&input).expect("Failed to get output1");
446
447        model.reset();
448
449        let output2 = model.step(&input).expect("Failed to get output2");
450        assert_eq!(output2.len(), 32);
451    }
452
453    #[test]
454    fn test_invalid_config() {
455        let mut config = H3Config::new(32, 64, 2);
456        config.num_heads = 0;
457        assert!(config.validate().is_err());
458    }
459
460    #[test]
461    fn test_h3_context_window() {
462        let config = H3Config::new(32, 64, 3);
463        let model = H3::new(config.clone()).expect("Failed to create H3 model");
464        assert_eq!(
465            model.context_window(),
466            config.shift_distance * config.num_layers
467        );
468    }
469
470    #[test]
471    fn test_h3_state_management() {
472        let config = H3Config::new(32, 64, 2);
473        let mut model = H3::new(config).expect("Failed to create H3 model");
474
475        // Run a few steps
476        let input = Array1::from_vec(vec![0.5; 32]);
477        for _ in 0..5 {
478            let _ = model.step(&input).expect("Failed to step H3 model");
479        }
480
481        // Get states
482        let states = model.get_states();
483        assert_eq!(states.len(), 2);
484
485        // Reset and set states
486        model.reset();
487        let result = model.set_states(states);
488        assert!(result.is_ok());
489    }
490}