use crate::error::Result;
use scirs2_core::ndarray::{Array, ScalarOperand};
use scirs2_core::numeric::{Float, NumAssign};
use std::fmt::Debug;
pub trait Layer<F: Float + Debug + ScalarOperand>: Send + Sync {
fn forward(&self, input: &Array<F, scirs2_core::ndarray::IxDyn>) -> Result<Array<F, scirs2_core::ndarray::IxDyn>>;
fn backward(
&self,
input: &Array<F, scirs2_core::ndarray::IxDyn>,
grad_output: &Array<F, scirs2_core::ndarray::IxDyn>,
) -> Result<Array<F, scirs2_core::ndarray::IxDyn>>;
fn update(&mut self, learningrate: F) -> Result<()>;
fn as_any(&self) -> &dyn std::any::Any;
fn as_any_mut(&mut self) -> &mut dyn std::any::Any;
fn params(&self) -> Vec<Array<F, scirs2_core::ndarray::IxDyn>> {
Vec::new()
}
fn gradients(&self) -> Vec<Array<F, scirs2_core::ndarray::IxDyn>> {
Vec::new()
}
fn set_gradients(&mut selfgradients: &[Array<F, scirs2_core::ndarray::IxDyn>]) -> Result<()> {
Ok(())
}
fn set_params(&mut selfparams: &[Array<F, scirs2_core::ndarray::IxDyn>]) -> Result<()> {
Ok(())
}
fn set_training(&mut self, training: bool) {
}
fn is_training(&self) -> bool {
true }
fn layer_type(&self) -> &str {
"Unknown"
}
fn parameter_count(&self) -> usize {
0
}
fn layer_description(&self) -> String {
format!("type:{}", self.layer_type())
}
}
pub trait ParamLayer<F: Float + Debug + ScalarOperand>: Layer<F> {
fn get_parameters(&self) -> Vec<&Array<F, scirs2_core::ndarray::IxDyn>>;
fn get_gradients(&self) -> Vec<&Array<F, scirs2_core::ndarray::IxDyn>>;
fn set_parameters(&mut self, params: Vec<Array<F, scirs2_core::ndarray::IxDyn>>) -> Result<()>;
}
pub struct Sequential<F: Float + Debug + ScalarOperand + NumAssign> {
layers: Vec<Box<dyn Layer<F> + Send + Sync>>,
training: bool,
}
impl<F: Float + Debug + ScalarOperand + NumAssign> std::fmt::Debug for Sequential<F> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Sequential")
.field("num_layers", &self.layers.len())
.field("training", &self.training)
.finish()
}
}
impl<F: Float + Debug + ScalarOperand + NumAssign + 'static> Clone for Sequential<F> {
fn clone(&self) -> Self {
Self {
layers: Vec::new(),
training: self.training,
}
}
}
impl<F: Float + Debug + ScalarOperand + NumAssign> Default for Sequential<F> {
fn default() -> Self {
Self::new()
}
}
impl<F: Float + Debug + ScalarOperand + NumAssign> Sequential<F> {
pub fn new() -> Self {
Self {
layers: Vec::new(),
training: true,
}
}
pub fn add<L: Layer<F> + Send + Sync + 'static>(&mut self, layer: L) {
self.layers.push(Box::new(layer));
}
pub fn len(&self) -> usize {
self.layers.len()
}
pub fn is_empty(&self) -> bool {
self.layers.is_empty()
}
pub fn total_parameters(&self) -> usize {
self.layers.iter().map(|layer| layer.parameter_count()).sum()
}
}
impl<F: Float + Debug + ScalarOperand + NumAssign> Layer<F> for Sequential<F> {
fn forward(&self, input: &Array<F, scirs2_core::ndarray::IxDyn>) -> Result<Array<F, scirs2_core::ndarray::IxDyn>> {
let mut output = input.clone();
for layer in &self.layers {
output = layer.forward(&output)?;
}
Ok(output)
}
fn backward(
&mut self,
_input: &Array<F, scirs2_core::ndarray::IxDyn>,
grad_output: &Array<F, scirs2_core::ndarray::IxDyn>,
) -> Result<Array<F, scirs2_core::ndarray::IxDyn>> {
Ok(grad_output.clone())
}
fn update(&mut self, learningrate: F) -> Result<()> {
for layer in &mut self.layers {
layer.update(learningrate)?;
}
Ok(())
}
fn params(&self) -> Vec<Array<F, scirs2_core::ndarray::IxDyn>> {
let mut params = Vec::new();
for layer in &self.layers {
params.extend(layer.params());
}
params
}
fn set_training(&mut self, training: bool) {
self.training = training;
for layer in &mut self.layers {
layer.set_training(training);
}
}
fn is_training(&self) -> bool {
self.training
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
self
}
fn layer_type(&self) -> &str {
"Sequential"
}
fn parameter_count(&self) -> usize {
self.layers.iter().map(|layer| layer.parameter_count()).sum()
}
}
#[derive(Debug, Clone)]
pub enum LayerConfig {
Dense { input_size: usize, output_size: usize, activation: Option<String> },
Conv2D { in_channels: usize, out_channels: usize, kernel_size: (usize, usize) },
Dropout { rate: f64 },
}