use crate::device::Device;
use crate::error::{NnlError, Result};
use crate::layers::{Layer, LayerConfig, create_layer};
use crate::losses::LossFunction;
use crate::network::Network;
use crate::optimizers::{OptimizerConfig, create_optimizer};
#[derive(Debug)]
pub struct NetworkBuilder {
layers: Vec<LayerConfig>,
loss_function: Option<LossFunction>,
optimizer: Option<OptimizerConfig>,
device: Option<Device>,
name: Option<String>,
description: Option<String>,
}
impl NetworkBuilder {
pub fn new() -> Self {
Self {
layers: Vec::new(),
loss_function: None,
optimizer: None,
device: None,
name: None,
description: None,
}
}
pub fn add_layer(mut self, layer_config: LayerConfig) -> Self {
self.layers.push(layer_config);
self
}
pub fn add_layers(mut self, layer_configs: Vec<LayerConfig>) -> Self {
self.layers.extend(layer_configs);
self
}
pub fn loss(mut self, loss_function: LossFunction) -> Self {
self.loss_function = Some(loss_function);
self
}
pub fn optimizer(mut self, optimizer: OptimizerConfig) -> Self {
self.optimizer = Some(optimizer);
self
}
pub fn device(mut self, device: Device) -> Self {
self.device = Some(device);
self
}
pub fn name<S: Into<String>>(mut self, name: S) -> Self {
self.name = Some(name.into());
self
}
pub fn description<S: Into<String>>(mut self, description: S) -> Self {
self.description = Some(description.into());
self
}
pub fn build(self) -> Result<Network> {
if self.layers.is_empty() {
return Err(NnlError::network("Network must have at least one layer"));
}
let loss_function = self
.loss_function
.ok_or_else(|| NnlError::network("Loss function must be specified"))?;
let optimizer_config = self
.optimizer
.ok_or_else(|| NnlError::network("Optimizer must be specified"))?;
let device = if let Some(device) = self.device {
device
} else {
Device::auto_select()?
};
let mut layers: Vec<Box<dyn Layer>> = Vec::new();
for layer_config in &self.layers {
let layer = create_layer(layer_config.clone(), device.clone())?;
layers.push(layer);
}
NetworkBuilder::validate_layer_compatibility(&self.layers, &layers)?;
let optimizer = create_optimizer(optimizer_config.clone())?;
Network::new_with_configs(
layers,
self.layers,
loss_function,
optimizer_config,
optimizer,
device,
)
}
fn validate_layer_compatibility(
_layer_configs: &[LayerConfig],
layers: &[Box<dyn Layer>],
) -> Result<()> {
if layers.is_empty() {
return Err(NnlError::network("No layers created"));
}
Ok(())
}
}
impl Default for NetworkBuilder {
fn default() -> Self {
Self::new()
}
}
impl NetworkBuilder {
pub fn feedforward(input_size: usize, hidden_sizes: Vec<usize>, output_size: usize) -> Self {
let mut builder = Self::new();
if let Some(&first_hidden) = hidden_sizes.first() {
builder = builder.add_layer(LayerConfig::dense_relu(input_size, first_hidden));
}
for window in hidden_sizes.windows(2) {
builder = builder.add_layer(LayerConfig::dense_relu(window[0], window[1]));
}
if let Some(&last_hidden) = hidden_sizes.last() {
builder = builder.add_layer(LayerConfig::dense_linear(last_hidden, output_size));
} else {
builder = builder.add_layer(LayerConfig::dense_linear(input_size, output_size));
}
builder
}
pub fn cnn_classifier(input_channels: usize, num_classes: usize, image_size: usize) -> Self {
let mut builder = Self::new();
builder = builder
.add_layer(LayerConfig::conv2d_3x3(input_channels, 32))
.add_layer(LayerConfig::max_pool2d())
.add_layer(LayerConfig::conv2d_3x3(32, 64))
.add_layer(LayerConfig::max_pool2d())
.add_layer(LayerConfig::conv2d_3x3(64, 128))
.add_layer(LayerConfig::max_pool2d());
let pooled_size = image_size / 8; let flattened_size = 128 * pooled_size * pooled_size;
builder = builder
.add_layer(LayerConfig::flatten())
.add_layer(LayerConfig::dense_relu(flattened_size, 512))
.add_layer(LayerConfig::dropout(0.5))
.add_layer(LayerConfig::dense_linear(512, num_classes));
builder
}
pub fn binary_classifier(input_size: usize, hidden_sizes: Vec<usize>) -> Self {
let mut builder = Self::feedforward(input_size, hidden_sizes, 1);
builder = builder.loss(LossFunction::BinaryCrossEntropy);
builder = builder.optimizer(OptimizerConfig::adam(0.001));
builder
}
pub fn multiclass_classifier(
input_size: usize,
hidden_sizes: Vec<usize>,
num_classes: usize,
) -> Self {
let mut builder = Self::feedforward(input_size, hidden_sizes, num_classes);
builder = builder.loss(LossFunction::CrossEntropy);
builder = builder.optimizer(OptimizerConfig::adam(0.001));
builder
}
pub fn regressor(input_size: usize, hidden_sizes: Vec<usize>, output_size: usize) -> Self {
let mut builder = Self::feedforward(input_size, hidden_sizes, output_size);
builder = builder.loss(LossFunction::MeanSquaredError);
builder = builder.optimizer(OptimizerConfig::adam(0.001));
builder
}
pub fn autoencoder(input_size: usize, encoding_size: usize) -> Self {
let mut builder = Self::new();
let mut current_size = input_size;
while current_size > encoding_size {
let next_size = (current_size + encoding_size) / 2;
builder = builder.add_layer(LayerConfig::dense_relu(current_size, next_size));
current_size = next_size;
}
while current_size < input_size {
let next_size = if current_size * 2 > input_size {
input_size
} else {
current_size * 2
};
builder = builder.add_layer(LayerConfig::dense_relu(current_size, next_size));
current_size = next_size;
}
builder = builder
.loss(LossFunction::MeanSquaredError)
.optimizer(OptimizerConfig::adam(0.001));
builder
}
}
pub mod presets {
use super::*;
use crate::activations::Activation;
use crate::layers::WeightInit;
pub fn lenet5() -> NetworkBuilder {
NetworkBuilder::new()
.add_layer(LayerConfig::Conv2D {
in_channels: 1,
out_channels: 6,
kernel_size: (5, 5),
stride: (1, 1),
padding: (0, 0),
dilation: (1, 1),
activation: Activation::Tanh,
use_bias: true,
weight_init: WeightInit::Xavier,
})
.add_layer(LayerConfig::MaxPool2D {
kernel_size: (2, 2),
stride: Some((2, 2)),
padding: (0, 0),
})
.add_layer(LayerConfig::Conv2D {
in_channels: 6,
out_channels: 16,
kernel_size: (5, 5),
stride: (1, 1),
padding: (0, 0),
dilation: (1, 1),
activation: Activation::Tanh,
use_bias: true,
weight_init: WeightInit::Xavier,
})
.add_layer(LayerConfig::MaxPool2D {
kernel_size: (2, 2),
stride: Some((2, 2)),
padding: (0, 0),
})
.add_layer(LayerConfig::Flatten {
start_dim: 1,
end_dim: None,
})
.add_layer(LayerConfig::Dense {
input_size: 16 * 5 * 5,
output_size: 120,
activation: Activation::Tanh,
use_bias: true,
weight_init: WeightInit::Xavier,
})
.add_layer(LayerConfig::Dense {
input_size: 120,
output_size: 84,
activation: Activation::Tanh,
use_bias: true,
weight_init: WeightInit::Xavier,
})
.add_layer(LayerConfig::Dense {
input_size: 84,
output_size: 10,
activation: Activation::Linear,
use_bias: true,
weight_init: WeightInit::Xavier,
})
.loss(LossFunction::CrossEntropy)
.optimizer(OptimizerConfig::sgd(0.01))
}
pub fn xor_network() -> NetworkBuilder {
NetworkBuilder::new()
.add_layer(LayerConfig::Dense {
input_size: 2,
output_size: 4,
activation: Activation::ReLU,
use_bias: true,
weight_init: WeightInit::Xavier,
})
.add_layer(LayerConfig::Dense {
input_size: 4,
output_size: 1,
activation: Activation::Sigmoid,
use_bias: true,
weight_init: WeightInit::Xavier,
})
.loss(LossFunction::BinaryCrossEntropy)
.optimizer(OptimizerConfig::adam(0.01))
}
pub fn mnist_classifier() -> NetworkBuilder {
NetworkBuilder::new()
.add_layer(LayerConfig::Dense {
input_size: 784,
output_size: 128,
activation: Activation::ReLU,
use_bias: true,
weight_init: WeightInit::HeNormal,
})
.add_layer(LayerConfig::Dropout { dropout_rate: 0.2 })
.add_layer(LayerConfig::Dense {
input_size: 128,
output_size: 64,
activation: Activation::ReLU,
use_bias: true,
weight_init: WeightInit::HeNormal,
})
.add_layer(LayerConfig::Dropout { dropout_rate: 0.2 })
.add_layer(LayerConfig::Dense {
input_size: 64,
output_size: 10,
activation: Activation::Softmax,
use_bias: true,
weight_init: WeightInit::Xavier,
})
.loss(LossFunction::CrossEntropy)
.optimizer(OptimizerConfig::adam(0.001))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_builder_basic() {
let network = NetworkBuilder::new()
.add_layer(LayerConfig::dense_relu(2, 4))
.add_layer(LayerConfig::dense_sigmoid(4, 1))
.loss(LossFunction::MeanSquaredError)
.optimizer(OptimizerConfig::sgd(0.1))
.build();
assert!(network.is_ok());
let network = network.unwrap();
assert_eq!(network.num_layers(), 2);
}
#[test]
fn test_builder_missing_loss() {
let result = NetworkBuilder::new()
.add_layer(LayerConfig::dense_relu(2, 1))
.optimizer(OptimizerConfig::sgd(0.1))
.build();
assert!(result.is_err());
}
#[test]
fn test_builder_missing_optimizer() {
let result = NetworkBuilder::new()
.add_layer(LayerConfig::dense_relu(2, 1))
.loss(LossFunction::MeanSquaredError)
.build();
assert!(result.is_err());
}
#[test]
fn test_builder_no_layers() {
let result = NetworkBuilder::new()
.loss(LossFunction::MeanSquaredError)
.optimizer(OptimizerConfig::sgd(0.1))
.build();
assert!(result.is_err());
}
#[test]
fn test_feedforward_builder() {
let network = NetworkBuilder::feedforward(784, vec![128, 64], 10)
.loss(LossFunction::CrossEntropy)
.optimizer(OptimizerConfig::adam(0.001))
.build();
assert!(network.is_ok());
let network = network.unwrap();
assert_eq!(network.num_layers(), 3); }
#[test]
fn test_binary_classifier_builder() {
let network = NetworkBuilder::binary_classifier(10, vec![5, 3]).build();
assert!(network.is_ok());
let network = network.unwrap();
assert_eq!(network.num_layers(), 3);
}
#[test]
fn test_multiclass_classifier_builder() {
let network = NetworkBuilder::multiclass_classifier(784, vec![128], 10).build();
assert!(network.is_ok());
let network = network.unwrap();
assert_eq!(network.num_layers(), 2);
}
#[test]
fn test_regressor_builder() {
let network = NetworkBuilder::regressor(5, vec![10, 5], 1).build();
assert!(network.is_ok());
let network = network.unwrap();
assert_eq!(network.num_layers(), 3);
}
#[test]
fn test_preset_xor_network() {
let network = presets::xor_network().build();
assert!(network.is_ok());
let network = network.unwrap();
assert_eq!(network.num_layers(), 2);
}
#[test]
fn test_preset_mnist_classifier() {
let network = presets::mnist_classifier().build();
assert!(network.is_ok());
let network = network.unwrap();
assert_eq!(network.num_layers(), 5); }
#[test]
fn test_builder_fluent_api() {
let network = NetworkBuilder::new()
.name("Test Network")
.description("A test neural network")
.add_layer(LayerConfig::dense_relu(10, 5))
.add_layer(LayerConfig::dense_linear(5, 1))
.loss(LossFunction::MeanSquaredError)
.optimizer(OptimizerConfig::adam(0.001))
.build();
assert!(network.is_ok());
}
#[test]
fn test_autoencoder_builder() {
let network = NetworkBuilder::autoencoder(784, 32).build();
assert!(network.is_ok());
let network = network.unwrap();
assert!(network.num_layers() > 2); }
}