kizzasi_core/
conv.rs

1//! Causal convolution implementations for SSM architectures
2//!
3//! Causal convolutions are essential for autoregressive models as they
4//! ensure the output at time t only depends on inputs at times <= t.
5
6use scirs2_core::ndarray::Array1;
7
8/// 1D Causal Convolution Layer
9///
10/// Implements a causal (left-padded) convolution that preserves causality
11/// for autoregressive inference. Used as the input projection in Mamba.
12#[derive(Debug, Clone)]
13pub struct CausalConv1d {
14    /// Convolution kernel weights [out_channels, in_channels, kernel_size]
15    weights: Vec<Vec<Vec<f32>>>,
16    /// Bias per output channel
17    bias: Vec<f32>,
18    /// Kernel size (also determines padding)
19    kernel_size: usize,
20    /// Input channels
21    in_channels: usize,
22    /// Output channels
23    out_channels: usize,
24    /// Ring buffer for causal history
25    history: Vec<Vec<f32>>,
26}
27
28impl CausalConv1d {
29    /// Create a new causal convolution
30    pub fn new(in_channels: usize, out_channels: usize, kernel_size: usize) -> Self {
31        // Initialize weights with small random values using Kaiming init
32        let scale = (2.0 / (in_channels * kernel_size) as f32).sqrt();
33        let mut weights = Vec::with_capacity(out_channels);
34
35        for _ in 0..out_channels {
36            let mut out_ch = Vec::with_capacity(in_channels);
37            for _ in 0..in_channels {
38                let kernel: Vec<f32> = (0..kernel_size)
39                    .map(|i| {
40                        // Simple deterministic initialization
41                        (i as f32 * 0.1).sin() * scale
42                    })
43                    .collect();
44                out_ch.push(kernel);
45            }
46            weights.push(out_ch);
47        }
48
49        let bias = vec![0.0; out_channels];
50
51        // Initialize history buffer with zeros (kernel_size - 1 frames needed for causality)
52        let history: Vec<Vec<f32>> = (0..(kernel_size - 1))
53            .map(|_| vec![0.0; in_channels])
54            .collect();
55
56        Self {
57            weights,
58            bias,
59            kernel_size,
60            in_channels,
61            out_channels,
62            history,
63        }
64    }
65
66    /// Set weights from external source
67    pub fn set_weights(&mut self, weights: Vec<Vec<Vec<f32>>>) {
68        assert_eq!(weights.len(), self.out_channels);
69        for oc in &weights {
70            assert_eq!(oc.len(), self.in_channels);
71            for ic in oc {
72                assert_eq!(ic.len(), self.kernel_size);
73            }
74        }
75        self.weights = weights;
76    }
77
78    /// Set bias from external source
79    pub fn set_bias(&mut self, bias: Vec<f32>) {
80        assert_eq!(bias.len(), self.out_channels);
81        self.bias = bias;
82    }
83
84    /// Forward pass for a single time step (streaming/causal)
85    ///
86    /// Takes input of shape [in_channels] and returns output of shape [out_channels]
87    pub fn forward_step(&mut self, input: &[f32]) -> Vec<f32> {
88        assert_eq!(input.len(), self.in_channels);
89
90        // Add current input to history
91        self.history.push(input.to_vec());
92
93        // Keep only kernel_size frames
94        while self.history.len() > self.kernel_size {
95            self.history.remove(0);
96        }
97
98        // Compute convolution output
99        let mut output = self.bias.clone();
100
101        for (oc, out_weights) in self.weights.iter().enumerate() {
102            for (ic, in_weights) in out_weights.iter().enumerate() {
103                for (k, &weight) in in_weights.iter().enumerate() {
104                    // For causal conv, we use history[0..kernel_size]
105                    // where history[kernel_size-1] is the current input
106                    if k < self.history.len() {
107                        let hist_idx = self.history.len() - 1 - k;
108                        output[oc] += weight * self.history[hist_idx][ic];
109                    }
110                }
111            }
112        }
113
114        output
115    }
116
117    /// Forward pass for a batch of time steps
118    ///
119    /// Input shape: [time, in_channels]
120    /// Output shape: [time, out_channels]
121    pub fn forward_batch(&mut self, input: &[Vec<f32>]) -> Vec<Vec<f32>> {
122        input.iter().map(|x| self.forward_step(x)).collect()
123    }
124
125    /// Reset the history buffer
126    pub fn reset(&mut self) {
127        for h in &mut self.history {
128            h.fill(0.0);
129        }
130    }
131
132    /// Get the current history buffer state
133    /// Returns kernel_size - 1 frames (excluding current input if present)
134    pub fn get_history(&self) -> Vec<Vec<f32>> {
135        // If history has kernel_size elements (after forward_step but before next call),
136        // we only save the first kernel_size - 1 elements
137        let expected_len = self.kernel_size - 1;
138        if self.history.len() >= expected_len {
139            self.history[..expected_len].to_vec()
140        } else {
141            self.history.clone()
142        }
143    }
144
145    /// Set the history buffer state
146    pub fn set_history(&mut self, history: Vec<Vec<f32>>) {
147        assert_eq!(
148            history.len(),
149            self.kernel_size - 1,
150            "History length must be kernel_size - 1 = {}",
151            self.kernel_size - 1
152        );
153        for h in &history {
154            assert_eq!(
155                h.len(),
156                self.in_channels,
157                "Each history frame must have in_channels = {} elements",
158                self.in_channels
159            );
160        }
161        self.history = history;
162    }
163
164    /// Get kernel size
165    pub fn kernel_size(&self) -> usize {
166        self.kernel_size
167    }
168
169    /// Get input channels
170    pub fn in_channels(&self) -> usize {
171        self.in_channels
172    }
173
174    /// Get output channels
175    pub fn out_channels(&self) -> usize {
176        self.out_channels
177    }
178}
179
180/// Depthwise Causal Convolution (used in Mamba)
181///
182/// Each input channel has its own kernel (groups = in_channels).
183/// More efficient than standard convolution for SSM preprocessing.
184#[derive(Debug, Clone)]
185pub struct DepthwiseCausalConv1d {
186    /// Kernel weights [channels, kernel_size]
187    weights: Vec<Vec<f32>>,
188    /// Bias per channel
189    bias: Vec<f32>,
190    /// Kernel size
191    kernel_size: usize,
192    /// Number of channels
193    channels: usize,
194    /// Ring buffer for causal history [kernel_size - 1, channels]
195    history: Vec<Vec<f32>>,
196}
197
198impl DepthwiseCausalConv1d {
199    /// Create a new depthwise causal convolution
200    pub fn new(channels: usize, kernel_size: usize) -> Self {
201        let scale = (2.0 / kernel_size as f32).sqrt();
202        let weights: Vec<Vec<f32>> = (0..channels)
203            .map(|c| {
204                (0..kernel_size)
205                    .map(|k| ((c + k) as f32 * 0.1).sin() * scale)
206                    .collect()
207            })
208            .collect();
209
210        let bias = vec![0.0; channels];
211        let history: Vec<Vec<f32>> = (0..(kernel_size - 1))
212            .map(|_| vec![0.0; channels])
213            .collect();
214
215        Self {
216            weights,
217            bias,
218            kernel_size,
219            channels,
220            history,
221        }
222    }
223
224    /// Set weights
225    pub fn set_weights(&mut self, weights: Vec<Vec<f32>>) {
226        assert_eq!(weights.len(), self.channels);
227        for w in &weights {
228            assert_eq!(w.len(), self.kernel_size);
229        }
230        self.weights = weights;
231    }
232
233    /// Set bias
234    pub fn set_bias(&mut self, bias: Vec<f32>) {
235        assert_eq!(bias.len(), self.channels);
236        self.bias = bias;
237    }
238
239    /// Forward pass for single time step
240    pub fn forward_step(&mut self, input: &[f32]) -> Vec<f32> {
241        assert_eq!(input.len(), self.channels);
242
243        self.history.push(input.to_vec());
244        while self.history.len() > self.kernel_size {
245            self.history.remove(0);
246        }
247
248        let mut output = self.bias.clone();
249
250        for (c, kernel) in self.weights.iter().enumerate() {
251            for (k, &weight) in kernel.iter().enumerate() {
252                if k < self.history.len() {
253                    let hist_idx = self.history.len() - 1 - k;
254                    output[c] += weight * self.history[hist_idx][c];
255                }
256            }
257        }
258
259        output
260    }
261
262    /// Forward for Array1
263    pub fn forward(&mut self, input: &Array1<f32>) -> Array1<f32> {
264        Array1::from_vec(self.forward_step(input.as_slice().unwrap()))
265    }
266
267    /// Forward pass for batch
268    pub fn forward_batch(&mut self, input: &[Vec<f32>]) -> Vec<Vec<f32>> {
269        input.iter().map(|x| self.forward_step(x)).collect()
270    }
271
272    /// Reset history
273    pub fn reset(&mut self) {
274        for h in &mut self.history {
275            h.fill(0.0);
276        }
277    }
278
279    /// Get the current history buffer state
280    /// Returns kernel_size - 1 frames (excluding current input if present)
281    pub fn get_history(&self) -> Vec<Vec<f32>> {
282        // If history has kernel_size elements (after forward_step but before next call),
283        // we only save the first kernel_size - 1 elements
284        let expected_len = self.kernel_size - 1;
285        if self.history.len() >= expected_len {
286            self.history[..expected_len].to_vec()
287        } else {
288            self.history.clone()
289        }
290    }
291
292    /// Set the history buffer state
293    pub fn set_history(&mut self, history: Vec<Vec<f32>>) {
294        assert_eq!(
295            history.len(),
296            self.kernel_size - 1,
297            "History length must be kernel_size - 1 = {}",
298            self.kernel_size - 1
299        );
300        for h in &history {
301            assert_eq!(
302                h.len(),
303                self.channels,
304                "Each history frame must have channels = {} elements",
305                self.channels
306            );
307        }
308        self.history = history;
309    }
310
311    /// Get kernel size
312    pub fn kernel_size(&self) -> usize {
313        self.kernel_size
314    }
315
316    /// Get channels
317    pub fn channels(&self) -> usize {
318        self.channels
319    }
320}
321
322/// Short convolution for SSM (commonly kernel_size=4 in Mamba)
323///
324/// Optimized implementation for small kernel sizes using loop unrolling.
325#[derive(Debug, Clone)]
326pub struct ShortConv {
327    /// The underlying depthwise convolution
328    conv: DepthwiseCausalConv1d,
329}
330
331impl ShortConv {
332    /// Create a new short convolution (defaults to kernel_size=4)
333    pub fn new(channels: usize) -> Self {
334        Self::with_kernel_size(channels, 4)
335    }
336
337    /// Create with custom kernel size
338    pub fn with_kernel_size(channels: usize, kernel_size: usize) -> Self {
339        Self {
340            conv: DepthwiseCausalConv1d::new(channels, kernel_size),
341        }
342    }
343
344    /// Forward pass
345    pub fn forward(&mut self, input: &Array1<f32>) -> Array1<f32> {
346        self.conv.forward(input)
347    }
348
349    /// Reset state
350    pub fn reset(&mut self) {
351        self.conv.reset();
352    }
353
354    /// Set weights
355    pub fn set_weights(&mut self, weights: Vec<Vec<f32>>) {
356        self.conv.set_weights(weights);
357    }
358
359    /// Get channels
360    pub fn channels(&self) -> usize {
361        self.conv.channels()
362    }
363}
364
365/// Dilated Causal Convolution
366///
367/// Supports dilation for increasing receptive field without
368/// increasing kernel size or computation.
369#[derive(Debug, Clone)]
370pub struct DilatedCausalConv1d {
371    /// Kernel weights [channels, kernel_size]
372    weights: Vec<Vec<f32>>,
373    /// Bias
374    bias: Vec<f32>,
375    /// Kernel size
376    kernel_size: usize,
377    /// Dilation factor
378    dilation: usize,
379    /// Channels
380    channels: usize,
381    /// History buffer [effective_kernel_size, channels]
382    history: Vec<Vec<f32>>,
383}
384
385impl DilatedCausalConv1d {
386    /// Create a new dilated causal convolution
387    pub fn new(channels: usize, kernel_size: usize, dilation: usize) -> Self {
388        let scale = (2.0 / kernel_size as f32).sqrt();
389        let weights: Vec<Vec<f32>> = (0..channels)
390            .map(|c| {
391                (0..kernel_size)
392                    .map(|k| ((c + k) as f32 * 0.1).sin() * scale)
393                    .collect()
394            })
395            .collect();
396
397        let bias = vec![0.0; channels];
398
399        // Effective kernel size for history: (kernel_size - 1) * dilation + 1
400        let effective_size = (kernel_size - 1) * dilation;
401        let history: Vec<Vec<f32>> = (0..effective_size).map(|_| vec![0.0; channels]).collect();
402
403        Self {
404            weights,
405            bias,
406            kernel_size,
407            dilation,
408            channels,
409            history,
410        }
411    }
412
413    /// Forward pass for single time step
414    pub fn forward_step(&mut self, input: &[f32]) -> Vec<f32> {
415        assert_eq!(input.len(), self.channels);
416
417        self.history.push(input.to_vec());
418        let effective_size = (self.kernel_size - 1) * self.dilation;
419        while self.history.len() > effective_size + 1 {
420            self.history.remove(0);
421        }
422
423        let mut output = self.bias.clone();
424
425        for (c, kernel) in self.weights.iter().enumerate() {
426            for (k, &weight) in kernel.iter().enumerate() {
427                // Dilated index: current is at end, go back by k * dilation
428                let offset = k * self.dilation;
429                if offset < self.history.len() {
430                    let hist_idx = self.history.len() - 1 - offset;
431                    output[c] += weight * self.history[hist_idx][c];
432                }
433            }
434        }
435
436        output
437    }
438
439    /// Forward for Array1
440    pub fn forward(&mut self, input: &Array1<f32>) -> Array1<f32> {
441        Array1::from_vec(self.forward_step(input.as_slice().unwrap()))
442    }
443
444    /// Reset history
445    pub fn reset(&mut self) {
446        for h in &mut self.history {
447            h.fill(0.0);
448        }
449    }
450
451    /// Get receptive field
452    pub fn receptive_field(&self) -> usize {
453        (self.kernel_size - 1) * self.dilation + 1
454    }
455}
456
457/// Stack of dilated causal convolutions (WaveNet-style)
458///
459/// Each layer has increasing dilation: 1, 2, 4, 8, ...
460#[derive(Debug, Clone)]
461pub struct DilatedStack {
462    layers: Vec<DilatedCausalConv1d>,
463    residual: bool,
464}
465
466impl DilatedStack {
467    /// Create a new dilated stack with num_layers
468    ///
469    /// Dilations: 2^0, 2^1, 2^2, ..., 2^(num_layers-1)
470    pub fn new(channels: usize, kernel_size: usize, num_layers: usize) -> Self {
471        let layers: Vec<_> = (0..num_layers)
472            .map(|i| {
473                let dilation = 1 << i; // 2^i
474                DilatedCausalConv1d::new(channels, kernel_size, dilation)
475            })
476            .collect();
477
478        Self {
479            layers,
480            residual: true,
481        }
482    }
483
484    /// Disable residual connections
485    pub fn without_residual(mut self) -> Self {
486        self.residual = false;
487        self
488    }
489
490    /// Forward pass
491    pub fn forward(&mut self, input: &Array1<f32>) -> Array1<f32> {
492        let mut x = input.clone();
493        for layer in &mut self.layers {
494            let y = layer.forward(&x);
495            if self.residual {
496                x = &x + &y;
497            } else {
498                x = y;
499            }
500        }
501        x
502    }
503
504    /// Reset all layers
505    pub fn reset(&mut self) {
506        for layer in &mut self.layers {
507            layer.reset();
508        }
509    }
510
511    /// Get total receptive field
512    pub fn receptive_field(&self) -> usize {
513        self.layers
514            .iter()
515            .map(|l| l.receptive_field() - 1)
516            .sum::<usize>()
517            + 1
518    }
519}
520
521#[cfg(test)]
522mod tests {
523    use super::*;
524
525    #[test]
526    fn test_causal_conv1d() {
527        let mut conv = CausalConv1d::new(2, 3, 3);
528
529        // First step - only current input contributes
530        let out1 = conv.forward_step(&[1.0, 0.0]);
531        assert_eq!(out1.len(), 3);
532
533        // Second step - current + previous
534        let out2 = conv.forward_step(&[0.0, 1.0]);
535        assert_eq!(out2.len(), 3);
536
537        // Third step - full kernel used
538        let out3 = conv.forward_step(&[0.5, 0.5]);
539        assert_eq!(out3.len(), 3);
540    }
541
542    #[test]
543    fn test_depthwise_causal() {
544        let mut conv = DepthwiseCausalConv1d::new(4, 3);
545
546        let input = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
547        let out = conv.forward(&input);
548        assert_eq!(out.len(), 4);
549
550        // After reset, should behave as if fresh
551        conv.reset();
552        let out2 = conv.forward(&input);
553        assert_eq!(out, out2);
554    }
555
556    #[test]
557    fn test_short_conv() {
558        let mut conv = ShortConv::new(8);
559        assert_eq!(conv.channels(), 8);
560
561        let input = Array1::ones(8);
562        let out = conv.forward(&input);
563        assert_eq!(out.len(), 8);
564    }
565
566    #[test]
567    fn test_dilated_conv() {
568        let mut conv = DilatedCausalConv1d::new(4, 3, 2);
569        assert_eq!(conv.receptive_field(), 5); // (3-1)*2 + 1
570
571        let input = Array1::ones(4);
572        let out = conv.forward(&input);
573        assert_eq!(out.len(), 4);
574    }
575
576    #[test]
577    fn test_dilated_stack() {
578        let mut stack = DilatedStack::new(4, 2, 4);
579        // Receptive field: 1 + 2 + 4 + 8 = 15
580        // Actually: layers have dilations 1,2,4,8 with kernel_size=2
581        // RF = sum((k-1)*d) + 1 = (1*1) + (1*2) + (1*4) + (1*8) + 1 = 16
582
583        let input = Array1::ones(4);
584        let out = stack.forward(&input);
585        assert_eq!(out.len(), 4);
586    }
587
588    #[test]
589    fn test_causality() {
590        // Verify that output only depends on current and past inputs
591        let mut conv1 = DepthwiseCausalConv1d::new(2, 3);
592        let mut conv2 = DepthwiseCausalConv1d::new(2, 3);
593
594        // Same weights
595        conv2.set_weights(conv1.weights.clone());
596        conv2.set_bias(conv1.bias.clone());
597
598        // Feed same first two inputs
599        let in1 = vec![1.0, 0.0];
600        let in2 = vec![0.0, 1.0];
601
602        let _ = conv1.forward_step(&in1);
603        let out1 = conv1.forward_step(&in2);
604
605        let _ = conv2.forward_step(&in1);
606        let out2 = conv2.forward_step(&in2);
607
608        // Outputs should be identical (causality preserved)
609        assert_eq!(out1, out2);
610
611        // Now feed different third inputs - previous outputs should have been same
612        let _ = conv1.forward_step(&[1.0, 1.0]);
613        let _ = conv2.forward_step(&[0.5, 0.5]);
614
615        // First two outputs were identical, proving causality
616    }
617}