use crate::error::{ModelError, ModelResult};
use crate::{AutoregressiveModel, ModelType};
use kizzasi_core::{sigmoid, silu, CoreResult, HiddenState, LayerNorm, NormType, SignalPredictor};
use scirs2_core::ndarray::{Array1, Array2};
use serde::{Deserialize, Serialize};
#[allow(unused_imports)]
use tracing::{debug, instrument, trace};
struct SeededRng {
state: u64,
}
impl SeededRng {
fn new(seed: u64) -> Self {
Self { state: seed.max(1) }
}
fn next_f32(&mut self) -> f32 {
self.state ^= self.state << 13;
self.state ^= self.state >> 7;
self.state ^= self.state << 17;
(self.state as f64 / u64::MAX as f64 * 2.0 - 1.0) as f32
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Rwkv7Config {
pub input_dim: usize,
pub hidden_dim: usize,
pub num_layers: usize,
pub num_heads: usize,
pub head_dim: usize,
pub expand_factor: f32,
pub context_length: usize,
pub time_decay_init: f32,
}
impl Default for Rwkv7Config {
fn default() -> Self {
let hidden_dim = 768;
let num_heads = 12;
Self {
input_dim: 1,
hidden_dim,
num_layers: 24,
num_heads,
head_dim: hidden_dim / num_heads,
expand_factor: 3.5,
context_length: 16384,
time_decay_init: -6.0,
}
}
}
impl Rwkv7Config {
pub fn new() -> Self {
Self::default()
}
pub fn small(input_dim: usize) -> Self {
Self {
input_dim,
hidden_dim: 256,
num_layers: 4,
num_heads: 4,
head_dim: 64,
expand_factor: 3.5,
context_length: 4096,
time_decay_init: -5.0,
}
}
pub fn base(input_dim: usize) -> Self {
Self {
input_dim,
hidden_dim: 768,
num_layers: 12,
num_heads: 12,
head_dim: 64,
expand_factor: 3.5,
context_length: 8192,
time_decay_init: -6.0,
}
}
pub fn large(input_dim: usize) -> Self {
Self {
input_dim,
hidden_dim: 4096,
num_layers: 32,
num_heads: 32,
head_dim: 128,
expand_factor: 3.5,
context_length: 16384,
time_decay_init: -6.0,
}
}
pub fn input_dim(mut self, dim: usize) -> Self {
self.input_dim = dim;
self
}
pub fn hidden_dim(mut self, dim: usize) -> Self {
self.hidden_dim = dim;
if let Some(d) = dim.checked_div(self.num_heads) {
self.head_dim = d;
}
self
}
pub fn num_layers(mut self, n: usize) -> Self {
self.num_layers = n;
self
}
pub fn num_heads(mut self, n: usize) -> Self {
self.num_heads = n;
if let Some(d) = self.hidden_dim.checked_div(n) {
self.head_dim = d;
}
self
}
pub fn context_length(mut self, len: usize) -> Self {
self.context_length = len;
self
}
pub fn validate(&self) -> ModelResult<()> {
if self.hidden_dim == 0 {
return Err(ModelError::invalid_config("hidden_dim must be > 0"));
}
if self.num_layers == 0 {
return Err(ModelError::invalid_config("num_layers must be > 0"));
}
if self.num_heads == 0 {
return Err(ModelError::invalid_config("num_heads must be > 0"));
}
if !self.hidden_dim.is_multiple_of(self.num_heads) {
return Err(ModelError::invalid_config(
"hidden_dim must be divisible by num_heads",
));
}
if self.expand_factor <= 0.0 {
return Err(ModelError::invalid_config("expand_factor must be > 0"));
}
Ok(())
}
}
pub struct Rwkv7State {
pub wkv_states: Vec<Vec<Array2<f32>>>,
pub shift_states: Vec<Array1<f32>>,
}
impl Rwkv7State {
pub fn new(config: &Rwkv7Config) -> Self {
let wkv_states = (0..config.num_layers)
.map(|_| {
(0..config.num_heads)
.map(|_| Array2::zeros((config.head_dim, config.head_dim)))
.collect()
})
.collect();
let shift_states = (0..config.num_layers)
.map(|_| Array1::zeros(config.hidden_dim))
.collect();
Self {
wkv_states,
shift_states,
}
}
pub fn reset(&mut self) {
for layer_states in &mut self.wkv_states {
for head_state in layer_states.iter_mut() {
head_state.fill(0.0);
}
}
for shift in &mut self.shift_states {
shift.fill(0.0);
}
}
}
pub struct Rwkv7TimeMixing {
w_r: Array2<f32>, w_w: Array2<f32>, w_k: Array2<f32>, w_v: Array2<f32>, w_o: Array2<f32>, w_g: Array2<f32>, w_a: Array2<f32>, w_b: Array2<f32>,
lerp_r: Array1<f32>,
lerp_w: Array1<f32>,
lerp_k: Array1<f32>,
lerp_v: Array1<f32>,
ln_x: LayerNorm,
num_heads: usize,
head_dim: usize,
}
impl Rwkv7TimeMixing {
pub fn new(config: &Rwkv7Config) -> ModelResult<Self> {
let d = config.hidden_dim;
let mut rng = SeededRng::new(42 + d as u64);
let scale = (2.0 / d as f32).sqrt();
let make_proj = |rng: &mut SeededRng| -> Array2<f32> {
Array2::from_shape_fn((d, d), |_| rng.next_f32() * scale)
};
let w_r = make_proj(&mut rng);
let w_w = make_proj(&mut rng);
let w_k = make_proj(&mut rng);
let w_v = make_proj(&mut rng);
let w_o = make_proj(&mut rng);
let w_g = make_proj(&mut rng);
let w_a = make_proj(&mut rng);
let w_b = make_proj(&mut rng);
let lerp_r = Array1::from_shape_fn(d, |_| rng.next_f32().abs() * 0.5 + 0.25);
let lerp_w = Array1::from_shape_fn(d, |_| rng.next_f32().abs() * 0.5 + 0.25);
let lerp_k = Array1::from_shape_fn(d, |_| rng.next_f32().abs() * 0.5 + 0.25);
let lerp_v = Array1::from_shape_fn(d, |_| rng.next_f32().abs() * 0.5 + 0.25);
let ln_x = LayerNorm::new(d, NormType::RMSNorm).with_eps(1e-5);
Ok(Self {
w_r,
w_w,
w_k,
w_v,
w_o,
w_g,
w_a,
w_b,
lerp_r,
lerp_w,
lerp_k,
lerp_v,
ln_x,
num_heads: config.num_heads,
head_dim: config.head_dim,
})
}
pub fn forward(
&self,
x: &Array1<f32>,
state: &mut Rwkv7State,
layer_idx: usize,
) -> ModelResult<Array1<f32>> {
let d = x.len();
let prev = &state.shift_states[layer_idx];
let dx = x - prev;
state.shift_states[layer_idx] = x.clone();
let xr = x + &(&self.lerp_r * &dx);
let xw = x + &(&self.lerp_w * &dx);
let xk = x + &(&self.lerp_k * &dx);
let xv = x + &(&self.lerp_v * &dx);
let r_raw = self.matvec(&self.w_r, &xr);
let w_raw = self.matvec(&self.w_w, &xw);
let k_raw = self.matvec(&self.w_k, &xk);
let v_raw = self.matvec(&self.w_v, &xv);
let r = sigmoid(&r_raw); let w = sigmoid(&w_raw); let g = silu(&self.matvec(&self.w_g, x)); let a = sigmoid(&self.matvec(&self.w_a, x)); let b = sigmoid(&self.matvec(&self.w_b, x));
let mut output_heads = Array1::zeros(d);
for h in 0..self.num_heads {
let lo = h * self.head_dim;
let hi = lo + self.head_dim;
let r_h = r.slice(scirs2_core::ndarray::s![lo..hi]).to_owned();
let k_h = k_raw.slice(scirs2_core::ndarray::s![lo..hi]).to_owned();
let v_h = v_raw.slice(scirs2_core::ndarray::s![lo..hi]).to_owned();
let w_h = w.slice(scirs2_core::ndarray::s![lo..hi]).to_owned();
let a_h = a.slice(scirs2_core::ndarray::s![lo..hi]).to_owned();
let b_h = b.slice(scirs2_core::ndarray::s![lo..hi]).to_owned();
let head_state = &mut state.wkv_states[layer_idx][h];
for i in 0..self.head_dim {
let decay = w_h[i].clamp(0.0, 1.0);
for j in 0..self.head_dim {
head_state[[i, j]] = decay * head_state[[i, j]] + k_h[i] * v_h[j];
}
}
let state_b = self.matvec_small(head_state, &b_h);
for i in 0..self.head_dim {
let val = r_h[i] * (state_b[i] + a_h[i] * v_h[i]);
output_heads[lo + i] = val;
}
}
let normed = self.ln_x.forward(&output_heads);
let gated = &g * &normed;
let out = self.matvec(&self.w_o, &gated);
Ok(out)
}
fn matvec(&self, w: &Array2<f32>, x: &Array1<f32>) -> Array1<f32> {
let rows = w.shape()[0];
let cols = w.shape()[1];
let xlen = x.len();
let mut out = Array1::zeros(rows);
for i in 0..rows {
let mut sum = 0.0f32;
for j in 0..cols.min(xlen) {
sum += w[[i, j]] * x[j];
}
out[i] = sum;
}
out
}
fn matvec_small(&self, w: &Array2<f32>, x: &Array1<f32>) -> Array1<f32> {
let rows = w.shape()[0];
let cols = w.shape()[1];
let xlen = x.len();
let mut out = Array1::zeros(rows);
for i in 0..rows {
let mut sum = 0.0f32;
for j in 0..cols.min(xlen) {
sum += w[[i, j]] * x[j];
}
out[i] = sum;
}
out
}
}
struct Rwkv7ChannelMixing {
hidden_dim: usize,
intermediate_dim: usize,
time_mix_k: Array1<f32>,
time_mix_r: Array1<f32>,
key_proj: Array2<f32>, value_proj: Array2<f32>, receptance_proj: Array2<f32>,
prev_x: Array1<f32>,
}
impl Rwkv7ChannelMixing {
fn new(config: &Rwkv7Config) -> ModelResult<Self> {
let d = config.hidden_dim;
let inter = (d as f32 * config.expand_factor) as usize;
let mut rng = SeededRng::new(137 + d as u64 + inter as u64);
let scale = (2.0 / d as f32).sqrt();
let time_mix_k = Array1::from_shape_fn(d, |_| rng.next_f32().abs() * 0.5 + 0.25);
let time_mix_r = Array1::from_shape_fn(d, |_| rng.next_f32().abs() * 0.5 + 0.25);
let key_proj = Array2::from_shape_fn((d, inter), |_| rng.next_f32() * scale);
let value_proj = Array2::from_shape_fn((inter, d), |_| rng.next_f32() * scale);
let receptance_proj = Array2::from_shape_fn((d, d), |_| rng.next_f32() * scale);
Ok(Self {
hidden_dim: d,
intermediate_dim: inter,
time_mix_k,
time_mix_r,
key_proj,
value_proj,
receptance_proj,
prev_x: Array1::zeros(d),
})
}
fn forward(&mut self, x: &Array1<f32>) -> CoreResult<Array1<f32>> {
let d = x.len().min(self.hidden_dim);
let mut xk = Array1::zeros(d);
let mut xr = Array1::zeros(d);
for i in 0..d {
let prev = if i < self.prev_x.len() {
self.prev_x[i]
} else {
0.0
};
xk[i] = self.time_mix_k[i] * x[i] + (1.0 - self.time_mix_k[i]) * prev;
xr[i] = self.time_mix_r[i] * x[i] + (1.0 - self.time_mix_r[i]) * prev;
}
let k = self.project_up(&xk);
let k_act = k.mapv(|v| {
let relu = v.max(0.0);
relu * relu
});
let vk = self.project_down(&k_act);
let r = self.project_r(&xr);
let r_sig = sigmoid(&r);
let mut output = Array1::zeros(d);
for i in 0..d.min(vk.len()).min(r_sig.len()) {
output[i] = r_sig[i] * vk[i];
}
self.prev_x = x.slice(scirs2_core::ndarray::s![..d]).to_owned();
Ok(output)
}
fn project_up(&self, x: &Array1<f32>) -> Array1<f32> {
let out_dim = self.intermediate_dim;
let mut output = Array1::zeros(out_dim);
for i in 0..out_dim {
let mut sum = 0.0f32;
for j in 0..x.len().min(self.key_proj.shape()[0]) {
sum += self.key_proj[[j, i]] * x[j];
}
output[i] = sum;
}
output
}
fn project_down(&self, x: &Array1<f32>) -> Array1<f32> {
let out_dim = self.hidden_dim;
let mut output = Array1::zeros(out_dim);
for i in 0..out_dim {
let mut sum = 0.0f32;
for j in 0..x.len().min(self.value_proj.shape()[0]) {
sum += self.value_proj[[j, i]] * x[j];
}
output[i] = sum;
}
output
}
fn project_r(&self, x: &Array1<f32>) -> Array1<f32> {
let out_dim = self.receptance_proj.shape()[0];
let mut output = Array1::zeros(out_dim.min(x.len()));
for i in 0..output.len() {
let mut sum = 0.0f32;
for j in 0..x.len().min(self.receptance_proj.shape()[1]) {
sum += self.receptance_proj[[i, j]] * x[j];
}
output[i] = sum;
}
output
}
fn reset(&mut self) {
self.prev_x.fill(0.0);
}
}
struct Rwkv7Layer {
ln1: LayerNorm,
ln2: LayerNorm,
time_mixing: Rwkv7TimeMixing,
channel_mixing: Rwkv7ChannelMixing,
}
impl Rwkv7Layer {
fn new(config: &Rwkv7Config) -> ModelResult<Self> {
let ln1 = LayerNorm::new(config.hidden_dim, NormType::RMSNorm).with_eps(1e-5);
let ln2 = LayerNorm::new(config.hidden_dim, NormType::RMSNorm).with_eps(1e-5);
let time_mixing = Rwkv7TimeMixing::new(config)?;
let channel_mixing = Rwkv7ChannelMixing::new(config)?;
Ok(Self {
ln1,
ln2,
time_mixing,
channel_mixing,
})
}
fn forward(
&mut self,
x: &Array1<f32>,
state: &mut Rwkv7State,
layer_idx: usize,
) -> ModelResult<Array1<f32>> {
let x_norm = self.ln1.forward(x);
let tm_out = self.time_mixing.forward(&x_norm, state, layer_idx)?;
let x_after_tm = x + &tm_out;
let x_norm2 = self.ln2.forward(&x_after_tm);
let cm_out = self
.channel_mixing
.forward(&x_norm2)
.map_err(|e| ModelError::forward_error(layer_idx, format!("channel mixing: {e}")))?;
let output = &x_after_tm + &cm_out;
Ok(output)
}
fn reset_channel_mixing(&mut self) {
self.channel_mixing.reset();
}
}
pub struct Rwkv7Model {
pub config: Rwkv7Config,
layers: Vec<Rwkv7Layer>,
ln_out: LayerNorm,
input_proj: Array2<f32>,
output_proj: Array2<f32>,
state: Rwkv7State,
}
impl Rwkv7Model {
pub fn new(config: Rwkv7Config) -> ModelResult<Self> {
config.validate()?;
let mut layers = Vec::with_capacity(config.num_layers);
for _ in 0..config.num_layers {
layers.push(Rwkv7Layer::new(&config)?);
}
let ln_out = LayerNorm::new(config.hidden_dim, NormType::RMSNorm).with_eps(1e-5);
let mut rng = SeededRng::new(7777 + config.hidden_dim as u64);
let scale = (2.0 / (config.input_dim + config.hidden_dim) as f32).sqrt();
let input_proj = Array2::from_shape_fn((config.input_dim, config.hidden_dim), |_| {
rng.next_f32() * scale
});
let output_proj = Array2::from_shape_fn((config.hidden_dim, config.input_dim), |_| {
rng.next_f32() * scale
});
let state = Rwkv7State::new(&config);
debug!(
"Created RWKV v7 model: {} layers, {} hidden, {} heads",
config.num_layers, config.hidden_dim, config.num_heads
);
Ok(Self {
config,
layers,
ln_out,
input_proj,
output_proj,
state,
})
}
pub fn small() -> ModelResult<Self> {
Self::new(Rwkv7Config::small(1))
}
pub fn base() -> ModelResult<Self> {
Self::new(Rwkv7Config::base(1))
}
pub fn large() -> ModelResult<Self> {
Self::new(Rwkv7Config::large(1))
}
pub fn init_state(&self) -> Rwkv7State {
Rwkv7State::new(&self.config)
}
pub fn config(&self) -> &Rwkv7Config {
&self.config
}
}
impl SignalPredictor for Rwkv7Model {
#[instrument(skip(self, input))]
fn step(&mut self, input: &Array1<f32>) -> CoreResult<Array1<f32>> {
let mut hidden = input.dot(&self.input_proj);
for layer_idx in 0..self.layers.len() {
let layer = &mut self.layers[layer_idx];
hidden = layer
.forward(&hidden, &mut self.state, layer_idx)
.map_err(|e| {
kizzasi_core::CoreError::InferenceError(format!("rwkv7 layer {layer_idx}: {e}"))
})?;
}
hidden = self.ln_out.forward(&hidden);
let output = hidden.dot(&self.output_proj);
Ok(output)
}
fn reset(&mut self) {
self.state.reset();
for layer in &mut self.layers {
layer.reset_channel_mixing();
}
}
fn context_window(&self) -> usize {
usize::MAX
}
}
impl AutoregressiveModel for Rwkv7Model {
fn hidden_dim(&self) -> usize {
self.config.hidden_dim
}
fn state_dim(&self) -> usize {
self.config.head_dim * self.config.num_heads
}
fn num_layers(&self) -> usize {
self.config.num_layers
}
fn model_type(&self) -> ModelType {
ModelType::Rwkv
}
fn get_states(&self) -> Vec<HiddenState> {
self.state
.wkv_states
.iter()
.map(|layer_heads| {
let total_rows = self.config.num_heads * self.config.head_dim;
let cols = self.config.head_dim;
let mut combined = Array2::zeros((total_rows, cols));
for (h, head_state) in layer_heads.iter().enumerate() {
let row_start = h * self.config.head_dim;
for i in 0..self.config.head_dim {
for j in 0..cols {
combined[[row_start + i, j]] = head_state[[i, j]];
}
}
}
let mut hs = HiddenState::new(total_rows, cols);
hs.update(combined);
hs
})
.collect()
}
fn set_states(&mut self, states: Vec<HiddenState>) -> ModelResult<()> {
if states.len() != self.config.num_layers {
return Err(ModelError::state_count_mismatch(
"RWKV7",
self.config.num_layers,
states.len(),
));
}
for (layer_idx, hs) in states.iter().enumerate() {
let combined = hs.state();
for h in 0..self.config.num_heads {
let row_start = h * self.config.head_dim;
let head_state = &mut self.state.wkv_states[layer_idx][h];
for i in 0..self.config.head_dim {
for j in 0..self.config.head_dim {
if row_start + i < combined.shape()[0] && j < combined.shape()[1] {
head_state[[i, j]] = combined[[row_start + i, j]];
}
}
}
}
}
Ok(())
}
}
pub type Rwkv7 = Rwkv7Model;
#[cfg(test)]
mod tests {
use super::*;
fn tiny_config() -> Rwkv7Config {
Rwkv7Config {
input_dim: 1,
hidden_dim: 64,
num_layers: 2,
num_heads: 4,
head_dim: 16,
expand_factor: 2.0,
context_length: 256,
time_decay_init: -5.0,
}
}
#[test]
fn test_rwkv7_config_valid() {
let config = Rwkv7Config::new();
assert!(config.validate().is_ok());
let bad = Rwkv7Config {
hidden_dim: 0,
..Rwkv7Config::default()
};
assert!(bad.validate().is_err());
let bad2 = Rwkv7Config {
hidden_dim: 100,
num_heads: 3,
..Rwkv7Config::default()
};
assert!(bad2.validate().is_err());
}
#[test]
fn test_rwkv7_small_forward() {
let config = tiny_config();
let mut model = Rwkv7Model::new(config).expect("model creation");
let input = Array1::from_vec(vec![0.5]);
let output = model.step(&input).expect("forward step");
assert_eq!(output.len(), 1, "output should match input_dim");
assert!(output[0].is_finite(), "output must be finite");
}
#[test]
fn test_rwkv7_state_persistence() {
let config = tiny_config();
let mut model = Rwkv7Model::new(config).expect("model creation");
let input = Array1::from_vec(vec![0.1]);
for _ in 0..10 {
let out = model.step(&input).expect("step");
for &v in out.iter() {
assert!(v.is_finite(), "output should stay finite over 10 steps");
assert!(!v.is_nan(), "no NaN values");
}
}
}
#[test]
fn test_rwkv7_state_reset() {
let config = tiny_config();
let mut model = Rwkv7Model::new(config).expect("model creation");
let input = Array1::from_vec(vec![0.3]);
for _ in 0..5 {
let _ = model.step(&input).expect("step");
}
model.reset();
let out_after_reset = model.step(&input).expect("step after reset");
let config2 = tiny_config();
let mut fresh = Rwkv7Model::new(config2).expect("fresh model creation");
let out_fresh = fresh.step(&input).expect("fresh step");
for (a, b) in out_after_reset.iter().zip(out_fresh.iter()) {
assert!(
(a - b).abs() < 1e-5,
"reset output should match fresh model: {a} vs {b}"
);
}
}
#[test]
fn test_rwkv7_multi_layer() {
let mut config = tiny_config();
config.num_layers = 4;
let mut model = Rwkv7Model::new(config).expect("4-layer model");
let input = Array1::from_vec(vec![0.42]);
let out = model.step(&input).expect("forward");
assert_eq!(out.len(), 1);
assert!(out[0].is_finite());
}
#[test]
fn test_rwkv7_signal_predictor_trait() {
let config = tiny_config();
let mut model = Rwkv7Model::new(config).expect("model");
let input = Array1::from_vec(vec![1.0]);
let out = model.step(&input).expect("step");
assert_eq!(out.len(), 1);
model.reset();
assert_eq!(model.context_window(), usize::MAX);
}
#[test]
fn test_rwkv7_autoregressive_trait() {
let config = tiny_config();
let mut model = Rwkv7Model::new(config.clone()).expect("model");
let input = Array1::from_vec(vec![0.7]);
let _ = model.step(&input).expect("step");
let states = model.get_states();
assert_eq!(states.len(), config.num_layers);
let mut model2 = Rwkv7Model::new(config).expect("model2");
model2.set_states(states.clone()).expect("set_states");
let states2 = model2.get_states();
assert_eq!(states.len(), states2.len());
for (s1, s2) in states.iter().zip(states2.iter()) {
let a = s1.state();
let b = s2.state();
assert_eq!(a.shape(), b.shape());
for (va, vb) in a.iter().zip(b.iter()) {
assert!((va - vb).abs() < 1e-6, "state roundtrip mismatch");
}
}
}
#[test]
fn test_rwkv7_numerical_stability() {
let config = tiny_config();
let mut model = Rwkv7Model::new(config).expect("model");
let large_input = Array1::from_vec(vec![1000.0]);
let out_large = model.step(&large_input).expect("large input step");
for &v in out_large.iter() {
assert!(
v.is_finite(),
"output should be finite for large input: {v}"
);
}
model.reset();
let small_input = Array1::from_vec(vec![1e-10]);
let out_small = model.step(&small_input).expect("small input step");
for &v in out_small.iter() {
assert!(
v.is_finite(),
"output should be finite for small input: {v}"
);
}
model.reset();
let neg_input = Array1::from_vec(vec![-500.0]);
let out_neg = model.step(&neg_input).expect("negative input step");
for &v in out_neg.iter() {
assert!(
v.is_finite(),
"output should be finite for negative input: {v}"
);
}
}
#[test]
fn test_rwkv7_hidden_dim_state_dim() {
let config = tiny_config();
let model = Rwkv7Model::new(config).expect("model");
assert_eq!(model.hidden_dim(), 64);
assert_eq!(model.state_dim(), 64); assert_eq!(model.num_layers(), 2);
assert_eq!(model.model_type(), ModelType::Rwkv);
}
}