atomr_accel_train/
tensor_parallel.rs1use std::time::Instant;
12
13use async_trait::async_trait;
14use atomr_core::actor::{Actor, ActorRef, Context, Props};
15use tokio::sync::oneshot;
16
17use atomr_accel_cuda::error::GpuError;
18
19use crate::optimizer::{OptimizerKind, StepStats};
20
21pub trait ShardProtocol: Send + 'static {
22 type Msg: Send + 'static;
23 fn make_step(
26 input_slice: Vec<f32>,
27 reply: oneshot::Sender<Result<ShardStepResult, GpuError>>,
28 ) -> Self::Msg;
29}
30
31#[derive(Debug, Clone, Default)]
32pub struct ShardStepResult {
33 pub partial_output: Vec<f32>,
34 pub loss: f32,
35 pub grad_norm: f32,
36 pub samples: usize,
37}
38
39#[derive(Debug, Clone)]
40pub struct TensorParallelConfig {
41 pub shard_count: usize,
42 pub optimizer: OptimizerKind,
43}
44
45pub enum TensorParallelMsg<P: ShardProtocol> {
46 Step {
47 input: Vec<f32>,
48 reply: oneshot::Sender<Result<(Vec<f32>, StepStats), GpuError>>,
49 },
50 #[doc(hidden)]
51 _Phantom(std::marker::PhantomData<fn() -> P>),
52}
53
54pub struct TensorParallelTrainer<P: ShardProtocol> {
55 config: TensorParallelConfig,
56 shards: Vec<ActorRef<P::Msg>>,
57}
58
59impl<P: ShardProtocol> TensorParallelTrainer<P> {
60 pub fn props(config: TensorParallelConfig, shards: Vec<ActorRef<P::Msg>>) -> Props<Self> {
61 Props::create(move || TensorParallelTrainer {
62 config: config.clone(),
63 shards: shards.clone(),
64 })
65 }
66}
67
68#[async_trait]
69impl<P: ShardProtocol> Actor for TensorParallelTrainer<P> {
70 type Msg = TensorParallelMsg<P>;
71
72 async fn handle(&mut self, _ctx: &mut Context<Self>, msg: TensorParallelMsg<P>) {
73 match msg {
74 TensorParallelMsg::_Phantom(_) => {}
75 TensorParallelMsg::Step { input, reply } => {
76 if self.shards.is_empty() {
77 let _ = reply.send(Err(GpuError::Unrecoverable(
78 "TensorParallelTrainer::Step: no shards".into(),
79 )));
80 return;
81 }
82 let _ = &self.config;
83 let started = Instant::now();
84 let n = self.shards.len();
85 let chunk_size = input.len().div_ceil(n);
88 let mut chunks: Vec<Vec<f32>> = Vec::with_capacity(n);
89 for i in 0..n {
90 let lo = (i * chunk_size).min(input.len());
91 let hi = ((i + 1) * chunk_size).min(input.len());
92 chunks.push(input[lo..hi].to_vec());
93 }
94 let mut rxs = Vec::with_capacity(n);
95 for (s, chunk) in self.shards.iter().zip(chunks) {
96 let (tx, rx) = oneshot::channel();
97 s.tell(P::make_step(chunk, tx));
98 rxs.push(rx);
99 }
100 tokio::spawn(async move {
101 let mut summed: Option<Vec<f32>> = None;
102 let mut total_loss = 0.0f32;
103 let mut total_grad = 0.0f32;
104 let mut total_samples = 0usize;
105 for rx in rxs {
106 match rx.await {
107 Ok(Ok(r)) => {
108 match summed.as_mut() {
109 None => summed = Some(r.partial_output),
110 Some(acc) => {
111 if acc.len() != r.partial_output.len() {
112 let m = acc.len().max(r.partial_output.len());
114 acc.resize(m, 0.0);
115 for (i, v) in r.partial_output.iter().enumerate() {
116 acc[i] += *v;
117 }
118 } else {
119 for (i, v) in r.partial_output.iter().enumerate() {
120 acc[i] += *v;
121 }
122 }
123 }
124 }
125 total_loss += r.loss * r.samples as f32;
126 total_grad += r.grad_norm * r.samples as f32;
127 total_samples += r.samples;
128 }
129 Ok(Err(e)) => {
130 let _ = reply.send(Err(e));
131 return;
132 }
133 Err(_) => {
134 let _ = reply.send(Err(GpuError::Unrecoverable(
135 "tensor-parallel: shard dropped reply".into(),
136 )));
137 return;
138 }
139 }
140 }
141 let out = summed.unwrap_or_default();
142 let stats = StepStats {
143 loss: if total_samples > 0 {
144 total_loss / total_samples as f32
145 } else {
146 0.0
147 },
148 grad_norm: if total_samples > 0 {
149 total_grad / total_samples as f32
150 } else {
151 0.0
152 },
153 step_micros: started.elapsed().as_micros() as u64,
154 };
155 let _ = reply.send(Ok((out, stats)));
156 });
157 }
158 }
159 }
160}