kizzasi_model/
rwkv.rs

1//! RWKV v6: Receptance Weighted Key Value
2//!
3//! RWKV is a novel RNN architecture that combines the efficient parallelizable training
4//! of Transformers with the efficient inference of RNNs. Unlike attention, RWKV uses
5//! linear attention with time-mixing, achieving O(1) per-step inference complexity.
6//!
7//! # RWKV v6 Features
8//!
9//! - **Time-Mixing**: Linear attention with exponential decay
10//! - **Channel-Mixing**: Token-shift with gated linear units
11//! - **Efficient Training**: Parallelizable via WKV algorithm
12//! - **O(1) Inference**: Constant memory and time per step
13//! - **No Positional Encoding**: Time awareness through mixing
14//!
15//! # Architecture
16//!
17//! ```text
18//! Input → [LayerNorm] → [Time-Mixing] → [Add] →
19//!           ↓                                   ↓
20//!        [LayerNorm] → [Channel-Mixing] → [Add] → Output
21//! ```
22//!
23//! # References
24//!
25//! - RWKV paper: https://arxiv.org/abs/2305.13048
26//! - RWKV v6 improvements: Enhanced stability and performance
27
28use crate::error::{ModelError, ModelResult};
29use crate::{AutoregressiveModel, ModelType};
30use kizzasi_core::{sigmoid, silu, CoreResult, HiddenState, LayerNorm, NormType, SignalPredictor};
31use scirs2_core::ndarray::{Array1, Array2};
32use scirs2_core::random::{rng, Rng};
33#[allow(unused_imports)]
34use tracing::{debug, instrument, trace};
35
36/// Configuration for RWKV v6
37#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
38pub struct RwkvConfig {
39    /// Input dimension
40    pub input_dim: usize,
41    /// Hidden dimension (d_model)
42    pub hidden_dim: usize,
43    /// Intermediate dimension for FFN (typically 4x hidden_dim)
44    pub intermediate_dim: usize,
45    /// Number of layers
46    pub num_layers: usize,
47    /// Number of attention heads (v6 uses multi-head)
48    pub num_heads: usize,
49    /// Head dimension
50    pub head_dim: usize,
51    /// Dropout rate
52    pub dropout: f32,
53    /// Time decay initialization
54    pub time_decay_init: f32,
55    /// Use RMSNorm instead of LayerNorm
56    pub use_rms_norm: bool,
57}
58
59impl Default for RwkvConfig {
60    fn default() -> Self {
61        let hidden_dim = 512;
62        let num_heads = 8;
63        Self {
64            input_dim: 1,
65            hidden_dim,
66            intermediate_dim: hidden_dim * 4,
67            num_layers: 12,
68            num_heads,
69            head_dim: hidden_dim / num_heads,
70            dropout: 0.0,
71            time_decay_init: -5.0,
72            use_rms_norm: true,
73        }
74    }
75}
76
77impl RwkvConfig {
78    /// Create a new RWKV configuration
79    pub fn new() -> Self {
80        Self::default()
81    }
82
83    /// Set input dimension
84    pub fn input_dim(mut self, dim: usize) -> Self {
85        self.input_dim = dim;
86        self
87    }
88
89    /// Set hidden dimension
90    pub fn hidden_dim(mut self, dim: usize) -> Self {
91        self.hidden_dim = dim;
92        self.head_dim = dim / self.num_heads;
93        self
94    }
95
96    /// Set intermediate dimension
97    pub fn intermediate_dim(mut self, dim: usize) -> Self {
98        self.intermediate_dim = dim;
99        self
100    }
101
102    /// Set number of layers
103    pub fn num_layers(mut self, n: usize) -> Self {
104        self.num_layers = n;
105        self
106    }
107
108    /// Set number of heads
109    pub fn num_heads(mut self, n: usize) -> Self {
110        self.num_heads = n;
111        self.head_dim = self.hidden_dim / n;
112        self
113    }
114
115    /// Validate the configuration
116    pub fn validate(&self) -> ModelResult<()> {
117        if self.hidden_dim == 0 {
118            return Err(ModelError::invalid_config("hidden_dim must be > 0"));
119        }
120        if self.num_layers == 0 {
121            return Err(ModelError::invalid_config("num_layers must be > 0"));
122        }
123        if self.num_heads == 0 {
124            return Err(ModelError::invalid_config("num_heads must be > 0"));
125        }
126        if !self.hidden_dim.is_multiple_of(self.num_heads) {
127            return Err(ModelError::invalid_config(
128                "hidden_dim must be divisible by num_heads",
129            ));
130        }
131        Ok(())
132    }
133}
134
135/// RWKV Time-Mixing block
136///
137/// Implements linear attention with time decay:
138/// wkv[t] = (w * wkv[t-1] + k[t] * v[t]) / (w * aa[t-1] + k[t])
139struct TimeMixing {
140    hidden_dim: usize,
141    num_heads: usize,
142    head_dim: usize,
143
144    /// Time-mixing parameters
145    time_mix_k: Array1<f32>,
146    #[allow(dead_code)]
147    time_mix_v: Array1<f32>, // Reserved for future use
148    time_mix_r: Array1<f32>,
149    time_mix_g: Array1<f32>,
150
151    /// Time decay (per head)
152    time_decay: Array2<f32>, // [num_heads, head_dim]
153
154    /// Projection matrices
155    key_proj: Array2<f32>,
156    value_proj: Array2<f32>,
157    receptance_proj: Array2<f32>,
158    gate_proj: Array2<f32>,
159    output_proj: Array2<f32>,
160
161    /// State: WKV accumulator and normalizer per head
162    wkv_state: Vec<Array1<f32>>, // [num_heads][head_dim]
163    wkv_norm: Vec<f32>, // [num_heads]
164    prev_x: Array1<f32>,
165}
166
167impl TimeMixing {
168    fn new(config: &RwkvConfig) -> ModelResult<Self> {
169        let mut rng = rng();
170
171        // Initialize time-mixing parameters (learnable interpolation)
172        let time_mix_k = Array1::from_shape_fn(config.hidden_dim, |_| rng.random::<f32>());
173        let time_mix_v = Array1::from_shape_fn(config.hidden_dim, |_| rng.random::<f32>());
174        let time_mix_r = Array1::from_shape_fn(config.hidden_dim, |_| rng.random::<f32>());
175        let time_mix_g = Array1::from_shape_fn(config.hidden_dim, |_| rng.random::<f32>());
176
177        // Initialize time decay (log scale, negative for decay)
178        let time_decay = Array2::from_shape_fn((config.num_heads, config.head_dim), |(h, i)| {
179            // Different decay rates per head and dimension
180            config.time_decay_init - (h as f32 * 0.1) - (i as f32 * 0.01)
181        });
182
183        // Initialize projection matrices
184        let scale = (2.0 / config.hidden_dim as f32).sqrt();
185        let key_proj = Array2::from_shape_fn((config.hidden_dim, config.hidden_dim), |_| {
186            (rng.random::<f32>() - 0.5) * 2.0 * scale
187        });
188        let value_proj = Array2::from_shape_fn((config.hidden_dim, config.hidden_dim), |_| {
189            (rng.random::<f32>() - 0.5) * 2.0 * scale
190        });
191        let receptance_proj = Array2::from_shape_fn((config.hidden_dim, config.hidden_dim), |_| {
192            (rng.random::<f32>() - 0.5) * 2.0 * scale
193        });
194        let gate_proj = Array2::from_shape_fn((config.hidden_dim, config.hidden_dim), |_| {
195            (rng.random::<f32>() - 0.5) * 2.0 * scale
196        });
197        let output_proj = Array2::from_shape_fn((config.hidden_dim, config.hidden_dim), |_| {
198            (rng.random::<f32>() - 0.5) * 2.0 * scale
199        });
200
201        // Initialize states
202        let wkv_state = (0..config.num_heads)
203            .map(|_| Array1::zeros(config.head_dim))
204            .collect();
205        let wkv_norm = vec![0.0; config.num_heads];
206        let prev_x = Array1::zeros(config.hidden_dim);
207
208        Ok(Self {
209            hidden_dim: config.hidden_dim,
210            num_heads: config.num_heads,
211            head_dim: config.head_dim,
212            time_mix_k,
213            time_mix_v,
214            time_mix_r,
215            time_mix_g,
216            time_decay,
217            key_proj,
218            value_proj,
219            receptance_proj,
220            gate_proj,
221            output_proj,
222            wkv_state,
223            wkv_norm,
224            prev_x,
225        })
226    }
227
228    fn forward(&mut self, x: &Array1<f32>) -> CoreResult<Array1<f32>> {
229        let batch_size = x.len().min(self.hidden_dim);
230
231        // Time-mixing: interpolate between current and previous input
232        let mut xx = Array1::zeros(batch_size);
233        for i in 0..batch_size {
234            let prev_val = if i < self.prev_x.len() {
235                self.prev_x[i]
236            } else {
237                0.0
238            };
239            xx[i] = self.time_mix_k[i] * x[i] + (1.0 - self.time_mix_k[i]) * prev_val;
240        }
241
242        // Compute K, V, R, G projections
243        let k = self.project(&xx, &self.key_proj);
244        let v = self.project(&xx, &self.value_proj);
245
246        let mut xr = Array1::zeros(batch_size);
247        for i in 0..batch_size {
248            let prev_val = if i < self.prev_x.len() {
249                self.prev_x[i]
250            } else {
251                0.0
252            };
253            xr[i] = self.time_mix_r[i] * x[i] + (1.0 - self.time_mix_r[i]) * prev_val;
254        }
255        let r = self.project(&xr, &self.receptance_proj);
256
257        let mut xg = Array1::zeros(batch_size);
258        for i in 0..batch_size {
259            let prev_val = if i < self.prev_x.len() {
260                self.prev_x[i]
261            } else {
262                0.0
263            };
264            xg[i] = self.time_mix_g[i] * x[i] + (1.0 - self.time_mix_g[i]) * prev_val;
265        }
266        let g = self.project(&xg, &self.gate_proj);
267
268        // WKV: Weighted Key-Value with time decay (per head)
269        let mut wkv_output = Array1::zeros(batch_size);
270
271        for head in 0..self.num_heads {
272            let head_start = head * self.head_dim;
273            let head_end = (head_start + self.head_dim).min(batch_size);
274
275            for i in 0..(head_end - head_start) {
276                let idx = head_start + i;
277                if idx >= k.len() || idx >= v.len() {
278                    break;
279                }
280
281                // Get time decay for this head and dimension
282                let w = self.time_decay[[head, i]].exp();
283
284                // Update WKV state: wkv[t] = w * wkv[t-1] + k[t] * v[t]
285                let new_wkv = w * self.wkv_state[head][i] + k[idx] * v[idx];
286                self.wkv_state[head][i] = new_wkv;
287
288                // Update normalizer: norm[t] = w * norm[t-1] + k[t]
289                self.wkv_norm[head] = w * self.wkv_norm[head] + k[idx];
290
291                // Output: wkv / norm
292                let norm = self.wkv_norm[head].max(1e-8);
293                wkv_output[idx] = new_wkv / norm;
294            }
295        }
296
297        // Apply receptance (gating)
298        let r_sigmoid = sigmoid(&r);
299        for i in 0..wkv_output.len().min(r_sigmoid.len()) {
300            wkv_output[i] *= r_sigmoid[i];
301        }
302
303        // Apply group normalization (v6 feature)
304        let g_silu = silu(&g);
305        for i in 0..wkv_output.len().min(g_silu.len()) {
306            wkv_output[i] *= g_silu[i];
307        }
308
309        // Output projection
310        let output = self.project(&wkv_output, &self.output_proj);
311
312        // Update previous input
313        self.prev_x = Array1::from_vec(x.iter().take(self.hidden_dim).copied().collect());
314
315        Ok(output)
316    }
317
318    fn project(&self, x: &Array1<f32>, weight: &Array2<f32>) -> Array1<f32> {
319        let out_dim = weight.shape()[0];
320        let mut output = Array1::zeros(out_dim.min(x.len()));
321        for i in 0..output.len() {
322            let mut sum = 0.0;
323            for j in 0..x.len().min(weight.shape()[1]) {
324                sum += weight[[i, j]] * x[j];
325            }
326            output[i] = sum;
327        }
328        output
329    }
330
331    fn reset(&mut self) {
332        for state in &mut self.wkv_state {
333            state.fill(0.0);
334        }
335        self.wkv_norm.fill(0.0);
336        self.prev_x.fill(0.0);
337    }
338}
339
340/// RWKV Channel-Mixing block
341///
342/// Token-shifted feed-forward network with gated linear units
343struct ChannelMixing {
344    hidden_dim: usize,
345    intermediate_dim: usize,
346
347    /// Time-mixing parameter for channel mixing
348    time_mix_k: Array1<f32>,
349    time_mix_r: Array1<f32>,
350
351    /// Projection matrices
352    key_proj: Array2<f32>,
353    value_proj: Array2<f32>,
354    receptance_proj: Array2<f32>,
355
356    /// Previous input
357    prev_x: Array1<f32>,
358}
359
360impl ChannelMixing {
361    fn new(config: &RwkvConfig) -> ModelResult<Self> {
362        let mut rng = rng();
363
364        // Initialize time-mixing parameters
365        let time_mix_k = Array1::from_shape_fn(config.hidden_dim, |_| rng.random::<f32>());
366        let time_mix_r = Array1::from_shape_fn(config.hidden_dim, |_| rng.random::<f32>());
367
368        // Initialize projection matrices
369        let scale = (2.0 / config.hidden_dim as f32).sqrt();
370        let key_proj = Array2::from_shape_fn((config.hidden_dim, config.intermediate_dim), |_| {
371            (rng.random::<f32>() - 0.5) * 2.0 * scale
372        });
373
374        let value_proj =
375            Array2::from_shape_fn((config.intermediate_dim, config.hidden_dim), |_| {
376                (rng.random::<f32>() - 0.5) * 2.0 * scale
377            });
378
379        let receptance_proj = Array2::from_shape_fn((config.hidden_dim, config.hidden_dim), |_| {
380            (rng.random::<f32>() - 0.5) * 2.0 * scale
381        });
382
383        let prev_x = Array1::zeros(config.hidden_dim);
384
385        Ok(Self {
386            hidden_dim: config.hidden_dim,
387            intermediate_dim: config.intermediate_dim,
388            time_mix_k,
389            time_mix_r,
390            key_proj,
391            value_proj,
392            receptance_proj,
393            prev_x,
394        })
395    }
396
397    fn forward(&mut self, x: &Array1<f32>) -> CoreResult<Array1<f32>> {
398        let batch_size = x.len().min(self.hidden_dim);
399
400        // Time-mixing for key
401        let mut xk = Array1::zeros(batch_size);
402        for i in 0..batch_size {
403            let prev_val = if i < self.prev_x.len() {
404                self.prev_x[i]
405            } else {
406                0.0
407            };
408            xk[i] = self.time_mix_k[i] * x[i] + (1.0 - self.time_mix_k[i]) * prev_val;
409        }
410
411        // Time-mixing for receptance
412        let mut xr = Array1::zeros(batch_size);
413        for i in 0..batch_size {
414            let prev_val = if i < self.prev_x.len() {
415                self.prev_x[i]
416            } else {
417                0.0
418            };
419            xr[i] = self.time_mix_r[i] * x[i] + (1.0 - self.time_mix_r[i]) * prev_val;
420        }
421
422        // Project and apply activation
423        let k = self.project(&xk, &self.key_proj);
424        let k_squared = k.mapv(|v| v * v); // Squared ReLU
425        let vk = self.project_back(&k_squared, &self.value_proj);
426
427        // Apply receptance gating
428        let r = self.project_r(&xr, &self.receptance_proj);
429        let r_sigmoid = sigmoid(&r);
430
431        let mut output = Array1::zeros(batch_size);
432        for i in 0..output.len().min(vk.len()).min(r_sigmoid.len()) {
433            output[i] = r_sigmoid[i] * vk[i];
434        }
435
436        // Update previous input
437        self.prev_x = Array1::from_vec(x.iter().take(self.hidden_dim).copied().collect());
438
439        Ok(output)
440    }
441
442    fn project(&self, x: &Array1<f32>, weight: &Array2<f32>) -> Array1<f32> {
443        let out_dim = weight.shape()[1].min(self.intermediate_dim);
444        let mut output = Array1::zeros(out_dim);
445        for i in 0..out_dim {
446            let mut sum = 0.0;
447            for j in 0..x.len().min(weight.shape()[0]) {
448                sum += weight[[j, i]] * x[j];
449            }
450            output[i] = sum;
451        }
452        output
453    }
454
455    fn project_back(&self, x: &Array1<f32>, weight: &Array2<f32>) -> Array1<f32> {
456        let out_dim = weight.shape()[1].min(self.hidden_dim);
457        let mut output = Array1::zeros(out_dim);
458        for i in 0..out_dim {
459            let mut sum = 0.0;
460            for j in 0..x.len().min(weight.shape()[0]) {
461                sum += weight[[j, i]] * x[j];
462            }
463            output[i] = sum;
464        }
465        output
466    }
467
468    fn project_r(&self, x: &Array1<f32>, weight: &Array2<f32>) -> Array1<f32> {
469        let out_dim = weight.shape()[0];
470        let mut output = Array1::zeros(out_dim.min(x.len()));
471        for i in 0..output.len() {
472            let mut sum = 0.0;
473            for j in 0..x.len().min(weight.shape()[1]) {
474                sum += weight[[i, j]] * x[j];
475            }
476            output[i] = sum;
477        }
478        output
479    }
480
481    fn reset(&mut self) {
482        self.prev_x.fill(0.0);
483    }
484}
485
486/// RWKV Layer combining time-mixing and channel-mixing
487struct RwkvLayer {
488    ln1: LayerNorm,
489    ln2: LayerNorm,
490    time_mixing: TimeMixing,
491    channel_mixing: ChannelMixing,
492}
493
494impl RwkvLayer {
495    fn new(config: &RwkvConfig) -> ModelResult<Self> {
496        let norm_type = if config.use_rms_norm {
497            NormType::RMSNorm
498        } else {
499            NormType::LayerNorm
500        };
501
502        let ln1 = LayerNorm::new(config.hidden_dim, norm_type).with_eps(1e-5);
503        let ln2 = LayerNorm::new(config.hidden_dim, norm_type).with_eps(1e-5);
504        let time_mixing = TimeMixing::new(config)?;
505        let channel_mixing = ChannelMixing::new(config)?;
506
507        Ok(Self {
508            ln1,
509            ln2,
510            time_mixing,
511            channel_mixing,
512        })
513    }
514
515    fn forward(&mut self, x: &Array1<f32>) -> CoreResult<Array1<f32>> {
516        // Time-mixing with residual
517        let x_norm = self.ln1.forward(x);
518        let tm_out = self.time_mixing.forward(&x_norm)?;
519        let mut x_tm = x.clone();
520        for i in 0..x_tm.len().min(tm_out.len()) {
521            x_tm[i] += tm_out[i];
522        }
523
524        // Channel-mixing with residual
525        let x_norm2 = self.ln2.forward(&x_tm);
526        let cm_out = self.channel_mixing.forward(&x_norm2)?;
527        let mut output = x_tm;
528        for i in 0..output.len().min(cm_out.len()) {
529            output[i] += cm_out[i];
530        }
531
532        Ok(output)
533    }
534
535    fn reset(&mut self) {
536        self.time_mixing.reset();
537        self.channel_mixing.reset();
538    }
539}
540
541/// RWKV v6 model
542pub struct Rwkv {
543    config: RwkvConfig,
544    layers: Vec<RwkvLayer>,
545    ln_out: LayerNorm,
546    input_proj: Array2<f32>,
547    output_proj: Array2<f32>,
548}
549
550impl Rwkv {
551    /// Create a new RWKV model
552    pub fn new(config: RwkvConfig) -> ModelResult<Self> {
553        config.validate()?;
554
555        // Initialize layers
556        let mut layers = Vec::with_capacity(config.num_layers);
557        for _ in 0..config.num_layers {
558            layers.push(RwkvLayer::new(&config)?);
559        }
560
561        // Output layer normalization
562        let norm_type = if config.use_rms_norm {
563            NormType::RMSNorm
564        } else {
565            NormType::LayerNorm
566        };
567        let ln_out = LayerNorm::new(config.hidden_dim, norm_type).with_eps(1e-5);
568
569        // Initialize input/output projections
570        let mut rng = rng();
571        let scale = (2.0 / (config.input_dim + config.hidden_dim) as f32).sqrt();
572        let input_proj = Array2::from_shape_fn((config.input_dim, config.hidden_dim), |_| {
573            (rng.random::<f32>() - 0.5) * 2.0 * scale
574        });
575
576        let scale = (2.0 / (config.hidden_dim + config.input_dim) as f32).sqrt();
577        let output_proj = Array2::from_shape_fn((config.hidden_dim, config.input_dim), |_| {
578            (rng.random::<f32>() - 0.5) * 2.0 * scale
579        });
580
581        Ok(Self {
582            config,
583            layers,
584            ln_out,
585            input_proj,
586            output_proj,
587        })
588    }
589
590    /// Get the configuration
591    pub fn config(&self) -> &RwkvConfig {
592        &self.config
593    }
594
595    /// Load weights from a SafeTensors model file
596    ///
597    /// # Weight Naming Convention
598    ///
599    /// The following tensor names are expected:
600    /// - `input_proj`: Input projection matrix (input_dim, hidden_dim)
601    /// - `output_proj`: Output projection matrix (hidden_dim, input_dim)
602    /// - `ln_out.weight`: Output layer norm weight
603    /// - `ln_out.bias`: Output layer norm bias (optional)
604    ///
605    /// For each layer i:
606    /// - `layers.{i}.ln1.weight`: Time-mixing layer norm weight
607    /// - `layers.{i}.ln1.bias`: Time-mixing layer norm bias (optional)
608    /// - `layers.{i}.ln2.weight`: Channel-mixing layer norm weight
609    /// - `layers.{i}.ln2.bias`: Channel-mixing layer norm bias (optional)
610    ///
611    /// Time-mixing parameters:
612    /// - `layers.{i}.time_mixing.time_mix_k`: Time mixing for key
613    /// - `layers.{i}.time_mixing.time_mix_v`: Time mixing for value
614    /// - `layers.{i}.time_mixing.time_mix_r`: Time mixing for receptance
615    /// - `layers.{i}.time_mixing.time_mix_g`: Time mixing for gate
616    /// - `layers.{i}.time_mixing.time_decay`: Time decay matrix
617    /// - `layers.{i}.time_mixing.key_proj`: Key projection
618    /// - `layers.{i}.time_mixing.value_proj`: Value projection
619    /// - `layers.{i}.time_mixing.receptance_proj`: Receptance projection
620    /// - `layers.{i}.time_mixing.gate_proj`: Gate projection
621    /// - `layers.{i}.time_mixing.output_proj`: Output projection
622    ///
623    /// Channel-mixing parameters:
624    /// - `layers.{i}.channel_mixing.time_mix_k`: Time mixing for key
625    /// - `layers.{i}.channel_mixing.time_mix_r`: Time mixing for receptance
626    /// - `layers.{i}.channel_mixing.key_proj`: Key projection
627    /// - `layers.{i}.channel_mixing.value_proj`: Value projection
628    /// - `layers.{i}.channel_mixing.receptance_proj`: Receptance projection
629    pub fn load_weights(&mut self, loader: &crate::loader::ModelLoader) -> ModelResult<()> {
630        // Load input/output projections
631        if loader.has_tensor("input_proj") {
632            self.input_proj = loader.load_array2("input_proj")?;
633        }
634        if loader.has_tensor("output_proj") {
635            self.output_proj = loader.load_array2("output_proj")?;
636        }
637
638        // Load output layer norm
639        if loader.has_tensor("ln_out.weight") {
640            let weight = loader.load_array1("ln_out.weight")?;
641            self.ln_out.set_gamma(weight);
642        }
643        if loader.has_tensor("ln_out.bias") {
644            let bias = loader.load_array1("ln_out.bias")?;
645            self.ln_out.set_beta(bias);
646        }
647
648        // Load each layer's weights
649        for (i, layer) in self.layers.iter_mut().enumerate() {
650            let prefix = format!("layers.{}", i);
651
652            // Load layer norm 1
653            if loader.has_tensor(&format!("{}.ln1.weight", prefix)) {
654                let weight = loader.load_array1(&format!("{}.ln1.weight", prefix))?;
655                layer.ln1.set_gamma(weight);
656            }
657            if loader.has_tensor(&format!("{}.ln1.bias", prefix)) {
658                let bias = loader.load_array1(&format!("{}.ln1.bias", prefix))?;
659                layer.ln1.set_beta(bias);
660            }
661
662            // Load layer norm 2
663            if loader.has_tensor(&format!("{}.ln2.weight", prefix)) {
664                let weight = loader.load_array1(&format!("{}.ln2.weight", prefix))?;
665                layer.ln2.set_gamma(weight);
666            }
667            if loader.has_tensor(&format!("{}.ln2.bias", prefix)) {
668                let bias = loader.load_array1(&format!("{}.ln2.bias", prefix))?;
669                layer.ln2.set_beta(bias);
670            }
671
672            // Load time-mixing parameters
673            let tm_prefix = format!("{}.time_mixing", prefix);
674            if loader.has_tensor(&format!("{}.time_mix_k", tm_prefix)) {
675                layer.time_mixing.time_mix_k =
676                    loader.load_array1(&format!("{}.time_mix_k", tm_prefix))?;
677            }
678            if loader.has_tensor(&format!("{}.time_mix_v", tm_prefix)) {
679                layer.time_mixing.time_mix_v =
680                    loader.load_array1(&format!("{}.time_mix_v", tm_prefix))?;
681            }
682            if loader.has_tensor(&format!("{}.time_mix_r", tm_prefix)) {
683                layer.time_mixing.time_mix_r =
684                    loader.load_array1(&format!("{}.time_mix_r", tm_prefix))?;
685            }
686            if loader.has_tensor(&format!("{}.time_mix_g", tm_prefix)) {
687                layer.time_mixing.time_mix_g =
688                    loader.load_array1(&format!("{}.time_mix_g", tm_prefix))?;
689            }
690            if loader.has_tensor(&format!("{}.time_decay", tm_prefix)) {
691                layer.time_mixing.time_decay =
692                    loader.load_array2(&format!("{}.time_decay", tm_prefix))?;
693            }
694            if loader.has_tensor(&format!("{}.key_proj", tm_prefix)) {
695                layer.time_mixing.key_proj =
696                    loader.load_array2(&format!("{}.key_proj", tm_prefix))?;
697            }
698            if loader.has_tensor(&format!("{}.value_proj", tm_prefix)) {
699                layer.time_mixing.value_proj =
700                    loader.load_array2(&format!("{}.value_proj", tm_prefix))?;
701            }
702            if loader.has_tensor(&format!("{}.receptance_proj", tm_prefix)) {
703                layer.time_mixing.receptance_proj =
704                    loader.load_array2(&format!("{}.receptance_proj", tm_prefix))?;
705            }
706            if loader.has_tensor(&format!("{}.gate_proj", tm_prefix)) {
707                layer.time_mixing.gate_proj =
708                    loader.load_array2(&format!("{}.gate_proj", tm_prefix))?;
709            }
710            if loader.has_tensor(&format!("{}.output_proj", tm_prefix)) {
711                layer.time_mixing.output_proj =
712                    loader.load_array2(&format!("{}.output_proj", tm_prefix))?;
713            }
714
715            // Load channel-mixing parameters
716            let cm_prefix = format!("{}.channel_mixing", prefix);
717            if loader.has_tensor(&format!("{}.time_mix_k", cm_prefix)) {
718                layer.channel_mixing.time_mix_k =
719                    loader.load_array1(&format!("{}.time_mix_k", cm_prefix))?;
720            }
721            if loader.has_tensor(&format!("{}.time_mix_r", cm_prefix)) {
722                layer.channel_mixing.time_mix_r =
723                    loader.load_array1(&format!("{}.time_mix_r", cm_prefix))?;
724            }
725            if loader.has_tensor(&format!("{}.key_proj", cm_prefix)) {
726                layer.channel_mixing.key_proj =
727                    loader.load_array2(&format!("{}.key_proj", cm_prefix))?;
728            }
729            if loader.has_tensor(&format!("{}.value_proj", cm_prefix)) {
730                layer.channel_mixing.value_proj =
731                    loader.load_array2(&format!("{}.value_proj", cm_prefix))?;
732            }
733            if loader.has_tensor(&format!("{}.receptance_proj", cm_prefix)) {
734                layer.channel_mixing.receptance_proj =
735                    loader.load_array2(&format!("{}.receptance_proj", cm_prefix))?;
736            }
737        }
738
739        Ok(())
740    }
741
742    /// Save weights to a SafeTensors model file (stub for future implementation)
743    #[allow(unused_variables)]
744    pub fn save_weights(&self, path: &str) -> ModelResult<()> {
745        // TODO: Implement SafeTensors saving
746        Err(ModelError::simple_load_error(
747            "RWKV save_weights not yet implemented".to_string(),
748        ))
749    }
750}
751
752impl SignalPredictor for Rwkv {
753    #[instrument(skip(self, input))]
754    fn step(&mut self, input: &Array1<f32>) -> CoreResult<Array1<f32>> {
755        // Project input to hidden dimension
756        let mut hidden = input.dot(&self.input_proj);
757
758        // Pass through each layer
759        for layer in &mut self.layers {
760            hidden = layer.forward(&hidden)?;
761        }
762
763        // Final layer normalization
764        hidden = self.ln_out.forward(&hidden);
765
766        // Project back to input dimension
767        let output = hidden.dot(&self.output_proj);
768        Ok(output)
769    }
770
771    fn reset(&mut self) {
772        for layer in &mut self.layers {
773            layer.reset();
774        }
775    }
776
777    fn context_window(&self) -> usize {
778        // RWKV has theoretically infinite context via recurrence
779        usize::MAX
780    }
781}
782
783impl AutoregressiveModel for Rwkv {
784    fn hidden_dim(&self) -> usize {
785        self.config.hidden_dim
786    }
787
788    fn state_dim(&self) -> usize {
789        self.config.head_dim
790    }
791
792    fn num_layers(&self) -> usize {
793        self.config.num_layers
794    }
795
796    fn model_type(&self) -> ModelType {
797        ModelType::Rwkv
798    }
799
800    fn get_states(&self) -> Vec<HiddenState> {
801        // Collect WKV states from each layer
802        self.layers
803            .iter()
804            .map(|layer| {
805                // Flatten multi-head WKV states
806                let total_size = layer.time_mixing.num_heads * layer.time_mixing.head_dim;
807                let mut combined = Array2::zeros((total_size, 1));
808
809                for (head_idx, head_state) in layer.time_mixing.wkv_state.iter().enumerate() {
810                    let start_idx = head_idx * layer.time_mixing.head_dim;
811                    for i in 0..layer.time_mixing.head_dim.min(head_state.len()) {
812                        combined[[start_idx + i, 0]] = head_state[i];
813                    }
814                }
815
816                let mut hs = HiddenState::new(combined.shape()[0], combined.shape()[1]);
817                hs.update(combined);
818                hs
819            })
820            .collect()
821    }
822
823    fn set_states(&mut self, states: Vec<HiddenState>) -> ModelResult<()> {
824        if states.len() != self.config.num_layers {
825            return Err(ModelError::state_count_mismatch(
826                "RWKV",
827                self.config.num_layers,
828                states.len(),
829            ));
830        }
831
832        for (layer_idx, layer) in self.layers.iter_mut().enumerate() {
833            let combined = states[layer_idx].state();
834
835            // Split combined state back into per-head WKV states
836            for (head_idx, head_state) in layer.time_mixing.wkv_state.iter_mut().enumerate() {
837                let start_idx = head_idx * layer.time_mixing.head_dim;
838                for i in 0..layer.time_mixing.head_dim.min(head_state.len()) {
839                    if start_idx + i < combined.shape()[0] && 0 < combined.shape()[1] {
840                        head_state[i] = combined[[start_idx + i, 0]];
841                    }
842                }
843            }
844        }
845
846        Ok(())
847    }
848}
849
850#[cfg(test)]
851mod tests {
852    use super::*;
853
854    #[test]
855    fn test_rwkv_config() {
856        let config = RwkvConfig::new().hidden_dim(512).num_heads(8).num_layers(6);
857
858        assert_eq!(config.hidden_dim, 512);
859        assert_eq!(config.num_heads, 8);
860        assert_eq!(config.head_dim, 64);
861        assert!(config.validate().is_ok());
862    }
863
864    #[test]
865    fn test_rwkv_creation() {
866        // Use smaller configuration for faster test
867        // Default has num_layers=12 which is slow to initialize
868        let config = RwkvConfig::new().hidden_dim(128).num_heads(4).num_layers(2);
869        let model = Rwkv::new(config);
870        assert!(model.is_ok());
871    }
872
873    #[test]
874    fn test_rwkv_forward() {
875        let config = RwkvConfig::new().hidden_dim(128).num_heads(4).num_layers(2);
876        let mut model = Rwkv::new(config).expect("Failed to create RWKV model");
877
878        let input = Array1::from_vec(vec![0.5]);
879        let output = model.step(&input);
880        assert!(output.is_ok());
881    }
882
883    #[test]
884    fn test_invalid_config() {
885        let config = RwkvConfig::new().hidden_dim(100).num_heads(3); // Not divisible
886        assert!(config.validate().is_err());
887    }
888}