1use crate::error::{ModelError, ModelResult};
29use crate::{AutoregressiveModel, ModelType};
30use kizzasi_core::{sigmoid, silu, CoreResult, HiddenState, LayerNorm, NormType, SignalPredictor};
31use scirs2_core::ndarray::{Array1, Array2};
32use scirs2_core::random::{rng, Rng};
33#[allow(unused_imports)]
34use tracing::{debug, instrument, trace};
35
36#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
38pub struct RwkvConfig {
39 pub input_dim: usize,
41 pub hidden_dim: usize,
43 pub intermediate_dim: usize,
45 pub num_layers: usize,
47 pub num_heads: usize,
49 pub head_dim: usize,
51 pub dropout: f32,
53 pub time_decay_init: f32,
55 pub use_rms_norm: bool,
57}
58
59impl Default for RwkvConfig {
60 fn default() -> Self {
61 let hidden_dim = 512;
62 let num_heads = 8;
63 Self {
64 input_dim: 1,
65 hidden_dim,
66 intermediate_dim: hidden_dim * 4,
67 num_layers: 12,
68 num_heads,
69 head_dim: hidden_dim / num_heads,
70 dropout: 0.0,
71 time_decay_init: -5.0,
72 use_rms_norm: true,
73 }
74 }
75}
76
77impl RwkvConfig {
78 pub fn new() -> Self {
80 Self::default()
81 }
82
83 pub fn input_dim(mut self, dim: usize) -> Self {
85 self.input_dim = dim;
86 self
87 }
88
89 pub fn hidden_dim(mut self, dim: usize) -> Self {
91 self.hidden_dim = dim;
92 self.head_dim = dim / self.num_heads;
93 self
94 }
95
96 pub fn intermediate_dim(mut self, dim: usize) -> Self {
98 self.intermediate_dim = dim;
99 self
100 }
101
102 pub fn num_layers(mut self, n: usize) -> Self {
104 self.num_layers = n;
105 self
106 }
107
108 pub fn num_heads(mut self, n: usize) -> Self {
110 self.num_heads = n;
111 self.head_dim = self.hidden_dim / n;
112 self
113 }
114
115 pub fn validate(&self) -> ModelResult<()> {
117 if self.hidden_dim == 0 {
118 return Err(ModelError::invalid_config("hidden_dim must be > 0"));
119 }
120 if self.num_layers == 0 {
121 return Err(ModelError::invalid_config("num_layers must be > 0"));
122 }
123 if self.num_heads == 0 {
124 return Err(ModelError::invalid_config("num_heads must be > 0"));
125 }
126 if !self.hidden_dim.is_multiple_of(self.num_heads) {
127 return Err(ModelError::invalid_config(
128 "hidden_dim must be divisible by num_heads",
129 ));
130 }
131 Ok(())
132 }
133}
134
135struct TimeMixing {
140 hidden_dim: usize,
141 num_heads: usize,
142 head_dim: usize,
143
144 time_mix_k: Array1<f32>,
146 #[allow(dead_code)]
147 time_mix_v: Array1<f32>, time_mix_r: Array1<f32>,
149 time_mix_g: Array1<f32>,
150
151 time_decay: Array2<f32>, key_proj: Array2<f32>,
156 value_proj: Array2<f32>,
157 receptance_proj: Array2<f32>,
158 gate_proj: Array2<f32>,
159 output_proj: Array2<f32>,
160
161 wkv_state: Vec<Array1<f32>>, wkv_norm: Vec<f32>, prev_x: Array1<f32>,
165}
166
167impl TimeMixing {
168 fn new(config: &RwkvConfig) -> ModelResult<Self> {
169 let mut rng = rng();
170
171 let time_mix_k = Array1::from_shape_fn(config.hidden_dim, |_| rng.random::<f32>());
173 let time_mix_v = Array1::from_shape_fn(config.hidden_dim, |_| rng.random::<f32>());
174 let time_mix_r = Array1::from_shape_fn(config.hidden_dim, |_| rng.random::<f32>());
175 let time_mix_g = Array1::from_shape_fn(config.hidden_dim, |_| rng.random::<f32>());
176
177 let time_decay = Array2::from_shape_fn((config.num_heads, config.head_dim), |(h, i)| {
179 config.time_decay_init - (h as f32 * 0.1) - (i as f32 * 0.01)
181 });
182
183 let scale = (2.0 / config.hidden_dim as f32).sqrt();
185 let key_proj = Array2::from_shape_fn((config.hidden_dim, config.hidden_dim), |_| {
186 (rng.random::<f32>() - 0.5) * 2.0 * scale
187 });
188 let value_proj = Array2::from_shape_fn((config.hidden_dim, config.hidden_dim), |_| {
189 (rng.random::<f32>() - 0.5) * 2.0 * scale
190 });
191 let receptance_proj = Array2::from_shape_fn((config.hidden_dim, config.hidden_dim), |_| {
192 (rng.random::<f32>() - 0.5) * 2.0 * scale
193 });
194 let gate_proj = Array2::from_shape_fn((config.hidden_dim, config.hidden_dim), |_| {
195 (rng.random::<f32>() - 0.5) * 2.0 * scale
196 });
197 let output_proj = Array2::from_shape_fn((config.hidden_dim, config.hidden_dim), |_| {
198 (rng.random::<f32>() - 0.5) * 2.0 * scale
199 });
200
201 let wkv_state = (0..config.num_heads)
203 .map(|_| Array1::zeros(config.head_dim))
204 .collect();
205 let wkv_norm = vec![0.0; config.num_heads];
206 let prev_x = Array1::zeros(config.hidden_dim);
207
208 Ok(Self {
209 hidden_dim: config.hidden_dim,
210 num_heads: config.num_heads,
211 head_dim: config.head_dim,
212 time_mix_k,
213 time_mix_v,
214 time_mix_r,
215 time_mix_g,
216 time_decay,
217 key_proj,
218 value_proj,
219 receptance_proj,
220 gate_proj,
221 output_proj,
222 wkv_state,
223 wkv_norm,
224 prev_x,
225 })
226 }
227
228 fn forward(&mut self, x: &Array1<f32>) -> CoreResult<Array1<f32>> {
229 let batch_size = x.len().min(self.hidden_dim);
230
231 let mut xx = Array1::zeros(batch_size);
233 for i in 0..batch_size {
234 let prev_val = if i < self.prev_x.len() {
235 self.prev_x[i]
236 } else {
237 0.0
238 };
239 xx[i] = self.time_mix_k[i] * x[i] + (1.0 - self.time_mix_k[i]) * prev_val;
240 }
241
242 let k = self.project(&xx, &self.key_proj);
244 let v = self.project(&xx, &self.value_proj);
245
246 let mut xr = Array1::zeros(batch_size);
247 for i in 0..batch_size {
248 let prev_val = if i < self.prev_x.len() {
249 self.prev_x[i]
250 } else {
251 0.0
252 };
253 xr[i] = self.time_mix_r[i] * x[i] + (1.0 - self.time_mix_r[i]) * prev_val;
254 }
255 let r = self.project(&xr, &self.receptance_proj);
256
257 let mut xg = Array1::zeros(batch_size);
258 for i in 0..batch_size {
259 let prev_val = if i < self.prev_x.len() {
260 self.prev_x[i]
261 } else {
262 0.0
263 };
264 xg[i] = self.time_mix_g[i] * x[i] + (1.0 - self.time_mix_g[i]) * prev_val;
265 }
266 let g = self.project(&xg, &self.gate_proj);
267
268 let mut wkv_output = Array1::zeros(batch_size);
270
271 for head in 0..self.num_heads {
272 let head_start = head * self.head_dim;
273 let head_end = (head_start + self.head_dim).min(batch_size);
274
275 for i in 0..(head_end - head_start) {
276 let idx = head_start + i;
277 if idx >= k.len() || idx >= v.len() {
278 break;
279 }
280
281 let w = self.time_decay[[head, i]].exp();
283
284 let new_wkv = w * self.wkv_state[head][i] + k[idx] * v[idx];
286 self.wkv_state[head][i] = new_wkv;
287
288 self.wkv_norm[head] = w * self.wkv_norm[head] + k[idx];
290
291 let norm = self.wkv_norm[head].max(1e-8);
293 wkv_output[idx] = new_wkv / norm;
294 }
295 }
296
297 let r_sigmoid = sigmoid(&r);
299 for i in 0..wkv_output.len().min(r_sigmoid.len()) {
300 wkv_output[i] *= r_sigmoid[i];
301 }
302
303 let g_silu = silu(&g);
305 for i in 0..wkv_output.len().min(g_silu.len()) {
306 wkv_output[i] *= g_silu[i];
307 }
308
309 let output = self.project(&wkv_output, &self.output_proj);
311
312 self.prev_x = Array1::from_vec(x.iter().take(self.hidden_dim).copied().collect());
314
315 Ok(output)
316 }
317
318 fn project(&self, x: &Array1<f32>, weight: &Array2<f32>) -> Array1<f32> {
319 let out_dim = weight.shape()[0];
320 let mut output = Array1::zeros(out_dim.min(x.len()));
321 for i in 0..output.len() {
322 let mut sum = 0.0;
323 for j in 0..x.len().min(weight.shape()[1]) {
324 sum += weight[[i, j]] * x[j];
325 }
326 output[i] = sum;
327 }
328 output
329 }
330
331 fn reset(&mut self) {
332 for state in &mut self.wkv_state {
333 state.fill(0.0);
334 }
335 self.wkv_norm.fill(0.0);
336 self.prev_x.fill(0.0);
337 }
338}
339
340struct ChannelMixing {
344 hidden_dim: usize,
345 intermediate_dim: usize,
346
347 time_mix_k: Array1<f32>,
349 time_mix_r: Array1<f32>,
350
351 key_proj: Array2<f32>,
353 value_proj: Array2<f32>,
354 receptance_proj: Array2<f32>,
355
356 prev_x: Array1<f32>,
358}
359
360impl ChannelMixing {
361 fn new(config: &RwkvConfig) -> ModelResult<Self> {
362 let mut rng = rng();
363
364 let time_mix_k = Array1::from_shape_fn(config.hidden_dim, |_| rng.random::<f32>());
366 let time_mix_r = Array1::from_shape_fn(config.hidden_dim, |_| rng.random::<f32>());
367
368 let scale = (2.0 / config.hidden_dim as f32).sqrt();
370 let key_proj = Array2::from_shape_fn((config.hidden_dim, config.intermediate_dim), |_| {
371 (rng.random::<f32>() - 0.5) * 2.0 * scale
372 });
373
374 let value_proj =
375 Array2::from_shape_fn((config.intermediate_dim, config.hidden_dim), |_| {
376 (rng.random::<f32>() - 0.5) * 2.0 * scale
377 });
378
379 let receptance_proj = Array2::from_shape_fn((config.hidden_dim, config.hidden_dim), |_| {
380 (rng.random::<f32>() - 0.5) * 2.0 * scale
381 });
382
383 let prev_x = Array1::zeros(config.hidden_dim);
384
385 Ok(Self {
386 hidden_dim: config.hidden_dim,
387 intermediate_dim: config.intermediate_dim,
388 time_mix_k,
389 time_mix_r,
390 key_proj,
391 value_proj,
392 receptance_proj,
393 prev_x,
394 })
395 }
396
397 fn forward(&mut self, x: &Array1<f32>) -> CoreResult<Array1<f32>> {
398 let batch_size = x.len().min(self.hidden_dim);
399
400 let mut xk = Array1::zeros(batch_size);
402 for i in 0..batch_size {
403 let prev_val = if i < self.prev_x.len() {
404 self.prev_x[i]
405 } else {
406 0.0
407 };
408 xk[i] = self.time_mix_k[i] * x[i] + (1.0 - self.time_mix_k[i]) * prev_val;
409 }
410
411 let mut xr = Array1::zeros(batch_size);
413 for i in 0..batch_size {
414 let prev_val = if i < self.prev_x.len() {
415 self.prev_x[i]
416 } else {
417 0.0
418 };
419 xr[i] = self.time_mix_r[i] * x[i] + (1.0 - self.time_mix_r[i]) * prev_val;
420 }
421
422 let k = self.project(&xk, &self.key_proj);
424 let k_squared = k.mapv(|v| v * v); let vk = self.project_back(&k_squared, &self.value_proj);
426
427 let r = self.project_r(&xr, &self.receptance_proj);
429 let r_sigmoid = sigmoid(&r);
430
431 let mut output = Array1::zeros(batch_size);
432 for i in 0..output.len().min(vk.len()).min(r_sigmoid.len()) {
433 output[i] = r_sigmoid[i] * vk[i];
434 }
435
436 self.prev_x = Array1::from_vec(x.iter().take(self.hidden_dim).copied().collect());
438
439 Ok(output)
440 }
441
442 fn project(&self, x: &Array1<f32>, weight: &Array2<f32>) -> Array1<f32> {
443 let out_dim = weight.shape()[1].min(self.intermediate_dim);
444 let mut output = Array1::zeros(out_dim);
445 for i in 0..out_dim {
446 let mut sum = 0.0;
447 for j in 0..x.len().min(weight.shape()[0]) {
448 sum += weight[[j, i]] * x[j];
449 }
450 output[i] = sum;
451 }
452 output
453 }
454
455 fn project_back(&self, x: &Array1<f32>, weight: &Array2<f32>) -> Array1<f32> {
456 let out_dim = weight.shape()[1].min(self.hidden_dim);
457 let mut output = Array1::zeros(out_dim);
458 for i in 0..out_dim {
459 let mut sum = 0.0;
460 for j in 0..x.len().min(weight.shape()[0]) {
461 sum += weight[[j, i]] * x[j];
462 }
463 output[i] = sum;
464 }
465 output
466 }
467
468 fn project_r(&self, x: &Array1<f32>, weight: &Array2<f32>) -> Array1<f32> {
469 let out_dim = weight.shape()[0];
470 let mut output = Array1::zeros(out_dim.min(x.len()));
471 for i in 0..output.len() {
472 let mut sum = 0.0;
473 for j in 0..x.len().min(weight.shape()[1]) {
474 sum += weight[[i, j]] * x[j];
475 }
476 output[i] = sum;
477 }
478 output
479 }
480
481 fn reset(&mut self) {
482 self.prev_x.fill(0.0);
483 }
484}
485
486struct RwkvLayer {
488 ln1: LayerNorm,
489 ln2: LayerNorm,
490 time_mixing: TimeMixing,
491 channel_mixing: ChannelMixing,
492}
493
494impl RwkvLayer {
495 fn new(config: &RwkvConfig) -> ModelResult<Self> {
496 let norm_type = if config.use_rms_norm {
497 NormType::RMSNorm
498 } else {
499 NormType::LayerNorm
500 };
501
502 let ln1 = LayerNorm::new(config.hidden_dim, norm_type).with_eps(1e-5);
503 let ln2 = LayerNorm::new(config.hidden_dim, norm_type).with_eps(1e-5);
504 let time_mixing = TimeMixing::new(config)?;
505 let channel_mixing = ChannelMixing::new(config)?;
506
507 Ok(Self {
508 ln1,
509 ln2,
510 time_mixing,
511 channel_mixing,
512 })
513 }
514
515 fn forward(&mut self, x: &Array1<f32>) -> CoreResult<Array1<f32>> {
516 let x_norm = self.ln1.forward(x);
518 let tm_out = self.time_mixing.forward(&x_norm)?;
519 let mut x_tm = x.clone();
520 for i in 0..x_tm.len().min(tm_out.len()) {
521 x_tm[i] += tm_out[i];
522 }
523
524 let x_norm2 = self.ln2.forward(&x_tm);
526 let cm_out = self.channel_mixing.forward(&x_norm2)?;
527 let mut output = x_tm;
528 for i in 0..output.len().min(cm_out.len()) {
529 output[i] += cm_out[i];
530 }
531
532 Ok(output)
533 }
534
535 fn reset(&mut self) {
536 self.time_mixing.reset();
537 self.channel_mixing.reset();
538 }
539}
540
541pub struct Rwkv {
543 config: RwkvConfig,
544 layers: Vec<RwkvLayer>,
545 ln_out: LayerNorm,
546 input_proj: Array2<f32>,
547 output_proj: Array2<f32>,
548}
549
550impl Rwkv {
551 pub fn new(config: RwkvConfig) -> ModelResult<Self> {
553 config.validate()?;
554
555 let mut layers = Vec::with_capacity(config.num_layers);
557 for _ in 0..config.num_layers {
558 layers.push(RwkvLayer::new(&config)?);
559 }
560
561 let norm_type = if config.use_rms_norm {
563 NormType::RMSNorm
564 } else {
565 NormType::LayerNorm
566 };
567 let ln_out = LayerNorm::new(config.hidden_dim, norm_type).with_eps(1e-5);
568
569 let mut rng = rng();
571 let scale = (2.0 / (config.input_dim + config.hidden_dim) as f32).sqrt();
572 let input_proj = Array2::from_shape_fn((config.input_dim, config.hidden_dim), |_| {
573 (rng.random::<f32>() - 0.5) * 2.0 * scale
574 });
575
576 let scale = (2.0 / (config.hidden_dim + config.input_dim) as f32).sqrt();
577 let output_proj = Array2::from_shape_fn((config.hidden_dim, config.input_dim), |_| {
578 (rng.random::<f32>() - 0.5) * 2.0 * scale
579 });
580
581 Ok(Self {
582 config,
583 layers,
584 ln_out,
585 input_proj,
586 output_proj,
587 })
588 }
589
590 pub fn config(&self) -> &RwkvConfig {
592 &self.config
593 }
594
595 pub fn load_weights(&mut self, loader: &crate::loader::ModelLoader) -> ModelResult<()> {
630 if loader.has_tensor("input_proj") {
632 self.input_proj = loader.load_array2("input_proj")?;
633 }
634 if loader.has_tensor("output_proj") {
635 self.output_proj = loader.load_array2("output_proj")?;
636 }
637
638 if loader.has_tensor("ln_out.weight") {
640 let weight = loader.load_array1("ln_out.weight")?;
641 self.ln_out.set_gamma(weight);
642 }
643 if loader.has_tensor("ln_out.bias") {
644 let bias = loader.load_array1("ln_out.bias")?;
645 self.ln_out.set_beta(bias);
646 }
647
648 for (i, layer) in self.layers.iter_mut().enumerate() {
650 let prefix = format!("layers.{}", i);
651
652 if loader.has_tensor(&format!("{}.ln1.weight", prefix)) {
654 let weight = loader.load_array1(&format!("{}.ln1.weight", prefix))?;
655 layer.ln1.set_gamma(weight);
656 }
657 if loader.has_tensor(&format!("{}.ln1.bias", prefix)) {
658 let bias = loader.load_array1(&format!("{}.ln1.bias", prefix))?;
659 layer.ln1.set_beta(bias);
660 }
661
662 if loader.has_tensor(&format!("{}.ln2.weight", prefix)) {
664 let weight = loader.load_array1(&format!("{}.ln2.weight", prefix))?;
665 layer.ln2.set_gamma(weight);
666 }
667 if loader.has_tensor(&format!("{}.ln2.bias", prefix)) {
668 let bias = loader.load_array1(&format!("{}.ln2.bias", prefix))?;
669 layer.ln2.set_beta(bias);
670 }
671
672 let tm_prefix = format!("{}.time_mixing", prefix);
674 if loader.has_tensor(&format!("{}.time_mix_k", tm_prefix)) {
675 layer.time_mixing.time_mix_k =
676 loader.load_array1(&format!("{}.time_mix_k", tm_prefix))?;
677 }
678 if loader.has_tensor(&format!("{}.time_mix_v", tm_prefix)) {
679 layer.time_mixing.time_mix_v =
680 loader.load_array1(&format!("{}.time_mix_v", tm_prefix))?;
681 }
682 if loader.has_tensor(&format!("{}.time_mix_r", tm_prefix)) {
683 layer.time_mixing.time_mix_r =
684 loader.load_array1(&format!("{}.time_mix_r", tm_prefix))?;
685 }
686 if loader.has_tensor(&format!("{}.time_mix_g", tm_prefix)) {
687 layer.time_mixing.time_mix_g =
688 loader.load_array1(&format!("{}.time_mix_g", tm_prefix))?;
689 }
690 if loader.has_tensor(&format!("{}.time_decay", tm_prefix)) {
691 layer.time_mixing.time_decay =
692 loader.load_array2(&format!("{}.time_decay", tm_prefix))?;
693 }
694 if loader.has_tensor(&format!("{}.key_proj", tm_prefix)) {
695 layer.time_mixing.key_proj =
696 loader.load_array2(&format!("{}.key_proj", tm_prefix))?;
697 }
698 if loader.has_tensor(&format!("{}.value_proj", tm_prefix)) {
699 layer.time_mixing.value_proj =
700 loader.load_array2(&format!("{}.value_proj", tm_prefix))?;
701 }
702 if loader.has_tensor(&format!("{}.receptance_proj", tm_prefix)) {
703 layer.time_mixing.receptance_proj =
704 loader.load_array2(&format!("{}.receptance_proj", tm_prefix))?;
705 }
706 if loader.has_tensor(&format!("{}.gate_proj", tm_prefix)) {
707 layer.time_mixing.gate_proj =
708 loader.load_array2(&format!("{}.gate_proj", tm_prefix))?;
709 }
710 if loader.has_tensor(&format!("{}.output_proj", tm_prefix)) {
711 layer.time_mixing.output_proj =
712 loader.load_array2(&format!("{}.output_proj", tm_prefix))?;
713 }
714
715 let cm_prefix = format!("{}.channel_mixing", prefix);
717 if loader.has_tensor(&format!("{}.time_mix_k", cm_prefix)) {
718 layer.channel_mixing.time_mix_k =
719 loader.load_array1(&format!("{}.time_mix_k", cm_prefix))?;
720 }
721 if loader.has_tensor(&format!("{}.time_mix_r", cm_prefix)) {
722 layer.channel_mixing.time_mix_r =
723 loader.load_array1(&format!("{}.time_mix_r", cm_prefix))?;
724 }
725 if loader.has_tensor(&format!("{}.key_proj", cm_prefix)) {
726 layer.channel_mixing.key_proj =
727 loader.load_array2(&format!("{}.key_proj", cm_prefix))?;
728 }
729 if loader.has_tensor(&format!("{}.value_proj", cm_prefix)) {
730 layer.channel_mixing.value_proj =
731 loader.load_array2(&format!("{}.value_proj", cm_prefix))?;
732 }
733 if loader.has_tensor(&format!("{}.receptance_proj", cm_prefix)) {
734 layer.channel_mixing.receptance_proj =
735 loader.load_array2(&format!("{}.receptance_proj", cm_prefix))?;
736 }
737 }
738
739 Ok(())
740 }
741
742 #[allow(unused_variables)]
744 pub fn save_weights(&self, path: &str) -> ModelResult<()> {
745 Err(ModelError::simple_load_error(
747 "RWKV save_weights not yet implemented".to_string(),
748 ))
749 }
750}
751
752impl SignalPredictor for Rwkv {
753 #[instrument(skip(self, input))]
754 fn step(&mut self, input: &Array1<f32>) -> CoreResult<Array1<f32>> {
755 let mut hidden = input.dot(&self.input_proj);
757
758 for layer in &mut self.layers {
760 hidden = layer.forward(&hidden)?;
761 }
762
763 hidden = self.ln_out.forward(&hidden);
765
766 let output = hidden.dot(&self.output_proj);
768 Ok(output)
769 }
770
771 fn reset(&mut self) {
772 for layer in &mut self.layers {
773 layer.reset();
774 }
775 }
776
777 fn context_window(&self) -> usize {
778 usize::MAX
780 }
781}
782
783impl AutoregressiveModel for Rwkv {
784 fn hidden_dim(&self) -> usize {
785 self.config.hidden_dim
786 }
787
788 fn state_dim(&self) -> usize {
789 self.config.head_dim
790 }
791
792 fn num_layers(&self) -> usize {
793 self.config.num_layers
794 }
795
796 fn model_type(&self) -> ModelType {
797 ModelType::Rwkv
798 }
799
800 fn get_states(&self) -> Vec<HiddenState> {
801 self.layers
803 .iter()
804 .map(|layer| {
805 let total_size = layer.time_mixing.num_heads * layer.time_mixing.head_dim;
807 let mut combined = Array2::zeros((total_size, 1));
808
809 for (head_idx, head_state) in layer.time_mixing.wkv_state.iter().enumerate() {
810 let start_idx = head_idx * layer.time_mixing.head_dim;
811 for i in 0..layer.time_mixing.head_dim.min(head_state.len()) {
812 combined[[start_idx + i, 0]] = head_state[i];
813 }
814 }
815
816 let mut hs = HiddenState::new(combined.shape()[0], combined.shape()[1]);
817 hs.update(combined);
818 hs
819 })
820 .collect()
821 }
822
823 fn set_states(&mut self, states: Vec<HiddenState>) -> ModelResult<()> {
824 if states.len() != self.config.num_layers {
825 return Err(ModelError::state_count_mismatch(
826 "RWKV",
827 self.config.num_layers,
828 states.len(),
829 ));
830 }
831
832 for (layer_idx, layer) in self.layers.iter_mut().enumerate() {
833 let combined = states[layer_idx].state();
834
835 for (head_idx, head_state) in layer.time_mixing.wkv_state.iter_mut().enumerate() {
837 let start_idx = head_idx * layer.time_mixing.head_dim;
838 for i in 0..layer.time_mixing.head_dim.min(head_state.len()) {
839 if start_idx + i < combined.shape()[0] && 0 < combined.shape()[1] {
840 head_state[i] = combined[[start_idx + i, 0]];
841 }
842 }
843 }
844 }
845
846 Ok(())
847 }
848}
849
850#[cfg(test)]
851mod tests {
852 use super::*;
853
854 #[test]
855 fn test_rwkv_config() {
856 let config = RwkvConfig::new().hidden_dim(512).num_heads(8).num_layers(6);
857
858 assert_eq!(config.hidden_dim, 512);
859 assert_eq!(config.num_heads, 8);
860 assert_eq!(config.head_dim, 64);
861 assert!(config.validate().is_ok());
862 }
863
864 #[test]
865 fn test_rwkv_creation() {
866 let config = RwkvConfig::new().hidden_dim(128).num_heads(4).num_layers(2);
869 let model = Rwkv::new(config);
870 assert!(model.is_ok());
871 }
872
873 #[test]
874 fn test_rwkv_forward() {
875 let config = RwkvConfig::new().hidden_dim(128).num_heads(4).num_layers(2);
876 let mut model = Rwkv::new(config).expect("Failed to create RWKV model");
877
878 let input = Array1::from_vec(vec![0.5]);
879 let output = model.step(&input);
880 assert!(output.is_ok());
881 }
882
883 #[test]
884 fn test_invalid_config() {
885 let config = RwkvConfig::new().hidden_dim(100).num_heads(3); assert!(config.validate().is_err());
887 }
888}