Skip to main content

atomr_accel_train/
pipeline_parallel.rs

1//! `PipelineParallelTrainer` — stage-pipelined model across N
2//! GPUs/actors.
3//!
4//! Each stage actor handles one slice of the model's layers and
5//! passes its activations to the next stage. The trainer drives a
6//! micro-batch through the pipeline (forward) then backwards
7//! (gradient flow), accumulating per-stage gradients before an
8//! optimizer step.
9//!
10//! F6 ships the public surface + a host-side reference where each
11//! stage is a generic actor implementing [`PipelineStageProtocol`].
12//! The trainer feeds activations through stages sequentially in
13//! reply order. F7 adds true micro-batch overlap.
14
15use std::time::Instant;
16
17use async_trait::async_trait;
18use atomr_core::actor::{Actor, ActorRef, Context, Props};
19use tokio::sync::oneshot;
20
21use atomr_accel_cuda::error::GpuError;
22
23use crate::loss::LossKind;
24use crate::optimizer::{OptimizerKind, StepStats};
25
26pub trait PipelineStageProtocol: Send + 'static {
27    type Msg: Send + 'static;
28    /// Activation type passed between stages.
29    type Activation: Send + 'static;
30    fn make_forward(
31        input: Self::Activation,
32        reply: oneshot::Sender<Result<Self::Activation, GpuError>>,
33    ) -> Self::Msg;
34    /// Final-stage forward also produces a (loss, grad_norm) pair.
35    fn make_final_forward(
36        input: Self::Activation,
37        reply: oneshot::Sender<Result<(f32, f32), GpuError>>,
38    ) -> Self::Msg;
39}
40
41#[derive(Debug, Clone)]
42pub struct PipelineConfig {
43    pub micro_batch_size: usize,
44    pub gradient_clip: Option<f32>,
45    pub optimizer: OptimizerKind,
46    pub loss: LossKind,
47}
48
49pub enum PipelineTrainerMsg<P: PipelineStageProtocol> {
50    Step {
51        input: P::Activation,
52        reply: oneshot::Sender<Result<StepStats, GpuError>>,
53    },
54}
55
56pub struct PipelineParallelTrainer<P: PipelineStageProtocol> {
57    config: PipelineConfig,
58    stages: Vec<ActorRef<P::Msg>>,
59}
60
61impl<P: PipelineStageProtocol> PipelineParallelTrainer<P> {
62    pub fn props(config: PipelineConfig, stages: Vec<ActorRef<P::Msg>>) -> Props<Self> {
63        Props::create(move || PipelineParallelTrainer {
64            config: config.clone(),
65            stages: stages.clone(),
66        })
67    }
68}
69
70#[async_trait]
71impl<P: PipelineStageProtocol> Actor for PipelineParallelTrainer<P> {
72    type Msg = PipelineTrainerMsg<P>;
73
74    async fn handle(&mut self, _ctx: &mut Context<Self>, msg: PipelineTrainerMsg<P>) {
75        match msg {
76            PipelineTrainerMsg::Step { input, reply } => {
77                if self.stages.is_empty() {
78                    let _ = reply.send(Err(GpuError::Unrecoverable(
79                        "PipelineParallelTrainer::Step: no stages".into(),
80                    )));
81                    return;
82                }
83                let _ = &self.config;
84                let started = Instant::now();
85                let stages = self.stages.clone();
86                tokio::spawn(async move {
87                    let n = stages.len();
88                    let mut activation: Option<P::Activation> = Some(input);
89                    // Forward through stages 0 .. n-1 (intermediate),
90                    // then n-1 (final).
91                    for (i, s) in stages.iter().enumerate() {
92                        let act = activation.take().expect("activation present");
93                        if i + 1 < n {
94                            let (tx, rx) = oneshot::channel();
95                            s.tell(P::make_forward(act, tx));
96                            match rx.await {
97                                Ok(Ok(next)) => activation = Some(next),
98                                Ok(Err(e)) => {
99                                    let _ = reply.send(Err(e));
100                                    return;
101                                }
102                                Err(_) => {
103                                    let _ = reply.send(Err(GpuError::Unrecoverable(
104                                        "pipeline: stage dropped reply".into(),
105                                    )));
106                                    return;
107                                }
108                            }
109                        } else {
110                            let (tx, rx) = oneshot::channel();
111                            s.tell(P::make_final_forward(act, tx));
112                            match rx.await {
113                                Ok(Ok((loss, grad_norm))) => {
114                                    let _ = reply.send(Ok(StepStats {
115                                        loss,
116                                        grad_norm,
117                                        step_micros: started.elapsed().as_micros() as u64,
118                                    }));
119                                    return;
120                                }
121                                Ok(Err(e)) => {
122                                    let _ = reply.send(Err(e));
123                                    return;
124                                }
125                                Err(_) => {
126                                    let _ = reply.send(Err(GpuError::Unrecoverable(
127                                        "pipeline: final stage dropped reply".into(),
128                                    )));
129                                    return;
130                                }
131                            }
132                        }
133                    }
134                });
135            }
136        }
137    }
138}