axonml-nn
Overview
axonml-nn provides neural network building blocks for the AxonML framework: the Module trait, the Parameter wrapper, the Sequential container, and a broad catalog of layers (dense, conv, pooling, norm, recurrent, attention, transformer, MoE, ternary, sparse, graph, FFT/STFT), activations, losses, and weight initializers.
Features
-
Module Trait - Core interface for all neural network components with parameter collection, train/eval mode, and
zero_grad. -
Dense and Conv Layers -
Linear,Conv1d,Conv2d,ConvTranspose2d,MaxPool1d,MaxPool2d,AvgPool1d,AvgPool2d,AdaptiveAvgPool2d. -
Normalization -
BatchNorm1d,BatchNorm2d,LayerNorm,GroupNorm,InstanceNorm2d. -
Recurrent Networks -
RNN,LSTM,GRUplus single-step cell variants (RNNCell,LSTMCell,GRUCell). Timesteps processed with batched matmul to avoid per-step allocation overhead. -
Attention and Transformers -
MultiHeadAttention,CrossAttention,DifferentialAttention,TransformerEncoderLayer/TransformerDecoderLayer,TransformerEncoder/TransformerDecoder, andSeq2SeqTransformerend-to-end. -
Mixture of Experts -
MoELayer,MoERouter, andExpertfor sparse expert routing. -
Ternary Weights -
TernaryLinearwithPackedTernaryWeightsfor 1.58-bit weight quantization. -
Graph Neural Networks -
GCNConvandGATConvfor graph convolution / attention. -
Spectral Layers -
FFT1dandSTFTvia rustfft for frequency-domain processing. -
Differentiable Structured Sparsity -
SparseLinear(soft-thresholded magnitude pruning),GroupSparsity(row/col L1 regularization),LotteryTicket(snapshot/prune/rewind). -
Other Building Blocks -
Embedding,Dropout,Dropout2d,ResidualBlock. -
Activations -
ReLU,LeakyReLU,Sigmoid,Tanh,GELU,SiLU,ELU,Softmax,LogSoftmax,Identity,Flatten. -
Losses -
MSELoss,L1Loss,SmoothL1Loss,CrossEntropyLoss,NLLLoss,BCELoss,BCEWithLogitsLoss, withReduction(Mean/Sum/None). -
Weight Initialization -
xavier_uniform,xavier_normal,glorot_uniform,glorot_normal,kaiming_uniform,kaiming_normal,he_uniform,he_normal,orthogonal,sparse,uniform,uniform_range,normal,constant,zeros,ones,eye,diag, plus theInitModeenum. -
Sequential Container -
Sequential::new().add(...)for quick model composition; alsoModuleListfor heterogeneous collections.
Modules
| Module | Description |
|---|---|
module |
Module trait and ModuleList container |
parameter |
Parameter wrapper for learnable weights with gradient tracking |
sequential |
Sequential container for chaining modules |
layers |
All layer types (see layers/ submodules: linear, conv, pooling, norm, rnn, attention, diff_attention, transformer, embedding, dropout, residual, moe, ternary, sparse, graph, fft) |
activation |
Activation function modules |
loss |
Loss function modules and Reduction enum |
init |
Weight initialization functions and InitMode |
functional |
Stateless functional versions of common operations |
Cargo Features
| Feature | Purpose |
|---|---|
cuda |
Forwards CUDA support to tensor / core |
cudnn |
Forwards cuDNN support (implies cuda) |
Usage
Add this to your Cargo.toml:
[]
= "0.6.1"
Building a Simple MLP
use *;
use Variable;
use Tensor;
// Build model using Sequential
let model = new
.add
.add
.add
.add
.add;
// Create input
let input = new;
// Forward pass
let output = model.forward;
assert_eq!;
// Get all parameters
let params = model.parameters;
println!;
Convolutional Neural Network
use *;
use Variable;
use Tensor;
let model = new
.add // [B, 1, 28, 28] -> [B, 32, 26, 26]
.add
.add // -> [B, 32, 13, 13]
.add // -> [B, 64, 11, 11]
.add
.add; // -> [B, 64, 5, 5]
let input = new;
let features = model.forward;
Recurrent Neural Network
use *;
use Variable;
use Tensor;
// LSTM for sequence modeling
let lstm = LSTMnew;
// Input: [batch, seq_len, input_size]
let input = new;
let output = lstm.forward; // [2, 5, 128]
Transformer Attention
use *;
use Variable;
use Tensor;
let attention = new;
// Input: [batch, seq_len, embed_dim]
let input = new;
let output = attention.forward; // [2, 5, 512]
Loss Functions
use *;
use Variable;
use Tensor;
// Cross Entropy Loss for classification
let logits = new;
let targets = new;
let loss_fn = new;
let loss = loss_fn.compute;
loss.backward;
// MSE Loss for regression
let mse = new;
let pred = new;
let target = new;
let loss = mse.compute;
Weight Initialization
use *;
// Xavier/Glorot initialization
let weights = xavier_uniform;
let weights = xavier_normal;
// Kaiming/He initialization (for ReLU networks)
let weights = kaiming_uniform;
let weights = kaiming_normal;
// Other initializations
let zeros_tensor = zeros;
let ones_tensor = ones;
let eye_tensor = eye;
let ortho_tensor = orthogonal;
Training/Evaluation Mode
use *;
let mut model = new
.add
.add
.add;
// Training mode (dropout active)
model.train;
assert!;
// Evaluation mode (dropout disabled)
model.eval;
assert!;
// Zero gradients before backward pass
model.zero_grad;
Differentiable Structured Sparsity
Learn which weights to prune end-to-end — the pruning mask is differentiable.
use ;
use Variable;
use Tensor;
// SparseLinear: soft-thresholded magnitude pruning
let sparse = new; // threshold=0.5, temperature=10.0
let input = new;
let output = sparse.forward; // [4, 128]
// Check actual sparsity
let sparsity = sparse.sparsity; // fraction of weights effectively pruned
println!;
// Group sparsity regularization
let group_reg = new; // L1 penalty on row norms
let reg_loss = group_reg.compute;
// Lottery Ticket Hypothesis
let mut ticket = new;
ticket.snapshot; // save initial weights
// ... train for a while ...
ticket.prune; // prune bottom 20% by magnitude
ticket.rewind; // rewind to initial weights with discovered mask
Ternary Weights
use TernaryLinear;
// 1.58-bit ternary linear layer ({-1, 0, +1})
let layer = new;
Mixture of Experts
use MoELayer;
// Sparse top-k expert routing
let moe = new;
Tests
Run the test suite:
License
Licensed under either of:
- Apache License, Version 2.0 (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0)
- MIT license (LICENSE-MIT or http://opensource.org/licenses/MIT)
at your option.
Last updated: 2026-04-16 (v0.6.1)