1#![cfg_attr(not(feature = "std"), no_std)]
2#![warn(missing_docs)]
3#![cfg_attr(docsrs, feature(doc_cfg))]
4#![recursion_limit = "256"]
5
6pub mod loss;
10
11pub mod modules;
13pub use modules::*;
14
15pub mod activation;
16pub use activation::{
17 gelu::*, glu::*, hard_sigmoid::*, leaky_relu::*, prelu::*, relu::*, sigmoid::*, softplus::*,
18 swiglu::*, tanh::*,
19};
20
21mod padding;
22pub use padding::*;
23
24pub use burn_core::module::Initializer;
26
27extern crate alloc;
28
29#[cfg(all(
31 test,
32 not(feature = "test-tch"),
33 not(feature = "test-wgpu"),
34 not(feature = "test-cuda"),
35 not(feature = "test-rocm")
36))]
37pub type TestBackend = burn_ndarray::NdArray<f32>;
38
39#[cfg(all(test, feature = "test-tch"))]
40pub type TestBackend = burn_tch::LibTorch<f32>;
42
43#[cfg(all(test, feature = "test-wgpu"))]
44pub type TestBackend = burn_wgpu::Wgpu;
46
47#[cfg(all(test, feature = "test-cuda"))]
48pub type TestBackend = burn_cuda::Cuda;
50
51#[cfg(all(test, feature = "test-rocm"))]
52pub type TestBackend = burn_rocm::Rocm;
54
55#[cfg(test)]
57pub type TestAutodiffBackend = burn_autodiff::Autodiff<TestBackend>;
58
59#[cfg(all(test, feature = "test-memory-checks"))]
60mod tests {
61 burn_fusion::memory_checks!();
62}