1#![allow(dead_code)]
2pub mod data;
3mod iterator;
4pub mod loss;
5mod model;
6pub mod module;
7pub mod optim;
8pub mod utils;
9
10#[cfg(feature = "dataset_hub")]
11pub use data::dataset::hub;
12pub use loss::CrossEntropyLoss;
13pub use module::{Identity, Linear, ReLU, SafeModule, Sequential, Softmax};
14pub use optim::SGD;
15
16mod macros {
17 #[doc(hidden)]
18 #[macro_export]
19 macro_rules! __rust_force_expr {
20 ($e:expr) => {
21 $e
22 };
23 }
24
25 #[cfg(test)]
26 #[macro_export]
27 macro_rules! assert_array_eq {
28 ($lhs:expr, $rhs:expr) => {
29 $crate::assert_array_eq!($lhs, $rhs, 1e-6)
30 };
31 ($lhs:expr, $rhs:expr, $tol:literal) => {
32 if $lhs.shape() != $rhs.shape() {
33 panic!(
34 "Incompatible shape \n- a={:?} \n\n- b={:?}",
35 $lhs.shape(),
36 $rhs.shape()
37 );
38 }
39
40 for (a, b) in $lhs.iter().zip(&$rhs) {
41 let diff = if a < b { b - a } else { a - b };
42 if (diff > $tol) {
43 panic!(
44 "Not equal with tolerance={}\n- a={} \n\n- b={}",
45 $tol, &$lhs, &$rhs
46 );
47 }
48 }
49 };
50 }
51}
52
53pub mod prelude {
54 pub use crate::module::init::InitParameters;
56 pub use crate::module::Module;
57
58 pub use crate::data::dataset::{Dataset, IterableDataset};
59 pub use crate::data::sampler::Sampler;
60
61 pub use crate::loss::Loss;
62
63 pub use crate::optim::Optimizer;
64
65 pub use crate::{safe, sequential};
67
68 pub use ndarray::prelude::*;
69}