1#![warn(missing_docs)]
75#![warn(clippy::all)]
76#![warn(clippy::pedantic)]
77#![allow(clippy::cast_possible_truncation)]
79#![allow(clippy::cast_sign_loss)]
80#![allow(clippy::cast_precision_loss)]
81#![allow(clippy::cast_possible_wrap)]
82#![allow(clippy::missing_errors_doc)]
83#![allow(clippy::missing_panics_doc)]
84#![allow(clippy::must_use_candidate)]
85#![allow(clippy::module_name_repetitions)]
86#![allow(clippy::similar_names)]
87#![allow(clippy::many_single_char_names)]
88#![allow(clippy::too_many_arguments)]
89#![allow(clippy::doc_markdown)]
90#![allow(clippy::cast_lossless)]
91#![allow(clippy::needless_pass_by_value)]
92#![allow(clippy::redundant_closure_for_method_calls)]
93#![allow(clippy::uninlined_format_args)]
94#![allow(clippy::ptr_arg)]
95#![allow(clippy::return_self_not_must_use)]
96#![allow(clippy::not_unsafe_ptr_arg_deref)]
97#![allow(clippy::items_after_statements)]
98#![allow(clippy::unreadable_literal)]
99#![allow(clippy::if_same_then_else)]
100#![allow(clippy::needless_range_loop)]
101#![allow(clippy::trivially_copy_pass_by_ref)]
102#![allow(clippy::unnecessary_wraps)]
103#![allow(clippy::match_same_arms)]
104#![allow(clippy::unused_self)]
105#![allow(clippy::too_many_lines)]
106#![allow(clippy::single_match_else)]
107#![allow(clippy::fn_params_excessive_bools)]
108#![allow(clippy::struct_excessive_bools)]
109#![allow(clippy::format_push_string)]
110#![allow(clippy::erasing_op)]
111#![allow(clippy::type_repetition_in_bounds)]
112#![allow(clippy::iter_without_into_iter)]
113#![allow(clippy::should_implement_trait)]
114#![allow(clippy::use_debug)]
115#![allow(clippy::case_sensitive_file_extension_comparisons)]
116#![allow(clippy::large_enum_variant)]
117#![allow(clippy::panic)]
118#![allow(clippy::struct_field_names)]
119#![allow(clippy::missing_fields_in_debug)]
120#![allow(clippy::upper_case_acronyms)]
121#![allow(clippy::assigning_clones)]
122#![allow(clippy::option_if_let_else)]
123#![allow(clippy::manual_let_else)]
124#![allow(clippy::explicit_iter_loop)]
125#![allow(clippy::default_trait_access)]
126#![allow(clippy::only_used_in_recursion)]
127#![allow(clippy::manual_clamp)]
128#![allow(clippy::ref_option)]
129#![allow(clippy::multiple_bound_locations)]
130#![allow(clippy::comparison_chain)]
131#![allow(clippy::manual_assert)]
132#![allow(clippy::unnecessary_debug_formatting)]
133
134#[cfg(feature = "core")]
139pub use axonml_core as core;
140
141#[cfg(feature = "core")]
142pub use axonml_tensor as tensor;
143
144#[cfg(feature = "core")]
145pub use axonml_autograd as autograd;
146
147#[cfg(feature = "nn")]
152pub use axonml_nn as nn;
153
154#[cfg(feature = "nn")]
155pub use axonml_optim as optim;
156
157#[cfg(feature = "data")]
162pub use axonml_data as data;
163
164#[cfg(feature = "vision")]
169pub use axonml_vision as vision;
170
171#[cfg(feature = "text")]
172pub use axonml_text as text;
173
174#[cfg(feature = "audio")]
175pub use axonml_audio as audio;
176
177#[cfg(feature = "distributed")]
178pub use axonml_distributed as distributed;
179
180#[cfg(feature = "profile")]
181pub use axonml_profile as profile;
182
183#[cfg(feature = "llm")]
184pub use axonml_llm as llm;
185
186#[cfg(feature = "jit")]
187pub use axonml_jit as jit;
188
189pub mod prelude {
202 #[cfg(feature = "core")]
204 pub use axonml_core::{DType, Device, Error, Result};
205
206 #[cfg(feature = "core")]
208 pub use axonml_tensor::Tensor;
209
210 #[cfg(feature = "core")]
212 pub use axonml_autograd::{no_grad, Variable};
213
214 #[cfg(feature = "nn")]
216 pub use axonml_nn::{
217 AvgPool2d, BCELoss, BatchNorm1d, BatchNorm2d, Conv2d, CrossEntropyLoss, Dropout, Embedding,
218 L1Loss, LayerNorm, LeakyReLU, Linear, MSELoss, MaxPool2d, Module, MultiHeadAttention,
219 Parameter, ReLU, Sequential, SiLU, Sigmoid, Softmax, Tanh, GELU, GRU, LSTM, RNN,
220 };
221
222 #[cfg(feature = "nn")]
224 pub use axonml_optim::{
225 Adam, AdamW, CosineAnnealingLR, ExponentialLR, LRScheduler, Optimizer, RMSprop, StepLR, SGD,
226 };
227
228 #[cfg(feature = "data")]
230 pub use axonml_data::{DataLoader, Dataset, RandomSampler, SequentialSampler, Transform};
231
232 #[cfg(feature = "vision")]
234 pub use axonml_vision::{
235 CenterCrop, ImageNormalize, LeNet, RandomHorizontalFlip, Resize, SimpleCNN, SyntheticCIFAR,
236 SyntheticMNIST,
237 };
238
239 #[cfg(feature = "text")]
241 pub use axonml_text::{
242 BasicBPETokenizer, CharTokenizer, LanguageModelDataset, SyntheticSentimentDataset,
243 TextDataset, Tokenizer, Vocab, WhitespaceTokenizer,
244 };
245
246 #[cfg(feature = "audio")]
248 pub use axonml_audio::{
249 AddNoise, MelSpectrogram, NormalizeAudio, Resample, SyntheticCommandDataset,
250 SyntheticMusicDataset, MFCC,
251 };
252
253 #[cfg(feature = "distributed")]
255 pub use axonml_distributed::{
256 all_reduce_mean, all_reduce_sum, barrier, broadcast, DistributedDataParallel, ProcessGroup,
257 World, DDP,
258 };
259
260 #[cfg(feature = "profile")]
262 pub use axonml_profile::{
263 Profiler, ProfileGuard, ProfileReport, MemoryProfiler, ComputeProfiler,
264 TimelineProfiler, BottleneckAnalyzer, Bottleneck,
265 };
266
267 #[cfg(feature = "llm")]
269 pub use axonml_llm::{
270 BertConfig, GPT2Config, Bert, BertForSequenceClassification, BertForMaskedLM,
271 GPT2, GPT2LMHead, GenerationConfig, TextGenerator,
272 };
273
274 #[cfg(feature = "jit")]
276 pub use axonml_jit::{
277 trace, Graph, JitCompiler, CompiledFunction, TracedValue, Optimizer as JitOptimizer,
278 };
279}
280
281#[must_use] pub fn version() -> &'static str {
287 env!("CARGO_PKG_VERSION")
288}
289
290#[must_use] pub fn features() -> String {
292 let mut features = Vec::new();
293
294 #[cfg(feature = "core")]
295 features.push("core");
296
297 #[cfg(feature = "nn")]
298 features.push("nn");
299
300 #[cfg(feature = "data")]
301 features.push("data");
302
303 #[cfg(feature = "vision")]
304 features.push("vision");
305
306 #[cfg(feature = "text")]
307 features.push("text");
308
309 #[cfg(feature = "audio")]
310 features.push("audio");
311
312 #[cfg(feature = "distributed")]
313 features.push("distributed");
314
315 #[cfg(feature = "profile")]
316 features.push("profile");
317
318 #[cfg(feature = "llm")]
319 features.push("llm");
320
321 #[cfg(feature = "jit")]
322 features.push("jit");
323
324 if features.is_empty() {
325 "none".to_string()
326 } else {
327 features.join(", ")
328 }
329}
330
331#[cfg(test)]
336mod tests {
337 use super::*;
338
339 #[test]
340 fn test_version() {
341 let v = version();
342 assert!(!v.is_empty());
343 }
344
345 #[test]
346 fn test_features() {
347 let f = features();
348 assert!(f.contains("core"));
350 }
351
352 #[cfg(feature = "core")]
353 #[test]
354 fn test_tensor_creation() {
355 use tensor::Tensor;
356
357 let t = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
358 assert_eq!(t.shape(), &[2, 2]);
359 }
360
361 #[cfg(feature = "core")]
362 #[test]
363 fn test_variable_creation() {
364 use autograd::Variable;
365 use tensor::Tensor;
366
367 let t = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
368 let v = Variable::new(t, true);
369 assert_eq!(v.data().shape(), &[3]);
370 }
371
372 #[cfg(feature = "nn")]
373 #[test]
374 fn test_linear_layer() {
375 use autograd::Variable;
376 use nn::Linear;
377 use nn::Module;
378 use tensor::Tensor;
379
380 let layer = Linear::new(4, 2);
381 let input = Variable::new(Tensor::from_vec(vec![1.0; 4], &[1, 4]).unwrap(), false);
382 let output = layer.forward(&input);
383
384 assert_eq!(output.data().shape(), &[1, 2]);
385 }
386
387 #[cfg(feature = "nn")]
388 #[test]
389 fn test_optimizer() {
390 use nn::Linear;
391 use nn::Module;
392 use optim::{Adam, Optimizer};
393
394 let model = Linear::new(4, 2);
395 let mut optimizer = Adam::new(model.parameters(), 0.001);
396
397 optimizer.zero_grad();
399 }
400
401 #[cfg(feature = "data")]
402 #[test]
403 fn test_dataloader() {
404 use data::{DataLoader, Dataset};
405 use tensor::Tensor;
406
407 struct DummyDataset;
408
409 impl Dataset for DummyDataset {
410 type Item = (Tensor<f32>, Tensor<f32>);
411
412 fn len(&self) -> usize {
413 100
414 }
415
416 fn get(&self, _index: usize) -> Option<Self::Item> {
417 let x = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
418 let y = Tensor::from_vec(vec![1.0], &[1]).unwrap();
419 Some((x, y))
420 }
421 }
422
423 let dataset = DummyDataset;
424 let loader = DataLoader::new(dataset, 10);
425
426 assert_eq!(loader.len(), 10); }
428
429 #[cfg(feature = "vision")]
430 #[test]
431 fn test_vision_dataset() {
432 use data::Dataset;
433 use vision::SyntheticMNIST;
434
435 let dataset = SyntheticMNIST::new(100);
436 assert_eq!(dataset.len(), 100);
437 }
438
439 #[cfg(feature = "text")]
440 #[test]
441 fn test_tokenizer() {
442 use text::{Tokenizer, WhitespaceTokenizer};
443
444 let tokenizer = WhitespaceTokenizer::new();
445 let tokens = tokenizer.tokenize("hello world");
446
447 assert_eq!(tokens, vec!["hello", "world"]);
448 }
449
450 #[cfg(feature = "audio")]
451 #[test]
452 fn test_audio_transform() {
453 use audio::MelSpectrogram;
454 use data::Transform;
455 use std::f32::consts::PI;
456 use tensor::Tensor;
457
458 let data: Vec<f32> = (0..4096)
460 .map(|i| (2.0 * PI * 440.0 * i as f32 / 16000.0).sin())
461 .collect();
462 let audio = Tensor::from_vec(data, &[4096]).unwrap();
463
464 let mel = MelSpectrogram::with_params(16000, 512, 256, 40);
465 let spec = mel.apply(&audio);
466
467 assert_eq!(spec.shape()[0], 40); }
469
470 #[cfg(feature = "distributed")]
471 #[test]
472 fn test_distributed_world() {
473 use distributed::World;
474
475 let world = World::mock();
476 assert_eq!(world.rank(), 0);
477 assert_eq!(world.world_size(), 1);
478 }
479
480 #[test]
481 fn test_prelude_imports() {
482 use crate::prelude::*;
484
485 #[cfg(feature = "core")]
486 {
487 let _ = Device::Cpu;
488 }
489 }
490}