Skip to main content

atomr_accel_train/
tensor_parallel.rs

1//! `TensorParallelTrainer` — weight-sharded matmul: each replica
2//! owns a slice of the weight matrix; activations are split, each
3//! shard runs a partial matmul, then results are summed via
4//! `AllReduce`.
5//!
6//! F6 ships the public surface + a host-side reference. Each shard
7//! implements [`ShardProtocol`] which receives a partial input slice
8//! and returns its partial output. The trainer collects all
9//! partials and sums them.
10
11use std::time::Instant;
12
13use async_trait::async_trait;
14use atomr_core::actor::{Actor, ActorRef, Context, Props};
15use tokio::sync::oneshot;
16
17use atomr_accel_cuda::error::GpuError;
18
19use crate::optimizer::{OptimizerKind, StepStats};
20
21pub trait ShardProtocol: Send + 'static {
22    type Msg: Send + 'static;
23    /// Per-shard step: takes a slice of input + the local weight
24    /// shard, returns `(partial_output, partial_loss, partial_grad_norm)`.
25    fn make_step(
26        input_slice: Vec<f32>,
27        reply: oneshot::Sender<Result<ShardStepResult, GpuError>>,
28    ) -> Self::Msg;
29}
30
31#[derive(Debug, Clone, Default)]
32pub struct ShardStepResult {
33    pub partial_output: Vec<f32>,
34    pub loss: f32,
35    pub grad_norm: f32,
36    pub samples: usize,
37}
38
39#[derive(Debug, Clone)]
40pub struct TensorParallelConfig {
41    pub shard_count: usize,
42    pub optimizer: OptimizerKind,
43}
44
45pub enum TensorParallelMsg<P: ShardProtocol> {
46    Step {
47        input: Vec<f32>,
48        reply: oneshot::Sender<Result<(Vec<f32>, StepStats), GpuError>>,
49    },
50    #[doc(hidden)]
51    _Phantom(std::marker::PhantomData<fn() -> P>),
52}
53
54pub struct TensorParallelTrainer<P: ShardProtocol> {
55    config: TensorParallelConfig,
56    shards: Vec<ActorRef<P::Msg>>,
57}
58
59impl<P: ShardProtocol> TensorParallelTrainer<P> {
60    pub fn props(config: TensorParallelConfig, shards: Vec<ActorRef<P::Msg>>) -> Props<Self> {
61        Props::create(move || TensorParallelTrainer {
62            config: config.clone(),
63            shards: shards.clone(),
64        })
65    }
66}
67
68#[async_trait]
69impl<P: ShardProtocol> Actor for TensorParallelTrainer<P> {
70    type Msg = TensorParallelMsg<P>;
71
72    async fn handle(&mut self, _ctx: &mut Context<Self>, msg: TensorParallelMsg<P>) {
73        match msg {
74            TensorParallelMsg::_Phantom(_) => {}
75            TensorParallelMsg::Step { input, reply } => {
76                if self.shards.is_empty() {
77                    let _ = reply.send(Err(GpuError::Unrecoverable(
78                        "TensorParallelTrainer::Step: no shards".into(),
79                    )));
80                    return;
81                }
82                let _ = &self.config;
83                let started = Instant::now();
84                let n = self.shards.len();
85                // Split input by chunks across shards (round-robin
86                // along the row dimension).
87                let chunk_size = input.len().div_ceil(n);
88                let mut chunks: Vec<Vec<f32>> = Vec::with_capacity(n);
89                for i in 0..n {
90                    let lo = (i * chunk_size).min(input.len());
91                    let hi = ((i + 1) * chunk_size).min(input.len());
92                    chunks.push(input[lo..hi].to_vec());
93                }
94                let mut rxs = Vec::with_capacity(n);
95                for (s, chunk) in self.shards.iter().zip(chunks) {
96                    let (tx, rx) = oneshot::channel();
97                    s.tell(P::make_step(chunk, tx));
98                    rxs.push(rx);
99                }
100                tokio::spawn(async move {
101                    let mut summed: Option<Vec<f32>> = None;
102                    let mut total_loss = 0.0f32;
103                    let mut total_grad = 0.0f32;
104                    let mut total_samples = 0usize;
105                    for rx in rxs {
106                        match rx.await {
107                            Ok(Ok(r)) => {
108                                match summed.as_mut() {
109                                    None => summed = Some(r.partial_output),
110                                    Some(acc) => {
111                                        if acc.len() != r.partial_output.len() {
112                                            // Pad to longest.
113                                            let m = acc.len().max(r.partial_output.len());
114                                            acc.resize(m, 0.0);
115                                            for (i, v) in r.partial_output.iter().enumerate() {
116                                                acc[i] += *v;
117                                            }
118                                        } else {
119                                            for (i, v) in r.partial_output.iter().enumerate() {
120                                                acc[i] += *v;
121                                            }
122                                        }
123                                    }
124                                }
125                                total_loss += r.loss * r.samples as f32;
126                                total_grad += r.grad_norm * r.samples as f32;
127                                total_samples += r.samples;
128                            }
129                            Ok(Err(e)) => {
130                                let _ = reply.send(Err(e));
131                                return;
132                            }
133                            Err(_) => {
134                                let _ = reply.send(Err(GpuError::Unrecoverable(
135                                    "tensor-parallel: shard dropped reply".into(),
136                                )));
137                                return;
138                            }
139                        }
140                    }
141                    let out = summed.unwrap_or_default();
142                    let stats = StepStats {
143                        loss: if total_samples > 0 {
144                            total_loss / total_samples as f32
145                        } else {
146                            0.0
147                        },
148                        grad_norm: if total_samples > 0 {
149                            total_grad / total_samples as f32
150                        } else {
151                            0.0
152                        },
153                        step_micros: started.elapsed().as_micros() as u64,
154                    };
155                    let _ = reply.send(Ok((out, stats)));
156                });
157            }
158        }
159    }
160}