use ::ndarray::{Array, Ix1};
use rand::{Rng, RngExt};
use crate::array_protocol::ml_ops::ActivationFunc;
use crate::array_protocol::operations::OperationError;
use crate::array_protocol::{ArrayProtocol, NdarrayWrapper};
pub trait Layer: Send + Sync {
fn layer_type(&self) -> &str;
fn forward(&self, inputs: &dyn ArrayProtocol)
-> Result<Box<dyn ArrayProtocol>, OperationError>;
fn parameters(&self) -> Vec<Box<dyn ArrayProtocol>>;
fn parameters_mut(&mut self) -> Vec<&mut Box<dyn ArrayProtocol>>;
fn update_parameter(
&mut self,
name: &str,
value: Box<dyn ArrayProtocol>,
) -> Result<(), OperationError>;
fn parameter_names(&self) -> Vec<String>;
fn train(&mut self);
fn eval(&mut self);
fn is_training(&self) -> bool;
fn name(&self) -> &str;
}
pub struct Linear {
name: String,
weights: Box<dyn ArrayProtocol>,
bias: Option<Box<dyn ArrayProtocol>>,
activation: Option<ActivationFunc>,
training: bool,
}
impl Linear {
pub fn new(
name: &str,
weights: Box<dyn ArrayProtocol>,
bias: Option<Box<dyn ArrayProtocol>>,
activation: Option<ActivationFunc>,
) -> Self {
Self {
name: name.to_string(),
weights,
bias,
activation,
training: true,
}
}
pub fn new_random(
name: &str,
in_features: usize,
out_features: usize,
withbias: bool,
activation: Option<ActivationFunc>,
) -> Self {
let scale = (6.0 / (in_features + out_features) as f64).sqrt();
let mut rng = rand::rng();
let weights = Array::from_shape_fn((out_features, in_features), |_| {
(rng.random::<f64>() * 2.0_f64 - 1.0) * scale
});
let bias = if withbias {
let bias_array: Array<f64, Ix1> = Array::zeros(out_features);
Some(Box::new(NdarrayWrapper::new(bias_array)) as Box<dyn ArrayProtocol>)
} else {
None
};
Self {
name: name.to_string(),
weights: Box::new(NdarrayWrapper::new(weights)),
bias,
activation,
training: true,
}
}
}
impl Layer for Linear {
fn layer_type(&self) -> &str {
"Linear"
}
fn forward(
&self,
inputs: &dyn ArrayProtocol,
) -> Result<Box<dyn ArrayProtocol>, OperationError> {
let mut result = crate::array_protocol::matmul(self.weights.as_ref(), inputs)?;
if let Some(bias) = &self.bias {
let intermediate = crate::array_protocol::add(result.as_ref(), bias.as_ref())?;
result = intermediate;
}
if let Some(act_fn) = self.activation {
let intermediate = crate::array_protocol::ml_ops::activation(result.as_ref(), act_fn)?;
result = intermediate;
}
Ok(result)
}
fn parameters(&self) -> Vec<Box<dyn ArrayProtocol>> {
let mut params = vec![self.weights.clone()];
if let Some(bias) = &self.bias {
params.push(bias.clone());
}
params
}
fn parameters_mut(&mut self) -> Vec<&mut Box<dyn ArrayProtocol>> {
let mut params = vec![&mut self.weights];
if let Some(bias) = &mut self.bias {
params.push(bias);
}
params
}
fn update_parameter(
&mut self,
name: &str,
value: Box<dyn ArrayProtocol>,
) -> Result<(), OperationError> {
match name {
"weights" => {
self.weights = value;
Ok(())
}
"bias" => {
self.bias = Some(value);
Ok(())
}
_ => Err(OperationError::Other(format!("Unknown parameter: {name}"))),
}
}
fn parameter_names(&self) -> Vec<String> {
let mut names = vec!["weights".to_string()];
if self.bias.is_some() {
names.push("bias".to_string());
}
names
}
fn train(&mut self) {
self.training = true;
}
fn eval(&mut self) {
self.training = false;
}
fn is_training(&self) -> bool {
self.training
}
fn name(&self) -> &str {
&self.name
}
}
pub struct Conv2D {
name: String,
filters: Box<dyn ArrayProtocol>,
bias: Option<Box<dyn ArrayProtocol>>,
stride: (usize, usize),
padding: (usize, usize),
activation: Option<ActivationFunc>,
training: bool,
}
impl Conv2D {
pub fn new(
name: &str,
filters: Box<dyn ArrayProtocol>,
bias: Option<Box<dyn ArrayProtocol>>,
stride: (usize, usize),
padding: (usize, usize),
activation: Option<ActivationFunc>,
) -> Self {
Self {
name: name.to_string(),
filters,
bias,
stride,
padding,
activation,
training: true,
}
}
#[allow(clippy::too_many_arguments)]
pub fn withshape(
name: &str,
filter_height: usize,
filter_width: usize,
in_channels: usize,
out_channels: usize,
stride: (usize, usize),
padding: (usize, usize),
withbias: bool,
activation: Option<ActivationFunc>,
) -> Self {
let fan_in = filter_height * filter_width * in_channels;
let scale = (2.0 / fan_in as f64).sqrt();
let mut rng = rand::rng();
let filters = Array::from_shape_fn(
(filter_height, filter_width, in_channels, out_channels),
|_| (rng.random::<f64>() * 2.0_f64 - 1.0) * scale,
);
let bias = if withbias {
let bias_array: Array<f64, Ix1> = Array::zeros(out_channels);
Some(Box::new(NdarrayWrapper::new(bias_array)) as Box<dyn ArrayProtocol>)
} else {
None
};
Self {
name: name.to_string(),
filters: Box::new(NdarrayWrapper::new(filters)),
bias,
stride,
padding,
activation,
training: true,
}
}
}
impl Layer for Conv2D {
fn layer_type(&self) -> &str {
"Conv2D"
}
fn forward(
&self,
inputs: &dyn ArrayProtocol,
) -> Result<Box<dyn ArrayProtocol>, OperationError> {
let mut result = crate::array_protocol::ml_ops::conv2d(
inputs,
self.filters.as_ref(),
self.stride,
self.padding,
)?;
if let Some(bias) = &self.bias {
result = crate::array_protocol::add(result.as_ref(), bias.as_ref())?;
}
if let Some(act_fn) = self.activation {
result = crate::array_protocol::ml_ops::activation(result.as_ref(), act_fn)?;
}
Ok(result)
}
fn parameters(&self) -> Vec<Box<dyn ArrayProtocol>> {
let mut params = vec![self.filters.clone()];
if let Some(bias) = &self.bias {
params.push(bias.clone());
}
params
}
fn parameters_mut(&mut self) -> Vec<&mut Box<dyn ArrayProtocol>> {
let mut params = vec![&mut self.filters];
if let Some(bias) = &mut self.bias {
params.push(bias);
}
params
}
fn update_parameter(
&mut self,
name: &str,
value: Box<dyn ArrayProtocol>,
) -> Result<(), OperationError> {
match name {
"filters" => {
self.filters = value;
Ok(())
}
"bias" => {
self.bias = Some(value);
Ok(())
}
_ => Err(OperationError::Other(format!("Unknown parameter: {name}"))),
}
}
fn parameter_names(&self) -> Vec<String> {
let mut names = vec!["filters".to_string()];
if self.bias.is_some() {
names.push("bias".to_string());
}
names
}
fn train(&mut self) {
self.training = true;
}
fn eval(&mut self) {
self.training = false;
}
fn is_training(&self) -> bool {
self.training
}
fn name(&self) -> &str {
&self.name
}
}
pub struct Conv2DBuilder {
name: String,
filter_height: usize,
filter_width: usize,
in_channels: usize,
out_channels: usize,
stride: (usize, usize),
padding: (usize, usize),
withbias: bool,
activation: Option<ActivationFunc>,
}
impl Conv2DBuilder {
pub fn new(name: &str) -> Self {
Self {
name: name.to_string(),
filter_height: 3,
filter_width: 3,
in_channels: 1,
out_channels: 1,
stride: (1, 1),
padding: (0, 0),
withbias: true,
activation: None,
}
}
pub const fn filter_size(mut self, height: usize, width: usize) -> Self {
self.filter_height = height;
self.filter_width = width;
self
}
pub const fn channels(mut self, input: usize, output: usize) -> Self {
self.in_channels = input;
self.out_channels = output;
self
}
pub fn stride(mut self, stride: (usize, usize)) -> Self {
self.stride = stride;
self
}
pub fn padding(mut self, padding: (usize, usize)) -> Self {
self.padding = padding;
self
}
pub fn withbias(mut self, withbias: bool) -> Self {
self.withbias = withbias;
self
}
pub fn activation(mut self, activation: ActivationFunc) -> Self {
self.activation = Some(activation);
self
}
pub fn build(self) -> Conv2D {
Conv2D::withshape(
&self.name,
self.filter_height,
self.filter_width,
self.in_channels,
self.out_channels,
self.stride,
self.padding,
self.withbias,
self.activation,
)
}
}
#[allow(dead_code)]
pub struct MaxPool2D {
name: String,
kernel_size: (usize, usize),
stride: (usize, usize),
padding: (usize, usize),
training: bool,
}
impl MaxPool2D {
pub fn new(
name: &str,
kernel_size: (usize, usize),
stride: Option<(usize, usize)>,
padding: (usize, usize),
) -> Self {
let stride = stride.unwrap_or(kernel_size);
Self {
name: name.to_string(),
kernel_size,
stride,
padding,
training: true,
}
}
}
impl Layer for MaxPool2D {
fn layer_type(&self) -> &str {
"MaxPool2D"
}
fn forward(
&self,
inputs: &dyn ArrayProtocol,
) -> Result<Box<dyn ArrayProtocol>, OperationError> {
crate::array_protocol::ml_ops::max_pool2d(
inputs,
self.kernel_size,
self.stride,
self.padding,
)
}
fn parameters(&self) -> Vec<Box<dyn ArrayProtocol>> {
Vec::new()
}
fn parameters_mut(&mut self) -> Vec<&mut Box<dyn ArrayProtocol>> {
Vec::new()
}
fn update_parameter(
&mut self,
name: &str,
_value: Box<dyn ArrayProtocol>,
) -> Result<(), OperationError> {
Err(OperationError::Other(format!(
"MaxPool2D has no parameter: {name}"
)))
}
fn parameter_names(&self) -> Vec<String> {
Vec::new()
}
fn train(&mut self) {
self.training = true;
}
fn eval(&mut self) {
self.training = false;
}
fn is_training(&self) -> bool {
self.training
}
fn name(&self) -> &str {
&self.name
}
}
pub struct BatchNorm {
name: String,
scale: Box<dyn ArrayProtocol>,
offset: Box<dyn ArrayProtocol>,
running_mean: Box<dyn ArrayProtocol>,
running_var: Box<dyn ArrayProtocol>,
epsilon: f64,
training: bool,
}
impl BatchNorm {
pub fn new(
name: &str,
scale: Box<dyn ArrayProtocol>,
offset: Box<dyn ArrayProtocol>,
running_mean: Box<dyn ArrayProtocol>,
running_var: Box<dyn ArrayProtocol>,
epsilon: f64,
) -> Self {
Self {
name: name.to_string(),
scale,
offset,
running_mean,
running_var,
epsilon,
training: true,
}
}
pub fn withshape(
name: &str,
num_features: usize,
epsilon: Option<f64>,
_momentum: Option<f64>,
) -> Self {
let scale: Array<f64, Ix1> = Array::ones(num_features);
let offset: Array<f64, Ix1> = Array::zeros(num_features);
let running_mean: Array<f64, Ix1> = Array::zeros(num_features);
let running_var: Array<f64, Ix1> = Array::ones(num_features);
Self {
name: name.to_string(),
scale: Box::new(NdarrayWrapper::new(scale)),
offset: Box::new(NdarrayWrapper::new(offset)),
running_mean: Box::new(NdarrayWrapper::new(running_mean)),
running_var: Box::new(NdarrayWrapper::new(running_var)),
epsilon: epsilon.unwrap_or(1e-5),
training: true,
}
}
}
impl Layer for BatchNorm {
fn layer_type(&self) -> &str {
"BatchNorm"
}
fn forward(
&self,
inputs: &dyn ArrayProtocol,
) -> Result<Box<dyn ArrayProtocol>, OperationError> {
crate::array_protocol::ml_ops::batch_norm(
inputs,
self.scale.as_ref(),
self.offset.as_ref(),
self.running_mean.as_ref(),
self.running_var.as_ref(),
self.epsilon,
)
}
fn parameters(&self) -> Vec<Box<dyn ArrayProtocol>> {
vec![self.scale.clone(), self.offset.clone()]
}
fn parameters_mut(&mut self) -> Vec<&mut Box<dyn ArrayProtocol>> {
vec![&mut self.scale, &mut self.offset]
}
fn update_parameter(
&mut self,
name: &str,
value: Box<dyn ArrayProtocol>,
) -> Result<(), OperationError> {
match name {
"scale" => {
self.scale = value;
Ok(())
}
"offset" => {
self.offset = value;
Ok(())
}
_ => Err(OperationError::Other(format!("Unknown parameter: {name}"))),
}
}
fn parameter_names(&self) -> Vec<String> {
vec!["scale".to_string(), "offset".to_string()]
}
fn train(&mut self) {
self.training = true;
}
fn eval(&mut self) {
self.training = false;
}
fn is_training(&self) -> bool {
self.training
}
fn name(&self) -> &str {
&self.name
}
}
pub struct Dropout {
name: String,
rate: f64,
seed: Option<u64>,
training: bool,
}
impl Dropout {
pub fn new(name: &str, rate: f64, seed: Option<u64>) -> Self {
Self {
name: name.to_string(),
rate,
seed,
training: true,
}
}
}
impl Layer for Dropout {
fn layer_type(&self) -> &str {
"Dropout"
}
fn forward(
&self,
inputs: &dyn ArrayProtocol,
) -> Result<Box<dyn ArrayProtocol>, OperationError> {
crate::array_protocol::ml_ops::dropout(inputs, self.rate, self.training, self.seed)
}
fn parameters(&self) -> Vec<Box<dyn ArrayProtocol>> {
Vec::new()
}
fn parameters_mut(&mut self) -> Vec<&mut Box<dyn ArrayProtocol>> {
Vec::new()
}
fn update_parameter(
&mut self,
name: &str,
_value: Box<dyn ArrayProtocol>,
) -> Result<(), OperationError> {
Err(OperationError::Other(format!(
"Dropout has no parameter: {name}"
)))
}
fn parameter_names(&self) -> Vec<String> {
Vec::new()
}
fn train(&mut self) {
self.training = true;
}
fn eval(&mut self) {
self.training = false;
}
fn is_training(&self) -> bool {
self.training
}
fn name(&self) -> &str {
&self.name
}
}
pub struct MultiHeadAttention {
name: String,
wq: Box<dyn ArrayProtocol>,
wk: Box<dyn ArrayProtocol>,
wv: Box<dyn ArrayProtocol>,
wo: Box<dyn ArrayProtocol>,
num_heads: usize,
dmodel: usize,
training: bool,
}
impl MultiHeadAttention {
pub fn new(
name: &str,
wq: Box<dyn ArrayProtocol>,
wk: Box<dyn ArrayProtocol>,
wv: Box<dyn ArrayProtocol>,
wo: Box<dyn ArrayProtocol>,
num_heads: usize,
dmodel: usize,
) -> Self {
Self {
name: name.to_string(),
wq,
wk,
wv,
wo,
num_heads,
dmodel,
training: true,
}
}
pub fn with_params(name: &str, num_heads: usize, dmodel: usize) -> Self {
assert!(
dmodel % num_heads == 0,
"dmodel must be divisible by num_heads"
);
let scale = (1.0_f64 / dmodel as f64).sqrt();
let mut rng = rand::rng();
let wq = Array::from_shape_fn((dmodel, dmodel), |_| {
(rng.random::<f64>() * 2.0_f64 - 1.0) * scale
});
let wk = Array::from_shape_fn((dmodel, dmodel), |_| {
(rng.random::<f64>() * 2.0_f64 - 1.0) * scale
});
let wv = Array::from_shape_fn((dmodel, dmodel), |_| {
(rng.random::<f64>() * 2.0_f64 - 1.0) * scale
});
let wo = Array::from_shape_fn((dmodel, dmodel), |_| {
(rng.random::<f64>() * 2.0_f64 - 1.0) * scale
});
Self {
name: name.to_string(),
wq: Box::new(NdarrayWrapper::new(wq)),
wk: Box::new(NdarrayWrapper::new(wk)),
wv: Box::new(NdarrayWrapper::new(wv)),
wo: Box::new(NdarrayWrapper::new(wo)),
num_heads,
dmodel,
training: true,
}
}
}
impl Layer for MultiHeadAttention {
fn layer_type(&self) -> &str {
"MultiHeadAttention"
}
fn forward(
&self,
inputs: &dyn ArrayProtocol,
) -> Result<Box<dyn ArrayProtocol>, OperationError> {
let queries = crate::array_protocol::matmul(self.wq.as_ref(), inputs)?;
let keys = crate::array_protocol::matmul(self.wk.as_ref(), inputs)?;
let values = crate::array_protocol::matmul(self.wv.as_ref(), inputs)?;
let attention = crate::array_protocol::ml_ops::self_attention(
queries.as_ref(),
keys.as_ref(),
values.as_ref(),
None,
Some((self.dmodel / self.num_heads) as f64),
)?;
let output = crate::array_protocol::matmul(self.wo.as_ref(), attention.as_ref())?;
Ok(output)
}
fn parameters(&self) -> Vec<Box<dyn ArrayProtocol>> {
vec![
self.wq.clone(),
self.wk.clone(),
self.wv.clone(),
self.wo.clone(),
]
}
fn parameters_mut(&mut self) -> Vec<&mut Box<dyn ArrayProtocol>> {
vec![&mut self.wq, &mut self.wk, &mut self.wv, &mut self.wo]
}
fn update_parameter(
&mut self,
name: &str,
value: Box<dyn ArrayProtocol>,
) -> Result<(), OperationError> {
match name {
"wq" => {
self.wq = value;
Ok(())
}
"wk" => {
self.wk = value;
Ok(())
}
"wv" => {
self.wv = value;
Ok(())
}
"wo" => {
self.wo = value;
Ok(())
}
_ => Err(OperationError::Other(format!("Unknown parameter: {name}"))),
}
}
fn parameter_names(&self) -> Vec<String> {
vec![
"wq".to_string(),
"wk".to_string(),
"wv".to_string(),
"wo".to_string(),
]
}
fn train(&mut self) {
self.training = true;
}
fn eval(&mut self) {
self.training = false;
}
fn is_training(&self) -> bool {
self.training
}
fn name(&self) -> &str {
&self.name
}
}
pub struct Sequential {
name: String,
layers: Vec<Box<dyn Layer>>,
training: bool,
}
impl Sequential {
pub fn new(name: &str, layers: Vec<Box<dyn Layer>>) -> Self {
Self {
name: name.to_string(),
layers,
training: true,
}
}
pub fn add_layer(&mut self, layer: Box<dyn Layer>) {
self.layers.push(layer);
}
pub fn forward(
&self,
inputs: &dyn ArrayProtocol,
) -> Result<Box<dyn ArrayProtocol>, OperationError> {
let mut x: Box<dyn ArrayProtocol> = inputs.box_clone();
for layer in &self.layers {
let x_ref: &dyn ArrayProtocol = x.as_ref();
x = layer.forward(x_ref)?;
}
Ok(x)
}
pub fn parameters(&self) -> Vec<Box<dyn ArrayProtocol>> {
let mut params = Vec::new();
for layer in &self.layers {
params.extend(layer.parameters());
}
params
}
pub fn train(&mut self) {
self.training = true;
for layer in &mut self.layers {
layer.train();
}
}
pub fn eval(&mut self) {
self.training = false;
for layer in &mut self.layers {
layer.eval();
}
}
pub fn name(&self) -> &str {
&self.name
}
pub fn layers(&self) -> &[Box<dyn Layer>] {
&self.layers
}
pub fn backward(
&self,
_output: &dyn ArrayProtocol,
_target: &dyn ArrayProtocol,
) -> Result<crate::array_protocol::grad::GradientDict, crate::error::CoreError> {
Ok(crate::array_protocol::grad::GradientDict::new())
}
pub fn update_parameter(
&mut self,
param_name: &str,
gradient: &dyn ArrayProtocol,
learningrate: f64,
) -> Result<(), crate::error::CoreError> {
let parts: Vec<&str> = param_name.split('.').collect();
if parts.len() != 2 {
return Err(crate::error::CoreError::ValueError(
crate::error::ErrorContext::new(format!(
"Invalid parameter name format. Expected 'layer_index.param_name', got: {param_name}"
)),
));
}
let layer_index: usize = parts[0].parse().map_err(|_| {
crate::error::CoreError::ValueError(crate::error::ErrorContext::new(format!(
"Invalid layer index: {layer_idx}",
layer_idx = parts[0]
)))
})?;
let param_name = parts[1];
if layer_index >= self.layers.len() {
return Err(crate::error::CoreError::ValueError(
crate::error::ErrorContext::new(format!(
"Layer index {layer_index} out of bounds (model has {num_layers} layers)",
num_layers = self.layers.len()
)),
));
}
let layer = &mut self.layers[layer_index];
let current_params = layer.parameters();
let param_names = layer.parameter_names();
let param_idx = param_names
.iter()
.position(|name| name == param_name)
.ok_or_else(|| {
crate::error::CoreError::ValueError(crate::error::ErrorContext::new(format!(
"Parameter '{param_name}' not found in layer {layer_index}"
)))
})?;
let current_param = ¤t_params[param_idx];
let scaled_gradient =
crate::array_protocol::operations::multiply_by_scalar_f64(gradient, learningrate)
.map_err(|e| {
crate::error::CoreError::ComputationError(crate::error::ErrorContext::new(
format!("Failed to scale gradient: {e}"),
))
})?;
let updated_param = crate::array_protocol::operations::subtract(
current_param.as_ref(),
scaled_gradient.as_ref(),
)
.map_err(|e| {
crate::error::CoreError::ComputationError(crate::error::ErrorContext::new(format!(
"Failed to update parameter: {e}"
)))
})?;
layer
.update_parameter(param_name, updated_param)
.map_err(|e| {
crate::error::CoreError::ComputationError(crate::error::ErrorContext::new(format!(
"Failed to set parameter in layer: {e}"
)))
})?;
Ok(())
}
pub fn all_parameter_names(&self) -> Vec<String> {
let mut all_names = Vec::new();
for (layer_idx, layer) in self.layers.iter().enumerate() {
let layer_param_names = layer.parameter_names();
for param_name in layer_param_names {
all_names.push(format!("{layer_idx}.{param_name}"));
}
}
all_names
}
pub fn all_parameters(&self) -> Vec<Box<dyn ArrayProtocol>> {
let mut all_params = Vec::new();
for layer in &self.layers {
all_params.extend(layer.parameters());
}
all_params
}
}
#[allow(dead_code)]
pub fn create_simple_cnn(inputshape: (usize, usize, usize), num_classes: usize) -> Sequential {
let (height, width, channels) = inputshape;
let mut model = Sequential::new("SimpleCNN", Vec::new());
model.add_layer(Box::new(Conv2D::withshape(
"conv1",
3,
3, channels,
32, (1, 1), (1, 1), true, Some(ActivationFunc::ReLU),
)));
model.add_layer(Box::new(MaxPool2D::new(
"pool1",
(2, 2), None, (0, 0), )));
model.add_layer(Box::new(Conv2D::withshape(
"conv2",
3,
3, 32,
64, (1, 1), (1, 1), true, Some(ActivationFunc::ReLU),
)));
model.add_layer(Box::new(MaxPool2D::new(
"pool2",
(2, 2), None, (0, 0), )));
model.add_layer(Box::new(Linear::new_random(
"fc1",
64 * (height / 4) * (width / 4), 128, true, Some(ActivationFunc::ReLU),
)));
model.add_layer(Box::new(Dropout::new(
"dropout", 0.5, None, )));
model.add_layer(Box::new(Linear::new_random(
"fc2",
128, num_classes, true, None, )));
model
}
#[cfg(test)]
mod tests {
use super::*;
use crate::array_protocol::{self, NdarrayWrapper};
use ndarray::{Array1, Array2};
#[test]
fn test_linear_layer() {
array_protocol::init();
let weights = Array2::<f64>::eye(3);
let bias = Array1::<f64>::ones(3);
let layer = Linear::new(
"linear",
Box::new(NdarrayWrapper::new(weights)),
Some(Box::new(NdarrayWrapper::new(bias))),
Some(ActivationFunc::ReLU),
);
assert_eq!(layer.name(), "linear");
assert!(layer.is_training());
}
#[test]
fn test_sequential_model() {
array_protocol::init();
let mut model = Sequential::new("test_model", Vec::new());
model.add_layer(Box::new(Linear::new_random(
"fc1",
3, 2, true, Some(ActivationFunc::ReLU),
)));
model.add_layer(Box::new(Linear::new_random(
"fc2",
2, 1, true, Some(ActivationFunc::Sigmoid),
)));
assert_eq!(model.name(), "test_model");
assert_eq!(model.layers().len(), 2);
assert!(model.training);
}
#[test]
fn test_simple_cnn_creation() {
array_protocol::init();
let model = create_simple_cnn((28, 28, 1), 10);
assert_eq!(model.layers().len(), 7);
assert_eq!(model.name(), "SimpleCNN");
let params = model.parameters();
assert!(!params.is_empty());
}
}