Skip to main content

burn_nn/
lib.rs

1#![cfg_attr(not(feature = "std"), no_std)]
2#![warn(missing_docs)]
3#![cfg_attr(docsrs, feature(doc_cfg))]
4#![recursion_limit = "256"]
5
6//! Burn neural network module.
7
8/// Loss module
9pub mod loss;
10
11/// Neural network modules implementations.
12pub mod modules;
13pub use modules::*;
14
15pub mod activation;
16pub use activation::{
17    celu::*, elu::*, gelu::*, glu::*, hard_shrink::*, hard_sigmoid::*, leaky_relu::*, prelu::*,
18    relu::*, selu::*, shrink::*, sigmoid::*, soft_shrink::*, softplus::*, softsign::*, swiglu::*,
19    tanh::*, thresholded_relu::*,
20};
21
22mod padding;
23pub use padding::*;
24
25// For backward compat, `burn::nn::Initializer`
26pub use burn_core::module::Initializer;
27
28extern crate alloc;
29
30/// Backend for test cases
31#[cfg(all(
32    test,
33    not(feature = "test-tch"),
34    not(feature = "test-wgpu"),
35    not(feature = "test-cuda"),
36    not(feature = "test-rocm")
37))]
38pub type TestBackend = burn_flex::Flex;
39
40#[cfg(all(test, feature = "test-tch"))]
41/// Backend for test cases
42pub type TestBackend = burn_tch::LibTorch<f32>;
43
44#[cfg(all(test, feature = "test-wgpu"))]
45/// Backend for test cases
46pub type TestBackend = burn_wgpu::Wgpu;
47
48#[cfg(all(test, feature = "test-cuda"))]
49/// Backend for test cases
50pub type TestBackend = burn_cuda::Cuda;
51
52#[cfg(all(test, feature = "test-rocm"))]
53/// Backend for test cases
54pub type TestBackend = burn_rocm::Rocm;
55
56/// Backend for autodiff test cases
57#[cfg(test)]
58pub type TestAutodiffBackend = burn_autodiff::Autodiff<TestBackend>;
59
60#[cfg(all(test, feature = "test-memory-checks"))]
61mod tests {
62    burn_fusion::memory_checks!();
63}