1use crate::error::{CoreError, CoreResult};
37use crate::nn::{silu, LayerNorm, NormType};
38use scirs2_core::ndarray::{Array1, Array2, Array3, Axis};
39use scirs2_core::random::thread_rng;
40
41#[derive(Debug, Clone)]
43pub struct RetNetConfig {
44 pub hidden_dim: usize,
46 pub num_heads: usize,
48 pub head_dim: usize,
50 pub ffn_dim: usize,
52 pub num_layers: usize,
54 pub dropout: f32,
56}
57
58impl RetNetConfig {
59 pub fn new(hidden_dim: usize, num_heads: usize, num_layers: usize) -> CoreResult<Self> {
61 if !hidden_dim.is_multiple_of(num_heads) {
62 return Err(CoreError::InvalidConfig(format!(
63 "hidden_dim ({}) must be divisible by num_heads ({})",
64 hidden_dim, num_heads
65 )));
66 }
67
68 Ok(Self {
69 hidden_dim,
70 num_heads,
71 head_dim: hidden_dim / num_heads,
72 ffn_dim: hidden_dim * 4, num_layers,
74 dropout: 0.0,
75 })
76 }
77
78 pub fn ffn_dim(mut self, dim: usize) -> Self {
80 self.ffn_dim = dim;
81 self
82 }
83
84 pub fn dropout(mut self, rate: f32) -> Self {
86 self.dropout = rate;
87 self
88 }
89}
90
91#[derive(Debug)]
95pub struct MultiScaleRetention {
96 config: RetNetConfig,
97 w_q: Array2<f32>,
99 w_k: Array2<f32>,
100 w_v: Array2<f32>,
101 w_o: Array2<f32>,
103 gamma: Array1<f32>,
105 group_norm: LayerNorm,
107}
108
109impl MultiScaleRetention {
110 pub fn new(config: RetNetConfig) -> CoreResult<Self> {
112 let hidden_dim = config.hidden_dim;
113 let num_heads = config.num_heads;
114 let mut rng = thread_rng();
115 let scale = (1.0 / hidden_dim as f32).sqrt();
116
117 let w_q = Array2::from_shape_fn((hidden_dim, hidden_dim), |_| {
119 (rng.random::<f32>() - 0.5) * 2.0 * scale
120 });
121 let w_k = Array2::from_shape_fn((hidden_dim, hidden_dim), |_| {
122 (rng.random::<f32>() - 0.5) * 2.0 * scale
123 });
124 let w_v = Array2::from_shape_fn((hidden_dim, hidden_dim), |_| {
125 (rng.random::<f32>() - 0.5) * 2.0 * scale
126 });
127 let w_o = Array2::from_shape_fn((hidden_dim, hidden_dim), |_| {
128 (rng.random::<f32>() - 0.5) * 2.0 * scale
129 });
130
131 let gamma = Array1::from_shape_fn(num_heads, |h| {
134 let exponent = -(5.0 + h as f32);
135 1.0 - 2.0_f32.powf(exponent)
136 });
137
138 let group_norm = LayerNorm::new(hidden_dim, NormType::RMSNorm);
140
141 Ok(Self {
142 config,
143 w_q,
144 w_k,
145 w_v,
146 w_o,
147 gamma,
148 group_norm,
149 })
150 }
151
152 pub fn step(&self, input: &Array1<f32>, state: &mut Array3<f32>) -> CoreResult<Array1<f32>> {
157 let num_heads = self.config.num_heads;
158 let head_dim = self.config.head_dim;
159
160 let q = input.dot(&self.w_q);
162 let k = input.dot(&self.w_k);
163 let v = input.dot(&self.w_v);
164
165 let mut output = Array1::zeros(self.config.hidden_dim);
166
167 for h in 0..num_heads {
169 let start = h * head_dim;
170 let end = start + head_dim;
171
172 let q_h = q.slice(s![start..end]);
173 let k_h = k.slice(s![start..end]);
174 let v_h = v.slice(s![start..end]);
175
176 let mut s_h = state.index_axis_mut(Axis(0), h);
178
179 let gamma_h = self.gamma[h];
181 for i in 0..head_dim {
182 for j in 0..head_dim {
183 s_h[[i, j]] *= gamma_h;
184 }
185 }
186
187 for i in 0..head_dim {
189 for j in 0..head_dim {
190 s_h[[i, j]] += k_h[i] * v_h[j];
191 }
192 }
193
194 for j in 0..head_dim {
196 let mut sum = 0.0;
197 for i in 0..head_dim {
198 sum += q_h[i] * s_h[[i, j]];
199 }
200 output[start + j] = sum;
201 }
202 }
203
204 let normed = self.group_norm.forward(&output);
206
207 let output_proj = normed.dot(&self.w_o);
209 let activated = silu(&output_proj);
210
211 Ok(activated)
212 }
213
214 pub fn forward_sequence(&self, input: &Array2<f32>) -> CoreResult<Array2<f32>> {
219 let (seq_len, _) = input.dim();
220
221 let mut output = Array2::zeros((seq_len, self.config.hidden_dim));
222 let mut state = self.reset_state();
223
224 for t in 0..seq_len {
226 let x_t = input.row(t).to_owned();
227 let y_t = self.step(&x_t, &mut state)?;
228 output.row_mut(t).assign(&y_t);
229 }
230
231 Ok(output)
232 }
233
234 pub fn reset_state(&self) -> Array3<f32> {
236 Array3::zeros((
237 self.config.num_heads,
238 self.config.head_dim,
239 self.config.head_dim,
240 ))
241 }
242
243 pub fn num_parameters(&self) -> usize {
245 self.w_q.len() + self.w_k.len() + self.w_v.len() + self.w_o.len() + self.gamma.len()
246 }
247}
248
249#[derive(Debug)]
251pub struct RetNetFFN {
252 w1: Array2<f32>,
253 w2: Array2<f32>,
254 layer_norm: LayerNorm,
255}
256
257impl RetNetFFN {
258 pub fn new(hidden_dim: usize, ffn_dim: usize) -> Self {
260 let mut rng = thread_rng();
261 let scale1 = (1.0 / hidden_dim as f32).sqrt();
262 let scale2 = (1.0 / ffn_dim as f32).sqrt();
263
264 let w1 = Array2::from_shape_fn((hidden_dim, ffn_dim), |_| {
265 (rng.random::<f32>() - 0.5) * 2.0 * scale1
266 });
267 let w2 = Array2::from_shape_fn((ffn_dim, hidden_dim), |_| {
268 (rng.random::<f32>() - 0.5) * 2.0 * scale2
269 });
270
271 let layer_norm = LayerNorm::new(hidden_dim, NormType::RMSNorm);
272
273 Self { w1, w2, layer_norm }
274 }
275
276 pub fn forward(&self, input: &Array1<f32>) -> CoreResult<Array1<f32>> {
278 let normed = self.layer_norm.forward(input);
280
281 let hidden = normed.dot(&self.w1);
283 let activated = silu(&hidden);
284 let output = activated.dot(&self.w2);
285
286 Ok(output)
287 }
288}
289
290#[derive(Debug)]
294pub struct RetNetLayer {
295 retention: MultiScaleRetention,
296 ffn: RetNetFFN,
297}
298
299impl RetNetLayer {
300 pub fn new(config: RetNetConfig) -> CoreResult<Self> {
302 let retention = MultiScaleRetention::new(config.clone())?;
303 let ffn = RetNetFFN::new(config.hidden_dim, config.ffn_dim);
304
305 Ok(Self { retention, ffn })
306 }
307
308 pub fn step(&self, input: &Array1<f32>, state: &mut Array3<f32>) -> CoreResult<Array1<f32>> {
310 let retention_out = self.retention.step(input, state)?;
312 let after_retention = input + &retention_out;
313
314 let ffn_out = self.ffn.forward(&after_retention)?;
316 let output = &after_retention + &ffn_out;
317
318 Ok(output)
319 }
320
321 pub fn forward_sequence(&self, input: &Array2<f32>) -> CoreResult<Array2<f32>> {
323 let (seq_len, _) = input.dim();
324 let mut output = Array2::zeros(input.dim());
325 let mut state = self.retention.reset_state();
326
327 for t in 0..seq_len {
328 let x_t = input.row(t).to_owned();
329 let y_t = self.step(&x_t, &mut state)?;
330 output.row_mut(t).assign(&y_t);
331 }
332
333 Ok(output)
334 }
335
336 pub fn reset_state(&self) -> Array3<f32> {
338 self.retention.reset_state()
339 }
340}
341
342#[derive(Debug)]
344pub struct RetNetModel {
345 layers: Vec<RetNetLayer>,
346 config: RetNetConfig,
347}
348
349impl RetNetModel {
350 pub fn new(config: RetNetConfig) -> CoreResult<Self> {
352 let num_layers = config.num_layers;
353 let mut layers = Vec::with_capacity(num_layers);
354
355 for _ in 0..num_layers {
356 layers.push(RetNetLayer::new(config.clone())?);
357 }
358
359 Ok(Self { layers, config })
360 }
361
362 pub fn step(&self, input: &Array1<f32>, states: &mut [Array3<f32>]) -> CoreResult<Array1<f32>> {
364 if states.len() != self.config.num_layers {
365 return Err(CoreError::InvalidConfig(format!(
366 "Expected {} states, got {}",
367 self.config.num_layers,
368 states.len()
369 )));
370 }
371
372 let mut x = input.clone();
373 for (i, layer) in self.layers.iter().enumerate() {
374 x = layer.step(&x, &mut states[i])?;
375 }
376
377 Ok(x)
378 }
379
380 pub fn forward(&self, input: &Array2<f32>) -> CoreResult<Array2<f32>> {
382 let mut x = input.clone();
383
384 for layer in &self.layers {
385 x = layer.forward_sequence(&x)?;
386 }
387
388 Ok(x)
389 }
390
391 pub fn reset_states(&self) -> Vec<Array3<f32>> {
393 self.layers
394 .iter()
395 .map(|layer| layer.reset_state())
396 .collect()
397 }
398
399 pub fn num_parameters(&self) -> usize {
401 self.layers
402 .iter()
403 .map(|layer| layer.retention.num_parameters() + layer.ffn.w1.len() + layer.ffn.w2.len())
404 .sum()
405 }
406}
407
408use scirs2_core::ndarray::s;
410
411#[cfg(test)]
412mod tests {
413 use super::*;
414
415 #[test]
416 fn test_retnet_config() {
417 let config = RetNetConfig::new(256, 4, 6).unwrap();
418 assert_eq!(config.hidden_dim, 256);
419 assert_eq!(config.num_heads, 4);
420 assert_eq!(config.head_dim, 64);
421 assert_eq!(config.num_layers, 6);
422 }
423
424 #[test]
425 fn test_multi_scale_retention() {
426 let config = RetNetConfig::new(128, 4, 2).unwrap();
427 let msr = MultiScaleRetention::new(config).unwrap();
428
429 let input = Array1::from_vec(vec![0.1; 128]);
430 let mut state = msr.reset_state();
431
432 let output = msr.step(&input, &mut state).unwrap();
433 assert_eq!(output.len(), 128);
434
435 assert!(state.iter().any(|&x| x != 0.0));
437 }
438
439 #[test]
440 fn test_retnet_layer() {
441 let config = RetNetConfig::new(128, 4, 2).unwrap();
442 let layer = RetNetLayer::new(config).unwrap();
443
444 let input = Array1::from_vec(vec![0.1; 128]);
445 let mut state = layer.reset_state();
446
447 let output = layer.step(&input, &mut state).unwrap();
448 assert_eq!(output.len(), 128);
449 }
450
451 #[test]
452 fn test_retnet_model() {
453 let config = RetNetConfig::new(64, 2, 3).unwrap();
454 let model = RetNetModel::new(config).unwrap();
455
456 let seq_len = 10;
457 let input = Array2::from_shape_vec((seq_len, 64), vec![0.1; seq_len * 64]).unwrap();
458
459 let output = model.forward(&input).unwrap();
460 assert_eq!(output.dim(), (seq_len, 64));
461 }
462
463 #[test]
464 fn test_retnet_inference() {
465 let config = RetNetConfig::new(64, 2, 2).unwrap();
466 let model = RetNetModel::new(config).unwrap();
467
468 let mut states = model.reset_states();
469 let input = Array1::from_vec(vec![0.1; 64]);
470
471 for _ in 0..5 {
473 let output = model.step(&input, &mut states).unwrap();
474 assert_eq!(output.len(), 64);
475 }
476 }
477
478 #[test]
479 fn test_gamma_values() {
480 let config = RetNetConfig::new(128, 4, 2).unwrap();
481 let msr = MultiScaleRetention::new(config).unwrap();
482
483 for &gamma in msr.gamma.iter() {
485 assert!(gamma > 0.0 && gamma < 1.0);
486 }
487
488 for i in 1..msr.gamma.len() {
490 assert!(msr.gamma[i] >= msr.gamma[i - 1]);
491 }
492 }
493}