1#![cfg_attr(not(feature = "std"), no_std)]
2#![warn(missing_docs)]
3#![cfg_attr(docsrs, feature(doc_cfg))]
4#![recursion_limit = "256"]
5
6pub mod loss;
10
11pub mod modules;
13pub use modules::*;
14
15pub mod activation;
16pub use activation::{
17 celu::*, elu::*, gelu::*, glu::*, hard_shrink::*, hard_sigmoid::*, leaky_relu::*, prelu::*,
18 relu::*, selu::*, shrink::*, sigmoid::*, soft_shrink::*, softplus::*, softsign::*, swiglu::*,
19 tanh::*, thresholded_relu::*,
20};
21
22mod padding;
23pub use padding::*;
24
25pub use burn_core::module::Initializer;
27
28extern crate alloc;
29
30#[cfg(all(
32 test,
33 not(feature = "test-tch"),
34 not(feature = "test-wgpu"),
35 not(feature = "test-cuda"),
36 not(feature = "test-rocm")
37))]
38pub type TestBackend = burn_flex::Flex;
39
40#[cfg(all(test, feature = "test-tch"))]
41pub type TestBackend = burn_tch::LibTorch<f32>;
43
44#[cfg(all(test, feature = "test-wgpu"))]
45pub type TestBackend = burn_wgpu::Wgpu;
47
48#[cfg(all(test, feature = "test-cuda"))]
49pub type TestBackend = burn_cuda::Cuda;
51
52#[cfg(all(test, feature = "test-rocm"))]
53pub type TestBackend = burn_rocm::Rocm;
55
56#[cfg(test)]
58pub type TestAutodiffBackend = burn_autodiff::Autodiff<TestBackend>;
59
60#[cfg(all(test, feature = "test-memory-checks"))]
61mod tests {
62 burn_fusion::memory_checks!();
63}