burn_optim/
lib.rs

1#![cfg_attr(not(feature = "std"), no_std)]
2#![warn(missing_docs)]
3#![cfg_attr(docsrs, feature(doc_cfg))]
4#![recursion_limit = "256"]
5
6//! Burn optimizers.
7
8#[macro_use]
9extern crate derive_new;
10
11extern crate alloc;
12
13/// Optimizer module.
14pub mod optim;
15pub use optim::*;
16
17/// Gradient clipping module.
18pub mod grad_clipping;
19
20/// Learning rate scheduler module.
21#[cfg(feature = "std")]
22pub mod lr_scheduler;
23
24/// Type alias for the learning rate.
25///
26/// LearningRate also implements [learning rate scheduler](crate::lr_scheduler::LrScheduler) so it
27/// can be used for constant learning rate.
28pub type LearningRate = f64; // We could potentially change the type.
29
30/// Backend for test cases
31#[cfg(all(
32    test,
33    not(feature = "test-tch"),
34    not(feature = "test-wgpu"),
35    not(feature = "test-cuda"),
36    not(feature = "test-rocm")
37))]
38pub type TestBackend = burn_ndarray::NdArray<f32>;
39
40#[cfg(all(test, feature = "test-tch"))]
41/// Backend for test cases
42pub type TestBackend = burn_tch::LibTorch<f32>;
43
44#[cfg(all(test, feature = "test-wgpu"))]
45/// Backend for test cases
46pub type TestBackend = burn_wgpu::Wgpu;
47
48#[cfg(all(test, feature = "test-cuda"))]
49/// Backend for test cases
50pub type TestBackend = burn_cuda::Cuda;
51
52#[cfg(all(test, feature = "test-rocm"))]
53/// Backend for test cases
54pub type TestBackend = burn_rocm::Rocm;
55
56/// Backend for autodiff test cases
57#[cfg(test)]
58pub type TestAutodiffBackend = burn_autodiff::Autodiff<TestBackend>;
59
60#[cfg(all(test, feature = "test-memory-checks"))]
61mod tests {
62    burn_fusion::memory_checks!();
63}