1use scirs2_core::ndarray::Array1;
7
8#[derive(Debug, Clone)]
13pub struct CausalConv1d {
14 weights: Vec<Vec<Vec<f32>>>,
16 bias: Vec<f32>,
18 kernel_size: usize,
20 in_channels: usize,
22 out_channels: usize,
24 history: Vec<Vec<f32>>,
26}
27
28impl CausalConv1d {
29 pub fn new(in_channels: usize, out_channels: usize, kernel_size: usize) -> Self {
31 let scale = (2.0 / (in_channels * kernel_size) as f32).sqrt();
33 let mut weights = Vec::with_capacity(out_channels);
34
35 for _ in 0..out_channels {
36 let mut out_ch = Vec::with_capacity(in_channels);
37 for _ in 0..in_channels {
38 let kernel: Vec<f32> = (0..kernel_size)
39 .map(|i| {
40 (i as f32 * 0.1).sin() * scale
42 })
43 .collect();
44 out_ch.push(kernel);
45 }
46 weights.push(out_ch);
47 }
48
49 let bias = vec![0.0; out_channels];
50
51 let history: Vec<Vec<f32>> = (0..(kernel_size - 1))
53 .map(|_| vec![0.0; in_channels])
54 .collect();
55
56 Self {
57 weights,
58 bias,
59 kernel_size,
60 in_channels,
61 out_channels,
62 history,
63 }
64 }
65
66 pub fn set_weights(&mut self, weights: Vec<Vec<Vec<f32>>>) {
68 assert_eq!(weights.len(), self.out_channels);
69 for oc in &weights {
70 assert_eq!(oc.len(), self.in_channels);
71 for ic in oc {
72 assert_eq!(ic.len(), self.kernel_size);
73 }
74 }
75 self.weights = weights;
76 }
77
78 pub fn set_bias(&mut self, bias: Vec<f32>) {
80 assert_eq!(bias.len(), self.out_channels);
81 self.bias = bias;
82 }
83
84 pub fn forward_step(&mut self, input: &[f32]) -> Vec<f32> {
88 assert_eq!(input.len(), self.in_channels);
89
90 self.history.push(input.to_vec());
92
93 while self.history.len() > self.kernel_size {
95 self.history.remove(0);
96 }
97
98 let mut output = self.bias.clone();
100
101 for (oc, out_weights) in self.weights.iter().enumerate() {
102 for (ic, in_weights) in out_weights.iter().enumerate() {
103 for (k, &weight) in in_weights.iter().enumerate() {
104 if k < self.history.len() {
107 let hist_idx = self.history.len() - 1 - k;
108 output[oc] += weight * self.history[hist_idx][ic];
109 }
110 }
111 }
112 }
113
114 output
115 }
116
117 pub fn forward_batch(&mut self, input: &[Vec<f32>]) -> Vec<Vec<f32>> {
122 input.iter().map(|x| self.forward_step(x)).collect()
123 }
124
125 pub fn reset(&mut self) {
127 for h in &mut self.history {
128 h.fill(0.0);
129 }
130 }
131
132 pub fn get_history(&self) -> Vec<Vec<f32>> {
135 let expected_len = self.kernel_size - 1;
138 if self.history.len() >= expected_len {
139 self.history[..expected_len].to_vec()
140 } else {
141 self.history.clone()
142 }
143 }
144
145 pub fn set_history(&mut self, history: Vec<Vec<f32>>) {
147 assert_eq!(
148 history.len(),
149 self.kernel_size - 1,
150 "History length must be kernel_size - 1 = {}",
151 self.kernel_size - 1
152 );
153 for h in &history {
154 assert_eq!(
155 h.len(),
156 self.in_channels,
157 "Each history frame must have in_channels = {} elements",
158 self.in_channels
159 );
160 }
161 self.history = history;
162 }
163
164 pub fn kernel_size(&self) -> usize {
166 self.kernel_size
167 }
168
169 pub fn in_channels(&self) -> usize {
171 self.in_channels
172 }
173
174 pub fn out_channels(&self) -> usize {
176 self.out_channels
177 }
178}
179
180#[derive(Debug, Clone)]
185pub struct DepthwiseCausalConv1d {
186 weights: Vec<Vec<f32>>,
188 bias: Vec<f32>,
190 kernel_size: usize,
192 channels: usize,
194 history: Vec<Vec<f32>>,
196}
197
198impl DepthwiseCausalConv1d {
199 pub fn new(channels: usize, kernel_size: usize) -> Self {
201 let scale = (2.0 / kernel_size as f32).sqrt();
202 let weights: Vec<Vec<f32>> = (0..channels)
203 .map(|c| {
204 (0..kernel_size)
205 .map(|k| ((c + k) as f32 * 0.1).sin() * scale)
206 .collect()
207 })
208 .collect();
209
210 let bias = vec![0.0; channels];
211 let history: Vec<Vec<f32>> = (0..(kernel_size - 1))
212 .map(|_| vec![0.0; channels])
213 .collect();
214
215 Self {
216 weights,
217 bias,
218 kernel_size,
219 channels,
220 history,
221 }
222 }
223
224 pub fn set_weights(&mut self, weights: Vec<Vec<f32>>) {
226 assert_eq!(weights.len(), self.channels);
227 for w in &weights {
228 assert_eq!(w.len(), self.kernel_size);
229 }
230 self.weights = weights;
231 }
232
233 pub fn set_bias(&mut self, bias: Vec<f32>) {
235 assert_eq!(bias.len(), self.channels);
236 self.bias = bias;
237 }
238
239 pub fn forward_step(&mut self, input: &[f32]) -> Vec<f32> {
241 assert_eq!(input.len(), self.channels);
242
243 self.history.push(input.to_vec());
244 while self.history.len() > self.kernel_size {
245 self.history.remove(0);
246 }
247
248 let mut output = self.bias.clone();
249
250 for (c, kernel) in self.weights.iter().enumerate() {
251 for (k, &weight) in kernel.iter().enumerate() {
252 if k < self.history.len() {
253 let hist_idx = self.history.len() - 1 - k;
254 output[c] += weight * self.history[hist_idx][c];
255 }
256 }
257 }
258
259 output
260 }
261
262 pub fn forward(&mut self, input: &Array1<f32>) -> Array1<f32> {
264 Array1::from_vec(self.forward_step(input.as_slice().unwrap()))
265 }
266
267 pub fn forward_batch(&mut self, input: &[Vec<f32>]) -> Vec<Vec<f32>> {
269 input.iter().map(|x| self.forward_step(x)).collect()
270 }
271
272 pub fn reset(&mut self) {
274 for h in &mut self.history {
275 h.fill(0.0);
276 }
277 }
278
279 pub fn get_history(&self) -> Vec<Vec<f32>> {
282 let expected_len = self.kernel_size - 1;
285 if self.history.len() >= expected_len {
286 self.history[..expected_len].to_vec()
287 } else {
288 self.history.clone()
289 }
290 }
291
292 pub fn set_history(&mut self, history: Vec<Vec<f32>>) {
294 assert_eq!(
295 history.len(),
296 self.kernel_size - 1,
297 "History length must be kernel_size - 1 = {}",
298 self.kernel_size - 1
299 );
300 for h in &history {
301 assert_eq!(
302 h.len(),
303 self.channels,
304 "Each history frame must have channels = {} elements",
305 self.channels
306 );
307 }
308 self.history = history;
309 }
310
311 pub fn kernel_size(&self) -> usize {
313 self.kernel_size
314 }
315
316 pub fn channels(&self) -> usize {
318 self.channels
319 }
320}
321
322#[derive(Debug, Clone)]
326pub struct ShortConv {
327 conv: DepthwiseCausalConv1d,
329}
330
331impl ShortConv {
332 pub fn new(channels: usize) -> Self {
334 Self::with_kernel_size(channels, 4)
335 }
336
337 pub fn with_kernel_size(channels: usize, kernel_size: usize) -> Self {
339 Self {
340 conv: DepthwiseCausalConv1d::new(channels, kernel_size),
341 }
342 }
343
344 pub fn forward(&mut self, input: &Array1<f32>) -> Array1<f32> {
346 self.conv.forward(input)
347 }
348
349 pub fn reset(&mut self) {
351 self.conv.reset();
352 }
353
354 pub fn set_weights(&mut self, weights: Vec<Vec<f32>>) {
356 self.conv.set_weights(weights);
357 }
358
359 pub fn channels(&self) -> usize {
361 self.conv.channels()
362 }
363}
364
365#[derive(Debug, Clone)]
370pub struct DilatedCausalConv1d {
371 weights: Vec<Vec<f32>>,
373 bias: Vec<f32>,
375 kernel_size: usize,
377 dilation: usize,
379 channels: usize,
381 history: Vec<Vec<f32>>,
383}
384
385impl DilatedCausalConv1d {
386 pub fn new(channels: usize, kernel_size: usize, dilation: usize) -> Self {
388 let scale = (2.0 / kernel_size as f32).sqrt();
389 let weights: Vec<Vec<f32>> = (0..channels)
390 .map(|c| {
391 (0..kernel_size)
392 .map(|k| ((c + k) as f32 * 0.1).sin() * scale)
393 .collect()
394 })
395 .collect();
396
397 let bias = vec![0.0; channels];
398
399 let effective_size = (kernel_size - 1) * dilation;
401 let history: Vec<Vec<f32>> = (0..effective_size).map(|_| vec![0.0; channels]).collect();
402
403 Self {
404 weights,
405 bias,
406 kernel_size,
407 dilation,
408 channels,
409 history,
410 }
411 }
412
413 pub fn forward_step(&mut self, input: &[f32]) -> Vec<f32> {
415 assert_eq!(input.len(), self.channels);
416
417 self.history.push(input.to_vec());
418 let effective_size = (self.kernel_size - 1) * self.dilation;
419 while self.history.len() > effective_size + 1 {
420 self.history.remove(0);
421 }
422
423 let mut output = self.bias.clone();
424
425 for (c, kernel) in self.weights.iter().enumerate() {
426 for (k, &weight) in kernel.iter().enumerate() {
427 let offset = k * self.dilation;
429 if offset < self.history.len() {
430 let hist_idx = self.history.len() - 1 - offset;
431 output[c] += weight * self.history[hist_idx][c];
432 }
433 }
434 }
435
436 output
437 }
438
439 pub fn forward(&mut self, input: &Array1<f32>) -> Array1<f32> {
441 Array1::from_vec(self.forward_step(input.as_slice().unwrap()))
442 }
443
444 pub fn reset(&mut self) {
446 for h in &mut self.history {
447 h.fill(0.0);
448 }
449 }
450
451 pub fn receptive_field(&self) -> usize {
453 (self.kernel_size - 1) * self.dilation + 1
454 }
455}
456
457#[derive(Debug, Clone)]
461pub struct DilatedStack {
462 layers: Vec<DilatedCausalConv1d>,
463 residual: bool,
464}
465
466impl DilatedStack {
467 pub fn new(channels: usize, kernel_size: usize, num_layers: usize) -> Self {
471 let layers: Vec<_> = (0..num_layers)
472 .map(|i| {
473 let dilation = 1 << i; DilatedCausalConv1d::new(channels, kernel_size, dilation)
475 })
476 .collect();
477
478 Self {
479 layers,
480 residual: true,
481 }
482 }
483
484 pub fn without_residual(mut self) -> Self {
486 self.residual = false;
487 self
488 }
489
490 pub fn forward(&mut self, input: &Array1<f32>) -> Array1<f32> {
492 let mut x = input.clone();
493 for layer in &mut self.layers {
494 let y = layer.forward(&x);
495 if self.residual {
496 x = &x + &y;
497 } else {
498 x = y;
499 }
500 }
501 x
502 }
503
504 pub fn reset(&mut self) {
506 for layer in &mut self.layers {
507 layer.reset();
508 }
509 }
510
511 pub fn receptive_field(&self) -> usize {
513 self.layers
514 .iter()
515 .map(|l| l.receptive_field() - 1)
516 .sum::<usize>()
517 + 1
518 }
519}
520
521#[cfg(test)]
522mod tests {
523 use super::*;
524
525 #[test]
526 fn test_causal_conv1d() {
527 let mut conv = CausalConv1d::new(2, 3, 3);
528
529 let out1 = conv.forward_step(&[1.0, 0.0]);
531 assert_eq!(out1.len(), 3);
532
533 let out2 = conv.forward_step(&[0.0, 1.0]);
535 assert_eq!(out2.len(), 3);
536
537 let out3 = conv.forward_step(&[0.5, 0.5]);
539 assert_eq!(out3.len(), 3);
540 }
541
542 #[test]
543 fn test_depthwise_causal() {
544 let mut conv = DepthwiseCausalConv1d::new(4, 3);
545
546 let input = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
547 let out = conv.forward(&input);
548 assert_eq!(out.len(), 4);
549
550 conv.reset();
552 let out2 = conv.forward(&input);
553 assert_eq!(out, out2);
554 }
555
556 #[test]
557 fn test_short_conv() {
558 let mut conv = ShortConv::new(8);
559 assert_eq!(conv.channels(), 8);
560
561 let input = Array1::ones(8);
562 let out = conv.forward(&input);
563 assert_eq!(out.len(), 8);
564 }
565
566 #[test]
567 fn test_dilated_conv() {
568 let mut conv = DilatedCausalConv1d::new(4, 3, 2);
569 assert_eq!(conv.receptive_field(), 5); let input = Array1::ones(4);
572 let out = conv.forward(&input);
573 assert_eq!(out.len(), 4);
574 }
575
576 #[test]
577 fn test_dilated_stack() {
578 let mut stack = DilatedStack::new(4, 2, 4);
579 let input = Array1::ones(4);
584 let out = stack.forward(&input);
585 assert_eq!(out.len(), 4);
586 }
587
588 #[test]
589 fn test_causality() {
590 let mut conv1 = DepthwiseCausalConv1d::new(2, 3);
592 let mut conv2 = DepthwiseCausalConv1d::new(2, 3);
593
594 conv2.set_weights(conv1.weights.clone());
596 conv2.set_bias(conv1.bias.clone());
597
598 let in1 = vec![1.0, 0.0];
600 let in2 = vec![0.0, 1.0];
601
602 let _ = conv1.forward_step(&in1);
603 let out1 = conv1.forward_step(&in2);
604
605 let _ = conv2.forward_step(&in1);
606 let out2 = conv2.forward_step(&in2);
607
608 assert_eq!(out1, out2);
610
611 let _ = conv1.forward_step(&[1.0, 1.0]);
613 let _ = conv2.forward_step(&[0.5, 0.5]);
614
615 }
617}