#![cfg_attr(not(feature = "std"), no_std)]
#![warn(missing_docs)]
#![cfg_attr(docsrs, feature(doc_cfg))]
#![recursion_limit = "256"]
pub mod loss;
pub mod modules;
pub use modules::*;
pub mod activation;
pub use activation::{
gelu::*, glu::*, hard_sigmoid::*, leaky_relu::*, prelu::*, relu::*, sigmoid::*, softplus::*,
swiglu::*, tanh::*,
};
mod padding;
pub use padding::*;
pub use burn_core::module::Initializer;
extern crate alloc;
#[cfg(all(
test,
not(feature = "test-tch"),
not(feature = "test-wgpu"),
not(feature = "test-cuda"),
not(feature = "test-rocm")
))]
pub type TestBackend = burn_ndarray::NdArray<f32>;
#[cfg(all(test, feature = "test-tch"))]
pub type TestBackend = burn_tch::LibTorch<f32>;
#[cfg(all(test, feature = "test-wgpu"))]
pub type TestBackend = burn_wgpu::Wgpu;
#[cfg(all(test, feature = "test-cuda"))]
pub type TestBackend = burn_cuda::Cuda;
#[cfg(all(test, feature = "test-rocm"))]
pub type TestBackend = burn_rocm::Rocm;
#[cfg(test)]
pub type TestAutodiffBackend = burn_autodiff::Autodiff<TestBackend>;
#[cfg(all(test, feature = "test-memory-checks"))]
mod tests {
burn_fusion::memory_checks!();
}