1use std::time::Instant;
14
15use async_trait::async_trait;
16use atomr_core::actor::{Actor, ActorRef, Context, Props};
17use tokio::sync::oneshot;
18
19use atomr_accel_cuda::error::GpuError;
20
21use crate::loss::LossKind;
22use crate::optimizer::{OptimizerKind, StepStats};
23
24pub trait ReplicaProtocol: Send + 'static {
27 type Msg: Send + 'static;
28 fn make_step(
29 chunk: Vec<TrainSample>,
30 reply: oneshot::Sender<Result<ReplicaStepResult, GpuError>>,
31 ) -> Self::Msg;
32}
33
34#[derive(Debug, Clone)]
35pub struct TrainSample {
36 pub features: Vec<f32>,
37 pub label: Vec<f32>,
38}
39
40#[derive(Debug, Clone, Copy, Default)]
41pub struct ReplicaStepResult {
42 pub loss: f32,
43 pub grad_norm: f32,
44 pub samples: usize,
45}
46
47#[derive(Debug, Clone)]
48pub struct TrainerConfig {
49 pub batch_size_per_device: usize,
50 pub gradient_clip: Option<f32>,
51 pub optimizer: OptimizerKind,
52 pub loss: LossKind,
53}
54
55pub enum TrainerMsg<P: ReplicaProtocol> {
56 Step {
57 batch: Vec<TrainSample>,
58 reply: oneshot::Sender<Result<StepStats, GpuError>>,
59 },
60 SetReplicas { replicas: Vec<ActorRef<P::Msg>> },
64}
65
66pub struct DataParallelTrainer<P: ReplicaProtocol> {
67 config: TrainerConfig,
68 replicas: Vec<ActorRef<P::Msg>>,
69}
70
71impl<P: ReplicaProtocol> DataParallelTrainer<P> {
72 pub fn props(config: TrainerConfig, replicas: Vec<ActorRef<P::Msg>>) -> Props<Self> {
73 Props::create(move || DataParallelTrainer {
74 config: config.clone(),
75 replicas: replicas.clone(),
76 })
77 }
78
79 fn split_batch(&self, batch: Vec<TrainSample>) -> Vec<Vec<TrainSample>> {
80 let n = self.replicas.len().max(1);
81 let mut out: Vec<Vec<TrainSample>> = (0..n).map(|_| Vec::new()).collect();
82 for (i, s) in batch.into_iter().enumerate() {
83 out[i % n].push(s);
84 }
85 out
86 }
87}
88
89#[async_trait]
90impl<P: ReplicaProtocol> Actor for DataParallelTrainer<P> {
91 type Msg = TrainerMsg<P>;
92
93 async fn handle(&mut self, _ctx: &mut Context<Self>, msg: TrainerMsg<P>) {
94 match msg {
95 TrainerMsg::SetReplicas { replicas } => {
96 self.replicas = replicas;
97 }
98 TrainerMsg::Step { batch, reply } => {
99 if self.replicas.is_empty() {
100 let _ = reply.send(Err(GpuError::Unrecoverable(
101 "DataParallelTrainer::Step: no replicas configured".into(),
102 )));
103 return;
104 }
105 let _ = &self.config; let started = Instant::now();
107 let chunks = self.split_batch(batch);
108 let mut rxs = Vec::with_capacity(chunks.len());
109 for (replica, chunk) in self.replicas.iter().zip(chunks) {
110 let (tx, rx) = oneshot::channel();
111 replica.tell(P::make_step(chunk, tx));
112 rxs.push(rx);
113 }
114 tokio::spawn(async move {
115 let mut total_loss = 0.0f32;
116 let mut total_grad_norm = 0.0f32;
117 let mut total_samples = 0usize;
118 for rx in rxs {
119 match rx.await {
120 Ok(Ok(r)) => {
121 total_loss += r.loss * r.samples as f32;
122 total_grad_norm += r.grad_norm * r.samples as f32;
123 total_samples += r.samples;
124 }
125 Ok(Err(e)) => {
126 let _ = reply.send(Err(e));
127 return;
128 }
129 Err(_) => {
130 let _ = reply.send(Err(GpuError::Unrecoverable(
131 "trainer: replica dropped reply".into(),
132 )));
133 return;
134 }
135 }
136 }
137 if total_samples == 0 {
138 let _ = reply
139 .send(Err(GpuError::Unrecoverable("trainer: zero samples".into())));
140 return;
141 }
142 let stats = StepStats {
143 loss: total_loss / total_samples as f32,
144 grad_norm: total_grad_norm / total_samples as f32,
145 step_micros: started.elapsed().as_micros() as u64,
146 };
147 let _ = reply.send(Ok(stats));
148 });
149 }
150 }
151 }
152}
153
154#[cfg(test)]
155mod tests {
156 use super::*;
157 use atomr_config::Config;
158 use atomr_core::actor::ActorSystem;
159 use std::time::Duration;
160
161 enum EchoMsg {
163 Step {
164 chunk: Vec<TrainSample>,
165 reply: oneshot::Sender<Result<ReplicaStepResult, GpuError>>,
166 },
167 }
168
169 struct EchoReplicaActor;
170 #[async_trait]
171 impl Actor for EchoReplicaActor {
172 type Msg = EchoMsg;
173 async fn handle(&mut self, _ctx: &mut Context<Self>, msg: EchoMsg) {
174 match msg {
175 EchoMsg::Step { chunk, reply } => {
176 let n = chunk.len();
177 let mut sum = 0.0f32;
178 for s in &chunk {
179 sum += s.features.iter().sum::<f32>();
180 }
181 let _ = reply.send(Ok(ReplicaStepResult {
182 loss: if n > 0 { sum / n as f32 } else { 0.0 },
183 grad_norm: 1.0,
184 samples: n,
185 }));
186 }
187 }
188 }
189 }
190
191 struct EchoProtocol;
192 impl ReplicaProtocol for EchoProtocol {
193 type Msg = EchoMsg;
194 fn make_step(
195 chunk: Vec<TrainSample>,
196 reply: oneshot::Sender<Result<ReplicaStepResult, GpuError>>,
197 ) -> Self::Msg {
198 EchoMsg::Step { chunk, reply }
199 }
200 }
201
202 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
203 async fn step_aggregates_across_replicas() {
204 let sys = ActorSystem::create("trainer-test", Config::empty())
205 .await
206 .unwrap();
207 let r1 = sys
208 .actor_of(atomr_core::actor::Props::create(|| EchoReplicaActor), "r1")
209 .unwrap();
210 let r2 = sys
211 .actor_of(atomr_core::actor::Props::create(|| EchoReplicaActor), "r2")
212 .unwrap();
213 let trainer = sys
214 .actor_of(
215 DataParallelTrainer::<EchoProtocol>::props(
216 TrainerConfig {
217 batch_size_per_device: 1,
218 gradient_clip: None,
219 optimizer: OptimizerKind::Sgd {
220 lr: 0.1,
221 momentum: 0.0,
222 weight_decay: 0.0,
223 },
224 loss: LossKind::Mse,
225 },
226 vec![r1, r2],
227 ),
228 "trainer",
229 )
230 .unwrap();
231
232 let (tx, rx) = oneshot::channel();
233 trainer.tell(TrainerMsg::Step {
234 batch: vec![
235 TrainSample {
236 features: vec![1.0, 2.0],
237 label: vec![],
238 },
239 TrainSample {
240 features: vec![3.0, 4.0],
241 label: vec![],
242 },
243 ],
244 reply: tx,
245 });
246 let stats = tokio::time::timeout(Duration::from_secs(2), rx)
247 .await
248 .unwrap()
249 .unwrap()
250 .unwrap();
251 assert!((stats.loss - 5.0).abs() < 1e-5);
254 assert_eq!(stats.grad_norm, 1.0);
255
256 sys.terminate().await;
257 }
258}