use crate::error::Result;
use scirs2_core::ndarray::{Array, ScalarOperand};
use scirs2_core::numeric::{Float, NumAssign};
use std::fmt::Debug;
pub trait Layer<F: Float + Debug + ScalarOperand + NumAssign>: Send + Sync {
fn forward(
&self,
input: &Array<F, scirs2_core::ndarray::IxDyn>,
) -> Result<Array<F, scirs2_core::ndarray::IxDyn>>;
fn backward(
&self,
input: &Array<F, scirs2_core::ndarray::IxDyn>,
grad_output: &Array<F, scirs2_core::ndarray::IxDyn>,
) -> Result<Array<F, scirs2_core::ndarray::IxDyn>>;
fn update(&mut self, learningrate: F) -> Result<()>;
fn as_any(&self) -> &dyn std::any::Any;
fn as_any_mut(&mut self) -> &mut dyn std::any::Any;
fn params(&self) -> Vec<Array<F, scirs2_core::ndarray::IxDyn>> {
Vec::new()
}
fn gradients(&self) -> Vec<Array<F, scirs2_core::ndarray::IxDyn>> {
Vec::new()
}
fn set_gradients(
&mut self,
_gradients: &[Array<F, scirs2_core::ndarray::IxDyn>],
) -> Result<()> {
Ok(())
}
fn set_params(&mut self, _params: &[Array<F, scirs2_core::ndarray::IxDyn>]) -> Result<()> {
Ok(())
}
fn set_training(&mut self, _training: bool) {
}
fn is_training(&self) -> bool {
true }
fn layer_type(&self) -> &str {
"Unknown"
}
fn parameter_count(&self) -> usize {
0
}
fn layer_description(&self) -> String {
format!("type:{}", self.layer_type())
}
fn inputshape(&self) -> Option<Vec<usize>> {
None
}
fn outputshape(&self) -> Option<Vec<usize>> {
None
}
fn name(&self) -> Option<&str> {
None
}
}
pub trait ParamLayer<F: Float + Debug + ScalarOperand + NumAssign>: Layer<F> {
fn get_parameters(&self) -> Vec<Array<F, scirs2_core::ndarray::IxDyn>>;
fn get_gradients(&self) -> Vec<Array<F, scirs2_core::ndarray::IxDyn>>;
fn set_parameters(&mut self, params: Vec<Array<F, scirs2_core::ndarray::IxDyn>>) -> Result<()>;
}
#[derive(Debug, Clone)]
pub struct LayerInfo {
pub index: usize,
pub name: String,
pub layer_type: String,
pub parameter_count: usize,
pub inputshape: Option<Vec<usize>>,
pub outputshape: Option<Vec<usize>>,
}
pub struct Sequential<F: Float + Debug + ScalarOperand + NumAssign> {
layers: Vec<Box<dyn Layer<F> + Send + Sync>>,
training: bool,
}
impl<F: Float + Debug + ScalarOperand + NumAssign> std::fmt::Debug for Sequential<F> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Sequential")
.field("num_layers", &self.layers.len())
.field("training", &self.training)
.finish()
}
}
impl<F: Float + Debug + ScalarOperand + NumAssign + 'static> Clone for Sequential<F> {
fn clone(&self) -> Self {
Self {
layers: Vec::new(),
training: self.training,
}
}
}
impl<F: Float + Debug + ScalarOperand + NumAssign> Default for Sequential<F> {
fn default() -> Self {
Self::new()
}
}
impl<F: Float + Debug + ScalarOperand + NumAssign> Sequential<F> {
pub fn new() -> Self {
Self {
layers: Vec::new(),
training: true,
}
}
pub fn add<L: Layer<F> + Send + Sync + 'static>(&mut self, layer: L) {
self.layers.push(Box::new(layer));
}
pub fn len(&self) -> usize {
self.layers.len()
}
pub fn is_empty(&self) -> bool {
self.layers.is_empty()
}
pub fn total_parameters(&self) -> usize {
self.layers
.iter()
.map(|layer| layer.parameter_count())
.sum()
}
pub fn layer_info(&self) -> Vec<LayerInfo> {
self.layers
.iter()
.enumerate()
.map(|(i, layer)| LayerInfo {
index: i,
name: layer.name().unwrap_or(&format!("Layer_{i}")).to_string(),
layer_type: layer.layer_type().to_string(),
parameter_count: layer.parameter_count(),
inputshape: layer.inputshape(),
outputshape: layer.outputshape(),
})
.collect()
}
}
impl<F: Float + Debug + ScalarOperand + NumAssign> Layer<F> for Sequential<F> {
fn forward(
&self,
input: &Array<F, scirs2_core::ndarray::IxDyn>,
) -> Result<Array<F, scirs2_core::ndarray::IxDyn>> {
let mut output = input.clone();
for layer in &self.layers {
output = layer.forward(&output)?;
}
Ok(output)
}
fn backward(
&self,
_input: &Array<F, scirs2_core::ndarray::IxDyn>,
grad_output: &Array<F, scirs2_core::ndarray::IxDyn>,
) -> Result<Array<F, scirs2_core::ndarray::IxDyn>> {
Ok(grad_output.clone())
}
fn update(&mut self, learningrate: F) -> Result<()> {
for layer in &mut self.layers {
layer.update(learningrate)?;
}
Ok(())
}
fn params(&self) -> Vec<Array<F, scirs2_core::ndarray::IxDyn>> {
let mut params = Vec::new();
for layer in &self.layers {
params.extend(layer.params());
}
params
}
fn set_training(&mut self, training: bool) {
self.training = training;
for layer in &mut self.layers {
layer.set_training(training);
}
}
fn is_training(&self) -> bool {
self.training
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
self
}
fn layer_type(&self) -> &str {
"Sequential"
}
fn parameter_count(&self) -> usize {
self.layers
.iter()
.map(|layer| layer.parameter_count())
.sum()
}
}
#[derive(Debug, Clone)]
pub enum LayerConfig {
Dense {
input_size: usize,
output_size: usize,
activation: Option<String>,
},
Conv2D {
in_channels: usize,
out_channels: usize,
kernel_size: (usize, usize),
},
Dropout { rate: f64 },
}
pub mod conv;
pub mod dense;
pub mod dropout;
pub mod graph_conv;
pub mod normalization;
pub mod recurrent;
mod attention;
mod embedding;
mod flash_attention;
mod flash_attention_v2;
mod grouped_query_attention;
mod multi_query_attention;
mod regularization;
pub mod rnn_thread_safe;
pub use attention::{AttentionConfig, AttentionMask, MultiHeadAttention, SelfAttention};
pub use conv::{AvgPool2D, Conv2D, GlobalAvgPool2D, MaxPool2D};
pub use dense::Dense;
pub use dropout::Dropout;
pub use embedding::{Embedding, EmbeddingConfig, PositionalEmbedding};
pub use flash_attention::{flash_attention_compute, FlashAttention, FlashAttentionConfig};
pub use flash_attention_v2::{
flash_attention_v2_compute, FlashAttentionV2, FlashAttentionV2Config,
};
pub use graph_conv::{
GraphActivation, GraphAttentionLayer, GraphConvLayer, GraphSageLayer, SageAggregator,
};
pub use grouped_query_attention::{
GqaKvCache, GroupedQueryAttention, GroupedQueryAttentionConfig, RotaryPositionEmbedding,
};
pub use multi_query_attention::{KvCache, MultiQueryAttention, MultiQueryAttentionConfig};
pub use normalization::{BatchNorm, LayerNorm};
pub use recurrent::rnn::{RNNConfig, RecurrentActivation as RecurrentActivationRNN};
pub use recurrent::{LSTM, RNN};
pub use regularization::{
ActivityRegularization, L1ActivityRegularization, L2ActivityRegularization,
};
pub use rnn_thread_safe::{
RecurrentActivation as ThreadSafeRecurrentActivation, ThreadSafeBidirectional, ThreadSafeRNN,
};