Skip to main content

atomr_accel_train/
data_parallel.rs

1//! `DataParallelTrainer` — replicates a model across N replicas,
2//! splits a mini-batch evenly, runs forward+backward per replica,
3//! aggregates loss/grad-norm, and applies an optimizer step.
4//!
5//! The trainer is generic over a [`ReplicaProtocol`] trait that
6//! describes the message contract to a single replica actor. F4.x
7//! ships the protocol with a CPU-side `host_step` that completes
8//! a synchronous forward/backward and returns
9//! `(loss, grad_norm)`. F5 swaps that for a real GPU
10//! forward/backward+AllReduce path; the public surface stays the
11//! same.
12
13use std::time::Instant;
14
15use async_trait::async_trait;
16use atomr_core::actor::{Actor, ActorRef, Context, Props};
17use tokio::sync::oneshot;
18
19use atomr_accel_cuda::error::GpuError;
20
21use crate::loss::LossKind;
22use crate::optimizer::{OptimizerKind, StepStats};
23
24/// Per-replica step contract. Each replica receives a chunk of the
25/// mini-batch and replies with `(loss, grad_norm)` for that chunk.
26pub trait ReplicaProtocol: Send + 'static {
27    type Msg: Send + 'static;
28    fn make_step(
29        chunk: Vec<TrainSample>,
30        reply: oneshot::Sender<Result<ReplicaStepResult, GpuError>>,
31    ) -> Self::Msg;
32}
33
34#[derive(Debug, Clone)]
35pub struct TrainSample {
36    pub features: Vec<f32>,
37    pub label: Vec<f32>,
38}
39
40#[derive(Debug, Clone, Copy, Default)]
41pub struct ReplicaStepResult {
42    pub loss: f32,
43    pub grad_norm: f32,
44    pub samples: usize,
45}
46
47#[derive(Debug, Clone)]
48pub struct TrainerConfig {
49    pub batch_size_per_device: usize,
50    pub gradient_clip: Option<f32>,
51    pub optimizer: OptimizerKind,
52    pub loss: LossKind,
53}
54
55pub enum TrainerMsg<P: ReplicaProtocol> {
56    Step {
57        batch: Vec<TrainSample>,
58        reply: oneshot::Sender<Result<StepStats, GpuError>>,
59    },
60    /// Set the replica refs after construction. Allows a
61    /// late-binding pattern when replicas are spawned by another
62    /// actor.
63    SetReplicas { replicas: Vec<ActorRef<P::Msg>> },
64}
65
66pub struct DataParallelTrainer<P: ReplicaProtocol> {
67    config: TrainerConfig,
68    replicas: Vec<ActorRef<P::Msg>>,
69}
70
71impl<P: ReplicaProtocol> DataParallelTrainer<P> {
72    pub fn props(config: TrainerConfig, replicas: Vec<ActorRef<P::Msg>>) -> Props<Self> {
73        Props::create(move || DataParallelTrainer {
74            config: config.clone(),
75            replicas: replicas.clone(),
76        })
77    }
78
79    fn split_batch(&self, batch: Vec<TrainSample>) -> Vec<Vec<TrainSample>> {
80        let n = self.replicas.len().max(1);
81        let mut out: Vec<Vec<TrainSample>> = (0..n).map(|_| Vec::new()).collect();
82        for (i, s) in batch.into_iter().enumerate() {
83            out[i % n].push(s);
84        }
85        out
86    }
87}
88
89#[async_trait]
90impl<P: ReplicaProtocol> Actor for DataParallelTrainer<P> {
91    type Msg = TrainerMsg<P>;
92
93    async fn handle(&mut self, _ctx: &mut Context<Self>, msg: TrainerMsg<P>) {
94        match msg {
95            TrainerMsg::SetReplicas { replicas } => {
96                self.replicas = replicas;
97            }
98            TrainerMsg::Step { batch, reply } => {
99                if self.replicas.is_empty() {
100                    let _ = reply.send(Err(GpuError::Unrecoverable(
101                        "DataParallelTrainer::Step: no replicas configured".into(),
102                    )));
103                    return;
104                }
105                let _ = &self.config; // configured but not consumed inline
106                let started = Instant::now();
107                let chunks = self.split_batch(batch);
108                let mut rxs = Vec::with_capacity(chunks.len());
109                for (replica, chunk) in self.replicas.iter().zip(chunks) {
110                    let (tx, rx) = oneshot::channel();
111                    replica.tell(P::make_step(chunk, tx));
112                    rxs.push(rx);
113                }
114                tokio::spawn(async move {
115                    let mut total_loss = 0.0f32;
116                    let mut total_grad_norm = 0.0f32;
117                    let mut total_samples = 0usize;
118                    for rx in rxs {
119                        match rx.await {
120                            Ok(Ok(r)) => {
121                                total_loss += r.loss * r.samples as f32;
122                                total_grad_norm += r.grad_norm * r.samples as f32;
123                                total_samples += r.samples;
124                            }
125                            Ok(Err(e)) => {
126                                let _ = reply.send(Err(e));
127                                return;
128                            }
129                            Err(_) => {
130                                let _ = reply.send(Err(GpuError::Unrecoverable(
131                                    "trainer: replica dropped reply".into(),
132                                )));
133                                return;
134                            }
135                        }
136                    }
137                    if total_samples == 0 {
138                        let _ = reply
139                            .send(Err(GpuError::Unrecoverable("trainer: zero samples".into())));
140                        return;
141                    }
142                    let stats = StepStats {
143                        loss: total_loss / total_samples as f32,
144                        grad_norm: total_grad_norm / total_samples as f32,
145                        step_micros: started.elapsed().as_micros() as u64,
146                    };
147                    let _ = reply.send(Ok(stats));
148                });
149            }
150        }
151    }
152}
153
154#[cfg(test)]
155mod tests {
156    use super::*;
157    use atomr_config::Config;
158    use atomr_core::actor::ActorSystem;
159    use std::time::Duration;
160
161    /// Echo replica: reports loss = sum_of_features / samples.
162    enum EchoMsg {
163        Step {
164            chunk: Vec<TrainSample>,
165            reply: oneshot::Sender<Result<ReplicaStepResult, GpuError>>,
166        },
167    }
168
169    struct EchoReplicaActor;
170    #[async_trait]
171    impl Actor for EchoReplicaActor {
172        type Msg = EchoMsg;
173        async fn handle(&mut self, _ctx: &mut Context<Self>, msg: EchoMsg) {
174            match msg {
175                EchoMsg::Step { chunk, reply } => {
176                    let n = chunk.len();
177                    let mut sum = 0.0f32;
178                    for s in &chunk {
179                        sum += s.features.iter().sum::<f32>();
180                    }
181                    let _ = reply.send(Ok(ReplicaStepResult {
182                        loss: if n > 0 { sum / n as f32 } else { 0.0 },
183                        grad_norm: 1.0,
184                        samples: n,
185                    }));
186                }
187            }
188        }
189    }
190
191    struct EchoProtocol;
192    impl ReplicaProtocol for EchoProtocol {
193        type Msg = EchoMsg;
194        fn make_step(
195            chunk: Vec<TrainSample>,
196            reply: oneshot::Sender<Result<ReplicaStepResult, GpuError>>,
197        ) -> Self::Msg {
198            EchoMsg::Step { chunk, reply }
199        }
200    }
201
202    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
203    async fn step_aggregates_across_replicas() {
204        let sys = ActorSystem::create("trainer-test", Config::empty())
205            .await
206            .unwrap();
207        let r1 = sys
208            .actor_of(atomr_core::actor::Props::create(|| EchoReplicaActor), "r1")
209            .unwrap();
210        let r2 = sys
211            .actor_of(atomr_core::actor::Props::create(|| EchoReplicaActor), "r2")
212            .unwrap();
213        let trainer = sys
214            .actor_of(
215                DataParallelTrainer::<EchoProtocol>::props(
216                    TrainerConfig {
217                        batch_size_per_device: 1,
218                        gradient_clip: None,
219                        optimizer: OptimizerKind::Sgd {
220                            lr: 0.1,
221                            momentum: 0.0,
222                            weight_decay: 0.0,
223                        },
224                        loss: LossKind::Mse,
225                    },
226                    vec![r1, r2],
227                ),
228                "trainer",
229            )
230            .unwrap();
231
232        let (tx, rx) = oneshot::channel();
233        trainer.tell(TrainerMsg::Step {
234            batch: vec![
235                TrainSample {
236                    features: vec![1.0, 2.0],
237                    label: vec![],
238                },
239                TrainSample {
240                    features: vec![3.0, 4.0],
241                    label: vec![],
242                },
243            ],
244            reply: tx,
245        });
246        let stats = tokio::time::timeout(Duration::from_secs(2), rx)
247            .await
248            .unwrap()
249            .unwrap()
250            .unwrap();
251        // Replica 0 sees [1.0, 2.0] → loss 3; replica 1 sees [3.0, 4.0] → loss 7.
252        // Weighted avg = (3*1 + 7*1) / 2 = 5.
253        assert!((stats.loss - 5.0).abs() < 1e-5);
254        assert_eq!(stats.grad_norm, 1.0);
255
256        sys.terminate().await;
257    }
258}