training/
lib.rs

1//! Burn-based training and evaluation for CortenForge detection models.
2//!
3//! This crate provides:
4//! - Dataset loading and collation (`collate`, `collate_from_burn_batch`).
5//! - Training loop utilities (`run_train`, `TrainArgs`).
6//! - Model checkpoint loading/saving helpers.
7//!
8//! Supports both `LinearClassifier` and `MultiboxModel` from the `models` crate.
9//!
10//! ## Backend Selection
11//! - `backend-wgpu`: Uses WGPU for GPU-accelerated training.
12//! - Default: Falls back to NdArray CPU backend.
13//!
14//! ## Stability
15//!
16//! Training APIs are **experimental** and may change as the training pipeline evolves.
17//! Core model types (`TinyDet`, `BigDet`) are stable, but training utilities and loss functions
18//! are subject to refinement.
19
20#![recursion_limit = "256"]
21
22pub mod dataset;
23pub mod util;
24
25pub use dataset::{collate, collate_from_burn_batch, CollatedBatch, DatasetPathConfig, RunSample};
26pub use models::{LinearClassifier, LinearClassifierConfig, MultiboxModel, MultiboxModelConfig};
27pub use util::{run_train, TrainArgs};
28/// Backend alias for training/eval (NdArray by default; WGPU if enabled).
29#[cfg(feature = "backend-wgpu")]
30pub type TrainBackend = burn_wgpu::Wgpu<f32>;
31#[cfg(not(feature = "backend-wgpu"))]
32pub type TrainBackend = burn_ndarray::NdArray<f32>;