#![cfg_attr(not(feature = "std"), no_std)]
#![warn(missing_docs)]
#![cfg_attr(docsrs, feature(doc_cfg))]
#![recursion_limit = "256"]
pub mod loss;
pub mod modules;
pub use modules::*;
pub mod activation;
pub use activation::{
celu::*, elu::*, gelu::*, glu::*, hard_shrink::*, hard_sigmoid::*, leaky_relu::*, prelu::*,
relu::*, selu::*, shrink::*, sigmoid::*, soft_shrink::*, softplus::*, softsign::*, swiglu::*,
tanh::*, thresholded_relu::*,
};
mod padding;
pub use padding::*;
pub use burn_core::module::Initializer;
extern crate alloc;
#[cfg(all(
test,
not(feature = "test-tch"),
not(feature = "test-wgpu"),
not(feature = "test-cuda"),
not(feature = "test-rocm")
))]
pub type TestBackend = burn_flex::Flex;
#[cfg(all(test, feature = "test-tch"))]
pub type TestBackend = burn_tch::LibTorch<f32>;
#[cfg(all(test, feature = "test-wgpu"))]
pub type TestBackend = burn_wgpu::Wgpu;
#[cfg(all(test, feature = "test-cuda"))]
pub type TestBackend = burn_cuda::Cuda;
#[cfg(all(test, feature = "test-rocm"))]
pub type TestBackend = burn_rocm::Rocm;
#[cfg(test)]
pub type TestAutodiffBackend = burn_autodiff::Autodiff<TestBackend>;
#[cfg(all(test, feature = "test-memory-checks"))]
mod tests {
burn_fusion::memory_checks!();
}