kizzasi_core/
retnet.rs

1//! RetNet: Retention Networks for Multi-Scale Sequence Modeling
2//!
3//! RetNet replaces traditional attention with a retention mechanism that provides:
4//! - **O(1)** inference complexity (like RNNs)
5//! - **Parallel training** (like Transformers)
6//! - **Multi-scale temporal modeling** via multiple retention heads
7//! - **Linear memory complexity**
8//!
9//! # Architecture
10//!
11//! RetNet uses Multi-Scale Retention (MSR) which applies retention at different scales:
12//! ```text
13//! Input → [GroupNorm] → [MSRetention] → [FFN] → Output
14//!                          ↓
15//!                       [State]
16//! ```
17//!
18//! # Retention Mechanism
19//!
20//! The retention mechanism for head h:
21//! ```text
22//! Q = X W_Q,  K = X W_K,  V = X W_V
23//! Retention = (Q K^T ⊙ D) V
24//! ```
25//!
26//! Where D is a causal decay matrix: D[i,j] = γ^(i-j) for i >= j
27//! γ is the decay factor (different per head for multi-scale)
28//!
29//! # Recurrent Form (O(1) inference)
30//!
31//! ```text
32//! S_t = γ S_{t-1} + K_t^T V_t
33//! O_t = Q_t S_t
34//! ```
35
36use crate::error::{CoreError, CoreResult};
37use crate::nn::{silu, LayerNorm, NormType};
38use scirs2_core::ndarray::{Array1, Array2, Array3, Axis};
39use scirs2_core::random::thread_rng;
40
41/// Multi-Scale Retention Configuration
42#[derive(Debug, Clone)]
43pub struct RetNetConfig {
44    /// Model dimension
45    pub hidden_dim: usize,
46    /// Number of retention heads
47    pub num_heads: usize,
48    /// Head dimension (hidden_dim / num_heads)
49    pub head_dim: usize,
50    /// FFN expansion factor
51    pub ffn_dim: usize,
52    /// Number of layers
53    pub num_layers: usize,
54    /// Dropout rate
55    pub dropout: f32,
56}
57
58impl RetNetConfig {
59    /// Create a new RetNet configuration
60    pub fn new(hidden_dim: usize, num_heads: usize, num_layers: usize) -> CoreResult<Self> {
61        if !hidden_dim.is_multiple_of(num_heads) {
62            return Err(CoreError::InvalidConfig(format!(
63                "hidden_dim ({}) must be divisible by num_heads ({})",
64                hidden_dim, num_heads
65            )));
66        }
67
68        Ok(Self {
69            hidden_dim,
70            num_heads,
71            head_dim: hidden_dim / num_heads,
72            ffn_dim: hidden_dim * 4, // Standard 4x expansion
73            num_layers,
74            dropout: 0.0,
75        })
76    }
77
78    /// Set FFN dimension
79    pub fn ffn_dim(mut self, dim: usize) -> Self {
80        self.ffn_dim = dim;
81        self
82    }
83
84    /// Set dropout rate
85    pub fn dropout(mut self, rate: f32) -> Self {
86        self.dropout = rate;
87        self
88    }
89}
90
91/// Multi-Scale Retention (MSR) Module
92///
93/// Implements retention mechanism with multiple heads, each operating at different scales
94#[derive(Debug)]
95pub struct MultiScaleRetention {
96    config: RetNetConfig,
97    // QKV projections
98    w_q: Array2<f32>,
99    w_k: Array2<f32>,
100    w_v: Array2<f32>,
101    // Output projection
102    w_o: Array2<f32>,
103    // Decay factors (gamma) for each head
104    gamma: Array1<f32>,
105    // Group norm
106    group_norm: LayerNorm,
107}
108
109impl MultiScaleRetention {
110    /// Create a new multi-scale retention module
111    pub fn new(config: RetNetConfig) -> CoreResult<Self> {
112        let hidden_dim = config.hidden_dim;
113        let num_heads = config.num_heads;
114        let mut rng = thread_rng();
115        let scale = (1.0 / hidden_dim as f32).sqrt();
116
117        // Initialize QKV projections
118        let w_q = Array2::from_shape_fn((hidden_dim, hidden_dim), |_| {
119            (rng.random::<f32>() - 0.5) * 2.0 * scale
120        });
121        let w_k = Array2::from_shape_fn((hidden_dim, hidden_dim), |_| {
122            (rng.random::<f32>() - 0.5) * 2.0 * scale
123        });
124        let w_v = Array2::from_shape_fn((hidden_dim, hidden_dim), |_| {
125            (rng.random::<f32>() - 0.5) * 2.0 * scale
126        });
127        let w_o = Array2::from_shape_fn((hidden_dim, hidden_dim), |_| {
128            (rng.random::<f32>() - 0.5) * 2.0 * scale
129        });
130
131        // Initialize decay factors (gamma) for multi-scale
132        // Each head has different decay: γ_h = 1 - 2^(-5-h) for h = 0..H-1
133        let gamma = Array1::from_shape_fn(num_heads, |h| {
134            let exponent = -(5.0 + h as f32);
135            1.0 - 2.0_f32.powf(exponent)
136        });
137
138        // Group normalization (using RMSNorm for efficiency)
139        let group_norm = LayerNorm::new(hidden_dim, NormType::RMSNorm);
140
141        Ok(Self {
142            config,
143            w_q,
144            w_k,
145            w_v,
146            w_o,
147            gamma,
148            group_norm,
149        })
150    }
151
152    /// Recurrent forward step (O(1) inference)
153    ///
154    /// Updates retention state and computes output
155    /// State: (num_heads, head_dim, head_dim)
156    pub fn step(&self, input: &Array1<f32>, state: &mut Array3<f32>) -> CoreResult<Array1<f32>> {
157        let num_heads = self.config.num_heads;
158        let head_dim = self.config.head_dim;
159
160        // Project to Q, K, V
161        let q = input.dot(&self.w_q);
162        let k = input.dot(&self.w_k);
163        let v = input.dot(&self.w_v);
164
165        let mut output = Array1::zeros(self.config.hidden_dim);
166
167        // Process each head
168        for h in 0..num_heads {
169            let start = h * head_dim;
170            let end = start + head_dim;
171
172            let q_h = q.slice(s![start..end]);
173            let k_h = k.slice(s![start..end]);
174            let v_h = v.slice(s![start..end]);
175
176            // Get state for this head
177            let mut s_h = state.index_axis_mut(Axis(0), h);
178
179            // Decay previous state: S_t = γ S_{t-1}
180            let gamma_h = self.gamma[h];
181            for i in 0..head_dim {
182                for j in 0..head_dim {
183                    s_h[[i, j]] *= gamma_h;
184                }
185            }
186
187            // Add new contribution: S_t += K_t^T V_t (outer product)
188            for i in 0..head_dim {
189                for j in 0..head_dim {
190                    s_h[[i, j]] += k_h[i] * v_h[j];
191                }
192            }
193
194            // Output: O_t = Q_t S_t
195            for j in 0..head_dim {
196                let mut sum = 0.0;
197                for i in 0..head_dim {
198                    sum += q_h[i] * s_h[[i, j]];
199                }
200                output[start + j] = sum;
201            }
202        }
203
204        // Group normalization
205        let normed = self.group_norm.forward(&output);
206
207        // Output projection with SiLU activation
208        let output_proj = normed.dot(&self.w_o);
209        let activated = silu(&output_proj);
210
211        Ok(activated)
212    }
213
214    /// Parallel forward for sequence (training mode)
215    ///
216    /// Input shape: (seq_len, hidden_dim)
217    /// Output shape: (seq_len, hidden_dim)
218    pub fn forward_sequence(&self, input: &Array2<f32>) -> CoreResult<Array2<f32>> {
219        let (seq_len, _) = input.dim();
220
221        let mut output = Array2::zeros((seq_len, self.config.hidden_dim));
222        let mut state = self.reset_state();
223
224        // Process sequence step by step
225        for t in 0..seq_len {
226            let x_t = input.row(t).to_owned();
227            let y_t = self.step(&x_t, &mut state)?;
228            output.row_mut(t).assign(&y_t);
229        }
230
231        Ok(output)
232    }
233
234    /// Reset retention state
235    pub fn reset_state(&self) -> Array3<f32> {
236        Array3::zeros((
237            self.config.num_heads,
238            self.config.head_dim,
239            self.config.head_dim,
240        ))
241    }
242
243    /// Get number of parameters
244    pub fn num_parameters(&self) -> usize {
245        self.w_q.len() + self.w_k.len() + self.w_v.len() + self.w_o.len() + self.gamma.len()
246    }
247}
248
249/// Feed-Forward Network for RetNet
250#[derive(Debug)]
251pub struct RetNetFFN {
252    w1: Array2<f32>,
253    w2: Array2<f32>,
254    layer_norm: LayerNorm,
255}
256
257impl RetNetFFN {
258    /// Create a new FFN
259    pub fn new(hidden_dim: usize, ffn_dim: usize) -> Self {
260        let mut rng = thread_rng();
261        let scale1 = (1.0 / hidden_dim as f32).sqrt();
262        let scale2 = (1.0 / ffn_dim as f32).sqrt();
263
264        let w1 = Array2::from_shape_fn((hidden_dim, ffn_dim), |_| {
265            (rng.random::<f32>() - 0.5) * 2.0 * scale1
266        });
267        let w2 = Array2::from_shape_fn((ffn_dim, hidden_dim), |_| {
268            (rng.random::<f32>() - 0.5) * 2.0 * scale2
269        });
270
271        let layer_norm = LayerNorm::new(hidden_dim, NormType::RMSNorm);
272
273        Self { w1, w2, layer_norm }
274    }
275
276    /// Forward pass
277    pub fn forward(&self, input: &Array1<f32>) -> CoreResult<Array1<f32>> {
278        // Layer norm
279        let normed = self.layer_norm.forward(input);
280
281        // FFN with SiLU activation: SiLU(x W1) W2
282        let hidden = normed.dot(&self.w1);
283        let activated = silu(&hidden);
284        let output = activated.dot(&self.w2);
285
286        Ok(output)
287    }
288}
289
290/// RetNet Layer
291///
292/// Combines Multi-Scale Retention and FFN with residual connections
293#[derive(Debug)]
294pub struct RetNetLayer {
295    retention: MultiScaleRetention,
296    ffn: RetNetFFN,
297}
298
299impl RetNetLayer {
300    /// Create a new RetNet layer
301    pub fn new(config: RetNetConfig) -> CoreResult<Self> {
302        let retention = MultiScaleRetention::new(config.clone())?;
303        let ffn = RetNetFFN::new(config.hidden_dim, config.ffn_dim);
304
305        Ok(Self { retention, ffn })
306    }
307
308    /// Forward step with residual connections
309    pub fn step(&self, input: &Array1<f32>, state: &mut Array3<f32>) -> CoreResult<Array1<f32>> {
310        // Retention with residual
311        let retention_out = self.retention.step(input, state)?;
312        let after_retention = input + &retention_out;
313
314        // FFN with residual
315        let ffn_out = self.ffn.forward(&after_retention)?;
316        let output = &after_retention + &ffn_out;
317
318        Ok(output)
319    }
320
321    /// Forward sequence
322    pub fn forward_sequence(&self, input: &Array2<f32>) -> CoreResult<Array2<f32>> {
323        let (seq_len, _) = input.dim();
324        let mut output = Array2::zeros(input.dim());
325        let mut state = self.retention.reset_state();
326
327        for t in 0..seq_len {
328            let x_t = input.row(t).to_owned();
329            let y_t = self.step(&x_t, &mut state)?;
330            output.row_mut(t).assign(&y_t);
331        }
332
333        Ok(output)
334    }
335
336    /// Reset state
337    pub fn reset_state(&self) -> Array3<f32> {
338        self.retention.reset_state()
339    }
340}
341
342/// Multi-layer RetNet Model
343#[derive(Debug)]
344pub struct RetNetModel {
345    layers: Vec<RetNetLayer>,
346    config: RetNetConfig,
347}
348
349impl RetNetModel {
350    /// Create a new multi-layer RetNet model
351    pub fn new(config: RetNetConfig) -> CoreResult<Self> {
352        let num_layers = config.num_layers;
353        let mut layers = Vec::with_capacity(num_layers);
354
355        for _ in 0..num_layers {
356            layers.push(RetNetLayer::new(config.clone())?);
357        }
358
359        Ok(Self { layers, config })
360    }
361
362    /// Single step inference
363    pub fn step(&self, input: &Array1<f32>, states: &mut [Array3<f32>]) -> CoreResult<Array1<f32>> {
364        if states.len() != self.config.num_layers {
365            return Err(CoreError::InvalidConfig(format!(
366                "Expected {} states, got {}",
367                self.config.num_layers,
368                states.len()
369            )));
370        }
371
372        let mut x = input.clone();
373        for (i, layer) in self.layers.iter().enumerate() {
374            x = layer.step(&x, &mut states[i])?;
375        }
376
377        Ok(x)
378    }
379
380    /// Forward pass for sequence
381    pub fn forward(&self, input: &Array2<f32>) -> CoreResult<Array2<f32>> {
382        let mut x = input.clone();
383
384        for layer in &self.layers {
385            x = layer.forward_sequence(&x)?;
386        }
387
388        Ok(x)
389    }
390
391    /// Reset all states
392    pub fn reset_states(&self) -> Vec<Array3<f32>> {
393        self.layers
394            .iter()
395            .map(|layer| layer.reset_state())
396            .collect()
397    }
398
399    /// Get total number of parameters
400    pub fn num_parameters(&self) -> usize {
401        self.layers
402            .iter()
403            .map(|layer| layer.retention.num_parameters() + layer.ffn.w1.len() + layer.ffn.w2.len())
404            .sum()
405    }
406}
407
408// Re-export slice macro
409use scirs2_core::ndarray::s;
410
411#[cfg(test)]
412mod tests {
413    use super::*;
414
415    #[test]
416    fn test_retnet_config() {
417        let config = RetNetConfig::new(256, 4, 6).unwrap();
418        assert_eq!(config.hidden_dim, 256);
419        assert_eq!(config.num_heads, 4);
420        assert_eq!(config.head_dim, 64);
421        assert_eq!(config.num_layers, 6);
422    }
423
424    #[test]
425    fn test_multi_scale_retention() {
426        let config = RetNetConfig::new(128, 4, 2).unwrap();
427        let msr = MultiScaleRetention::new(config).unwrap();
428
429        let input = Array1::from_vec(vec![0.1; 128]);
430        let mut state = msr.reset_state();
431
432        let output = msr.step(&input, &mut state).unwrap();
433        assert_eq!(output.len(), 128);
434
435        // Check that state has been updated
436        assert!(state.iter().any(|&x| x != 0.0));
437    }
438
439    #[test]
440    fn test_retnet_layer() {
441        let config = RetNetConfig::new(128, 4, 2).unwrap();
442        let layer = RetNetLayer::new(config).unwrap();
443
444        let input = Array1::from_vec(vec![0.1; 128]);
445        let mut state = layer.reset_state();
446
447        let output = layer.step(&input, &mut state).unwrap();
448        assert_eq!(output.len(), 128);
449    }
450
451    #[test]
452    fn test_retnet_model() {
453        let config = RetNetConfig::new(64, 2, 3).unwrap();
454        let model = RetNetModel::new(config).unwrap();
455
456        let seq_len = 10;
457        let input = Array2::from_shape_vec((seq_len, 64), vec![0.1; seq_len * 64]).unwrap();
458
459        let output = model.forward(&input).unwrap();
460        assert_eq!(output.dim(), (seq_len, 64));
461    }
462
463    #[test]
464    fn test_retnet_inference() {
465        let config = RetNetConfig::new(64, 2, 2).unwrap();
466        let model = RetNetModel::new(config).unwrap();
467
468        let mut states = model.reset_states();
469        let input = Array1::from_vec(vec![0.1; 64]);
470
471        // Process multiple steps
472        for _ in 0..5 {
473            let output = model.step(&input, &mut states).unwrap();
474            assert_eq!(output.len(), 64);
475        }
476    }
477
478    #[test]
479    fn test_gamma_values() {
480        let config = RetNetConfig::new(128, 4, 2).unwrap();
481        let msr = MultiScaleRetention::new(config).unwrap();
482
483        // Check that gamma values are in valid range (0, 1)
484        for &gamma in msr.gamma.iter() {
485            assert!(gamma > 0.0 && gamma < 1.0);
486        }
487
488        // Check that gammas are decreasing (larger heads have smaller decay)
489        for i in 1..msr.gamma.len() {
490            assert!(msr.gamma[i] >= msr.gamma[i - 1]);
491        }
492    }
493}