use crate::{CoreError, CoreResult};
use scirs2_core::ndarray::{Array1, Array2};
use scirs2_core::random::quick::random_f32;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RWKV7Config {
pub d_input: usize,
pub d_model: usize,
pub n_layers: usize,
pub ffn_factor: f32,
pub layer_norm_eps: f32,
pub time_decay_init: f32,
pub time_first_init: f32,
}
impl Default for RWKV7Config {
fn default() -> Self {
Self {
d_input: 128,
d_model: 256,
n_layers: 6,
ffn_factor: 3.5,
layer_norm_eps: 1e-5,
time_decay_init: -5.0,
time_first_init: 0.0,
}
}
}
impl RWKV7Config {
pub fn new(d_input: usize, d_model: usize, n_layers: usize) -> Self {
Self {
d_input,
d_model,
n_layers,
..Default::default()
}
}
pub fn with_ffn_factor(mut self, factor: f32) -> Self {
self.ffn_factor = factor;
self
}
pub fn with_layer_norm_eps(mut self, eps: f32) -> Self {
self.layer_norm_eps = eps;
self
}
pub fn validate(&self) -> CoreResult<()> {
if self.d_input == 0 || self.d_model == 0 || self.n_layers == 0 {
return Err(CoreError::InvalidConfig(
"Dimensions and layers must be positive".to_string(),
));
}
if self.ffn_factor <= 0.0 {
return Err(CoreError::InvalidConfig(
"FFN factor must be positive".to_string(),
));
}
Ok(())
}
}
pub struct TimeMixing {
d_model: usize,
time_mix_k: Array1<f32>,
time_mix_v: Array1<f32>,
time_mix_r: Array1<f32>,
time_decay: Array1<f32>,
time_first: Array1<f32>,
key_w: Array2<f32>,
value_w: Array2<f32>,
receptance_w: Array2<f32>,
output_w: Array2<f32>,
wkv_state: Array2<f32>, prev_x: Array1<f32>,
}
impl TimeMixing {
pub fn new(config: &RWKV7Config) -> Self {
let d_model = config.d_model;
let time_mix_k = Array1::from_elem(d_model, 0.5);
let time_mix_v = Array1::from_elem(d_model, 0.5);
let time_mix_r = Array1::from_elem(d_model, 0.5);
let time_decay = Array1::from_elem(d_model, config.time_decay_init);
let time_first = Array1::from_elem(d_model, config.time_first_init);
let scale = (2.0 / d_model as f32).sqrt();
let key_w =
Array2::from_shape_fn((d_model, d_model), |_| (random_f32() - 0.5) * 2.0 * scale);
let value_w =
Array2::from_shape_fn((d_model, d_model), |_| (random_f32() - 0.5) * 2.0 * scale);
let receptance_w =
Array2::from_shape_fn((d_model, d_model), |_| (random_f32() - 0.5) * 2.0 * scale);
let output_w =
Array2::from_shape_fn((d_model, d_model), |_| (random_f32() - 0.5) * 2.0 * scale);
let wkv_state = Array2::zeros((d_model, 2));
let prev_x = Array1::zeros(d_model);
Self {
d_model,
time_mix_k,
time_mix_v,
time_mix_r,
time_decay,
time_first,
key_w,
value_w,
receptance_w,
output_w,
wkv_state,
prev_x,
}
}
pub fn forward(&mut self, x: &Array1<f32>) -> CoreResult<Array1<f32>> {
if x.len() != self.d_model {
return Err(CoreError::DimensionMismatch {
expected: self.d_model,
got: x.len(),
});
}
let mut k_input = Array1::zeros(self.d_model);
let mut v_input = Array1::zeros(self.d_model);
let mut r_input = Array1::zeros(self.d_model);
for i in 0..self.d_model {
k_input[i] = self.time_mix_k[i] * x[i] + (1.0 - self.time_mix_k[i]) * self.prev_x[i];
v_input[i] = self.time_mix_v[i] * x[i] + (1.0 - self.time_mix_v[i]) * self.prev_x[i];
r_input[i] = self.time_mix_r[i] * x[i] + (1.0 - self.time_mix_r[i]) * self.prev_x[i];
}
let k = self.key_w.dot(&k_input);
let v = self.value_w.dot(&v_input);
let r = self.receptance_w.dot(&r_input);
let wkv = self.compute_wkv(&k, &v)?;
let mut rwkv = Array1::zeros(self.d_model);
for i in 0..self.d_model {
rwkv[i] = self.sigmoid(r[i]) * wkv[i];
}
let output = self.output_w.dot(&rwkv);
self.prev_x = x.clone();
Ok(output)
}
fn compute_wkv(&mut self, k: &Array1<f32>, v: &Array1<f32>) -> CoreResult<Array1<f32>> {
let mut wkv = Array1::zeros(self.d_model);
for i in 0..self.d_model {
let prev_num = self.wkv_state[[i, 0]];
let prev_den = self.wkv_state[[i, 1]];
let w = (-self.time_decay[i].exp()).exp(); let u = self.time_first[i].exp();
let new_num = w * prev_num + u * k[i] * v[i];
let new_den = w * prev_den + u * k[i];
wkv[i] = if new_den.abs() > 1e-8 {
new_num / new_den
} else {
0.0
};
self.wkv_state[[i, 0]] = new_num;
self.wkv_state[[i, 1]] = new_den;
}
Ok(wkv)
}
fn sigmoid(&self, x: f32) -> f32 {
1.0 / (1.0 + (-x).exp())
}
pub fn reset(&mut self) {
self.wkv_state.fill(0.0);
self.prev_x.fill(0.0);
}
}
pub struct ChannelMixing {
d_model: usize,
#[allow(dead_code)] d_ffn: usize,
time_mix_k: Array1<f32>,
time_mix_r: Array1<f32>,
key_w: Array2<f32>,
value_w: Array2<f32>,
receptance_w: Array2<f32>,
prev_x: Array1<f32>,
}
impl ChannelMixing {
pub fn new(config: &RWKV7Config) -> Self {
let d_model = config.d_model;
let d_ffn = (d_model as f32 * config.ffn_factor) as usize;
let time_mix_k = Array1::from_elem(d_model, 0.5);
let time_mix_r = Array1::from_elem(d_model, 0.5);
let scale = (2.0 / d_model as f32).sqrt();
let key_w = Array2::from_shape_fn((d_ffn, d_model), |_| (random_f32() - 0.5) * 2.0 * scale);
let value_w =
Array2::from_shape_fn((d_model, d_ffn), |_| (random_f32() - 0.5) * 2.0 * scale);
let receptance_w =
Array2::from_shape_fn((d_model, d_model), |_| (random_f32() - 0.5) * 2.0 * scale);
let prev_x = Array1::zeros(d_model);
Self {
d_model,
d_ffn,
time_mix_k,
time_mix_r,
key_w,
value_w,
receptance_w,
prev_x,
}
}
pub fn forward(&mut self, x: &Array1<f32>) -> CoreResult<Array1<f32>> {
if x.len() != self.d_model {
return Err(CoreError::DimensionMismatch {
expected: self.d_model,
got: x.len(),
});
}
let mut k_input = Array1::zeros(self.d_model);
let mut r_input = Array1::zeros(self.d_model);
for i in 0..self.d_model {
k_input[i] = self.time_mix_k[i] * x[i] + (1.0 - self.time_mix_k[i]) * self.prev_x[i];
r_input[i] = self.time_mix_r[i] * x[i] + (1.0 - self.time_mix_r[i]) * self.prev_x[i];
}
let k = self.key_w.dot(&k_input);
let kv = self.apply_squared_relu(&k);
let v = self.value_w.dot(&kv);
let r = self.receptance_w.dot(&r_input);
let mut output = Array1::zeros(self.d_model);
for i in 0..self.d_model {
output[i] = self.sigmoid(r[i]) * v[i];
}
self.prev_x = x.clone();
Ok(output)
}
fn apply_squared_relu(&self, x: &Array1<f32>) -> Array1<f32> {
x.mapv(|v| if v > 0.0 { v * v } else { 0.0 })
}
fn sigmoid(&self, x: f32) -> f32 {
1.0 / (1.0 + (-x).exp())
}
pub fn reset(&mut self) {
self.prev_x.fill(0.0);
}
}
pub struct RWKV7Layer {
config: RWKV7Config,
time_mixing: TimeMixing,
channel_mixing: ChannelMixing,
ln1_weight: Array1<f32>,
ln1_bias: Array1<f32>,
ln2_weight: Array1<f32>,
ln2_bias: Array1<f32>,
}
impl RWKV7Layer {
pub fn new(config: RWKV7Config) -> CoreResult<Self> {
config.validate()?;
let time_mixing = TimeMixing::new(&config);
let channel_mixing = ChannelMixing::new(&config);
let ln1_weight = Array1::ones(config.d_model);
let ln1_bias = Array1::zeros(config.d_model);
let ln2_weight = Array1::ones(config.d_model);
let ln2_bias = Array1::zeros(config.d_model);
Ok(Self {
config,
time_mixing,
channel_mixing,
ln1_weight,
ln1_bias,
ln2_weight,
ln2_bias,
})
}
pub fn forward(&mut self, x: &Array1<f32>) -> CoreResult<Array1<f32>> {
let x_norm1 = self.layer_norm(x, &self.ln1_weight, &self.ln1_bias)?;
let tm_out = self.time_mixing.forward(&x_norm1)?;
let x = &(x + &tm_out);
let x_norm2 = self.layer_norm(x, &self.ln2_weight, &self.ln2_bias)?;
let cm_out = self.channel_mixing.forward(&x_norm2)?;
let output = x + &cm_out;
Ok(output)
}
fn layer_norm(
&self,
x: &Array1<f32>,
weight: &Array1<f32>,
bias: &Array1<f32>,
) -> CoreResult<Array1<f32>> {
let mean = x.mean().unwrap_or(0.0);
let var = x.mapv(|v| (v - mean).powi(2)).mean().unwrap_or(0.0);
let std = (var + self.config.layer_norm_eps).sqrt();
let mut normalized = Array1::zeros(x.len());
for i in 0..x.len() {
normalized[i] = ((x[i] - mean) / std) * weight[i] + bias[i];
}
Ok(normalized)
}
pub fn reset(&mut self) {
self.time_mixing.reset();
self.channel_mixing.reset();
}
}
pub struct RWKV7Model {
config: RWKV7Config,
embedding: Array2<f32>,
layers: Vec<RWKV7Layer>,
ln_out_weight: Array1<f32>,
ln_out_bias: Array1<f32>,
}
impl RWKV7Model {
pub fn new(config: RWKV7Config) -> CoreResult<Self> {
config.validate()?;
let scale = (1.0 / config.d_input as f32).sqrt();
let embedding = Array2::from_shape_fn((config.d_model, config.d_input), |_| {
(random_f32() - 0.5) * 2.0 * scale
});
let mut layers = Vec::with_capacity(config.n_layers);
for _ in 0..config.n_layers {
layers.push(RWKV7Layer::new(config.clone())?);
}
let ln_out_weight = Array1::ones(config.d_model);
let ln_out_bias = Array1::zeros(config.d_model);
Ok(Self {
config,
embedding,
layers,
ln_out_weight,
ln_out_bias,
})
}
pub fn forward(&mut self, x: &Array1<f32>) -> CoreResult<Array1<f32>> {
if x.len() != self.config.d_input {
return Err(CoreError::DimensionMismatch {
expected: self.config.d_input,
got: x.len(),
});
}
let mut h = self.embedding.dot(x);
for layer in &mut self.layers {
h = layer.forward(&h)?;
}
let output = self.layer_norm(&h)?;
Ok(output)
}
fn layer_norm(&self, x: &Array1<f32>) -> CoreResult<Array1<f32>> {
let mean = x.mean().unwrap_or(0.0);
let var = x.mapv(|v| (v - mean).powi(2)).mean().unwrap_or(0.0);
let std = (var + self.config.layer_norm_eps).sqrt();
let mut normalized = Array1::zeros(x.len());
for i in 0..x.len() {
normalized[i] = ((x[i] - mean) / std) * self.ln_out_weight[i] + self.ln_out_bias[i];
}
Ok(normalized)
}
pub fn reset(&mut self) {
for layer in &mut self.layers {
layer.reset();
}
}
pub fn config(&self) -> &RWKV7Config {
&self.config
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_rwkv7_config() {
let config = RWKV7Config::new(64, 128, 4);
assert_eq!(config.d_input, 64);
assert_eq!(config.d_model, 128);
assert_eq!(config.n_layers, 4);
assert!(config.validate().is_ok());
}
#[test]
fn test_rwkv7_config_validation() {
let config = RWKV7Config {
d_model: 0,
..Default::default()
};
assert!(config.validate().is_err());
}
#[test]
fn test_time_mixing_creation() {
let config = RWKV7Config::new(64, 128, 2);
let tm = TimeMixing::new(&config);
assert_eq!(tm.d_model, 128);
assert_eq!(tm.time_mix_k.len(), 128);
}
#[test]
fn test_time_mixing_forward() {
let config = RWKV7Config::new(64, 128, 2);
let mut tm = TimeMixing::new(&config);
let x = Array1::from_elem(128, 0.5);
let result = tm.forward(&x);
assert!(result.is_ok());
let output = result.unwrap();
assert_eq!(output.len(), 128);
}
#[test]
fn test_channel_mixing_creation() {
let config = RWKV7Config::new(64, 128, 2);
let cm = ChannelMixing::new(&config);
assert_eq!(cm.d_model, 128);
assert_eq!(cm.d_ffn, (128.0 * 3.5) as usize);
}
#[test]
fn test_channel_mixing_forward() {
let config = RWKV7Config::new(64, 128, 2);
let mut cm = ChannelMixing::new(&config);
let x = Array1::from_elem(128, 0.3);
let result = cm.forward(&x);
assert!(result.is_ok());
let output = result.unwrap();
assert_eq!(output.len(), 128);
}
#[test]
fn test_rwkv7_layer_creation() {
let config = RWKV7Config::new(64, 128, 2);
let layer = RWKV7Layer::new(config);
assert!(layer.is_ok());
}
#[test]
fn test_rwkv7_layer_forward() {
let config = RWKV7Config::new(64, 128, 2);
let mut layer = RWKV7Layer::new(config).unwrap();
let x = Array1::from_elem(128, 0.1);
let result = layer.forward(&x);
assert!(result.is_ok());
let output = result.unwrap();
assert_eq!(output.len(), 128);
assert!(output.iter().all(|&v| v.is_finite()));
}
#[test]
fn test_rwkv7_model_creation() {
let config = RWKV7Config::new(64, 128, 4);
let model = RWKV7Model::new(config);
assert!(model.is_ok());
let m = model.unwrap();
assert_eq!(m.layers.len(), 4);
}
#[test]
fn test_rwkv7_model_forward() {
let config = RWKV7Config::new(64, 128, 3);
let mut model = RWKV7Model::new(config).unwrap();
let x = Array1::from_elem(64, 0.2);
let result = model.forward(&x);
assert!(result.is_ok());
let output = result.unwrap();
assert_eq!(output.len(), 128);
assert!(output.iter().all(|&v| v.is_finite()));
}
#[test]
fn test_rwkv7_reset() {
let config = RWKV7Config::new(64, 128, 2);
let mut model = RWKV7Model::new(config).unwrap();
let x = Array1::from_elem(64, 0.5);
let _ = model.forward(&x).unwrap();
model.reset();
for layer in &model.layers {
assert!(layer.time_mixing.wkv_state.iter().all(|&v| v == 0.0));
}
}
#[test]
fn test_wkv_mechanism() {
let config = RWKV7Config::new(64, 128, 2);
let mut tm = TimeMixing::new(&config);
let x1 = Array1::from_elem(128, 0.1);
let x2 = Array1::from_elem(128, 0.5);
let x3 = Array1::from_elem(128, 0.9);
let _ = tm.forward(&x1).unwrap();
let _ = tm.forward(&x2).unwrap();
let out3 = tm.forward(&x3).unwrap();
assert!(tm.wkv_state.iter().any(|&v| v != 0.0));
assert!(out3.iter().all(|&v| v.is_finite()));
}
}