atomr_accel_train/
pipeline_parallel.rs1use std::time::Instant;
16
17use async_trait::async_trait;
18use atomr_core::actor::{Actor, ActorRef, Context, Props};
19use tokio::sync::oneshot;
20
21use atomr_accel_cuda::error::GpuError;
22
23use crate::loss::LossKind;
24use crate::optimizer::{OptimizerKind, StepStats};
25
26pub trait PipelineStageProtocol: Send + 'static {
27 type Msg: Send + 'static;
28 type Activation: Send + 'static;
30 fn make_forward(
31 input: Self::Activation,
32 reply: oneshot::Sender<Result<Self::Activation, GpuError>>,
33 ) -> Self::Msg;
34 fn make_final_forward(
36 input: Self::Activation,
37 reply: oneshot::Sender<Result<(f32, f32), GpuError>>,
38 ) -> Self::Msg;
39}
40
41#[derive(Debug, Clone)]
42pub struct PipelineConfig {
43 pub micro_batch_size: usize,
44 pub gradient_clip: Option<f32>,
45 pub optimizer: OptimizerKind,
46 pub loss: LossKind,
47}
48
49pub enum PipelineTrainerMsg<P: PipelineStageProtocol> {
50 Step {
51 input: P::Activation,
52 reply: oneshot::Sender<Result<StepStats, GpuError>>,
53 },
54}
55
56pub struct PipelineParallelTrainer<P: PipelineStageProtocol> {
57 config: PipelineConfig,
58 stages: Vec<ActorRef<P::Msg>>,
59}
60
61impl<P: PipelineStageProtocol> PipelineParallelTrainer<P> {
62 pub fn props(config: PipelineConfig, stages: Vec<ActorRef<P::Msg>>) -> Props<Self> {
63 Props::create(move || PipelineParallelTrainer {
64 config: config.clone(),
65 stages: stages.clone(),
66 })
67 }
68}
69
70#[async_trait]
71impl<P: PipelineStageProtocol> Actor for PipelineParallelTrainer<P> {
72 type Msg = PipelineTrainerMsg<P>;
73
74 async fn handle(&mut self, _ctx: &mut Context<Self>, msg: PipelineTrainerMsg<P>) {
75 match msg {
76 PipelineTrainerMsg::Step { input, reply } => {
77 if self.stages.is_empty() {
78 let _ = reply.send(Err(GpuError::Unrecoverable(
79 "PipelineParallelTrainer::Step: no stages".into(),
80 )));
81 return;
82 }
83 let _ = &self.config;
84 let started = Instant::now();
85 let stages = self.stages.clone();
86 tokio::spawn(async move {
87 let n = stages.len();
88 let mut activation: Option<P::Activation> = Some(input);
89 for (i, s) in stages.iter().enumerate() {
92 let act = activation.take().expect("activation present");
93 if i + 1 < n {
94 let (tx, rx) = oneshot::channel();
95 s.tell(P::make_forward(act, tx));
96 match rx.await {
97 Ok(Ok(next)) => activation = Some(next),
98 Ok(Err(e)) => {
99 let _ = reply.send(Err(e));
100 return;
101 }
102 Err(_) => {
103 let _ = reply.send(Err(GpuError::Unrecoverable(
104 "pipeline: stage dropped reply".into(),
105 )));
106 return;
107 }
108 }
109 } else {
110 let (tx, rx) = oneshot::channel();
111 s.tell(P::make_final_forward(act, tx));
112 match rx.await {
113 Ok(Ok((loss, grad_norm))) => {
114 let _ = reply.send(Ok(StepStats {
115 loss,
116 grad_norm,
117 step_micros: started.elapsed().as_micros() as u64,
118 }));
119 return;
120 }
121 Ok(Err(e)) => {
122 let _ = reply.send(Err(e));
123 return;
124 }
125 Err(_) => {
126 let _ = reply.send(Err(GpuError::Unrecoverable(
127 "pipeline: final stage dropped reply".into(),
128 )));
129 return;
130 }
131 }
132 }
133 }
134 });
135 }
136 }
137 }
138}