atomr_accel_train/lib.rs
1//! Distributed training blueprints on atomr-accel-cuda.
2//!
3//! ```ignore
4//! use atomr_accel_train::prelude::*;
5//! ```
6//!
7//! - [`data_parallel::DataParallelTrainer`] — N-replica trainer
8//! wired to NCCL all-reduce.
9//! - [`pipeline_parallel::PipelineParallelTrainer`] — staged
10//! forward/backward across pipeline ranks.
11//! - [`tensor_parallel::TensorParallelTrainer`] — sharded matmul
12//! coordinator.
13//! - [`parameter_server::AsyncParameterServer`] — async PS protocol.
14//! - [`optimizer`] / [`loss`] — typed enums for the common choices.
15
16pub mod data_parallel;
17pub mod loss;
18pub mod optimizer;
19pub mod parameter_server;
20pub mod pipeline_parallel;
21pub mod tensor_parallel;
22
23pub mod prelude {
24 //! Canonical re-exports. `use atomr_accel_train::prelude::*;`.
25 pub use crate::data_parallel::{
26 DataParallelTrainer, ReplicaStepResult, TrainSample, TrainerConfig, TrainerMsg,
27 };
28 pub use crate::loss::LossKind;
29 pub use crate::optimizer::{OptimizerKind, StepStats};
30 pub use crate::parameter_server::{
31 AsyncParameterServer, ParameterServerMsg, ParameterServerStats, WorkerId,
32 };
33 pub use crate::pipeline_parallel::{
34 PipelineConfig, PipelineParallelTrainer, PipelineTrainerMsg,
35 };
36 pub use crate::tensor_parallel::{
37 ShardStepResult, TensorParallelConfig, TensorParallelMsg, TensorParallelTrainer,
38 };
39}