use crate::{Module, ModuleBase, Parameter};
use std::collections::HashMap;
use torsh_core::device::DeviceType;
use torsh_core::dtype::DType;
use torsh_core::error::{Result, TorshError};
use torsh_tensor::{creation::*, Tensor};
pub struct SciRS2NeuralProcessor {
#[allow(dead_code)]
config: NeuralConfig,
device: DeviceType,
}
#[derive(Debug, Clone)]
pub struct NeuralConfig {
pub device: DeviceType,
pub dtype: DType,
pub mixed_precision: bool,
pub memory_optimization: u8,
pub gradient_checkpointing: bool,
pub max_sequence_length: usize,
}
impl Default for NeuralConfig {
fn default() -> Self {
Self {
device: DeviceType::Cpu,
dtype: DType::F32,
mixed_precision: false,
memory_optimization: 1,
gradient_checkpointing: false,
max_sequence_length: 512,
}
}
}
impl SciRS2NeuralProcessor {
pub fn new(config: NeuralConfig) -> Self {
Self {
device: config.device,
config,
}
}
pub fn create_attention_layer(
&self,
embed_dim: usize,
num_heads: usize,
dropout: f32,
) -> Result<MultiHeadAttention> {
MultiHeadAttention::new(embed_dim, num_heads, dropout, true, self.device)
}
pub fn create_layer_norm(&self, normalized_shape: Vec<usize>, eps: f64) -> Result<LayerNorm> {
LayerNorm::new(normalized_shape, eps, true, self.device)
}
pub fn create_transformer_encoder(
&self,
d_model: usize,
nhead: usize,
dim_feedforward: usize,
dropout: f32,
) -> Result<TransformerEncoderLayer> {
TransformerEncoderLayer::new(d_model, nhead, dim_feedforward, dropout, self.device)
}
}
pub struct MultiHeadAttention {
base: ModuleBase,
#[allow(dead_code)]
embed_dim: usize,
#[allow(dead_code)]
num_heads: usize,
head_dim: usize,
#[allow(dead_code)]
dropout: f32,
batch_first: bool,
#[allow(dead_code)]
device: DeviceType,
}
impl MultiHeadAttention {
pub fn new(
embed_dim: usize,
num_heads: usize,
dropout: f32,
batch_first: bool,
device: DeviceType,
) -> Result<Self> {
if embed_dim % num_heads != 0 {
return Err(TorshError::dimension_error_with_context(
"embed_dim must be divisible by num_heads",
"MultiHeadAttention::new",
));
}
let head_dim = embed_dim / num_heads;
let mut base = ModuleBase::new();
let q_proj = Parameter::new(zeros(&[embed_dim, embed_dim])?);
let k_proj = Parameter::new(zeros(&[embed_dim, embed_dim])?);
let v_proj = Parameter::new(zeros(&[embed_dim, embed_dim])?);
let out_proj = Parameter::new(zeros(&[embed_dim, embed_dim])?);
base.register_parameter("q_proj".to_string(), q_proj);
base.register_parameter("k_proj".to_string(), k_proj);
base.register_parameter("v_proj".to_string(), v_proj);
base.register_parameter("out_proj".to_string(), out_proj);
Ok(Self {
base,
embed_dim,
num_heads,
head_dim,
dropout,
batch_first,
device,
})
}
pub fn forward(
&self,
query: &Tensor,
key: &Tensor,
value: &Tensor,
attn_mask: Option<&Tensor>,
) -> Result<(Tensor, Tensor)> {
let params = self.base.named_parameters();
let q_proj = params.get("q_proj").expect("q_proj parameter should exist");
let k_proj = params.get("k_proj").expect("k_proj parameter should exist");
let v_proj = params.get("v_proj").expect("v_proj parameter should exist");
let out_proj = params
.get("out_proj")
.expect("out_proj parameter should exist");
let q = query.matmul(&*q_proj.tensor().read())?;
let k = key.matmul(&*k_proj.tensor().read())?;
let v = value.matmul(&*v_proj.tensor().read())?;
let shape = q.shape();
let (_batch_size, _seq_len) = if self.batch_first {
(shape.dims()[0], shape.dims()[1])
} else {
(shape.dims()[1], shape.dims()[0])
};
let scale = (self.head_dim as f64).sqrt();
let scores = q.matmul(&k.transpose(-2, -1)?)?;
let scaled_scores = scores.div_scalar(scale as f32)?;
let attn_weights = if let Some(mask) = attn_mask {
let masked_scores = scaled_scores.add(mask)?;
masked_scores.softmax(-1)?
} else {
scaled_scores.softmax(-1)?
};
let attn_output = attn_weights.matmul(&v)?;
let output = attn_output.matmul(&*out_proj.tensor().read())?;
Ok((output, attn_weights))
}
}
impl Module for MultiHeadAttention {
fn forward(&self, input: &Tensor) -> Result<Tensor> {
let (output, _) = self.forward(input, input, input, None)?;
Ok(output)
}
fn parameters(&self) -> HashMap<String, Parameter> {
self.base.named_parameters()
}
fn named_parameters(&self) -> HashMap<String, Parameter> {
self.base.named_parameters()
}
fn train(&mut self) {
self.base.set_training(true);
}
fn eval(&mut self) {
self.base.set_training(false);
}
}
pub struct TransformerEncoderLayer {
self_attn: MultiHeadAttention,
linear1: Parameter,
linear2: Parameter,
norm1: LayerNorm,
norm2: LayerNorm,
#[allow(dead_code)]
dropout: f32,
}
impl TransformerEncoderLayer {
pub fn new(
d_model: usize,
nhead: usize,
dim_feedforward: usize,
dropout: f32,
device: DeviceType,
) -> Result<Self> {
let self_attn = MultiHeadAttention::new(d_model, nhead, dropout, true, device)?;
let linear1 = Parameter::new(zeros(&[d_model, dim_feedforward])?);
let linear2 = Parameter::new(zeros(&[dim_feedforward, d_model])?);
let norm1 = LayerNorm::new(vec![d_model], 1e-5, true, device)?;
let norm2 = LayerNorm::new(vec![d_model], 1e-5, true, device)?;
Ok(Self {
self_attn,
linear1,
linear2,
norm1,
norm2,
dropout,
})
}
pub fn forward(&self, src: &Tensor, src_mask: Option<&Tensor>) -> Result<Tensor> {
let (attn_output, _) = self.self_attn.forward(src, src, src, src_mask)?;
let src2 = self.norm1.forward(&src.add(&attn_output)?)?;
let ff_output = src2.matmul(&*self.linear1.tensor().read())?;
let ff_output = ff_output.relu()?;
let ff_output = ff_output.matmul(&*self.linear2.tensor().read())?;
let output = self.norm2.forward(&src2.add(&ff_output)?)?;
Ok(output)
}
}
impl Module for TransformerEncoderLayer {
fn forward(&self, input: &Tensor) -> Result<Tensor> {
self.forward(input, None)
}
fn parameters(&self) -> HashMap<String, Parameter> {
let mut params = HashMap::new();
params.extend(self.self_attn.parameters());
params.insert("linear1".to_string(), self.linear1.clone());
params.insert("linear2".to_string(), self.linear2.clone());
params.extend(self.norm1.parameters());
params.extend(self.norm2.parameters());
params
}
fn named_parameters(&self) -> HashMap<String, Parameter> {
let mut params = HashMap::new();
for (name, param) in self.self_attn.named_parameters() {
params.insert(format!("self_attn.{}", name), param);
}
params.insert("linear1".to_string(), self.linear1.clone());
params.insert("linear2".to_string(), self.linear2.clone());
for (name, param) in self.norm1.named_parameters() {
params.insert(format!("norm1.{}", name), param);
}
for (name, param) in self.norm2.named_parameters() {
params.insert(format!("norm2.{}", name), param);
}
params
}
fn train(&mut self) {
self.self_attn.train();
self.norm1.train();
self.norm2.train();
}
fn eval(&mut self) {
self.self_attn.eval();
self.norm1.eval();
self.norm2.eval();
}
}
pub struct LayerNorm {
base: ModuleBase,
normalized_shape: Vec<usize>,
eps: f64,
elementwise_affine: bool,
#[allow(dead_code)]
device: DeviceType,
}
impl LayerNorm {
pub fn new(
normalized_shape: Vec<usize>,
eps: f64,
elementwise_affine: bool,
device: DeviceType,
) -> Result<Self> {
let mut base = ModuleBase::new();
if elementwise_affine {
let weight = Parameter::new(ones(&normalized_shape)?);
let bias = Parameter::new(zeros(&normalized_shape)?);
base.register_parameter("weight".to_string(), weight);
base.register_parameter("bias".to_string(), bias);
}
Ok(Self {
base,
normalized_shape,
eps,
elementwise_affine,
device,
})
}
pub fn forward(&self, input: &Tensor) -> Result<Tensor> {
let input_shape = input.shape();
let dims = input_shape.dims();
let normalize_dims = self.normalized_shape.len();
let input_dims = dims.len();
if input_dims < normalize_dims {
return Err(TorshError::InvalidArgument(format!(
"Input has {} dims but normalized_shape has {} dims",
input_dims, normalize_dims
)));
}
let start_idx = input_dims - normalize_dims;
for (i, &norm_dim) in self.normalized_shape.iter().enumerate() {
if dims[start_idx + i] != norm_dim {
return Err(TorshError::InvalidArgument(format!(
"Expected dimension {} to be {}, got {}",
start_idx + i,
norm_dim,
dims[start_idx + i]
)));
}
}
let norm_elements: usize = self.normalized_shape.iter().product();
let batch_size: usize = dims[..start_idx].iter().product();
let input_data = input.to_vec()?;
let mut means = vec![0.0f32; batch_size];
let mut vars = vec![0.0f32; batch_size];
for batch in 0..batch_size {
let mut sum = 0.0;
let mut sum_sq = 0.0;
let batch_start = batch * norm_elements;
for i in 0..norm_elements {
let val = input_data[batch_start + i];
sum += val;
sum_sq += val * val;
}
let mean = sum / norm_elements as f32;
let var = (sum_sq / norm_elements as f32) - (mean * mean);
means[batch] = mean;
vars[batch] = var;
}
let mut output_data = vec![0.0f32; input_data.len()];
for batch in 0..batch_size {
let mean = means[batch];
let std = (vars[batch] + self.eps as f32).sqrt();
let batch_start = batch * norm_elements;
for i in 0..norm_elements {
let val = input_data[batch_start + i];
output_data[batch_start + i] = (val - mean) / std;
}
}
let mut normalized = Tensor::from_data(output_data, dims.to_vec(), input.device())?;
if self.elementwise_affine {
let params = self.base.named_parameters();
let weight = params
.get("weight")
.expect("weight parameter should exist for elementwise_affine");
let bias = params
.get("bias")
.expect("bias parameter should exist for elementwise_affine");
let weight_tensor = weight.tensor().read().clone();
let bias_tensor = bias.tensor().read().clone();
let weight_data = weight_tensor.to_vec()?;
let bias_data = bias_tensor.to_vec()?;
let mut output_data = normalized.to_vec()?;
for batch in 0..batch_size {
let batch_start = batch * norm_elements;
for i in 0..norm_elements {
let idx = batch_start + i;
output_data[idx] = output_data[idx] * weight_data[i] + bias_data[i];
}
}
normalized = Tensor::from_data(output_data, dims.to_vec(), input.device())?;
}
Ok(normalized)
}
}
impl Module for LayerNorm {
fn forward(&self, input: &Tensor) -> Result<Tensor> {
self.forward(input)
}
fn parameters(&self) -> HashMap<String, Parameter> {
self.base.named_parameters()
}
fn named_parameters(&self) -> HashMap<String, Parameter> {
self.base.named_parameters()
}
fn train(&mut self) {
self.base.set_training(true);
}
fn eval(&mut self) {
self.base.set_training(false);
}
}
pub struct Swish;
impl Swish {
pub fn new() -> Self {
Self
}
pub fn forward(&self, input: &Tensor) -> Result<Tensor> {
let sigmoid_x = input.sigmoid()?;
input.mul(&sigmoid_x)
}
}
impl Module for Swish {
fn forward(&self, input: &Tensor) -> Result<Tensor> {
self.forward(input)
}
fn parameters(&self) -> HashMap<String, Parameter> {
HashMap::new()
}
fn named_parameters(&self) -> HashMap<String, Parameter> {
HashMap::new()
}
fn train(&mut self) {}
fn eval(&mut self) {}
}
pub struct GELU {
approximate: bool,
}
impl GELU {
pub fn new(approximate: bool) -> Self {
Self { approximate }
}
pub fn forward(&self, input: &Tensor) -> Result<Tensor> {
if self.approximate {
let x_cubed = input.pow(3.0)?;
let term = input.add(&x_cubed.mul_scalar(0.044715)?)?;
let sqrt_2_pi = (2.0 / std::f64::consts::PI).sqrt() as f32;
let tanh_arg = term.mul_scalar(sqrt_2_pi)?;
let tanh_val = tanh_arg.tanh()?;
let one_plus_tanh = tanh_val.add_scalar(1.0)?;
input.mul(&one_plus_tanh)?.mul_scalar(0.5)
} else {
let sqrt_2 = (2.0_f64).sqrt() as f32;
let erf_arg = input.div_scalar(sqrt_2)?;
let erf_approx = erf_arg.tanh()?;
let one_plus_erf = erf_approx.add_scalar(1.0)?;
input.mul(&one_plus_erf)?.mul_scalar(0.5)
}
}
}
impl Module for GELU {
fn forward(&self, input: &Tensor) -> Result<Tensor> {
self.forward(input)
}
fn parameters(&self) -> HashMap<String, Parameter> {
HashMap::new()
}
fn named_parameters(&self) -> HashMap<String, Parameter> {
HashMap::new()
}
fn train(&mut self) {}
fn eval(&mut self) {}
}
pub struct Mish;
impl Mish {
pub fn new() -> Self {
Self
}
pub fn forward(&self, input: &Tensor) -> Result<Tensor> {
let softplus = input.exp()?.add_scalar(1.0)?.log()?;
let tanh_softplus = softplus.tanh()?;
input.mul(&tanh_softplus)
}
}
impl Module for Mish {
fn forward(&self, input: &Tensor) -> Result<Tensor> {
self.forward(input)
}
fn parameters(&self) -> HashMap<String, Parameter> {
HashMap::new()
}
fn named_parameters(&self) -> HashMap<String, Parameter> {
HashMap::new()
}
fn train(&mut self) {}
fn eval(&mut self) {}
}
pub struct MemoryEfficientSequential {
modules: Vec<Box<dyn Module>>,
checkpointing: bool,
}
impl MemoryEfficientSequential {
pub fn new(checkpointing: bool) -> Self {
Self {
modules: Vec::new(),
checkpointing,
}
}
pub fn add_module(&mut self, module: Box<dyn Module>) {
self.modules.push(module);
}
pub fn forward(&self, mut input: Tensor) -> Result<Tensor> {
if self.checkpointing {
for module in &self.modules {
input = module.forward(&input)?;
}
} else {
for module in &self.modules {
input = module.forward(&input)?;
}
}
Ok(input)
}
}
impl Module for MemoryEfficientSequential {
fn forward(&self, input: &Tensor) -> Result<Tensor> {
self.forward(input.clone())
}
fn parameters(&self) -> HashMap<String, Parameter> {
let mut params = HashMap::new();
for (i, module) in self.modules.iter().enumerate() {
for (name, param) in module.parameters() {
params.insert(format!("{}.{}", i, name), param);
}
}
params
}
fn named_parameters(&self) -> HashMap<String, Parameter> {
let mut params = HashMap::new();
for (i, module) in self.modules.iter().enumerate() {
for (name, param) in module.named_parameters() {
params.insert(format!("{}.{}", i, name), param);
}
}
params
}
fn train(&mut self) {
for module in &mut self.modules {
module.train();
}
}
fn eval(&mut self) {
for module in &mut self.modules {
module.eval();
}
}
}