Skip to main content

atomr_accel_train/
lib.rs

1//! Distributed training blueprints on atomr-accel-cuda.
2//!
3//! ```ignore
4//! use atomr_accel_train::prelude::*;
5//! ```
6//!
7//! - [`data_parallel::DataParallelTrainer`] — N-replica trainer
8//!   wired to NCCL all-reduce.
9//! - [`pipeline_parallel::PipelineParallelTrainer`] — staged
10//!   forward/backward across pipeline ranks.
11//! - [`tensor_parallel::TensorParallelTrainer`] — sharded matmul
12//!   coordinator.
13//! - [`parameter_server::AsyncParameterServer`] — async PS protocol.
14//! - [`optimizer`] / [`loss`] — typed enums for the common choices.
15
16pub mod data_parallel;
17pub mod loss;
18pub mod optimizer;
19pub mod parameter_server;
20pub mod pipeline_parallel;
21pub mod tensor_parallel;
22
23pub mod prelude {
24    //! Canonical re-exports. `use atomr_accel_train::prelude::*;`.
25    pub use crate::data_parallel::{
26        DataParallelTrainer, ReplicaStepResult, TrainSample, TrainerConfig, TrainerMsg,
27    };
28    pub use crate::loss::LossKind;
29    pub use crate::optimizer::{OptimizerKind, StepStats};
30    pub use crate::parameter_server::{
31        AsyncParameterServer, ParameterServerMsg, ParameterServerStats, WorkerId,
32    };
33    pub use crate::pipeline_parallel::{
34        PipelineConfig, PipelineParallelTrainer, PipelineTrainerMsg,
35    };
36    pub use crate::tensor_parallel::{
37        ShardStepResult, TensorParallelConfig, TensorParallelMsg, TensorParallelTrainer,
38    };
39}