1#![warn(missing_docs)]
42#![warn(clippy::all)]
43#![warn(clippy::pedantic)]
44#![allow(clippy::cast_possible_truncation)]
46#![allow(clippy::cast_sign_loss)]
47#![allow(clippy::cast_precision_loss)]
48#![allow(clippy::cast_possible_wrap)]
49#![allow(clippy::missing_errors_doc)]
50#![allow(clippy::missing_panics_doc)]
51#![allow(clippy::must_use_candidate)]
52#![allow(clippy::module_name_repetitions)]
53#![allow(clippy::similar_names)]
54#![allow(clippy::many_single_char_names)]
55#![allow(clippy::too_many_arguments)]
56#![allow(clippy::doc_markdown)]
57#![allow(clippy::cast_lossless)]
58#![allow(clippy::needless_pass_by_value)]
59#![allow(clippy::redundant_closure_for_method_calls)]
60#![allow(clippy::uninlined_format_args)]
61#![allow(clippy::ptr_arg)]
62#![allow(clippy::return_self_not_must_use)]
63#![allow(clippy::not_unsafe_ptr_arg_deref)]
64#![allow(clippy::items_after_statements)]
65#![allow(clippy::unreadable_literal)]
66#![allow(clippy::if_same_then_else)]
67#![allow(clippy::needless_range_loop)]
68#![allow(clippy::trivially_copy_pass_by_ref)]
69#![allow(clippy::unnecessary_wraps)]
70#![allow(clippy::match_same_arms)]
71#![allow(clippy::unused_self)]
72#![allow(clippy::too_many_lines)]
73#![allow(clippy::single_match_else)]
74#![allow(clippy::fn_params_excessive_bools)]
75#![allow(clippy::struct_excessive_bools)]
76#![allow(clippy::format_push_string)]
77#![allow(clippy::erasing_op)]
78#![allow(clippy::type_repetition_in_bounds)]
79#![allow(clippy::iter_without_into_iter)]
80#![allow(clippy::should_implement_trait)]
81#![allow(clippy::use_debug)]
82#![allow(clippy::case_sensitive_file_extension_comparisons)]
83#![allow(clippy::large_enum_variant)]
84#![allow(clippy::panic)]
85#![allow(clippy::struct_field_names)]
86#![allow(clippy::missing_fields_in_debug)]
87#![allow(clippy::upper_case_acronyms)]
88#![allow(clippy::assigning_clones)]
89#![allow(clippy::option_if_let_else)]
90#![allow(clippy::manual_let_else)]
91#![allow(clippy::explicit_iter_loop)]
92#![allow(clippy::default_trait_access)]
93#![allow(clippy::only_used_in_recursion)]
94#![allow(clippy::manual_clamp)]
95#![allow(clippy::ref_option)]
96#![allow(clippy::multiple_bound_locations)]
97#![allow(clippy::comparison_chain)]
98#![allow(clippy::manual_assert)]
99#![allow(clippy::unnecessary_debug_formatting)]
100
101pub mod activation;
106pub mod functional;
107pub mod init;
108pub mod layers;
109pub mod loss;
110pub mod module;
111pub mod parameter;
112pub mod sequential;
113
114pub use module::{Module, ModuleList};
119pub use parameter::Parameter;
120pub use sequential::Sequential;
121
122pub use layers::{
124 AdaptiveAvgPool2d, AvgPool1d, AvgPool2d, BatchNorm1d, BatchNorm2d, Conv1d, Conv2d, Dropout,
125 Embedding, GRUCell, GroupNorm, InstanceNorm2d, LSTMCell, LayerNorm, Linear, MaxPool1d,
126 MaxPool2d, MultiHeadAttention, RNNCell, GRU, LSTM, RNN,
127};
128
129pub use activation::{
131 Identity, LeakyReLU, LogSoftmax, ReLU, SiLU, Sigmoid, Softmax, Tanh, ELU, GELU,
132};
133
134pub use loss::{
136 BCELoss, BCEWithLogitsLoss, CrossEntropyLoss, L1Loss, MSELoss, NLLLoss, Reduction, SmoothL1Loss,
137};
138
139pub use init::{
141 constant, diag, eye, glorot_normal, glorot_uniform, he_normal, he_uniform, kaiming_normal,
142 kaiming_uniform, normal, ones, orthogonal, randn, sparse, uniform, uniform_range,
143 xavier_normal, xavier_uniform, zeros, InitMode,
144};
145
146pub mod prelude {
152 pub use crate::{
153 functional,
155 AdaptiveAvgPool2d,
156 AvgPool1d,
157 AvgPool2d,
158 BCELoss,
159 BatchNorm1d,
160 BatchNorm2d,
161 Conv1d,
162 Conv2d,
163 CrossEntropyLoss,
164 Dropout,
165 Embedding,
166 GroupNorm,
167 Identity,
168 InstanceNorm2d,
169 L1Loss,
170 LayerNorm,
171 LeakyReLU,
172 Linear,
174 MSELoss,
175 MaxPool1d,
176 MaxPool2d,
177 Module,
179 ModuleList,
180 MultiHeadAttention,
181 NLLLoss,
182 Parameter,
183 ReLU,
185 Reduction,
187 Sequential,
188 SiLU,
189 Sigmoid,
190 Softmax,
191 Tanh,
192 ELU,
193 GELU,
194 GRU,
195 LSTM,
196 RNN,
197 };
198}
199
200#[cfg(test)]
205mod tests {
206 use super::*;
207 use axonml_autograd::Variable;
208 use axonml_tensor::Tensor;
209
210 #[test]
211 fn test_simple_mlp() {
212 let model = Sequential::new()
213 .add(Linear::new(10, 5))
214 .add(ReLU)
215 .add(Linear::new(5, 2));
216
217 let input = Variable::new(Tensor::from_vec(vec![1.0; 20], &[2, 10]).unwrap(), false);
218 let output = model.forward(&input);
219 assert_eq!(output.shape(), vec![2, 2]);
220 }
221
222 #[test]
223 fn test_module_parameters() {
224 let model = Sequential::new()
225 .add(Linear::new(10, 5))
226 .add(Linear::new(5, 2));
227
228 let params = model.parameters();
229 assert_eq!(params.len(), 4);
231 }
232
233 #[test]
234 fn test_conv_model() {
235 let model = Sequential::new()
236 .add(Conv2d::new(1, 16, 3))
237 .add(ReLU)
238 .add(MaxPool2d::new(2));
239
240 let input = Variable::new(
241 Tensor::from_vec(vec![1.0; 784], &[1, 1, 28, 28]).unwrap(),
242 false,
243 );
244 let output = model.forward(&input);
245 assert_eq!(output.shape(), vec![1, 16, 13, 13]);
247 }
248
249 #[test]
250 fn test_loss_computation() {
251 let pred = Variable::new(
252 Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]).unwrap(),
253 true,
254 );
255 let target = Variable::new(Tensor::from_vec(vec![0.0, 2.0], &[2]).unwrap(), false);
256
257 let loss_fn = CrossEntropyLoss::new();
258 let loss = loss_fn.compute(&pred, &target);
259 assert!(loss.numel() == 1);
260 }
261
262 #[test]
263 fn test_embedding_model() {
264 let emb = Embedding::new(100, 32);
265 let indices = Variable::new(
266 Tensor::from_vec(vec![0.0, 5.0, 10.0, 15.0], &[2, 2]).unwrap(),
267 false,
268 );
269 let output = emb.forward(&indices);
270 assert_eq!(output.shape(), vec![2, 2, 32]);
271 }
272
273 #[test]
274 fn test_rnn_model() {
275 let rnn = LSTM::new(10, 20, 1);
276 let input = Variable::new(Tensor::from_vec(vec![1.0; 60], &[2, 3, 10]).unwrap(), false);
277 let output = rnn.forward(&input);
278 assert_eq!(output.shape(), vec![2, 3, 20]);
279 }
280
281 #[test]
282 fn test_attention_model() {
283 let attn = MultiHeadAttention::new(64, 4);
284 let input = Variable::new(
285 Tensor::from_vec(vec![1.0; 640], &[2, 5, 64]).unwrap(),
286 false,
287 );
288 let output = attn.forward(&input);
289 assert_eq!(output.shape(), vec![2, 5, 64]);
290 }
291}