axonml-nn 0.4.3

Neural network modules for Axonml ML framework
Documentation

axonml-nn

Overview

axonml-nn provides neural network building blocks for the AxonML framework. It includes layers, activation functions, loss functions, and utilities for constructing and training deep learning models with a PyTorch-like API.

Features

  • Module Trait - Core interface for all neural network components with parameter management and train/eval modes.

  • Comprehensive Layers - Linear, Conv1d/Conv2d, RNN/LSTM/GRU, Embedding, BatchNorm, LayerNorm, Dropout, and MultiHeadAttention. RNN/LSTM/GRU layers use batched matmul to process all timesteps efficiently, avoiding per-step allocation overhead.

  • Activation Functions - ReLU, Sigmoid, Tanh, GELU, SiLU, ELU, LeakyReLU, Softmax, and LogSoftmax.

  • Loss Functions - MSELoss, CrossEntropyLoss, BCELoss, BCEWithLogitsLoss, NLLLoss, L1Loss, and SmoothL1Loss.

  • Weight Initialization - Xavier/Glorot, Kaiming/He, orthogonal, sparse, and custom initialization schemes.

  • Sequential Container - Easy model composition by chaining layers together.

Modules

Module Description
module Core Module trait and ModuleList container for neural network components
parameter Parameter wrapper for learnable weights with gradient tracking
sequential Sequential container for chaining modules in order
layers Neural network layers (Linear, Conv, RNN, Attention, Norm, Pooling, Embedding, Dropout)
activation Activation function modules (ReLU, Sigmoid, Tanh, GELU, etc.)
loss Loss function modules (MSE, CrossEntropy, BCE, etc.)
init Weight initialization functions (Xavier, Kaiming, orthogonal, etc.)
functional Stateless functional versions of operations

Usage

Add this to your Cargo.toml:

[dependencies]
axonml-nn = "0.1.0"

Building a Simple MLP

use axonml_nn::prelude::*;
use axonml_autograd::Variable;
use axonml_tensor::Tensor;

// Build model using Sequential
let model = Sequential::new()
    .add(Linear::new(784, 256))
    .add(ReLU)
    .add(Linear::new(256, 128))
    .add(ReLU)
    .add(Linear::new(128, 10));

// Create input
let input = Variable::new(
    Tensor::from_vec(vec![0.5; 784], &[1, 784]).unwrap(),
    false
);

// Forward pass
let output = model.forward(&input);
assert_eq!(output.shape(), vec![1, 10]);

// Get all parameters
let params = model.parameters();
println!("Total parameters: {}", model.num_parameters());

Convolutional Neural Network

use axonml_nn::prelude::*;
use axonml_autograd::Variable;
use axonml_tensor::Tensor;

let model = Sequential::new()
    .add(Conv2d::new(1, 32, 3))      // [B, 1, 28, 28] -> [B, 32, 26, 26]
    .add(ReLU)
    .add(MaxPool2d::new(2))          // -> [B, 32, 13, 13]
    .add(Conv2d::new(32, 64, 3))     // -> [B, 64, 11, 11]
    .add(ReLU)
    .add(MaxPool2d::new(2));         // -> [B, 64, 5, 5]

let input = Variable::new(
    Tensor::from_vec(vec![0.5; 784], &[1, 1, 28, 28]).unwrap(),
    false
);
let features = model.forward(&input);

Recurrent Neural Network

use axonml_nn::prelude::*;
use axonml_autograd::Variable;
use axonml_tensor::Tensor;

// LSTM for sequence modeling
let lstm = LSTM::new(
    64,   // input_size
    128,  // hidden_size
    2     // num_layers
);

// Input: [batch, seq_len, input_size]
let input = Variable::new(
    Tensor::from_vec(vec![0.5; 640], &[2, 5, 64]).unwrap(),
    false
);
let output = lstm.forward(&input);  // [2, 5, 128]

Transformer Attention

use axonml_nn::prelude::*;
use axonml_autograd::Variable;
use axonml_tensor::Tensor;

let attention = MultiHeadAttention::new(
    512,  // embed_dim
    8     // num_heads
);

// Input: [batch, seq_len, embed_dim]
let input = Variable::new(
    Tensor::from_vec(vec![0.5; 5120], &[2, 5, 512]).unwrap(),
    false
);
let output = attention.forward(&input);  // [2, 5, 512]

Loss Functions

use axonml_nn::prelude::*;
use axonml_autograd::Variable;
use axonml_tensor::Tensor;

// Cross Entropy Loss for classification
let logits = Variable::new(
    Tensor::from_vec(vec![1.0, 2.0, 3.0, 1.0, 2.0, 3.0], &[2, 3]).unwrap(),
    true
);
let targets = Variable::new(
    Tensor::from_vec(vec![2.0, 0.0], &[2]).unwrap(),
    false
);

let loss_fn = CrossEntropyLoss::new();
let loss = loss_fn.compute(&logits, &targets);
loss.backward();

// MSE Loss for regression
let mse = MSELoss::new();
let pred = Variable::new(Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap(), true);
let target = Variable::new(Tensor::from_vec(vec![1.5, 2.5], &[2]).unwrap(), false);
let loss = mse.compute(&pred, &target);

Weight Initialization

use axonml_nn::init::*;

// Xavier/Glorot initialization
let weights = xavier_uniform(256, 128);
let weights = xavier_normal(256, 128);

// Kaiming/He initialization (for ReLU networks)
let weights = kaiming_uniform(256, 128);
let weights = kaiming_normal(256, 128);

// Other initializations
let zeros_tensor = zeros(&[3, 3]);
let ones_tensor = ones(&[3, 3]);
let eye_tensor = eye(4);
let ortho_tensor = orthogonal(64, 64);

Training/Evaluation Mode

use axonml_nn::prelude::*;

let mut model = Sequential::new()
    .add(Linear::new(10, 5))
    .add(Dropout::new(0.5))
    .add(Linear::new(5, 2));

// Training mode (dropout active)
model.train();
assert!(model.is_training());

// Evaluation mode (dropout disabled)
model.eval();
assert!(!model.is_training());

// Zero gradients before backward pass
model.zero_grad();

Differentiable Structured Sparsity (novel)

Learn which weights to prune end-to-end — the pruning mask is differentiable.

use axonml_nn::layers::sparse::{SparseLinear, GroupSparsity, LotteryTicket};
use axonml_autograd::Variable;
use axonml_tensor::Tensor;

// SparseLinear: soft-thresholded magnitude pruning
let sparse = SparseLinear::new(256, 128, 0.5, 10.0); // threshold=0.5, temperature=10.0
let input = Variable::new(Tensor::randn(&[4, 256]), true);
let output = sparse.forward(&input); // [4, 128]

// Check actual sparsity
let sparsity = sparse.sparsity(); // fraction of weights effectively pruned
println!("Sparsity: {:.1}%", sparsity * 100.0);

// Group sparsity regularization
let group_reg = GroupSparsity::new(0.01, "row"); // L1 penalty on row norms
let reg_loss = group_reg.compute(&sparse);

// Lottery Ticket Hypothesis
let mut ticket = LotteryTicket::new(&sparse);
ticket.snapshot(); // save initial weights
// ... train for a while ...
ticket.prune(0.2); // prune bottom 20% by magnitude
ticket.rewind(&mut sparse); // rewind to initial weights with discovered mask

Tests

Run the test suite (171 tests):

cargo test -p axonml-nn

License

Licensed under either of:

at your option.