use crate::{CoreError, CoreResult};
use scirs2_core::ndarray::{s, Array1, Array2};
use scirs2_core::random::quick::random_f32;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct H3Config {
pub d_input: usize,
pub d_model: usize,
pub d_state: usize,
pub n_layers: usize,
pub d_head: usize,
pub n_heads: usize,
pub dropout: f32,
pub layer_norm_eps: f32,
}
impl Default for H3Config {
fn default() -> Self {
Self {
d_input: 128,
d_model: 512,
d_state: 64,
n_layers: 6,
d_head: 64,
n_heads: 8,
dropout: 0.1,
layer_norm_eps: 1e-5,
}
}
}
impl H3Config {
pub fn new(d_input: usize, d_model: usize, n_layers: usize) -> Self {
let n_heads = 8; let d_head = d_model / n_heads;
Self {
d_input,
d_model,
d_state: 64, n_layers,
d_head,
n_heads,
..Default::default()
}
}
pub fn with_d_state(mut self, d_state: usize) -> Self {
self.d_state = d_state;
self
}
pub fn with_n_heads(mut self, n_heads: usize) -> Self {
self.n_heads = n_heads;
if let Some(d_head) = self.d_model.checked_div(n_heads) {
self.d_head = d_head;
}
self
}
pub fn validate(&self) -> CoreResult<()> {
if self.d_input == 0 || self.d_model == 0 || self.d_state == 0 || self.n_layers == 0 {
return Err(CoreError::InvalidConfig(
"All dimensions and layers must be positive".to_string(),
));
}
if !self.d_model.is_multiple_of(self.n_heads) {
return Err(CoreError::InvalidConfig(
"d_model must be divisible by n_heads".to_string(),
));
}
if self.n_heads == 0 {
return Err(CoreError::InvalidConfig(
"n_heads must be positive".to_string(),
));
}
Ok(())
}
}
pub struct ShiftSSM {
d_model: usize,
d_state: usize,
d_head: usize,
n_heads: usize,
shift_coeffs: Array1<f32>,
q_proj: Array2<f32>,
k_proj: Array2<f32>,
v_proj: Array2<f32>,
state: Array2<f32>, }
impl ShiftSSM {
pub fn new(config: &H3Config) -> Self {
let d_model = config.d_model;
let d_state = config.d_state;
let d_head = config.d_head;
let n_heads = config.n_heads;
let shift_coeffs = Array1::from_shape_fn(d_state, |i| (-((i + 1) as f32).ln()).exp());
let scale = (2.0 / d_model as f32).sqrt();
let q_proj =
Array2::from_shape_fn((d_model, d_model), |_| (random_f32() - 0.5) * 2.0 * scale);
let k_proj =
Array2::from_shape_fn((d_model, d_model), |_| (random_f32() - 0.5) * 2.0 * scale);
let v_proj =
Array2::from_shape_fn((d_model, d_model), |_| (random_f32() - 0.5) * 2.0 * scale);
let state = Array2::zeros((n_heads, d_state));
Self {
d_model,
d_state,
d_head,
n_heads,
shift_coeffs,
q_proj,
k_proj,
v_proj,
state,
}
}
pub fn forward(&mut self, x: &Array1<f32>) -> CoreResult<Array1<f32>> {
if x.len() != self.d_model {
return Err(CoreError::DimensionMismatch {
expected: self.d_model,
got: x.len(),
});
}
let _q = self.q_proj.dot(x); let _k = self.k_proj.dot(x); let v = self.v_proj.dot(x);
let mut output = Array1::zeros(self.d_model);
for h in 0..self.n_heads {
let head_start = h * self.d_head;
let head_end = (h + 1) * self.d_head;
let v_head = v.slice(s![head_start..head_end]);
let mut new_state = Array1::zeros(self.d_state);
for i in 1..self.d_state {
new_state[i] = self.state[[h, i - 1]] * self.shift_coeffs[i];
}
if !v_head.is_empty() {
new_state[0] = v_head[0];
}
let mut head_out = Array1::zeros(self.d_head);
for i in 0..self.d_head.min(self.d_state) {
head_out[i] = new_state[i];
}
for i in 0..self.d_state {
self.state[[h, i]] = new_state[i];
}
for (i, &val) in head_out.iter().enumerate() {
output[head_start + i] = val;
}
}
Ok(output)
}
pub fn reset(&mut self) {
self.state.fill(0.0);
}
}
pub struct DiagonalSSM {
d_model: usize,
d_state: usize,
a_diag: Array1<f32>,
b_matrix: Array2<f32>,
c_matrix: Array2<f32>,
d_matrix: Array1<f32>,
state: Array1<f32>,
}
impl DiagonalSSM {
pub fn new(config: &H3Config) -> Self {
let d_model = config.d_model;
let d_state = config.d_state;
let a_diag = Array1::from_shape_fn(d_state, |i| -((i + 1) as f32 / d_state as f32));
let scale = (1.0 / d_model as f32).sqrt();
let b_matrix =
Array2::from_shape_fn((d_state, d_model), |_| (random_f32() - 0.5) * 2.0 * scale);
let c_matrix =
Array2::from_shape_fn((d_model, d_state), |_| (random_f32() - 0.5) * 2.0 * scale);
let d_matrix = Array1::zeros(d_model);
let state = Array1::zeros(d_state);
Self {
d_model,
d_state,
a_diag,
b_matrix,
c_matrix,
d_matrix,
state,
}
}
pub fn forward(&mut self, x: &Array1<f32>) -> CoreResult<Array1<f32>> {
if x.len() != self.d_model {
return Err(CoreError::DimensionMismatch {
expected: self.d_model,
got: x.len(),
});
}
let b_u = self.b_matrix.dot(x);
for i in 0..self.d_state {
self.state[i] = self.a_diag[i] * self.state[i] + b_u[i];
}
let c_x = self.c_matrix.dot(&self.state);
let mut output = Array1::zeros(self.d_model);
for i in 0..self.d_model {
output[i] = c_x[i] + self.d_matrix[i] * x[i];
}
Ok(output)
}
pub fn reset(&mut self) {
self.state.fill(0.0);
}
}
pub struct H3Layer {
config: H3Config,
shift_ssm: ShiftSSM,
diag_ssm: DiagonalSSM,
gate_proj: Array2<f32>,
output_proj: Array2<f32>,
ln1_weight: Array1<f32>,
ln1_bias: Array1<f32>,
ln2_weight: Array1<f32>,
ln2_bias: Array1<f32>,
}
impl H3Layer {
pub fn new(config: H3Config) -> CoreResult<Self> {
config.validate()?;
let shift_ssm = ShiftSSM::new(&config);
let diag_ssm = DiagonalSSM::new(&config);
let scale = (2.0 / config.d_model as f32).sqrt();
let gate_proj = Array2::from_shape_fn((config.d_model, config.d_model), |_| {
(random_f32() - 0.5) * 2.0 * scale
});
let output_proj = Array2::from_shape_fn((config.d_model, config.d_model), |_| {
(random_f32() - 0.5) * 2.0 * scale
});
let ln1_weight = Array1::ones(config.d_model);
let ln1_bias = Array1::zeros(config.d_model);
let ln2_weight = Array1::ones(config.d_model);
let ln2_bias = Array1::zeros(config.d_model);
Ok(Self {
config,
shift_ssm,
diag_ssm,
gate_proj,
output_proj,
ln1_weight,
ln1_bias,
ln2_weight,
ln2_bias,
})
}
pub fn forward(&mut self, x: &Array1<f32>) -> CoreResult<Array1<f32>> {
let x_norm = self.layer_norm(x, &self.ln1_weight, &self.ln1_bias)?;
let shift_out = self.shift_ssm.forward(&x_norm)?;
let diag_out = self.diag_ssm.forward(&x_norm)?;
let gate = self.gate_proj.dot(&x_norm);
let gate_act = gate.mapv(|v| 1.0 / (1.0 + (-v).exp()));
let mut combined = Array1::zeros(self.config.d_model);
for i in 0..self.config.d_model {
combined[i] = gate_act[i] * shift_out[i] + (1.0 - gate_act[i]) * diag_out[i];
}
let proj_out = self.output_proj.dot(&combined);
let output = x + &proj_out;
let output = self.layer_norm(&output, &self.ln2_weight, &self.ln2_bias)?;
Ok(output)
}
fn layer_norm(
&self,
x: &Array1<f32>,
weight: &Array1<f32>,
bias: &Array1<f32>,
) -> CoreResult<Array1<f32>> {
let mean = x.mean().unwrap_or(0.0);
let var = x.mapv(|v| (v - mean).powi(2)).mean().unwrap_or(0.0);
let std = (var + self.config.layer_norm_eps).sqrt();
let mut normalized = Array1::zeros(x.len());
for i in 0..x.len() {
normalized[i] = ((x[i] - mean) / std) * weight[i] + bias[i];
}
Ok(normalized)
}
pub fn reset(&mut self) {
self.shift_ssm.reset();
self.diag_ssm.reset();
}
}
pub struct H3Model {
config: H3Config,
embedding: Array2<f32>,
layers: Vec<H3Layer>,
ln_out_weight: Array1<f32>,
ln_out_bias: Array1<f32>,
}
impl H3Model {
pub fn new(config: H3Config) -> CoreResult<Self> {
config.validate()?;
let scale = (1.0 / config.d_input as f32).sqrt();
let embedding = Array2::from_shape_fn((config.d_model, config.d_input), |_| {
(random_f32() - 0.5) * 2.0 * scale
});
let mut layers = Vec::with_capacity(config.n_layers);
for _ in 0..config.n_layers {
layers.push(H3Layer::new(config.clone())?);
}
let ln_out_weight = Array1::ones(config.d_model);
let ln_out_bias = Array1::zeros(config.d_model);
Ok(Self {
config,
embedding,
layers,
ln_out_weight,
ln_out_bias,
})
}
pub fn forward(&mut self, x: &Array1<f32>) -> CoreResult<Array1<f32>> {
if x.len() != self.config.d_input {
return Err(CoreError::DimensionMismatch {
expected: self.config.d_input,
got: x.len(),
});
}
let mut h = self.embedding.dot(x);
for layer in &mut self.layers {
h = layer.forward(&h)?;
}
let output = self.layer_norm(&h)?;
Ok(output)
}
fn layer_norm(&self, x: &Array1<f32>) -> CoreResult<Array1<f32>> {
let mean = x.mean().unwrap_or(0.0);
let var = x.mapv(|v| (v - mean).powi(2)).mean().unwrap_or(0.0);
let std = (var + self.config.layer_norm_eps).sqrt();
let mut normalized = Array1::zeros(x.len());
for i in 0..x.len() {
normalized[i] = ((x[i] - mean) / std) * self.ln_out_weight[i] + self.ln_out_bias[i];
}
Ok(normalized)
}
pub fn reset(&mut self) {
for layer in &mut self.layers {
layer.reset();
}
}
pub fn config(&self) -> &H3Config {
&self.config
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_h3_config() {
let config = H3Config::new(64, 256, 4);
assert_eq!(config.d_input, 64);
assert_eq!(config.d_model, 256);
assert_eq!(config.n_layers, 4);
assert!(config.validate().is_ok());
}
#[test]
fn test_h3_config_validation() {
let config = H3Config {
d_model: 0,
..Default::default()
};
assert!(config.validate().is_err());
let config = H3Config {
d_model: 100,
n_heads: 3, ..Default::default()
};
assert!(config.validate().is_err());
}
#[test]
fn test_shift_ssm_creation() {
let config = H3Config::new(64, 256, 2);
let ssm = ShiftSSM::new(&config);
assert_eq!(ssm.d_model, 256);
assert_eq!(ssm.d_state, 64);
}
#[test]
fn test_shift_ssm_forward() {
let config = H3Config::new(64, 256, 2);
let mut ssm = ShiftSSM::new(&config);
let x = Array1::from_elem(256, 0.5);
let result = ssm.forward(&x);
assert!(result.is_ok());
let output = result.unwrap();
assert_eq!(output.len(), 256);
assert!(output.iter().all(|&v| v.is_finite()));
}
#[test]
fn test_diag_ssm_creation() {
let config = H3Config::new(64, 256, 2);
let ssm = DiagonalSSM::new(&config);
assert_eq!(ssm.d_model, 256);
assert_eq!(ssm.d_state, 64);
}
#[test]
fn test_diag_ssm_forward() {
let config = H3Config::new(64, 256, 2);
let mut ssm = DiagonalSSM::new(&config);
let x = Array1::from_elem(256, 0.3);
let result = ssm.forward(&x);
assert!(result.is_ok());
let output = result.unwrap();
assert_eq!(output.len(), 256);
assert!(output.iter().all(|&v| v.is_finite()));
}
#[test]
fn test_h3_layer_creation() {
let config = H3Config::new(64, 256, 2);
let layer = H3Layer::new(config);
assert!(layer.is_ok());
}
#[test]
fn test_h3_layer_forward() {
let config = H3Config::new(64, 256, 2);
let mut layer = H3Layer::new(config).unwrap();
let x = Array1::from_elem(256, 0.2);
let result = layer.forward(&x);
assert!(result.is_ok());
let output = result.unwrap();
assert_eq!(output.len(), 256);
assert!(output.iter().all(|&v| v.is_finite()));
}
#[test]
fn test_h3_model_creation() {
let config = H3Config::new(64, 256, 4);
let model = H3Model::new(config);
assert!(model.is_ok());
let m = model.unwrap();
assert_eq!(m.layers.len(), 4);
}
#[test]
fn test_h3_model_forward() {
let config = H3Config::new(64, 256, 3);
let mut model = H3Model::new(config).unwrap();
let x = Array1::from_elem(64, 0.1);
let result = model.forward(&x);
assert!(result.is_ok());
let output = result.unwrap();
assert_eq!(output.len(), 256);
assert!(output.iter().all(|&v| v.is_finite()));
}
#[test]
fn test_h3_reset() {
let config = H3Config::new(64, 256, 2);
let mut model = H3Model::new(config).unwrap();
let x = Array1::from_elem(64, 0.5);
let _ = model.forward(&x).unwrap();
model.reset();
for layer in &model.layers {
assert!(layer.shift_ssm.state.iter().all(|&v| v == 0.0));
assert!(layer.diag_ssm.state.iter().all(|&v| v == 0.0));
}
}
#[test]
fn test_h3_sequence_processing() {
let config = H3Config::new(32, 128, 2);
let mut model = H3Model::new(config).unwrap();
let x1 = Array1::from_elem(32, 0.1);
let x2 = Array1::from_elem(32, 0.5);
let x3 = Array1::from_elem(32, 0.9);
let out1 = model.forward(&x1).unwrap();
let out2 = model.forward(&x2).unwrap();
let out3 = model.forward(&x3).unwrap();
assert!(out1.iter().all(|&v| v.is_finite()));
assert!(out2.iter().all(|&v| v.is_finite()));
assert!(out3.iter().all(|&v| v.is_finite()));
assert!(model.layers[0].shift_ssm.state.iter().any(|&v| v != 0.0));
}
}