1use crate::error::{ModelError, ModelResult};
29use crate::{AutoregressiveModel, ModelType};
30use kizzasi_core::{CoreResult, HiddenState, LayerNorm, NormType, SignalPredictor};
31use scirs2_core::ndarray::{Array1, Array2};
32use scirs2_core::random::{rng, Rng};
33use serde::{Deserialize, Serialize};
34
35#[allow(unused_imports)]
36use tracing::{debug, instrument, trace};
37
38#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct Rwkv7Config {
41 pub input_dim: usize,
43 pub hidden_dim: usize,
45 pub num_layers: usize,
47 pub num_heads: usize,
49 pub head_dim: usize,
51 pub intermediate_dim: usize,
53 pub time_decay_init: f32,
55 pub enhanced_gradient_flow: bool,
57 pub multi_modal: bool,
59 pub max_context_length: usize,
61}
62
63impl Default for Rwkv7Config {
64 fn default() -> Self {
65 let hidden_dim = 768;
66 let num_heads = 12;
67 Self {
68 input_dim: 1,
69 hidden_dim,
70 num_layers: 24,
71 num_heads,
72 head_dim: hidden_dim / num_heads,
73 intermediate_dim: hidden_dim * 4,
74 time_decay_init: -6.0, enhanced_gradient_flow: true,
76 multi_modal: false,
77 max_context_length: 8192,
78 }
79 }
80}
81
82impl Rwkv7Config {
83 pub fn new() -> Self {
85 Self::default()
86 }
87
88 pub fn small(input_dim: usize) -> Self {
90 Self {
91 input_dim,
92 hidden_dim: 512,
93 num_layers: 12,
94 num_heads: 8,
95 head_dim: 64,
96 intermediate_dim: 2048,
97 ..Default::default()
98 }
99 }
100
101 pub fn base(input_dim: usize) -> Self {
103 Self {
104 input_dim,
105 hidden_dim: 768,
106 num_layers: 24,
107 num_heads: 12,
108 head_dim: 64,
109 intermediate_dim: 3072,
110 ..Default::default()
111 }
112 }
113
114 pub fn large(input_dim: usize) -> Self {
116 Self {
117 input_dim,
118 hidden_dim: 4096,
119 num_layers: 32,
120 num_heads: 32,
121 head_dim: 128,
122 intermediate_dim: 16384,
123 max_context_length: 16384,
124 ..Default::default()
125 }
126 }
127
128 pub fn input_dim(mut self, dim: usize) -> Self {
130 self.input_dim = dim;
131 self
132 }
133
134 pub fn multi_modal(mut self, enable: bool) -> Self {
136 self.multi_modal = enable;
137 self
138 }
139
140 pub fn max_context_length(mut self, length: usize) -> Self {
142 self.max_context_length = length;
143 self
144 }
145
146 pub fn validate(&self) -> ModelResult<()> {
148 if self.hidden_dim == 0 {
149 return Err(ModelError::invalid_config("hidden_dim must be > 0"));
150 }
151 if self.num_layers == 0 {
152 return Err(ModelError::invalid_config("num_layers must be > 0"));
153 }
154 if self.num_heads == 0 {
155 return Err(ModelError::invalid_config("num_heads must be > 0"));
156 }
157 if !self.hidden_dim.is_multiple_of(self.num_heads) {
158 return Err(ModelError::invalid_config(
159 "hidden_dim must be divisible by num_heads",
160 ));
161 }
162 Ok(())
163 }
164}
165
166#[allow(dead_code)]
170struct EnhancedTimeMixing {
171 hidden_dim: usize,
172 num_heads: usize,
173 head_dim: usize,
174
175 time_decay: Array1<f32>,
177 time_first: Array1<f32>,
178
179 key_proj: Array2<f32>,
181 value_proj: Array2<f32>,
182 receptance_proj: Array2<f32>,
183 gate_proj: Array2<f32>, output_proj: Array2<f32>,
185
186 ln: LayerNorm,
188
189 state: Vec<Array1<f32>>,
191}
192
193impl EnhancedTimeMixing {
194 #[allow(dead_code)]
195 fn new(hidden_dim: usize, num_heads: usize) -> Self {
196 let mut rng = rng();
197 let head_dim = hidden_dim / num_heads;
198
199 let scale = (1.0 / hidden_dim as f32).sqrt();
200
201 let time_decay = Array1::from_shape_fn(hidden_dim, |i| {
203 let layer_idx = (i / head_dim) as f32;
204 -6.0 - layer_idx * 0.1 });
206
207 let time_first = Array1::from_shape_fn(hidden_dim, |_| (rng.random::<f32>() - 0.5) * 0.1);
208
209 let key_proj = Array2::from_shape_fn((hidden_dim, hidden_dim), |_| {
210 (rng.random::<f32>() - 0.5) * 2.0 * scale
211 });
212
213 let value_proj = Array2::from_shape_fn((hidden_dim, hidden_dim), |_| {
214 (rng.random::<f32>() - 0.5) * 2.0 * scale
215 });
216
217 let receptance_proj = Array2::from_shape_fn((hidden_dim, hidden_dim), |_| {
218 (rng.random::<f32>() - 0.5) * 2.0 * scale
219 });
220
221 let gate_proj = Array2::from_shape_fn((hidden_dim, hidden_dim), |_| {
222 (rng.random::<f32>() - 0.5) * 2.0 * scale
223 });
224
225 let output_proj = Array2::from_shape_fn((hidden_dim, hidden_dim), |_| {
226 (rng.random::<f32>() - 0.5) * 2.0 * scale
227 });
228
229 let ln = LayerNorm::new(hidden_dim, NormType::LayerNorm);
230
231 let state = (0..num_heads)
232 .map(|_| Array1::zeros(head_dim * 2))
233 .collect();
234
235 Self {
236 hidden_dim,
237 num_heads,
238 head_dim,
239 time_decay,
240 time_first,
241 key_proj,
242 value_proj,
243 receptance_proj,
244 gate_proj,
245 output_proj,
246 ln,
247 state,
248 }
249 }
250
251 #[allow(dead_code)]
252 fn forward(&mut self, x: &Array1<f32>) -> ModelResult<Array1<f32>> {
253 let normalized = self.ln.forward(x);
256 Ok(normalized)
257 }
258
259 #[allow(dead_code)]
260 fn reset(&mut self) {
261 for state in &mut self.state {
262 state.fill(0.0);
263 }
264 }
265}
266
267pub struct Rwkv7 {
269 config: Rwkv7Config,
270 input_proj: Array2<f32>,
272 output_proj: Array2<f32>,
273}
274
275impl Rwkv7 {
276 pub fn new(config: Rwkv7Config) -> ModelResult<Self> {
278 config.validate()?;
279
280 let mut rng = rng();
281 let scale = (1.0 / config.hidden_dim as f32).sqrt();
282
283 let input_proj = Array2::from_shape_fn((config.hidden_dim, config.input_dim), |_| {
284 (rng.random::<f32>() - 0.5) * 2.0 * scale
285 });
286
287 let output_proj = Array2::from_shape_fn((config.input_dim, config.hidden_dim), |_| {
288 (rng.random::<f32>() - 0.5) * 2.0 * scale
289 });
290
291 debug!(
292 "Created RWKV v7 model: {} layers, {} hidden_dim, {} heads (SCAFFOLDING)",
293 config.num_layers, config.hidden_dim, config.num_heads
294 );
295
296 Ok(Self {
297 config,
298 input_proj,
299 output_proj,
300 })
301 }
302
303 pub fn config(&self) -> &Rwkv7Config {
305 &self.config
306 }
307}
308
309impl SignalPredictor for Rwkv7 {
310 fn step(&mut self, input: &Array1<f32>) -> CoreResult<Array1<f32>> {
311 trace!("RWKV v7 forward pass (scaffolding)");
314
315 let hidden = self.input_proj.dot(input);
317 let output = self.output_proj.dot(&hidden);
318
319 Ok(output)
320 }
321
322 fn reset(&mut self) {
323 trace!("RWKV v7 reset state (scaffolding)");
324 }
326
327 fn context_window(&self) -> usize {
328 self.config.max_context_length
329 }
330}
331
332impl AutoregressiveModel for Rwkv7 {
333 fn hidden_dim(&self) -> usize {
334 self.config.hidden_dim
335 }
336
337 fn state_dim(&self) -> usize {
338 self.config.hidden_dim * self.config.num_layers
339 }
340
341 fn num_layers(&self) -> usize {
342 self.config.num_layers
343 }
344
345 fn model_type(&self) -> ModelType {
346 ModelType::Rwkv }
348
349 fn get_states(&self) -> Vec<HiddenState> {
350 vec![]
352 }
353
354 fn set_states(&mut self, _states: Vec<HiddenState>) -> ModelResult<()> {
355 Ok(())
357 }
358}
359
360#[cfg(test)]
361mod tests {
362 use super::*;
363
364 #[test]
365 fn test_rwkv7_config_creation() {
366 let config = Rwkv7Config::new();
367 assert_eq!(config.num_heads, 12);
368 assert_eq!(config.hidden_dim, 768);
369 }
370
371 #[test]
372 fn test_rwkv7_small_config() {
373 let config = Rwkv7Config::small(8);
374 assert_eq!(config.input_dim, 8);
375 assert_eq!(config.hidden_dim, 512);
376 assert_eq!(config.num_layers, 12);
377 }
378
379 #[test]
380 fn test_rwkv7_base_config() {
381 let config = Rwkv7Config::base(8);
382 assert_eq!(config.hidden_dim, 768);
383 assert_eq!(config.num_layers, 24);
384 }
385
386 #[test]
387 fn test_rwkv7_large_config() {
388 let config = Rwkv7Config::large(8);
389 assert_eq!(config.hidden_dim, 4096);
390 assert_eq!(config.num_layers, 32);
391 assert_eq!(config.max_context_length, 16384);
392 }
393
394 #[test]
395 fn test_rwkv7_model_creation() {
396 let config = Rwkv7Config::small(4);
397 let model = Rwkv7::new(config);
398 assert!(model.is_ok());
399 }
400
401 #[test]
402 fn test_rwkv7_forward_pass() {
403 let config = Rwkv7Config::small(4);
404 let mut model = Rwkv7::new(config).expect("Failed to create RWKV7 model");
405
406 let input = Array1::from_vec(vec![0.1, 0.2, 0.3, 0.4]);
407 let output = model.step(&input);
408
409 assert!(output.is_ok());
410 assert_eq!(output.expect("Failed to get output").len(), 4);
411 }
412
413 #[test]
414 fn test_rwkv7_multi_modal_config() {
415 let config = Rwkv7Config::base(8).multi_modal(true);
416 assert!(config.multi_modal);
417 }
418
419 #[test]
420 fn test_rwkv7_validation() {
421 let config = Rwkv7Config::new();
422 assert!(config.validate().is_ok());
423
424 let invalid_config = Rwkv7Config {
425 hidden_dim: 0,
426 ..Rwkv7Config::default()
427 };
428 assert!(invalid_config.validate().is_err());
429 }
430}