use crate::error::{Result, TextError};
use scirs2_core::ndarray::{s, Array1, Array2, Array3, ArrayView1, ArrayView2, Axis};
use scirs2_core::random::{self, Rng, RngExt};
use statrs::statistics::Statistics;
#[derive(Debug, Clone, Copy)]
pub enum ActivationFunction {
Tanh,
Sigmoid,
ReLU,
GELU,
Swish,
Linear,
}
impl ActivationFunction {
pub fn apply(&self, x: f64) -> f64 {
match self {
ActivationFunction::Tanh => x.tanh(),
ActivationFunction::Sigmoid => 1.0 / (1.0 + (-x).exp()),
ActivationFunction::ReLU => x.max(0.0),
ActivationFunction::GELU => {
0.5 * x * (1.0 + (x * 0.7978845608 * (1.0 + 0.044715 * x * x)).tanh())
}
ActivationFunction::Swish => x / (1.0 + (-x).exp()),
ActivationFunction::Linear => x,
}
}
pub fn apply_array(&self, x: &Array1<f64>) -> Array1<f64> {
x.mapv(|val| self.apply(val))
}
pub fn derivative(&self, x: f64) -> f64 {
match self {
ActivationFunction::Tanh => {
let tanh_x = x.tanh();
1.0 - tanh_x * tanh_x
}
ActivationFunction::Sigmoid => {
let sig_x = self.apply(x);
sig_x * (1.0 - sig_x)
}
ActivationFunction::ReLU => {
if x > 0.0 {
1.0
} else {
0.0
}
}
ActivationFunction::GELU => {
let cdf = 0.5 * (1.0 + (x * 0.7978845608).tanh());
let pdf = 0.7978845608 * (-0.5 * x * x).exp();
cdf + x * pdf
}
ActivationFunction::Swish => {
let sig_x = 1.0 / (1.0 + (-x).exp());
sig_x + x * sig_x * (1.0 - sig_x)
}
ActivationFunction::Linear => 1.0,
}
}
}
#[derive(Debug, Clone)]
pub struct LSTMCell {
w_i: Array2<f64>,
w_f: Array2<f64>,
w_o: Array2<f64>,
w_c: Array2<f64>,
u_i: Array2<f64>,
u_f: Array2<f64>,
u_o: Array2<f64>,
u_c: Array2<f64>,
b_i: Array1<f64>,
b_f: Array1<f64>,
b_o: Array1<f64>,
b_c: Array1<f64>,
input_size: usize,
hidden_size: usize,
}
impl LSTMCell {
pub fn new(_input_size: usize, hiddensize: usize) -> Self {
let scale = (2.0 / (_input_size + hiddensize) as f64).sqrt();
let w_i = Array2::from_shape_fn((hiddensize, _input_size), |_| {
scirs2_core::random::rng().random_range(-scale..scale)
});
let w_f = Array2::from_shape_fn((hiddensize, _input_size), |_| {
scirs2_core::random::rng().random_range(-scale..scale)
});
let w_o = Array2::from_shape_fn((hiddensize, _input_size), |_| {
scirs2_core::random::rng().random_range(-scale..scale)
});
let w_c = Array2::from_shape_fn((hiddensize, _input_size), |_| {
scirs2_core::random::rng().random_range(-scale..scale)
});
let u_i = Array2::from_shape_fn((hiddensize, hiddensize), |_| {
scirs2_core::random::rng().random_range(-scale..scale)
});
let u_f = Array2::from_shape_fn((hiddensize, hiddensize), |_| {
scirs2_core::random::rng().random_range(-scale..scale)
});
let u_o = Array2::from_shape_fn((hiddensize, hiddensize), |_| {
scirs2_core::random::rng().random_range(-scale..scale)
});
let u_c = Array2::from_shape_fn((hiddensize, hiddensize), |_| {
scirs2_core::random::rng().random_range(-scale..scale)
});
let b_i = Array1::zeros(hiddensize);
let b_f = Array1::ones(hiddensize);
let b_o = Array1::zeros(hiddensize);
let b_c = Array1::zeros(hiddensize);
Self {
w_i,
w_f,
w_o,
w_c,
u_i,
u_f,
u_o,
u_c,
b_i,
b_f,
b_o,
b_c,
input_size: _input_size,
hidden_size: hiddensize,
}
}
pub fn forward(
&self,
x: ArrayView1<f64>,
h_prev: ArrayView1<f64>,
c_prev: ArrayView1<f64>,
) -> Result<(Array1<f64>, Array1<f64>)> {
if x.len() != self.input_size {
return Err(TextError::InvalidInput(format!(
"Expected input size {}, got {}",
self.input_size,
x.len()
)));
}
if h_prev.len() != self.hidden_size || c_prev.len() != self.hidden_size {
return Err(TextError::InvalidInput(format!(
"Expected hidden size {}, got h: {}, c: {}",
self.hidden_size,
h_prev.len(),
c_prev.len()
)));
}
let i_t = ActivationFunction::Sigmoid
.apply_array(&(self.w_i.dot(&x) + self.u_i.dot(&h_prev) + &self.b_i));
let f_t = ActivationFunction::Sigmoid
.apply_array(&(self.w_f.dot(&x) + self.u_f.dot(&h_prev) + &self.b_f));
let o_t = ActivationFunction::Sigmoid
.apply_array(&(self.w_o.dot(&x) + self.u_o.dot(&h_prev) + &self.b_o));
let c_tilde = ActivationFunction::Tanh
.apply_array(&(self.w_c.dot(&x) + self.u_c.dot(&h_prev) + &self.b_c));
let c_t = &f_t * &c_prev + &i_t * &c_tilde;
let h_t = &o_t * &ActivationFunction::Tanh.apply_array(&c_t);
Ok((h_t, c_t))
}
}
#[derive(Debug, Clone)]
pub struct GRUCell {
w_z: Array2<f64>,
w_r: Array2<f64>,
w_h: Array2<f64>,
u_z: Array2<f64>,
u_r: Array2<f64>,
u_h: Array2<f64>,
b_z: Array1<f64>,
b_r: Array1<f64>,
b_h: Array1<f64>,
input_size: usize,
hidden_size: usize,
}
impl GRUCell {
pub fn new(_input_size: usize, hiddensize: usize) -> Self {
let scale = (2.0 / (_input_size + hiddensize) as f64).sqrt();
let w_z = Array2::from_shape_fn((hiddensize, _input_size), |_| {
scirs2_core::random::rng().random_range(-scale..scale)
});
let w_r = Array2::from_shape_fn((hiddensize, _input_size), |_| {
scirs2_core::random::rng().random_range(-scale..scale)
});
let w_h = Array2::from_shape_fn((hiddensize, _input_size), |_| {
scirs2_core::random::rng().random_range(-scale..scale)
});
let u_z = Array2::from_shape_fn((hiddensize, hiddensize), |_| {
scirs2_core::random::rng().random_range(-scale..scale)
});
let u_r = Array2::from_shape_fn((hiddensize, hiddensize), |_| {
scirs2_core::random::rng().random_range(-scale..scale)
});
let u_h = Array2::from_shape_fn((hiddensize, hiddensize), |_| {
scirs2_core::random::rng().random_range(-scale..scale)
});
let b_z = Array1::zeros(hiddensize);
let b_r = Array1::zeros(hiddensize);
let b_h = Array1::zeros(hiddensize);
Self {
w_z,
w_r,
w_h,
u_z,
u_r,
u_h,
b_z,
b_r,
b_h,
input_size: _input_size,
hidden_size: hiddensize,
}
}
pub fn forward(&self, x: ArrayView1<f64>, hprev: ArrayView1<f64>) -> Result<Array1<f64>> {
if x.len() != self.input_size {
return Err(TextError::InvalidInput(format!(
"Expected input size {}, got {}",
self.input_size,
x.len()
)));
}
if hprev.len() != self.hidden_size {
return Err(TextError::InvalidInput(format!(
"Expected hidden size {}, got {}",
self.hidden_size,
hprev.len()
)));
}
let z_t = ActivationFunction::Sigmoid
.apply_array(&(self.w_z.dot(&x) + self.u_z.dot(&hprev) + &self.b_z));
let r_t = ActivationFunction::Sigmoid
.apply_array(&(self.w_r.dot(&x) + self.u_r.dot(&hprev) + &self.b_r));
let h_tilde = ActivationFunction::Tanh
.apply_array(&(self.w_h.dot(&x) + self.u_h.dot(&(&r_t * &hprev)) + &self.b_h));
let h_t = &(&Array1::ones(self.hidden_size) - &z_t) * &hprev + &z_t * &h_tilde;
Ok(h_t)
}
}
pub struct BiLSTM {
forward_cells: Vec<LSTMCell>,
backward_cells: Vec<LSTMCell>,
num_layers: usize,
hidden_size: usize,
}
impl BiLSTM {
pub fn new(_input_size: usize, hidden_size: usize, numlayers: usize) -> Self {
let mut forward_cells = Vec::new();
let mut backward_cells = Vec::new();
for i in 0..numlayers {
let layer_input_size = if i == 0 { _input_size } else { hidden_size * 2 };
forward_cells.push(LSTMCell::new(layer_input_size, hidden_size));
backward_cells.push(LSTMCell::new(layer_input_size, hidden_size));
}
Self {
forward_cells,
backward_cells,
num_layers: numlayers,
hidden_size,
}
}
pub fn forward(&self, sequence: ArrayView2<f64>) -> Result<Array2<f64>> {
let (seq_len, input_size) = sequence.dim();
let output_size = self.hidden_size * 2;
let mut current_input = sequence.to_owned();
for layer in 0..self.num_layers {
let mut forward_outputs = Vec::new();
let mut backward_outputs = Vec::new();
let mut h_forward = Array1::zeros(self.hidden_size);
let mut c_forward = Array1::zeros(self.hidden_size);
for t in 0..seq_len {
let (h_new, c_new) = self.forward_cells[layer].forward(
current_input.row(t),
h_forward.view(),
c_forward.view(),
)?;
h_forward = h_new;
c_forward = c_new;
forward_outputs.push(h_forward.clone());
}
let mut h_backward = Array1::zeros(self.hidden_size);
let mut c_backward = Array1::zeros(self.hidden_size);
for t in (0..seq_len).rev() {
let (h_new, c_new) = self.backward_cells[layer].forward(
current_input.row(t),
h_backward.view(),
c_backward.view(),
)?;
h_backward = h_new;
c_backward = c_new;
backward_outputs.push(h_backward.clone());
}
backward_outputs.reverse();
let mut layer_output = Array2::zeros((seq_len, output_size));
for t in 0..seq_len {
let mut concat_output = Array1::zeros(output_size);
concat_output
.slice_mut(s![..self.hidden_size])
.assign(&forward_outputs[t]);
concat_output
.slice_mut(s![self.hidden_size..])
.assign(&backward_outputs[t]);
layer_output.row_mut(t).assign(&concat_output);
}
current_input = layer_output;
}
Ok(current_input)
}
}
#[derive(Debug, Clone)]
pub struct Conv1D {
filters: Array3<f64>,
bias: Array1<f64>,
num_filters: usize,
kernel_size: usize,
input_channels: usize,
activation: ActivationFunction,
}
impl Conv1D {
pub fn new(
input_channels: usize,
num_filters: usize,
kernel_size: usize,
activation: ActivationFunction,
) -> Self {
let scale = (2.0 / (input_channels * kernel_size) as f64).sqrt();
let _filters = Array3::from_shape_fn((num_filters, input_channels, kernel_size), |_| {
scirs2_core::random::rng().random_range(-scale..scale)
});
let bias = Array1::zeros(num_filters);
Self {
filters: _filters,
bias,
num_filters,
kernel_size,
input_channels,
activation,
}
}
pub fn forward(&self, input: ArrayView2<f64>) -> Result<Array2<f64>> {
let (seq_len, input_dim) = input.dim();
if input_dim != self.input_channels {
return Err(TextError::InvalidInput(format!(
"Expected {} input channels, got {}",
self.input_channels, input_dim
)));
}
let output_len = seq_len.saturating_sub(self.kernel_size - 1);
let mut output = Array2::zeros((output_len, self.num_filters));
for filter_idx in 0..self.num_filters {
for pos in 0..output_len {
let mut conv_sum = 0.0;
for ch in 0..self.input_channels {
for k in 0..self.kernel_size {
if pos + k < seq_len {
conv_sum += input[[pos + k, ch]] * self.filters[[filter_idx, ch, k]];
}
}
}
conv_sum += self.bias[filter_idx];
output[[pos, filter_idx]] = self.activation.apply(conv_sum);
}
}
Ok(output)
}
}
#[derive(Debug)]
pub struct MaxPool1D {
pool_size: usize,
stride: usize,
}
impl MaxPool1D {
pub fn new(poolsize: usize, stride: usize) -> Self {
Self {
pool_size: poolsize,
stride,
}
}
pub fn forward(&self, input: ArrayView2<f64>) -> Array2<f64> {
let (seq_len, channels) = input.dim();
let output_len = (seq_len - self.pool_size) / self.stride + 1;
let mut output = Array2::zeros((output_len, channels));
for ch in 0..channels {
for i in 0..output_len {
let start = i * self.stride;
let end = (start + self.pool_size).min(seq_len);
let mut max_val = f64::NEG_INFINITY;
for j in start..end {
max_val = max_val.max(input[[j, ch]]);
}
output[[i, ch]] = max_val;
}
}
output
}
}
#[derive(Debug, Clone)]
pub struct ResidualBlock1D {
conv1: Conv1D,
conv2: Conv1D,
skip_projection: Option<Array2<f64>>,
bn1_scale: Array1<f64>,
bn1_shift: Array1<f64>,
bn2_scale: Array1<f64>,
bn2_shift: Array1<f64>,
}
impl ResidualBlock1D {
pub fn new(_input_channels: usize, output_channels: usize, kernelsize: usize) -> Self {
let conv1 = Conv1D::new(
_input_channels,
output_channels,
kernelsize,
ActivationFunction::Linear,
);
let conv2 = Conv1D::new(
output_channels,
output_channels,
kernelsize,
ActivationFunction::Linear,
);
let skip_projection = if _input_channels != output_channels {
let scale = (2.0 / _input_channels as f64).sqrt();
Some(Array2::from_shape_fn(
(output_channels, _input_channels),
|_| scirs2_core::random::rng().random_range(-scale..scale),
))
} else {
None
};
let bn1_scale = Array1::ones(output_channels);
let bn1_shift = Array1::zeros(output_channels);
let bn2_scale = Array1::ones(output_channels);
let bn2_shift = Array1::zeros(output_channels);
Self {
conv1,
conv2,
skip_projection,
bn1_scale,
bn1_shift,
bn2_scale,
bn2_shift,
}
}
pub fn forward(&self, input: ArrayView2<f64>) -> Result<Array2<f64>> {
let conv1_out = self.conv1.forward(input)?;
let bn1_out = self.batch_norm(&conv1_out, &self.bn1_scale, &self.bn1_shift);
let relu1_out = bn1_out.mapv(|x| ActivationFunction::ReLU.apply(x));
let conv2_out = self.conv2.forward(relu1_out.view())?;
let bn2_out = self.batch_norm(&conv2_out, &self.bn2_scale, &self.bn2_shift);
let skip_out = if let Some(ref projection) = self.skip_projection {
let projected = input.dot(&projection.t());
let conv_output_len = bn2_out.shape()[0];
let skip_len = projected.shape()[0];
if conv_output_len < skip_len {
let start = (skip_len - conv_output_len) / 2;
let end = start + conv_output_len;
projected.slice(s![start..end, ..]).to_owned()
} else {
projected
}
} else {
let conv_output_len = bn2_out.shape()[0];
let skip_len = input.shape()[0];
if conv_output_len < skip_len {
let start = (skip_len - conv_output_len) / 2;
let end = start + conv_output_len;
input.slice(s![start..end, ..]).to_owned()
} else {
input.to_owned()
}
};
let output = &bn2_out + &skip_out;
Ok(output.mapv(|x| ActivationFunction::ReLU.apply(x)))
}
fn batch_norm(
&self,
input: &Array2<f64>,
scale: &Array1<f64>,
shift: &Array1<f64>,
) -> Array2<f64> {
let mut result = input.clone();
let eps = 1e-5;
for ch in 0..input.shape()[1] {
let channel_data = input.column(ch);
let mean = channel_data.mean();
let var = channel_data.mapv(|x| (x - mean).powi(2)).mean();
let std = (var + eps).sqrt();
let mut normalized = channel_data.mapv(|x| (x - mean) / std);
normalized = normalized * scale[ch] + shift[ch];
result.column_mut(ch).assign(&normalized);
}
result
}
}
#[derive(Debug)]
pub struct MultiScaleCNN {
conv_branches: Vec<Conv1D>,
bn_branches: Vec<(Array1<f64>, Array1<f64>)>,
combinationweights: Array2<f64>,
#[allow(dead_code)]
global_pool: MaxPool1D,
}
impl MultiScaleCNN {
pub fn new(
input_channels: usize,
num_filters_per_scale: usize,
kernel_sizes: Vec<usize>,
output_size: usize,
) -> Self {
let mut conv_branches = Vec::new();
let mut bn_branches = Vec::new();
for &kernel_size in &kernel_sizes {
conv_branches.push(Conv1D::new(
input_channels,
num_filters_per_scale,
kernel_size,
ActivationFunction::ReLU,
));
bn_branches.push((
Array1::ones(num_filters_per_scale),
Array1::zeros(num_filters_per_scale),
));
}
let total_features = kernel_sizes.len() * num_filters_per_scale;
let _scale = (2.0 / total_features as f64).sqrt();
let combination_weights = Array2::from_shape_fn((output_size, total_features), |_| {
scirs2_core::random::rng().random_range(-_scale.._scale)
});
let global_pool = MaxPool1D::new(2, 2);
Self {
conv_branches,
bn_branches,
combinationweights: combination_weights,
global_pool,
}
}
pub fn forward(&self, input: ArrayView2<f64>) -> Result<Array1<f64>> {
let mut branch_outputs = Vec::new();
for (i, conv) in self.conv_branches.iter().enumerate() {
let conv_out = conv.forward(input)?;
let (scale, shift) = &self.bn_branches[i];
let bn_out = self.batch_norm_branch(&conv_out, scale, shift);
let global_max = bn_out.map_axis(Axis(0), |row| {
row.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b))
});
branch_outputs.push(global_max);
}
let mut concatenated = Array1::zeros(branch_outputs.iter().map(|x| x.len()).sum::<usize>());
let mut offset = 0;
for branch_output in branch_outputs {
let end = offset + branch_output.len();
concatenated
.slice_mut(s![offset..end])
.assign(&branch_output);
offset = end;
}
Ok(self.combinationweights.dot(&concatenated))
}
fn batch_norm_branch(
&self,
input: &Array2<f64>,
scale: &Array1<f64>,
shift: &Array1<f64>,
) -> Array2<f64> {
let mut result = input.clone();
let eps = 1e-5;
for ch in 0..input.shape()[1] {
let channel_data = input.column(ch);
let mean = channel_data.mean();
let var = channel_data.mapv(|x| (x - mean).powi(2)).mean();
let std = (var + eps).sqrt();
let mut normalized = channel_data.mapv(|x| (x - mean) / std);
normalized = normalized * scale[ch] + shift[ch];
result.column_mut(ch).assign(&normalized);
}
result
}
}
pub struct AdditiveAttention {
w_a: Array2<f64>,
#[allow(dead_code)]
w_q: Array2<f64>,
#[allow(dead_code)]
w_k: Array2<f64>,
#[allow(dead_code)]
w_v: Array2<f64>,
v_a: Array1<f64>,
}
impl AdditiveAttention {
pub fn new(_encoder_dim: usize, decoder_dim: usize, attentiondim: usize) -> Self {
let scale = (2.0 / attentiondim as f64).sqrt();
let w_a = Array2::from_shape_fn((attentiondim, _encoder_dim + decoder_dim), |_| {
scirs2_core::random::rng().random_range(-scale..scale)
});
let w_q = Array2::from_shape_fn((attentiondim, decoder_dim), |_| {
scirs2_core::random::rng().random_range(-scale..scale)
});
let w_k = Array2::from_shape_fn((attentiondim, _encoder_dim), |_| {
scirs2_core::random::rng().random_range(-scale..scale)
});
let w_v = Array2::from_shape_fn((_encoder_dim, _encoder_dim), |_| {
scirs2_core::random::rng().random_range(-scale..scale)
});
let v_a = Array1::from_shape_fn(attentiondim, |_| {
scirs2_core::random::rng().random_range(-scale..scale)
});
Self {
w_a,
w_q,
w_k,
w_v,
v_a,
}
}
pub fn forward(
&self,
query: ArrayView1<f64>,
encoder_outputs: ArrayView2<f64>,
) -> Result<(Array1<f64>, Array1<f64>)> {
let seq_len = encoder_outputs.shape()[0];
let mut attention_scores = Array1::zeros(seq_len);
for i in 0..seq_len {
let encoder_output = encoder_outputs.row(i);
let mut combined = Array1::zeros(query.len() + encoder_output.len());
combined.slice_mut(s![..query.len()]).assign(&query);
combined
.slice_mut(s![query.len()..])
.assign(&encoder_output);
let attention_input = self.w_a.dot(&combined);
let activated = ActivationFunction::Tanh.apply_array(&attention_input);
attention_scores[i] = self.v_a.dot(&activated);
}
let attention_weights = self.softmax(&attention_scores);
let context = encoder_outputs.t().dot(&attention_weights);
Ok((context, attention_weights))
}
fn softmax(&self, scores: &Array1<f64>) -> Array1<f64> {
let max_score = scores.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
let exp_scores = scores.mapv(|x| (x - max_score).exp());
let sum_exp = exp_scores.sum();
exp_scores / sum_exp
}
}
#[derive(Debug)]
pub struct SelfAttention {
w_q: Array2<f64>,
w_k: Array2<f64>,
w_v: Array2<f64>,
w_o: Array2<f64>,
d_k: usize,
#[allow(dead_code)]
dropout: f64,
}
impl SelfAttention {
pub fn new(_dmodel: usize, dropout: f64) -> Self {
let d_k = _dmodel;
let scale = (2.0 / _dmodel as f64).sqrt();
let w_q = Array2::from_shape_fn((_dmodel, d_k), |_| {
scirs2_core::random::rng().random_range(-scale..scale)
});
let w_k = Array2::from_shape_fn((_dmodel, d_k), |_| {
scirs2_core::random::rng().random_range(-scale..scale)
});
let w_v = Array2::from_shape_fn((_dmodel, d_k), |_| {
scirs2_core::random::rng().random_range(-scale..scale)
});
let w_o = Array2::from_shape_fn((d_k, _dmodel), |_| {
scirs2_core::random::rng().random_range(-scale..scale)
});
Self {
w_q,
w_k,
w_v,
w_o,
d_k,
dropout,
}
}
pub fn forward(
&self,
input: ArrayView2<f64>,
mask: Option<ArrayView2<bool>>,
) -> Result<Array2<f64>> {
let _seq_len = input.shape()[0];
let q = input.dot(&self.w_q);
let k = input.dot(&self.w_k);
let v = input.dot(&self.w_v);
let attention_output =
self.scaled_dot_product_attention(q.view(), k.view(), v.view(), mask)?;
Ok(attention_output.dot(&self.w_o))
}
fn scaled_dot_product_attention(
&self,
q: ArrayView2<f64>,
k: ArrayView2<f64>,
v: ArrayView2<f64>,
mask: Option<ArrayView2<bool>>,
) -> Result<Array2<f64>> {
let d_k = self.d_k as f64;
let scores = q.dot(&k.t()) / d_k.sqrt();
let mut masked_scores = scores;
if let Some(mask) = mask {
for ((i, j), &should_mask) in mask.indexed_iter() {
if should_mask {
masked_scores[[i, j]] = f64::NEG_INFINITY;
}
}
}
let attention_weights = self.softmax_2d(&masked_scores)?;
Ok(attention_weights.dot(&v))
}
fn softmax_2d(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
let mut result = x.clone();
for mut row in result.rows_mut() {
let max_val = row.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
row.mapv_inplace(|x| (x - max_val).exp());
let sum: f64 = row.sum();
if sum > 0.0 {
row /= sum;
}
}
Ok(result)
}
}
#[derive(Debug)]
pub struct CrossAttention {
w_q: Array2<f64>,
w_k: Array2<f64>,
w_v: Array2<f64>,
w_o: Array2<f64>,
d_k: usize,
}
impl CrossAttention {
pub fn new(_dmodel: usize) -> Self {
let d_k = _dmodel;
let scale = (2.0 / _dmodel as f64).sqrt();
let w_q = Array2::from_shape_fn((_dmodel, d_k), |_| {
scirs2_core::random::rng().random_range(-scale..scale)
});
let w_k = Array2::from_shape_fn((_dmodel, d_k), |_| {
scirs2_core::random::rng().random_range(-scale..scale)
});
let w_v = Array2::from_shape_fn((_dmodel, d_k), |_| {
scirs2_core::random::rng().random_range(-scale..scale)
});
let w_o = Array2::from_shape_fn((d_k, _dmodel), |_| {
scirs2_core::random::rng().random_range(-scale..scale)
});
Self {
w_q,
w_k,
w_v,
w_o,
d_k,
}
}
pub fn forward(
&self,
query: ArrayView2<f64>,
key: ArrayView2<f64>,
value: ArrayView2<f64>,
mask: Option<ArrayView2<bool>>,
) -> Result<Array2<f64>> {
let q = query.dot(&self.w_q);
let k = key.dot(&self.w_k);
let v = value.dot(&self.w_v);
let attention_output =
self.scaled_dot_product_attention(q.view(), k.view(), v.view(), mask)?;
Ok(attention_output.dot(&self.w_o))
}
fn scaled_dot_product_attention(
&self,
q: ArrayView2<f64>,
k: ArrayView2<f64>,
v: ArrayView2<f64>,
mask: Option<ArrayView2<bool>>,
) -> Result<Array2<f64>> {
let d_k = self.d_k as f64;
let scores = q.dot(&k.t()) / d_k.sqrt();
let mut masked_scores = scores;
if let Some(mask) = mask {
for ((i, j), &should_mask) in mask.indexed_iter() {
if should_mask {
masked_scores[[i, j]] = f64::NEG_INFINITY;
}
}
}
let attention_weights = self.softmax_2d(&masked_scores)?;
Ok(attention_weights.dot(&v))
}
fn softmax_2d(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
let mut result = x.clone();
for mut row in result.rows_mut() {
let max_val = row.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
row.mapv_inplace(|x| (x - max_val).exp());
let sum: f64 = row.sum();
if sum > 0.0 {
row /= sum;
}
}
Ok(result)
}
}
#[derive(Debug)]
pub struct PositionwiseFeedForward {
w1: Array2<f64>,
w2: Array2<f64>,
b1: Array1<f64>,
b2: Array1<f64>,
dropout: f64,
}
impl PositionwiseFeedForward {
pub fn new(_dmodel: usize, dff: usize, dropout: f64) -> Self {
let scale1 = (2.0 / _dmodel as f64).sqrt();
let scale2 = (2.0 / dff as f64).sqrt();
let w1 = Array2::from_shape_fn((dff, _dmodel), |_| {
scirs2_core::random::rng().random_range(-scale1..scale1)
});
let w2 = Array2::from_shape_fn((_dmodel, dff), |_| {
scirs2_core::random::rng().random_range(-scale2..scale2)
});
let b1 = Array1::zeros(dff);
let b2 = Array1::zeros(_dmodel);
Self {
w1,
w2,
b1,
b2,
dropout,
}
}
pub fn forward(&self, x: ArrayView2<f64>) -> Array2<f64> {
let hidden = x.dot(&self.w1.t()) + &self.b1;
let activated = hidden.mapv(|x| ActivationFunction::GELU.apply(x));
let dropout_mask = if self.dropout > 0.0 {
1.0 - self.dropout
} else {
1.0
};
let dropped = activated * dropout_mask;
dropped.dot(&self.w2.t()) + &self.b2
}
}
pub struct TextCNN {
conv_layers: Vec<Conv1D>,
pool_layers: Vec<MaxPool1D>,
fcweights: Array2<f64>,
fc_bias: Array1<f64>,
dropout_rate: f64,
}
impl TextCNN {
#[allow(clippy::too_many_arguments)]
pub fn new(
_vocab_size: usize,
embedding_dim: usize,
num_filters: usize,
filter_sizes: Vec<usize>,
num_classes: usize,
dropout_rate: f64,
) -> Self {
let mut conv_layers = Vec::new();
let mut pool_layers = Vec::new();
for &filter_size in &filter_sizes {
conv_layers.push(Conv1D::new(
embedding_dim,
num_filters,
filter_size,
ActivationFunction::ReLU,
));
pool_layers.push(MaxPool1D::new(2, 2));
}
let fc_input_size = num_filters * filter_sizes.len();
let scale = (2.0 / fc_input_size as f64).sqrt();
let fc_weights = Array2::from_shape_fn((num_classes, fc_input_size), |_| {
scirs2_core::random::rng().random_range(-scale..scale)
});
let fc_bias = Array1::zeros(num_classes);
Self {
conv_layers,
pool_layers,
fcweights: fc_weights,
fc_bias,
dropout_rate,
}
}
pub fn forward(&self, embeddings: ArrayView2<f64>) -> Result<Array1<f64>> {
let mut feature_maps = Vec::new();
for (conv_layer, pool_layer) in self.conv_layers.iter().zip(&self.pool_layers) {
let conv_output = conv_layer.forward(embeddings)?;
let pooled_output = pool_layer.forward(conv_output.view());
let global_max = pooled_output.map_axis(Axis(0), |row| {
row.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b))
});
feature_maps.push(global_max);
}
let mut concatenated_features =
Array1::zeros(feature_maps.iter().map(|fm| fm.len()).sum::<usize>());
let mut offset = 0;
for feature_map in feature_maps {
let end = offset + feature_map.len();
concatenated_features
.slice_mut(s![offset..end])
.assign(&feature_map);
offset = end;
}
let dropout_mask = if self.dropout_rate > 0.0 {
1.0 - self.dropout_rate
} else {
1.0
};
concatenated_features *= dropout_mask;
let output = self.fcweights.dot(&concatenated_features) + &self.fc_bias;
Ok(output)
}
}
pub struct CNNLSTMHybrid {
cnn: TextCNN,
lstm: BiLSTM,
classifier: Array2<f64>,
classifier_bias: Array1<f64>,
}
impl CNNLSTMHybrid {
#[allow(clippy::too_many_arguments)]
pub fn new(
embedding_dim: usize,
cnn_filters: usize,
filter_sizes: Vec<usize>,
lstm_hidden_size: usize,
lstm_layers: usize,
num_classes: usize,
) -> Self {
let cnn = TextCNN::new(
0, embedding_dim,
cnn_filters,
filter_sizes.clone(),
cnn_filters * filter_sizes.len(),
0.0, );
let lstm_input_size = cnn_filters * filter_sizes.len();
let lstm = BiLSTM::new(lstm_input_size, lstm_hidden_size, lstm_layers);
let classifier_input_size = lstm_hidden_size * 2; let scale = (2.0 / classifier_input_size as f64).sqrt();
let classifier = Array2::from_shape_fn((num_classes, classifier_input_size), |_| {
scirs2_core::random::rng().random_range(-scale..scale)
});
let classifier_bias = Array1::zeros(num_classes);
Self {
cnn,
lstm,
classifier,
classifier_bias,
}
}
pub fn forward(&self, embeddings: ArrayView2<f64>) -> Result<Array1<f64>> {
let cnn_features = self.cnn.forward(embeddings)?;
let lstm_input = Array2::from_shape_vec((1, cnn_features.len()), cnn_features.to_vec())
.map_err(|e| TextError::InvalidInput(format!("Reshape error: {e}")))?;
let lstm_output = self.lstm.forward(lstm_input.view())?;
let final_hidden = lstm_output.row(lstm_output.shape()[0] - 1);
let output = self.classifier.dot(&final_hidden) + &self.classifier_bias;
Ok(output)
}
}
pub struct LayerNorm {
weight: Array1<f64>,
bias: Array1<f64>,
eps: f64,
}
impl LayerNorm {
pub fn new(normalizedshape: usize) -> Self {
Self {
weight: Array1::ones(normalizedshape),
bias: Array1::zeros(normalizedshape),
eps: 1e-6,
}
}
pub fn forward(&self, x: ArrayView2<f64>) -> Result<Array2<f64>> {
let mut output = Array2::zeros(x.raw_dim());
for (i, row) in x.outer_iter().enumerate() {
let mean = row.mean();
let variance = row.mapv(|v| (v - mean).powi(2)).mean();
let std = (variance + self.eps).sqrt();
for (j, &val) in row.iter().enumerate() {
let normalized = (val - mean) / std;
output[[i, j]] = normalized * self.weight[j] + self.bias[j];
}
}
Ok(output)
}
}
pub struct Dropout {
p: f64,
training: bool,
}
impl Dropout {
pub fn new(p: f64) -> Self {
Self {
p: p.clamp(0.0, 1.0),
training: true,
}
}
pub fn set_training(&mut self, training: bool) {
self.training = training;
}
pub fn forward(&self, x: ArrayView2<f64>) -> Array2<f64> {
if !self.training || self.p == 0.0 {
return x.to_owned();
}
let mut output = x.to_owned();
let scale = 1.0 / (1.0 - self.p);
for elem in output.iter_mut() {
if scirs2_core::random::rng().random_range(0.0..1.0) < self.p {
*elem = 0.0; } else {
*elem *= scale; }
}
output
}
}
pub struct MultiHeadAttention {
num_heads: usize,
d_model: usize,
d_k: usize,
w_q: Array2<f64>,
w_k: Array2<f64>,
w_v: Array2<f64>,
w_o: Array2<f64>,
dropout: Dropout,
}
impl MultiHeadAttention {
pub fn new(_dmodel: usize, num_heads: usize, dropoutp: f64) -> Result<Self> {
if !_dmodel.is_multiple_of(num_heads) {
return Err(TextError::InvalidInput(
"Model dimension must be divisible by number of _heads".to_string(),
));
}
let d_k = _dmodel / num_heads;
let scale = (2.0 / _dmodel as f64).sqrt();
let w_q = Array2::from_shape_fn((_dmodel, _dmodel), |_| {
scirs2_core::random::rng().random_range(-scale..scale)
});
let w_k = Array2::from_shape_fn((_dmodel, _dmodel), |_| {
scirs2_core::random::rng().random_range(-scale..scale)
});
let w_v = Array2::from_shape_fn((_dmodel, _dmodel), |_| {
scirs2_core::random::rng().random_range(-scale..scale)
});
let w_o = Array2::from_shape_fn((_dmodel, _dmodel), |_| {
scirs2_core::random::rng().random_range(-scale..scale)
});
Ok(Self {
num_heads,
d_model: _dmodel,
d_k,
w_q,
w_k,
w_v,
w_o,
dropout: Dropout::new(dropoutp),
})
}
pub fn forward(
&self,
query: ArrayView2<f64>,
key: ArrayView2<f64>,
value: ArrayView2<f64>,
mask: Option<ArrayView2<bool>>,
) -> Result<Array2<f64>> {
let seq_len = query.shape()[0];
let _batch_size = 1;
let q = query.dot(&self.w_q);
let k = key.dot(&self.w_k);
let v = value.dot(&self.w_v);
let mut q_heads = Array3::zeros((seq_len, self.num_heads, self.d_k));
let mut k_heads = Array3::zeros((seq_len, self.num_heads, self.d_k));
let mut v_heads = Array3::zeros((seq_len, self.num_heads, self.d_k));
for i in 0..seq_len {
for h in 0..self.num_heads {
let start = h * self.d_k;
let _end = start + self.d_k;
for j in 0..self.d_k {
q_heads[[i, h, j]] = q[[i, start + j]];
k_heads[[i, h, j]] = k[[i, start + j]];
v_heads[[i, h, j]] = v[[i, start + j]];
}
}
}
let mut attention_outputs = Array3::zeros((seq_len, self.num_heads, self.d_k));
for h in 0..self.num_heads {
let q_h = q_heads.slice(s![.., h, ..]);
let k_h = k_heads.slice(s![.., h, ..]);
let v_h = v_heads.slice(s![.., h, ..]);
let scores = q_h.dot(&k_h.t()) / (self.d_k as f64).sqrt();
let mut masked_scores = scores;
if let Some(mask) = mask {
for i in 0..seq_len {
for j in 0..seq_len {
if mask[[i, j]] {
masked_scores[[i, j]] = f64::NEG_INFINITY;
}
}
}
}
let mut attention_weights = Array2::zeros((seq_len, seq_len));
for i in 0..seq_len {
let row = masked_scores.row(i);
let max_val = row.fold(f64::NEG_INFINITY, |acc, &x| acc.max(x));
let exp_sum: f64 = row.iter().map(|&x| (x - max_val).exp()).sum();
for j in 0..seq_len {
attention_weights[[i, j]] = (masked_scores[[i, j]] - max_val).exp() / exp_sum;
}
}
let attention_weights_dropped = self.dropout.forward(attention_weights.view());
let attended = attention_weights_dropped.dot(&v_h);
for i in 0..seq_len {
for j in 0..self.d_k {
attention_outputs[[i, h, j]] = attended[[i, j]];
}
}
}
let mut concatenated = Array2::zeros((seq_len, self.d_model));
for i in 0..seq_len {
for h in 0..self.num_heads {
let start = h * self.d_k;
for j in 0..self.d_k {
concatenated[[i, start + j]] = attention_outputs[[i, h, j]];
}
}
}
Ok(concatenated.dot(&self.w_o))
}
pub fn set_training(&mut self, training: bool) {
self.dropout.set_training(training);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_activation_functions() {
let x = 0.5;
assert!(ActivationFunction::Sigmoid.apply(x) > 0.0);
assert!(ActivationFunction::Sigmoid.apply(x) < 1.0);
assert!(ActivationFunction::Tanh.apply(x) > -1.0);
assert!(ActivationFunction::Tanh.apply(x) < 1.0);
assert_eq!(ActivationFunction::ReLU.apply(-1.0), 0.0);
assert_eq!(ActivationFunction::ReLU.apply(1.0), 1.0);
}
#[test]
fn test_lstm_cell() {
let lstm = LSTMCell::new(10, 20);
let input = Array1::ones(10);
let h_prev = Array1::zeros(20);
let c_prev = Array1::zeros(20);
let (h_new, c_new) = lstm
.forward(input.view(), h_prev.view(), c_prev.view())
.expect("Operation failed");
assert_eq!(h_new.len(), 20);
assert_eq!(c_new.len(), 20);
}
#[test]
fn test_conv1d() {
let conv = Conv1D::new(5, 10, 3, ActivationFunction::ReLU);
let input = Array2::ones((8, 5));
let output = conv.forward(input.view()).expect("Operation failed");
assert_eq!(output.shape(), &[6, 10]); }
#[test]
fn test_bilstm() {
let bilstm = BiLSTM::new(10, 20, 2);
let input = Array2::ones((5, 10));
let output = bilstm.forward(input.view()).expect("Operation failed");
assert_eq!(output.shape(), &[5, 40]); }
#[test]
fn test_gru_cell() {
let gru = GRUCell::new(10, 20);
let input = Array1::ones(10);
let h_prev = Array1::zeros(20);
let h_new = gru
.forward(input.view(), h_prev.view())
.expect("Operation failed");
assert_eq!(h_new.len(), 20);
assert!(h_new.iter().any(|&x| x != 0.0));
}
#[test]
fn test_self_attention() {
let attention = SelfAttention::new(8, 0.1);
let input = Array2::ones((4, 8));
let output = attention
.forward(input.view(), None)
.expect("Operation failed");
assert_eq!(output.shape(), &[4, 8]);
}
#[test]
fn test_cross_attention() {
let attention = CrossAttention::new(8);
let query = Array2::ones((3, 8));
let key = Array2::ones((5, 8));
let value = Array2::ones((5, 8));
let output = attention
.forward(query.view(), key.view(), value.view(), None)
.expect("Operation failed");
assert_eq!(output.shape(), &[3, 8]);
}
#[test]
fn test_residual_block() {
let block = ResidualBlock1D::new(4, 8, 3);
let input = Array2::ones((10, 4));
let output = block.forward(input.view()).expect("Operation failed");
assert_eq!(output.shape(), &[6, 8]); }
#[test]
fn test_multi_scale_cnn() {
let cnn = MultiScaleCNN::new(
5, 10, vec![2, 3, 4], 30, );
let input = Array2::ones((8, 5));
let output = cnn.forward(input.view()).expect("Operation failed");
assert_eq!(output.len(), 30);
}
#[test]
fn test_positionwise_feedforward() {
let ff = PositionwiseFeedForward::new(8, 16, 0.1);
let input = Array2::ones((4, 8));
let output = ff.forward(input.view());
assert_eq!(output.shape(), &[4, 8]);
}
}