rstorch/
lib.rs

1#![allow(dead_code)]
2pub mod data;
3mod iterator;
4pub mod loss;
5mod model;
6pub mod module;
7pub mod optim;
8pub mod utils;
9
10#[cfg(feature = "dataset_hub")]
11pub use data::dataset::hub;
12pub use loss::CrossEntropyLoss;
13pub use module::{Identity, Linear, ReLU, SafeModule, Sequential, Softmax};
14pub use optim::SGD;
15
16mod macros {
17    #[doc(hidden)]
18    #[macro_export]
19    macro_rules! __rust_force_expr {
20        ($e:expr) => {
21            $e
22        };
23    }
24
25    #[cfg(test)]
26    #[macro_export]
27    macro_rules! assert_array_eq {
28        ($lhs:expr, $rhs:expr) => {
29            $crate::assert_array_eq!($lhs, $rhs, 1e-6)
30        };
31        ($lhs:expr, $rhs:expr, $tol:literal) => {
32            if $lhs.shape() != $rhs.shape() {
33                panic!(
34                    "Incompatible shape \n- a={:?} \n\n- b={:?}",
35                    $lhs.shape(),
36                    $rhs.shape()
37                );
38            }
39
40            for (a, b) in $lhs.iter().zip(&$rhs) {
41                let diff = if a < b { b - a } else { a - b };
42                if (diff > $tol) {
43                    panic!(
44                        "Not equal with tolerance={}\n- a={} \n\n- b={}",
45                        $tol, &$lhs, &$rhs
46                    );
47                }
48            }
49        };
50    }
51}
52
53pub mod prelude {
54    // traits
55    pub use crate::module::init::InitParameters;
56    pub use crate::module::Module;
57
58    pub use crate::data::dataset::{Dataset, IterableDataset};
59    pub use crate::data::sampler::Sampler;
60
61    pub use crate::loss::Loss;
62
63    pub use crate::optim::Optimizer;
64
65    // macros
66    pub use crate::{safe, sequential};
67
68    pub use ndarray::prelude::*;
69}