1use crate::error::{Result, TextError};
7use scirs2_core::ndarray::{s, Array1, Array2, Array3, ArrayView1, ArrayView2, Axis};
8use scirs2_core::random::{self, Rng};
9use statrs::statistics::Statistics;
10
11#[derive(Debug, Clone, Copy)]
13pub enum ActivationFunction {
14 Tanh,
16 Sigmoid,
18 ReLU,
20 GELU,
22 Swish,
24 Linear,
26}
27
28impl ActivationFunction {
29 pub fn apply(&self, x: f64) -> f64 {
31 match self {
32 ActivationFunction::Tanh => x.tanh(),
33 ActivationFunction::Sigmoid => 1.0 / (1.0 + (-x).exp()),
34 ActivationFunction::ReLU => x.max(0.0),
35 ActivationFunction::GELU => {
36 0.5 * x * (1.0 + (x * 0.7978845608 * (1.0 + 0.044715 * x * x)).tanh())
37 }
38 ActivationFunction::Swish => x / (1.0 + (-x).exp()),
39 ActivationFunction::Linear => x,
40 }
41 }
42
43 pub fn apply_array(&self, x: &Array1<f64>) -> Array1<f64> {
45 x.mapv(|val| self.apply(val))
46 }
47
48 pub fn derivative(&self, x: f64) -> f64 {
50 match self {
51 ActivationFunction::Tanh => {
52 let tanh_x = x.tanh();
53 1.0 - tanh_x * tanh_x
54 }
55 ActivationFunction::Sigmoid => {
56 let sig_x = self.apply(x);
57 sig_x * (1.0 - sig_x)
58 }
59 ActivationFunction::ReLU => {
60 if x > 0.0 {
61 1.0
62 } else {
63 0.0
64 }
65 }
66 ActivationFunction::GELU => {
67 let cdf = 0.5 * (1.0 + (x * 0.7978845608).tanh());
69 let pdf = 0.7978845608 * (-0.5 * x * x).exp();
70 cdf + x * pdf
71 }
72 ActivationFunction::Swish => {
73 let sig_x = 1.0 / (1.0 + (-x).exp());
74 sig_x + x * sig_x * (1.0 - sig_x)
75 }
76 ActivationFunction::Linear => 1.0,
77 }
78 }
79}
80
81#[derive(Debug, Clone)]
83pub struct LSTMCell {
84 w_i: Array2<f64>,
86 w_f: Array2<f64>,
88 w_o: Array2<f64>,
90 w_c: Array2<f64>,
92 u_i: Array2<f64>,
94 u_f: Array2<f64>,
96 u_o: Array2<f64>,
98 u_c: Array2<f64>,
100 b_i: Array1<f64>,
102 b_f: Array1<f64>,
103 b_o: Array1<f64>,
104 b_c: Array1<f64>,
105 input_size: usize,
107 hidden_size: usize,
109}
110
111impl LSTMCell {
112 pub fn new(_input_size: usize, hiddensize: usize) -> Self {
114 let scale = (2.0 / (_input_size + hiddensize) as f64).sqrt();
115
116 let w_i = Array2::from_shape_fn((hiddensize, _input_size), |_| {
118 scirs2_core::random::rng().random_range(-scale..scale)
119 });
120 let w_f = Array2::from_shape_fn((hiddensize, _input_size), |_| {
121 scirs2_core::random::rng().random_range(-scale..scale)
122 });
123 let w_o = Array2::from_shape_fn((hiddensize, _input_size), |_| {
124 scirs2_core::random::rng().random_range(-scale..scale)
125 });
126 let w_c = Array2::from_shape_fn((hiddensize, _input_size), |_| {
127 scirs2_core::random::rng().random_range(-scale..scale)
128 });
129
130 let u_i = Array2::from_shape_fn((hiddensize, hiddensize), |_| {
131 scirs2_core::random::rng().random_range(-scale..scale)
132 });
133 let u_f = Array2::from_shape_fn((hiddensize, hiddensize), |_| {
134 scirs2_core::random::rng().random_range(-scale..scale)
135 });
136 let u_o = Array2::from_shape_fn((hiddensize, hiddensize), |_| {
137 scirs2_core::random::rng().random_range(-scale..scale)
138 });
139 let u_c = Array2::from_shape_fn((hiddensize, hiddensize), |_| {
140 scirs2_core::random::rng().random_range(-scale..scale)
141 });
142
143 let b_i = Array1::zeros(hiddensize);
145 let b_f = Array1::ones(hiddensize);
146 let b_o = Array1::zeros(hiddensize);
147 let b_c = Array1::zeros(hiddensize);
148
149 Self {
150 w_i,
151 w_f,
152 w_o,
153 w_c,
154 u_i,
155 u_f,
156 u_o,
157 u_c,
158 b_i,
159 b_f,
160 b_o,
161 b_c,
162 input_size: _input_size,
163 hidden_size: hiddensize,
164 }
165 }
166
167 pub fn forward(
169 &self,
170 x: ArrayView1<f64>,
171 h_prev: ArrayView1<f64>,
172 c_prev: ArrayView1<f64>,
173 ) -> Result<(Array1<f64>, Array1<f64>)> {
174 if x.len() != self.input_size {
175 return Err(TextError::InvalidInput(format!(
176 "Expected input size {}, got {}",
177 self.input_size,
178 x.len()
179 )));
180 }
181
182 if h_prev.len() != self.hidden_size || c_prev.len() != self.hidden_size {
183 return Err(TextError::InvalidInput(format!(
184 "Expected hidden size {}, got h: {}, c: {}",
185 self.hidden_size,
186 h_prev.len(),
187 c_prev.len()
188 )));
189 }
190
191 let i_t = ActivationFunction::Sigmoid
193 .apply_array(&(self.w_i.dot(&x) + self.u_i.dot(&h_prev) + &self.b_i));
194
195 let f_t = ActivationFunction::Sigmoid
197 .apply_array(&(self.w_f.dot(&x) + self.u_f.dot(&h_prev) + &self.b_f));
198
199 let o_t = ActivationFunction::Sigmoid
201 .apply_array(&(self.w_o.dot(&x) + self.u_o.dot(&h_prev) + &self.b_o));
202
203 let c_tilde = ActivationFunction::Tanh
205 .apply_array(&(self.w_c.dot(&x) + self.u_c.dot(&h_prev) + &self.b_c));
206
207 let c_t = &f_t * &c_prev + &i_t * &c_tilde;
209
210 let h_t = &o_t * &ActivationFunction::Tanh.apply_array(&c_t);
212
213 Ok((h_t, c_t))
214 }
215}
216
217#[derive(Debug, Clone)]
219pub struct GRUCell {
220 w_z: Array2<f64>,
222 w_r: Array2<f64>,
224 w_h: Array2<f64>,
226 u_z: Array2<f64>,
228 u_r: Array2<f64>,
230 u_h: Array2<f64>,
232 b_z: Array1<f64>,
234 b_r: Array1<f64>,
235 b_h: Array1<f64>,
236 input_size: usize,
238 hidden_size: usize,
240}
241
242impl GRUCell {
243 pub fn new(_input_size: usize, hiddensize: usize) -> Self {
245 let scale = (2.0 / (_input_size + hiddensize) as f64).sqrt();
246
247 let w_z = Array2::from_shape_fn((hiddensize, _input_size), |_| {
249 scirs2_core::random::rng().random_range(-scale..scale)
250 });
251 let w_r = Array2::from_shape_fn((hiddensize, _input_size), |_| {
252 scirs2_core::random::rng().random_range(-scale..scale)
253 });
254 let w_h = Array2::from_shape_fn((hiddensize, _input_size), |_| {
255 scirs2_core::random::rng().random_range(-scale..scale)
256 });
257
258 let u_z = Array2::from_shape_fn((hiddensize, hiddensize), |_| {
259 scirs2_core::random::rng().random_range(-scale..scale)
260 });
261 let u_r = Array2::from_shape_fn((hiddensize, hiddensize), |_| {
262 scirs2_core::random::rng().random_range(-scale..scale)
263 });
264 let u_h = Array2::from_shape_fn((hiddensize, hiddensize), |_| {
265 scirs2_core::random::rng().random_range(-scale..scale)
266 });
267
268 let b_z = Array1::zeros(hiddensize);
270 let b_r = Array1::zeros(hiddensize);
271 let b_h = Array1::zeros(hiddensize);
272
273 Self {
274 w_z,
275 w_r,
276 w_h,
277 u_z,
278 u_r,
279 u_h,
280 b_z,
281 b_r,
282 b_h,
283 input_size: _input_size,
284 hidden_size: hiddensize,
285 }
286 }
287
288 pub fn forward(&self, x: ArrayView1<f64>, hprev: ArrayView1<f64>) -> Result<Array1<f64>> {
290 if x.len() != self.input_size {
291 return Err(TextError::InvalidInput(format!(
292 "Expected input size {}, got {}",
293 self.input_size,
294 x.len()
295 )));
296 }
297
298 if hprev.len() != self.hidden_size {
299 return Err(TextError::InvalidInput(format!(
300 "Expected hidden size {}, got {}",
301 self.hidden_size,
302 hprev.len()
303 )));
304 }
305
306 let z_t = ActivationFunction::Sigmoid
308 .apply_array(&(self.w_z.dot(&x) + self.u_z.dot(&hprev) + &self.b_z));
309
310 let r_t = ActivationFunction::Sigmoid
312 .apply_array(&(self.w_r.dot(&x) + self.u_r.dot(&hprev) + &self.b_r));
313
314 let h_tilde = ActivationFunction::Tanh
316 .apply_array(&(self.w_h.dot(&x) + self.u_h.dot(&(&r_t * &hprev)) + &self.b_h));
317
318 let h_t = &(&Array1::ones(self.hidden_size) - &z_t) * &hprev + &z_t * &h_tilde;
320
321 Ok(h_t)
322 }
323}
324
325pub struct BiLSTM {
327 forward_cells: Vec<LSTMCell>,
329 backward_cells: Vec<LSTMCell>,
331 num_layers: usize,
333 hidden_size: usize,
335}
336
337impl BiLSTM {
338 pub fn new(_input_size: usize, hidden_size: usize, numlayers: usize) -> Self {
340 let mut forward_cells = Vec::new();
341 let mut backward_cells = Vec::new();
342
343 for i in 0..numlayers {
344 let layer_input_size = if i == 0 { _input_size } else { hidden_size * 2 };
345 forward_cells.push(LSTMCell::new(layer_input_size, hidden_size));
346 backward_cells.push(LSTMCell::new(layer_input_size, hidden_size));
347 }
348
349 Self {
350 forward_cells,
351 backward_cells,
352 num_layers: numlayers,
353 hidden_size,
354 }
355 }
356
357 pub fn forward(&self, sequence: ArrayView2<f64>) -> Result<Array2<f64>> {
359 let (seq_len, input_size) = sequence.dim();
360 let output_size = self.hidden_size * 2; let mut current_input = sequence.to_owned();
363
364 for layer in 0..self.num_layers {
365 let mut forward_outputs = Vec::new();
366 let mut backward_outputs = Vec::new();
367
368 let mut h_forward = Array1::zeros(self.hidden_size);
370 let mut c_forward = Array1::zeros(self.hidden_size);
371
372 for t in 0..seq_len {
373 let (h_new, c_new) = self.forward_cells[layer].forward(
374 current_input.row(t),
375 h_forward.view(),
376 c_forward.view(),
377 )?;
378 h_forward = h_new;
379 c_forward = c_new;
380 forward_outputs.push(h_forward.clone());
381 }
382
383 let mut h_backward = Array1::zeros(self.hidden_size);
385 let mut c_backward = Array1::zeros(self.hidden_size);
386
387 for t in (0..seq_len).rev() {
388 let (h_new, c_new) = self.backward_cells[layer].forward(
389 current_input.row(t),
390 h_backward.view(),
391 c_backward.view(),
392 )?;
393 h_backward = h_new;
394 c_backward = c_new;
395 backward_outputs.push(h_backward.clone());
396 }
397
398 backward_outputs.reverse();
400
401 let mut layer_output = Array2::zeros((seq_len, output_size));
403 for t in 0..seq_len {
404 let mut concat_output = Array1::zeros(output_size);
405 concat_output
406 .slice_mut(s![..self.hidden_size])
407 .assign(&forward_outputs[t]);
408 concat_output
409 .slice_mut(s![self.hidden_size..])
410 .assign(&backward_outputs[t]);
411 layer_output.row_mut(t).assign(&concat_output);
412 }
413
414 current_input = layer_output;
415 }
416
417 Ok(current_input)
418 }
419}
420
421#[derive(Debug, Clone)]
423pub struct Conv1D {
424 filters: Array3<f64>,
426 bias: Array1<f64>,
428 num_filters: usize,
430 kernel_size: usize,
432 input_channels: usize,
434 activation: ActivationFunction,
436}
437
438impl Conv1D {
439 pub fn new(
441 input_channels: usize,
442 num_filters: usize,
443 kernel_size: usize,
444 activation: ActivationFunction,
445 ) -> Self {
446 let scale = (2.0 / (input_channels * kernel_size) as f64).sqrt();
447
448 let _filters = Array3::from_shape_fn((num_filters, input_channels, kernel_size), |_| {
450 scirs2_core::random::rng().random_range(-scale..scale)
451 });
452
453 let bias = Array1::zeros(num_filters);
454
455 Self {
456 filters: _filters,
457 bias,
458 num_filters,
459 kernel_size,
460 input_channels,
461 activation,
462 }
463 }
464
465 pub fn forward(&self, input: ArrayView2<f64>) -> Result<Array2<f64>> {
467 let (seq_len, input_dim) = input.dim();
468
469 if input_dim != self.input_channels {
470 return Err(TextError::InvalidInput(format!(
471 "Expected {} input channels, got {}",
472 self.input_channels, input_dim
473 )));
474 }
475
476 let output_len = seq_len.saturating_sub(self.kernel_size - 1);
477 let mut output = Array2::zeros((output_len, self.num_filters));
478
479 for filter_idx in 0..self.num_filters {
480 for pos in 0..output_len {
481 let mut conv_sum = 0.0;
482
483 for ch in 0..self.input_channels {
484 for k in 0..self.kernel_size {
485 if pos + k < seq_len {
486 conv_sum += input[[pos + k, ch]] * self.filters[[filter_idx, ch, k]];
487 }
488 }
489 }
490
491 conv_sum += self.bias[filter_idx];
492 output[[pos, filter_idx]] = self.activation.apply(conv_sum);
493 }
494 }
495
496 Ok(output)
497 }
498}
499
500#[derive(Debug)]
502pub struct MaxPool1D {
503 pool_size: usize,
505 stride: usize,
507}
508
509impl MaxPool1D {
510 pub fn new(poolsize: usize, stride: usize) -> Self {
512 Self {
513 pool_size: poolsize,
514 stride,
515 }
516 }
517
518 pub fn forward(&self, input: ArrayView2<f64>) -> Array2<f64> {
520 let (seq_len, channels) = input.dim();
521 let output_len = (seq_len - self.pool_size) / self.stride + 1;
522
523 let mut output = Array2::zeros((output_len, channels));
524
525 for ch in 0..channels {
526 for i in 0..output_len {
527 let start = i * self.stride;
528 let end = (start + self.pool_size).min(seq_len);
529
530 let mut max_val = f64::NEG_INFINITY;
531 for j in start..end {
532 max_val = max_val.max(input[[j, ch]]);
533 }
534
535 output[[i, ch]] = max_val;
536 }
537 }
538
539 output
540 }
541}
542
543#[derive(Debug, Clone)]
545pub struct ResidualBlock1D {
546 conv1: Conv1D,
548 conv2: Conv1D,
550 skip_projection: Option<Array2<f64>>,
552 bn1_scale: Array1<f64>,
554 bn1_shift: Array1<f64>,
555 bn2_scale: Array1<f64>,
556 bn2_shift: Array1<f64>,
557}
558
559impl ResidualBlock1D {
560 pub fn new(_input_channels: usize, output_channels: usize, kernelsize: usize) -> Self {
562 let conv1 = Conv1D::new(
563 _input_channels,
564 output_channels,
565 kernelsize,
566 ActivationFunction::Linear,
567 );
568 let conv2 = Conv1D::new(
569 output_channels,
570 output_channels,
571 kernelsize,
572 ActivationFunction::Linear,
573 );
574
575 let skip_projection = if _input_channels != output_channels {
577 let scale = (2.0 / _input_channels as f64).sqrt();
578 Some(Array2::from_shape_fn(
579 (output_channels, _input_channels),
580 |_| scirs2_core::random::rng().random_range(-scale..scale),
581 ))
582 } else {
583 None
584 };
585
586 let bn1_scale = Array1::ones(output_channels);
588 let bn1_shift = Array1::zeros(output_channels);
589 let bn2_scale = Array1::ones(output_channels);
590 let bn2_shift = Array1::zeros(output_channels);
591
592 Self {
593 conv1,
594 conv2,
595 skip_projection,
596 bn1_scale,
597 bn1_shift,
598 bn2_scale,
599 bn2_shift,
600 }
601 }
602
603 pub fn forward(&self, input: ArrayView2<f64>) -> Result<Array2<f64>> {
605 let conv1_out = self.conv1.forward(input)?;
607 let bn1_out = self.batch_norm(&conv1_out, &self.bn1_scale, &self.bn1_shift);
608 let relu1_out = bn1_out.mapv(|x| ActivationFunction::ReLU.apply(x));
609
610 let conv2_out = self.conv2.forward(relu1_out.view())?;
612 let bn2_out = self.batch_norm(&conv2_out, &self.bn2_scale, &self.bn2_shift);
613
614 let skip_out = if let Some(ref projection) = self.skip_projection {
616 let projected = input.dot(&projection.t());
618
619 let conv_output_len = bn2_out.shape()[0];
622 let skip_len = projected.shape()[0];
623
624 if conv_output_len < skip_len {
625 let start = (skip_len - conv_output_len) / 2;
627 let end = start + conv_output_len;
628 projected.slice(s![start..end, ..]).to_owned()
629 } else {
630 projected
631 }
632 } else {
633 let conv_output_len = bn2_out.shape()[0];
635 let skip_len = input.shape()[0];
636
637 if conv_output_len < skip_len {
638 let start = (skip_len - conv_output_len) / 2;
640 let end = start + conv_output_len;
641 input.slice(s![start..end, ..]).to_owned()
642 } else {
643 input.to_owned()
644 }
645 };
646
647 let output = &bn2_out + &skip_out;
649 Ok(output.mapv(|x| ActivationFunction::ReLU.apply(x)))
650 }
651
652 fn batch_norm(
654 &self,
655 input: &Array2<f64>,
656 scale: &Array1<f64>,
657 shift: &Array1<f64>,
658 ) -> Array2<f64> {
659 let mut result = input.clone();
660 let eps = 1e-5;
661
662 for ch in 0..input.shape()[1] {
664 let channel_data = input.column(ch);
665 let mean = channel_data.mean();
666 let var = channel_data.mapv(|x| (x - mean).powi(2)).mean();
667 let std = (var + eps).sqrt();
668
669 let mut normalized = channel_data.mapv(|x| (x - mean) / std);
670 normalized = normalized * scale[ch] + shift[ch];
671
672 result.column_mut(ch).assign(&normalized);
673 }
674
675 result
676 }
677}
678
679#[derive(Debug)]
681pub struct MultiScaleCNN {
682 conv_branches: Vec<Conv1D>,
684 bn_branches: Vec<(Array1<f64>, Array1<f64>)>,
686 combinationweights: Array2<f64>,
688 #[allow(dead_code)]
690 global_pool: MaxPool1D,
691}
692
693impl MultiScaleCNN {
694 pub fn new(
696 input_channels: usize,
697 num_filters_per_scale: usize,
698 kernel_sizes: Vec<usize>,
699 output_size: usize,
700 ) -> Self {
701 let mut conv_branches = Vec::new();
702 let mut bn_branches = Vec::new();
703
704 for &kernel_size in &kernel_sizes {
706 conv_branches.push(Conv1D::new(
707 input_channels,
708 num_filters_per_scale,
709 kernel_size,
710 ActivationFunction::ReLU,
711 ));
712
713 bn_branches.push((
715 Array1::ones(num_filters_per_scale),
716 Array1::zeros(num_filters_per_scale),
717 ));
718 }
719
720 let total_features = kernel_sizes.len() * num_filters_per_scale;
722 let _scale = (2.0 / total_features as f64).sqrt();
723 let combination_weights = Array2::from_shape_fn((output_size, total_features), |_| {
724 scirs2_core::random::rng().random_range(-_scale.._scale)
725 });
726
727 let global_pool = MaxPool1D::new(2, 2);
728
729 Self {
730 conv_branches,
731 bn_branches,
732 combinationweights: combination_weights,
733 global_pool,
734 }
735 }
736
737 pub fn forward(&self, input: ArrayView2<f64>) -> Result<Array1<f64>> {
739 let mut branch_outputs = Vec::new();
740
741 for (i, conv) in self.conv_branches.iter().enumerate() {
743 let conv_out = conv.forward(input)?;
744
745 let (scale, shift) = &self.bn_branches[i];
747 let bn_out = self.batch_norm_branch(&conv_out, scale, shift);
748
749 let global_max = bn_out.map_axis(Axis(0), |row| {
751 row.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b))
752 });
753
754 branch_outputs.push(global_max);
755 }
756
757 let mut concatenated = Array1::zeros(branch_outputs.iter().map(|x| x.len()).sum::<usize>());
759 let mut offset = 0;
760 for branch_output in branch_outputs {
761 let end = offset + branch_output.len();
762 concatenated
763 .slice_mut(s![offset..end])
764 .assign(&branch_output);
765 offset = end;
766 }
767
768 Ok(self.combinationweights.dot(&concatenated))
770 }
771
772 fn batch_norm_branch(
774 &self,
775 input: &Array2<f64>,
776 scale: &Array1<f64>,
777 shift: &Array1<f64>,
778 ) -> Array2<f64> {
779 let mut result = input.clone();
780 let eps = 1e-5;
781
782 for ch in 0..input.shape()[1] {
783 let channel_data = input.column(ch);
784 let mean = channel_data.mean();
785 let var = channel_data.mapv(|x| (x - mean).powi(2)).mean();
786 let std = (var + eps).sqrt();
787
788 let mut normalized = channel_data.mapv(|x| (x - mean) / std);
789 normalized = normalized * scale[ch] + shift[ch];
790
791 result.column_mut(ch).assign(&normalized);
792 }
793
794 result
795 }
796}
797
798pub struct AdditiveAttention {
800 w_a: Array2<f64>,
802 #[allow(dead_code)]
804 w_q: Array2<f64>,
805 #[allow(dead_code)]
807 w_k: Array2<f64>,
808 #[allow(dead_code)]
810 w_v: Array2<f64>,
811 v_a: Array1<f64>,
813}
814
815impl AdditiveAttention {
816 pub fn new(_encoder_dim: usize, decoder_dim: usize, attentiondim: usize) -> Self {
818 let scale = (2.0 / attentiondim as f64).sqrt();
819
820 let w_a = Array2::from_shape_fn((attentiondim, _encoder_dim + decoder_dim), |_| {
821 scirs2_core::random::rng().random_range(-scale..scale)
822 });
823
824 let w_q = Array2::from_shape_fn((attentiondim, decoder_dim), |_| {
825 scirs2_core::random::rng().random_range(-scale..scale)
826 });
827
828 let w_k = Array2::from_shape_fn((attentiondim, _encoder_dim), |_| {
829 scirs2_core::random::rng().random_range(-scale..scale)
830 });
831
832 let w_v = Array2::from_shape_fn((_encoder_dim, _encoder_dim), |_| {
833 scirs2_core::random::rng().random_range(-scale..scale)
834 });
835
836 let v_a = Array1::from_shape_fn(attentiondim, |_| {
837 scirs2_core::random::rng().random_range(-scale..scale)
838 });
839
840 Self {
841 w_a,
842 w_q,
843 w_k,
844 w_v,
845 v_a,
846 }
847 }
848
849 pub fn forward(
851 &self,
852 query: ArrayView1<f64>,
853 encoder_outputs: ArrayView2<f64>,
854 ) -> Result<(Array1<f64>, Array1<f64>)> {
855 let seq_len = encoder_outputs.shape()[0];
856 let mut attention_scores = Array1::zeros(seq_len);
857
858 for i in 0..seq_len {
860 let encoder_output = encoder_outputs.row(i);
861
862 let mut combined = Array1::zeros(query.len() + encoder_output.len());
864 combined.slice_mut(s![..query.len()]).assign(&query);
865 combined
866 .slice_mut(s![query.len()..])
867 .assign(&encoder_output);
868
869 let attention_input = self.w_a.dot(&combined);
871 let activated = ActivationFunction::Tanh.apply_array(&attention_input);
872 attention_scores[i] = self.v_a.dot(&activated);
873 }
874
875 let attention_weights = self.softmax(&attention_scores);
877
878 let context = encoder_outputs.t().dot(&attention_weights);
880
881 Ok((context, attention_weights))
882 }
883
884 fn softmax(&self, scores: &Array1<f64>) -> Array1<f64> {
886 let max_score = scores.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
887 let exp_scores = scores.mapv(|x| (x - max_score).exp());
888 let sum_exp = exp_scores.sum();
889 exp_scores / sum_exp
890 }
891}
892
893#[derive(Debug)]
895pub struct SelfAttention {
896 w_q: Array2<f64>,
898 w_k: Array2<f64>,
900 w_v: Array2<f64>,
902 w_o: Array2<f64>,
904 d_k: usize,
906 #[allow(dead_code)]
908 dropout: f64,
909}
910
911impl SelfAttention {
912 pub fn new(_dmodel: usize, dropout: f64) -> Self {
914 let d_k = _dmodel;
915 let scale = (2.0 / _dmodel as f64).sqrt();
916
917 let w_q = Array2::from_shape_fn((_dmodel, d_k), |_| {
918 scirs2_core::random::rng().random_range(-scale..scale)
919 });
920 let w_k = Array2::from_shape_fn((_dmodel, d_k), |_| {
921 scirs2_core::random::rng().random_range(-scale..scale)
922 });
923 let w_v = Array2::from_shape_fn((_dmodel, d_k), |_| {
924 scirs2_core::random::rng().random_range(-scale..scale)
925 });
926 let w_o = Array2::from_shape_fn((d_k, _dmodel), |_| {
927 scirs2_core::random::rng().random_range(-scale..scale)
928 });
929
930 Self {
931 w_q,
932 w_k,
933 w_v,
934 w_o,
935 d_k,
936 dropout,
937 }
938 }
939
940 pub fn forward(
942 &self,
943 input: ArrayView2<f64>,
944 mask: Option<ArrayView2<bool>>,
945 ) -> Result<Array2<f64>> {
946 let _seq_len = input.shape()[0];
947
948 let q = input.dot(&self.w_q);
950 let k = input.dot(&self.w_k);
951 let v = input.dot(&self.w_v);
952
953 let attention_output =
955 self.scaled_dot_product_attention(q.view(), k.view(), v.view(), mask)?;
956
957 Ok(attention_output.dot(&self.w_o))
959 }
960
961 fn scaled_dot_product_attention(
963 &self,
964 q: ArrayView2<f64>,
965 k: ArrayView2<f64>,
966 v: ArrayView2<f64>,
967 mask: Option<ArrayView2<bool>>,
968 ) -> Result<Array2<f64>> {
969 let d_k = self.d_k as f64;
970
971 let scores = q.dot(&k.t()) / d_k.sqrt();
973
974 let mut masked_scores = scores;
976 if let Some(mask) = mask {
977 for ((i, j), &should_mask) in mask.indexed_iter() {
978 if should_mask {
979 masked_scores[[i, j]] = f64::NEG_INFINITY;
980 }
981 }
982 }
983
984 let attention_weights = self.softmax_2d(&masked_scores)?;
986
987 Ok(attention_weights.dot(&v))
989 }
990
991 fn softmax_2d(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
993 let mut result = x.clone();
994
995 for mut row in result.rows_mut() {
996 let max_val = row.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
997 row.mapv_inplace(|x| (x - max_val).exp());
998 let sum: f64 = row.sum();
999 if sum > 0.0 {
1000 row /= sum;
1001 }
1002 }
1003
1004 Ok(result)
1005 }
1006}
1007
1008#[derive(Debug)]
1010pub struct CrossAttention {
1011 w_q: Array2<f64>,
1013 w_k: Array2<f64>,
1015 w_v: Array2<f64>,
1017 w_o: Array2<f64>,
1019 d_k: usize,
1021}
1022
1023impl CrossAttention {
1024 pub fn new(_dmodel: usize) -> Self {
1026 let d_k = _dmodel;
1027 let scale = (2.0 / _dmodel as f64).sqrt();
1028
1029 let w_q = Array2::from_shape_fn((_dmodel, d_k), |_| {
1030 scirs2_core::random::rng().random_range(-scale..scale)
1031 });
1032 let w_k = Array2::from_shape_fn((_dmodel, d_k), |_| {
1033 scirs2_core::random::rng().random_range(-scale..scale)
1034 });
1035 let w_v = Array2::from_shape_fn((_dmodel, d_k), |_| {
1036 scirs2_core::random::rng().random_range(-scale..scale)
1037 });
1038 let w_o = Array2::from_shape_fn((d_k, _dmodel), |_| {
1039 scirs2_core::random::rng().random_range(-scale..scale)
1040 });
1041
1042 Self {
1043 w_q,
1044 w_k,
1045 w_v,
1046 w_o,
1047 d_k,
1048 }
1049 }
1050
1051 pub fn forward(
1053 &self,
1054 query: ArrayView2<f64>,
1055 key: ArrayView2<f64>,
1056 value: ArrayView2<f64>,
1057 mask: Option<ArrayView2<bool>>,
1058 ) -> Result<Array2<f64>> {
1059 let q = query.dot(&self.w_q);
1061 let k = key.dot(&self.w_k);
1062 let v = value.dot(&self.w_v);
1063
1064 let attention_output =
1066 self.scaled_dot_product_attention(q.view(), k.view(), v.view(), mask)?;
1067
1068 Ok(attention_output.dot(&self.w_o))
1070 }
1071
1072 fn scaled_dot_product_attention(
1074 &self,
1075 q: ArrayView2<f64>,
1076 k: ArrayView2<f64>,
1077 v: ArrayView2<f64>,
1078 mask: Option<ArrayView2<bool>>,
1079 ) -> Result<Array2<f64>> {
1080 let d_k = self.d_k as f64;
1081
1082 let scores = q.dot(&k.t()) / d_k.sqrt();
1084
1085 let mut masked_scores = scores;
1087 if let Some(mask) = mask {
1088 for ((i, j), &should_mask) in mask.indexed_iter() {
1089 if should_mask {
1090 masked_scores[[i, j]] = f64::NEG_INFINITY;
1091 }
1092 }
1093 }
1094
1095 let attention_weights = self.softmax_2d(&masked_scores)?;
1097
1098 Ok(attention_weights.dot(&v))
1100 }
1101
1102 fn softmax_2d(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
1104 let mut result = x.clone();
1105
1106 for mut row in result.rows_mut() {
1107 let max_val = row.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
1108 row.mapv_inplace(|x| (x - max_val).exp());
1109 let sum: f64 = row.sum();
1110 if sum > 0.0 {
1111 row /= sum;
1112 }
1113 }
1114
1115 Ok(result)
1116 }
1117}
1118
1119#[derive(Debug)]
1121pub struct PositionwiseFeedForward {
1122 w1: Array2<f64>,
1124 w2: Array2<f64>,
1126 b1: Array1<f64>,
1128 b2: Array1<f64>,
1129 dropout: f64,
1131}
1132
1133impl PositionwiseFeedForward {
1134 pub fn new(_dmodel: usize, dff: usize, dropout: f64) -> Self {
1136 let scale1 = (2.0 / _dmodel as f64).sqrt();
1137 let scale2 = (2.0 / dff as f64).sqrt();
1138
1139 let w1 = Array2::from_shape_fn((dff, _dmodel), |_| {
1140 scirs2_core::random::rng().random_range(-scale1..scale1)
1141 });
1142 let w2 = Array2::from_shape_fn((_dmodel, dff), |_| {
1143 scirs2_core::random::rng().random_range(-scale2..scale2)
1144 });
1145 let b1 = Array1::zeros(dff);
1146 let b2 = Array1::zeros(_dmodel);
1147
1148 Self {
1149 w1,
1150 w2,
1151 b1,
1152 b2,
1153 dropout,
1154 }
1155 }
1156
1157 pub fn forward(&self, x: ArrayView2<f64>) -> Array2<f64> {
1159 let hidden = x.dot(&self.w1.t()) + &self.b1;
1161 let activated = hidden.mapv(|x| ActivationFunction::GELU.apply(x));
1162
1163 let dropout_mask = if self.dropout > 0.0 {
1165 1.0 - self.dropout
1166 } else {
1167 1.0
1168 };
1169 let dropped = activated * dropout_mask;
1170
1171 dropped.dot(&self.w2.t()) + &self.b2
1173 }
1174}
1175
1176pub struct TextCNN {
1178 conv_layers: Vec<Conv1D>,
1180 pool_layers: Vec<MaxPool1D>,
1182 fcweights: Array2<f64>,
1184 fc_bias: Array1<f64>,
1186 dropout_rate: f64,
1188}
1189
1190impl TextCNN {
1191 #[allow(clippy::too_many_arguments)]
1193 pub fn new(
1194 _vocab_size: usize,
1195 embedding_dim: usize,
1196 num_filters: usize,
1197 filter_sizes: Vec<usize>,
1198 num_classes: usize,
1199 dropout_rate: f64,
1200 ) -> Self {
1201 let mut conv_layers = Vec::new();
1202 let mut pool_layers = Vec::new();
1203
1204 for &filter_size in &filter_sizes {
1206 conv_layers.push(Conv1D::new(
1207 embedding_dim,
1208 num_filters,
1209 filter_size,
1210 ActivationFunction::ReLU,
1211 ));
1212 pool_layers.push(MaxPool1D::new(2, 2));
1213 }
1214
1215 let fc_input_size = num_filters * filter_sizes.len();
1217 let scale = (2.0 / fc_input_size as f64).sqrt();
1218
1219 let fc_weights = Array2::from_shape_fn((num_classes, fc_input_size), |_| {
1220 scirs2_core::random::rng().random_range(-scale..scale)
1221 });
1222 let fc_bias = Array1::zeros(num_classes);
1223
1224 Self {
1225 conv_layers,
1226 pool_layers,
1227 fcweights: fc_weights,
1228 fc_bias,
1229 dropout_rate,
1230 }
1231 }
1232
1233 pub fn forward(&self, embeddings: ArrayView2<f64>) -> Result<Array1<f64>> {
1235 let mut feature_maps = Vec::new();
1236
1237 for (conv_layer, pool_layer) in self.conv_layers.iter().zip(&self.pool_layers) {
1239 let conv_output = conv_layer.forward(embeddings)?;
1240 let pooled_output = pool_layer.forward(conv_output.view());
1241
1242 let global_max = pooled_output.map_axis(Axis(0), |row| {
1244 row.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b))
1245 });
1246
1247 feature_maps.push(global_max);
1248 }
1249
1250 let mut concatenated_features =
1252 Array1::zeros(feature_maps.iter().map(|fm| fm.len()).sum::<usize>());
1253 let mut offset = 0;
1254 for feature_map in feature_maps {
1255 let end = offset + feature_map.len();
1256 concatenated_features
1257 .slice_mut(s![offset..end])
1258 .assign(&feature_map);
1259 offset = end;
1260 }
1261
1262 let dropout_mask = if self.dropout_rate > 0.0 {
1264 1.0 - self.dropout_rate
1265 } else {
1266 1.0
1267 };
1268 concatenated_features *= dropout_mask;
1269
1270 let output = self.fcweights.dot(&concatenated_features) + &self.fc_bias;
1272
1273 Ok(output)
1274 }
1275}
1276
1277pub struct CNNLSTMHybrid {
1279 cnn: TextCNN,
1281 lstm: BiLSTM,
1283 classifier: Array2<f64>,
1285 classifier_bias: Array1<f64>,
1287}
1288
1289impl CNNLSTMHybrid {
1290 #[allow(clippy::too_many_arguments)]
1292 pub fn new(
1293 embedding_dim: usize,
1294 cnn_filters: usize,
1295 filter_sizes: Vec<usize>,
1296 lstm_hidden_size: usize,
1297 lstm_layers: usize,
1298 num_classes: usize,
1299 ) -> Self {
1300 let cnn = TextCNN::new(
1302 0, embedding_dim,
1304 cnn_filters,
1305 filter_sizes.clone(),
1306 cnn_filters * filter_sizes.len(),
1307 0.0, );
1309
1310 let lstm_input_size = cnn_filters * filter_sizes.len();
1312 let lstm = BiLSTM::new(lstm_input_size, lstm_hidden_size, lstm_layers);
1313
1314 let classifier_input_size = lstm_hidden_size * 2; let scale = (2.0 / classifier_input_size as f64).sqrt();
1317
1318 let classifier = Array2::from_shape_fn((num_classes, classifier_input_size), |_| {
1319 scirs2_core::random::rng().random_range(-scale..scale)
1320 });
1321 let classifier_bias = Array1::zeros(num_classes);
1322
1323 Self {
1324 cnn,
1325 lstm,
1326 classifier,
1327 classifier_bias,
1328 }
1329 }
1330
1331 pub fn forward(&self, embeddings: ArrayView2<f64>) -> Result<Array1<f64>> {
1333 let cnn_features = self.cnn.forward(embeddings)?;
1335
1336 let lstm_input = Array2::from_shape_vec((1, cnn_features.len()), cnn_features.to_vec())
1338 .map_err(|e| TextError::InvalidInput(format!("Reshape error: {e}")))?;
1339
1340 let lstm_output = self.lstm.forward(lstm_input.view())?;
1342
1343 let final_hidden = lstm_output.row(lstm_output.shape()[0] - 1);
1345
1346 let output = self.classifier.dot(&final_hidden) + &self.classifier_bias;
1348
1349 Ok(output)
1350 }
1351}
1352
1353pub struct LayerNorm {
1355 weight: Array1<f64>,
1357 bias: Array1<f64>,
1359 eps: f64,
1361}
1362
1363impl LayerNorm {
1364 pub fn new(normalizedshape: usize) -> Self {
1366 Self {
1367 weight: Array1::ones(normalizedshape),
1368 bias: Array1::zeros(normalizedshape),
1369 eps: 1e-6,
1370 }
1371 }
1372
1373 pub fn forward(&self, x: ArrayView2<f64>) -> Result<Array2<f64>> {
1375 let mut output = Array2::zeros(x.raw_dim());
1376
1377 for (i, row) in x.outer_iter().enumerate() {
1379 let mean = row.mean();
1380 let variance = row.mapv(|v| (v - mean).powi(2)).mean();
1381 let std = (variance + self.eps).sqrt();
1382
1383 for (j, &val) in row.iter().enumerate() {
1385 let normalized = (val - mean) / std;
1386 output[[i, j]] = normalized * self.weight[j] + self.bias[j];
1387 }
1388 }
1389
1390 Ok(output)
1391 }
1392}
1393
1394pub struct Dropout {
1396 p: f64,
1398 training: bool,
1400}
1401
1402impl Dropout {
1403 pub fn new(p: f64) -> Self {
1405 Self {
1406 p: p.clamp(0.0, 1.0),
1407 training: true,
1408 }
1409 }
1410
1411 pub fn set_training(&mut self, training: bool) {
1413 self.training = training;
1414 }
1415
1416 pub fn forward(&self, x: ArrayView2<f64>) -> Array2<f64> {
1418 if !self.training || self.p == 0.0 {
1419 return x.to_owned();
1420 }
1421
1422 let mut output = x.to_owned();
1423 let scale = 1.0 / (1.0 - self.p);
1424
1425 for elem in output.iter_mut() {
1426 if scirs2_core::random::rng().random_range(0.0..1.0) < self.p {
1427 *elem = 0.0; } else {
1429 *elem *= scale; }
1431 }
1432
1433 output
1434 }
1435}
1436
1437pub struct MultiHeadAttention {
1439 num_heads: usize,
1441 d_model: usize,
1443 d_k: usize,
1445 w_q: Array2<f64>,
1447 w_k: Array2<f64>,
1449 w_v: Array2<f64>,
1451 w_o: Array2<f64>,
1453 dropout: Dropout,
1455}
1456
1457impl MultiHeadAttention {
1458 pub fn new(_dmodel: usize, num_heads: usize, dropoutp: f64) -> Result<Self> {
1460 if !_dmodel.is_multiple_of(num_heads) {
1461 return Err(TextError::InvalidInput(
1462 "Model dimension must be divisible by number of _heads".to_string(),
1463 ));
1464 }
1465
1466 let d_k = _dmodel / num_heads;
1467 let scale = (2.0 / _dmodel as f64).sqrt();
1468
1469 let w_q = Array2::from_shape_fn((_dmodel, _dmodel), |_| {
1470 scirs2_core::random::rng().random_range(-scale..scale)
1471 });
1472 let w_k = Array2::from_shape_fn((_dmodel, _dmodel), |_| {
1473 scirs2_core::random::rng().random_range(-scale..scale)
1474 });
1475 let w_v = Array2::from_shape_fn((_dmodel, _dmodel), |_| {
1476 scirs2_core::random::rng().random_range(-scale..scale)
1477 });
1478 let w_o = Array2::from_shape_fn((_dmodel, _dmodel), |_| {
1479 scirs2_core::random::rng().random_range(-scale..scale)
1480 });
1481
1482 Ok(Self {
1483 num_heads,
1484 d_model: _dmodel,
1485 d_k,
1486 w_q,
1487 w_k,
1488 w_v,
1489 w_o,
1490 dropout: Dropout::new(dropoutp),
1491 })
1492 }
1493
1494 pub fn forward(
1496 &self,
1497 query: ArrayView2<f64>,
1498 key: ArrayView2<f64>,
1499 value: ArrayView2<f64>,
1500 mask: Option<ArrayView2<bool>>,
1501 ) -> Result<Array2<f64>> {
1502 let seq_len = query.shape()[0];
1503 let _batch_size = 1; let q = query.dot(&self.w_q);
1507 let k = key.dot(&self.w_k);
1508 let v = value.dot(&self.w_v);
1509
1510 let mut q_heads = Array3::zeros((seq_len, self.num_heads, self.d_k));
1512 let mut k_heads = Array3::zeros((seq_len, self.num_heads, self.d_k));
1513 let mut v_heads = Array3::zeros((seq_len, self.num_heads, self.d_k));
1514
1515 for i in 0..seq_len {
1516 for h in 0..self.num_heads {
1517 let start = h * self.d_k;
1518 let _end = start + self.d_k;
1519
1520 for j in 0..self.d_k {
1521 q_heads[[i, h, j]] = q[[i, start + j]];
1522 k_heads[[i, h, j]] = k[[i, start + j]];
1523 v_heads[[i, h, j]] = v[[i, start + j]];
1524 }
1525 }
1526 }
1527
1528 let mut attention_outputs = Array3::zeros((seq_len, self.num_heads, self.d_k));
1530
1531 for h in 0..self.num_heads {
1532 let q_h = q_heads.slice(s![.., h, ..]);
1533 let k_h = k_heads.slice(s![.., h, ..]);
1534 let v_h = v_heads.slice(s![.., h, ..]);
1535
1536 let scores = q_h.dot(&k_h.t()) / (self.d_k as f64).sqrt();
1538
1539 let mut masked_scores = scores;
1541 if let Some(mask) = mask {
1542 for i in 0..seq_len {
1543 for j in 0..seq_len {
1544 if mask[[i, j]] {
1545 masked_scores[[i, j]] = f64::NEG_INFINITY;
1546 }
1547 }
1548 }
1549 }
1550
1551 let mut attention_weights = Array2::zeros((seq_len, seq_len));
1553 for i in 0..seq_len {
1554 let row = masked_scores.row(i);
1555 let max_val = row.fold(f64::NEG_INFINITY, |acc, &x| acc.max(x));
1556 let exp_sum: f64 = row.iter().map(|&x| (x - max_val).exp()).sum();
1557
1558 for j in 0..seq_len {
1559 attention_weights[[i, j]] = (masked_scores[[i, j]] - max_val).exp() / exp_sum;
1560 }
1561 }
1562
1563 let attention_weights_dropped = self.dropout.forward(attention_weights.view());
1565
1566 let attended = attention_weights_dropped.dot(&v_h);
1568
1569 for i in 0..seq_len {
1571 for j in 0..self.d_k {
1572 attention_outputs[[i, h, j]] = attended[[i, j]];
1573 }
1574 }
1575 }
1576
1577 let mut concatenated = Array2::zeros((seq_len, self.d_model));
1579 for i in 0..seq_len {
1580 for h in 0..self.num_heads {
1581 let start = h * self.d_k;
1582 for j in 0..self.d_k {
1583 concatenated[[i, start + j]] = attention_outputs[[i, h, j]];
1584 }
1585 }
1586 }
1587
1588 Ok(concatenated.dot(&self.w_o))
1590 }
1591
1592 pub fn set_training(&mut self, training: bool) {
1594 self.dropout.set_training(training);
1595 }
1596}
1597
1598#[cfg(test)]
1599mod tests {
1600 use super::*;
1601
1602 #[test]
1603 fn test_activation_functions() {
1604 let x = 0.5;
1605
1606 assert!(ActivationFunction::Sigmoid.apply(x) > 0.0);
1608 assert!(ActivationFunction::Sigmoid.apply(x) < 1.0);
1609 assert!(ActivationFunction::Tanh.apply(x) > -1.0);
1610 assert!(ActivationFunction::Tanh.apply(x) < 1.0);
1611 assert_eq!(ActivationFunction::ReLU.apply(-1.0), 0.0);
1612 assert_eq!(ActivationFunction::ReLU.apply(1.0), 1.0);
1613 }
1614
1615 #[test]
1616 fn test_lstm_cell() {
1617 let lstm = LSTMCell::new(10, 20);
1618 let input = Array1::ones(10);
1619 let h_prev = Array1::zeros(20);
1620 let c_prev = Array1::zeros(20);
1621
1622 let (h_new, c_new) = lstm
1623 .forward(input.view(), h_prev.view(), c_prev.view())
1624 .unwrap();
1625
1626 assert_eq!(h_new.len(), 20);
1627 assert_eq!(c_new.len(), 20);
1628 }
1629
1630 #[test]
1631 fn test_conv1d() {
1632 let conv = Conv1D::new(5, 10, 3, ActivationFunction::ReLU);
1633 let input = Array2::ones((8, 5)); let output = conv.forward(input.view()).unwrap();
1636 assert_eq!(output.shape(), &[6, 10]); }
1638
1639 #[test]
1640 fn test_bilstm() {
1641 let bilstm = BiLSTM::new(10, 20, 2);
1642 let input = Array2::ones((5, 10)); let output = bilstm.forward(input.view()).unwrap();
1645 assert_eq!(output.shape(), &[5, 40]); }
1647
1648 #[test]
1649 fn test_gru_cell() {
1650 let gru = GRUCell::new(10, 20);
1651 let input = Array1::ones(10);
1652 let h_prev = Array1::zeros(20);
1653
1654 let h_new = gru.forward(input.view(), h_prev.view()).unwrap();
1655
1656 assert_eq!(h_new.len(), 20);
1657 assert!(h_new.iter().any(|&x| x != 0.0));
1659 }
1660
1661 #[test]
1662 fn test_self_attention() {
1663 let attention = SelfAttention::new(8, 0.1);
1664 let input = Array2::ones((4, 8)); let output = attention.forward(input.view(), None).unwrap();
1667 assert_eq!(output.shape(), &[4, 8]);
1668 }
1669
1670 #[test]
1671 fn test_cross_attention() {
1672 let attention = CrossAttention::new(8);
1673 let query = Array2::ones((3, 8));
1674 let key = Array2::ones((5, 8));
1675 let value = Array2::ones((5, 8));
1676
1677 let output = attention
1678 .forward(query.view(), key.view(), value.view(), None)
1679 .unwrap();
1680 assert_eq!(output.shape(), &[3, 8]);
1681 }
1682
1683 #[test]
1684 fn test_residual_block() {
1685 let block = ResidualBlock1D::new(4, 8, 3);
1686 let input = Array2::ones((10, 4)); let output = block.forward(input.view()).unwrap();
1689 assert_eq!(output.shape(), &[6, 8]); }
1692
1693 #[test]
1694 fn test_multi_scale_cnn() {
1695 let cnn = MultiScaleCNN::new(
1696 5, 10, vec![2, 3, 4], 30, );
1701 let input = Array2::ones((8, 5)); let output = cnn.forward(input.view()).unwrap();
1704 assert_eq!(output.len(), 30);
1705 }
1706
1707 #[test]
1708 fn test_positionwise_feedforward() {
1709 let ff = PositionwiseFeedForward::new(8, 16, 0.1);
1710 let input = Array2::ones((4, 8)); let output = ff.forward(input.view());
1713 assert_eq!(output.shape(), &[4, 8]);
1714 }
1715}