torsh-nn
Neural network modules for ToRSh with PyTorch-compatible API, powered by scirs2-neural.
Overview
This crate provides comprehensive neural network building blocks including:
- Common layers (Linear, Conv2d, BatchNorm, etc.)
- Activation functions (ReLU, Sigmoid, GELU, etc.)
- Container modules (Sequential, ModuleList, ModuleDict)
- Parameter initialization utilities
- Functional API for stateless operations
Usage
Basic Module Definition
use *;
use *;
// Define a simple neural network
Using Sequential Container
use *;
let model = new
.add
.add
.add
.add
.add
.add;
let output = model.forward?;
Functional API
use functional as F;
// Activation functions
let x = relu;
let x = gelu;
let x = softmax?;
// Pooling
let x = max_pool2d?;
let x = global_avg_pool2d?;
// Loss functions
let loss = cross_entropy?;
let loss = mse_loss?;
Parameter Initialization
use init;
// Xavier/Glorot initialization
let weight = xavier_uniform;
// Kaiming/He initialization for ReLU
let weight = kaiming_normal;
// Initialize existing tensor
let mut tensor = zeros;
init_tensor;
Common Layers
Linear Layer
let linear = new;
Convolutional Layer
let conv = new;
Batch Normalization
let bn = new;
LSTM
let lstm = LSTMnew;
Container Modules
ModuleList
let mut layers = new;
layers.append;
layers.append;
// Access by index
if let Some = layers.get
ModuleDict
let mut blocks = new;
blocks.insert;
blocks.insert;
// Access by key
if let Some = blocks.get
Parameter Management
use utils;
// Count parameters
let total = count_parameters;
let trainable = count_trainable_parameters;
// Freeze/unfreeze parameters
freeze_parameters;
unfreeze_parameters;
// Get parameter statistics
let stats = parameter_stats;
println!;
// Gradient clipping
clip_grad_norm_;
clip_grad_value_;
Integration with SciRS2
This crate leverages scirs2-neural for:
- Optimized layer implementations
- Automatic differentiation support
- Hardware acceleration
- Memory-efficient operations
All modules are designed to work seamlessly with ToRSh's autograd system while benefiting from scirs2's performance optimizations.
License
Licensed under the Apache License, Version 2.0. See LICENSE for details.