Skip to main content

axonml_nn/
lib.rs

1//! axonml-nn - Neural Network Module Library
2//!
3//! Provides neural network layers, activation functions, loss functions,
4//! and utilities for building deep learning models in Axonml.
5//!
6//! # Key Components
7//!
8//! - **Module trait**: Core interface for all neural network modules
9//! - **Parameter**: Wrapper for learnable parameters
10//! - **Sequential**: Container for chaining modules
11//! - **Layers**: Linear, Conv, RNN, LSTM, Attention, etc.
12//! - **Activations**: ReLU, Sigmoid, Tanh, GELU, etc.
13//! - **Loss Functions**: MSE, CrossEntropy, BCE, etc.
14//! - **Initialization**: Xavier, Kaiming, orthogonal, etc.
15//! - **Functional API**: Stateless operations
16//!
17//! # Example
18//!
19//! ```ignore
20//! use axonml_nn::prelude::*;
21//!
22//! // Build a simple MLP
23//! let model = Sequential::new()
24//!     .add(Linear::new(784, 256))
25//!     .add(ReLU)
26//!     .add(Linear::new(256, 10));
27//!
28//! // Forward pass
29//! let output = model.forward(&input);
30//!
31//! // Compute loss
32//! let loss = CrossEntropyLoss::new().compute(&output, &target);
33//!
34//! // Backward pass
35//! loss.backward();
36//! ```
37//!
38//! @version 0.1.0
39//! @author AutomataNexus Development Team
40
41#![warn(missing_docs)]
42#![warn(clippy::all)]
43#![warn(clippy::pedantic)]
44// ML/tensor-specific allowances
45#![allow(clippy::cast_possible_truncation)]
46#![allow(clippy::cast_sign_loss)]
47#![allow(clippy::cast_precision_loss)]
48#![allow(clippy::cast_possible_wrap)]
49#![allow(clippy::missing_errors_doc)]
50#![allow(clippy::missing_panics_doc)]
51#![allow(clippy::must_use_candidate)]
52#![allow(clippy::module_name_repetitions)]
53#![allow(clippy::similar_names)]
54#![allow(clippy::many_single_char_names)]
55#![allow(clippy::too_many_arguments)]
56#![allow(clippy::doc_markdown)]
57#![allow(clippy::cast_lossless)]
58#![allow(clippy::needless_pass_by_value)]
59#![allow(clippy::redundant_closure_for_method_calls)]
60#![allow(clippy::uninlined_format_args)]
61#![allow(clippy::ptr_arg)]
62#![allow(clippy::return_self_not_must_use)]
63#![allow(clippy::not_unsafe_ptr_arg_deref)]
64#![allow(clippy::items_after_statements)]
65#![allow(clippy::unreadable_literal)]
66#![allow(clippy::if_same_then_else)]
67#![allow(clippy::needless_range_loop)]
68#![allow(clippy::trivially_copy_pass_by_ref)]
69#![allow(clippy::unnecessary_wraps)]
70#![allow(clippy::match_same_arms)]
71#![allow(clippy::unused_self)]
72#![allow(clippy::too_many_lines)]
73#![allow(clippy::single_match_else)]
74#![allow(clippy::fn_params_excessive_bools)]
75#![allow(clippy::struct_excessive_bools)]
76#![allow(clippy::format_push_string)]
77#![allow(clippy::erasing_op)]
78#![allow(clippy::type_repetition_in_bounds)]
79#![allow(clippy::iter_without_into_iter)]
80#![allow(clippy::should_implement_trait)]
81#![allow(clippy::use_debug)]
82#![allow(clippy::case_sensitive_file_extension_comparisons)]
83#![allow(clippy::large_enum_variant)]
84#![allow(clippy::panic)]
85#![allow(clippy::struct_field_names)]
86#![allow(clippy::missing_fields_in_debug)]
87#![allow(clippy::upper_case_acronyms)]
88#![allow(clippy::assigning_clones)]
89#![allow(clippy::option_if_let_else)]
90#![allow(clippy::manual_let_else)]
91#![allow(clippy::explicit_iter_loop)]
92#![allow(clippy::default_trait_access)]
93#![allow(clippy::only_used_in_recursion)]
94#![allow(clippy::manual_clamp)]
95#![allow(clippy::ref_option)]
96#![allow(clippy::multiple_bound_locations)]
97#![allow(clippy::comparison_chain)]
98#![allow(clippy::manual_assert)]
99#![allow(clippy::unnecessary_debug_formatting)]
100
101// =============================================================================
102// Module Declarations
103// =============================================================================
104
105pub mod activation;
106pub mod functional;
107pub mod init;
108pub mod layers;
109pub mod loss;
110pub mod module;
111pub mod parameter;
112pub mod sequential;
113
114// =============================================================================
115// Re-exports
116// =============================================================================
117
118pub use module::{Module, ModuleList};
119pub use parameter::Parameter;
120pub use sequential::Sequential;
121
122// Layer re-exports
123pub use layers::{
124    AdaptiveAvgPool2d, AvgPool1d, AvgPool2d, BatchNorm1d, BatchNorm2d, Conv1d, Conv2d, Dropout,
125    Embedding, GRUCell, GroupNorm, InstanceNorm2d, LSTMCell, LayerNorm, Linear, MaxPool1d,
126    MaxPool2d, MultiHeadAttention, RNNCell, GRU, LSTM, RNN,
127};
128
129// Activation re-exports
130pub use activation::{
131    Identity, LeakyReLU, LogSoftmax, ReLU, SiLU, Sigmoid, Softmax, Tanh, ELU, GELU,
132};
133
134// Loss re-exports
135pub use loss::{
136    BCELoss, BCEWithLogitsLoss, CrossEntropyLoss, L1Loss, MSELoss, NLLLoss, Reduction, SmoothL1Loss,
137};
138
139// Init re-exports
140pub use init::{
141    constant, diag, eye, glorot_normal, glorot_uniform, he_normal, he_uniform, kaiming_normal,
142    kaiming_uniform, normal, ones, orthogonal, randn, sparse, uniform, uniform_range,
143    xavier_normal, xavier_uniform, zeros, InitMode,
144};
145
146// =============================================================================
147// Prelude
148// =============================================================================
149
150/// Common imports for neural network development.
151pub mod prelude {
152    pub use crate::{
153        // Functional
154        functional,
155        AdaptiveAvgPool2d,
156        AvgPool1d,
157        AvgPool2d,
158        BCELoss,
159        BatchNorm1d,
160        BatchNorm2d,
161        Conv1d,
162        Conv2d,
163        CrossEntropyLoss,
164        Dropout,
165        Embedding,
166        GroupNorm,
167        Identity,
168        InstanceNorm2d,
169        L1Loss,
170        LayerNorm,
171        LeakyReLU,
172        // Layers
173        Linear,
174        MSELoss,
175        MaxPool1d,
176        MaxPool2d,
177        // Core traits and types
178        Module,
179        ModuleList,
180        MultiHeadAttention,
181        NLLLoss,
182        Parameter,
183        // Activations
184        ReLU,
185        // Loss functions
186        Reduction,
187        Sequential,
188        SiLU,
189        Sigmoid,
190        Softmax,
191        Tanh,
192        ELU,
193        GELU,
194        GRU,
195        LSTM,
196        RNN,
197    };
198}
199
200// =============================================================================
201// Tests
202// =============================================================================
203
204#[cfg(test)]
205mod tests {
206    use super::*;
207    use axonml_autograd::Variable;
208    use axonml_tensor::Tensor;
209
210    #[test]
211    fn test_simple_mlp() {
212        let model = Sequential::new()
213            .add(Linear::new(10, 5))
214            .add(ReLU)
215            .add(Linear::new(5, 2));
216
217        let input = Variable::new(Tensor::from_vec(vec![1.0; 20], &[2, 10]).unwrap(), false);
218        let output = model.forward(&input);
219        assert_eq!(output.shape(), vec![2, 2]);
220    }
221
222    #[test]
223    fn test_module_parameters() {
224        let model = Sequential::new()
225            .add(Linear::new(10, 5))
226            .add(Linear::new(5, 2));
227
228        let params = model.parameters();
229        // 2 Linear layers with weight + bias each = 4 parameters
230        assert_eq!(params.len(), 4);
231    }
232
233    #[test]
234    fn test_conv_model() {
235        let model = Sequential::new()
236            .add(Conv2d::new(1, 16, 3))
237            .add(ReLU)
238            .add(MaxPool2d::new(2));
239
240        let input = Variable::new(
241            Tensor::from_vec(vec![1.0; 784], &[1, 1, 28, 28]).unwrap(),
242            false,
243        );
244        let output = model.forward(&input);
245        // Conv2d: 28 -> 26, MaxPool2d: 26 -> 13
246        assert_eq!(output.shape(), vec![1, 16, 13, 13]);
247    }
248
249    #[test]
250    fn test_loss_computation() {
251        let pred = Variable::new(
252            Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]).unwrap(),
253            true,
254        );
255        let target = Variable::new(Tensor::from_vec(vec![0.0, 2.0], &[2]).unwrap(), false);
256
257        let loss_fn = CrossEntropyLoss::new();
258        let loss = loss_fn.compute(&pred, &target);
259        assert!(loss.numel() == 1);
260    }
261
262    #[test]
263    fn test_embedding_model() {
264        let emb = Embedding::new(100, 32);
265        let indices = Variable::new(
266            Tensor::from_vec(vec![0.0, 5.0, 10.0, 15.0], &[2, 2]).unwrap(),
267            false,
268        );
269        let output = emb.forward(&indices);
270        assert_eq!(output.shape(), vec![2, 2, 32]);
271    }
272
273    #[test]
274    fn test_rnn_model() {
275        let rnn = LSTM::new(10, 20, 1);
276        let input = Variable::new(Tensor::from_vec(vec![1.0; 60], &[2, 3, 10]).unwrap(), false);
277        let output = rnn.forward(&input);
278        assert_eq!(output.shape(), vec![2, 3, 20]);
279    }
280
281    #[test]
282    fn test_attention_model() {
283        let attn = MultiHeadAttention::new(64, 4);
284        let input = Variable::new(
285            Tensor::from_vec(vec![1.0; 640], &[2, 5, 64]).unwrap(),
286            false,
287        );
288        let output = attn.forward(&input);
289        assert_eq!(output.shape(), vec![2, 5, 64]);
290    }
291}