use std::time::Instant;
use async_trait::async_trait;
use atomr_core::actor::{Actor, ActorRef, Context, Props};
use tokio::sync::oneshot;
use atomr_accel_cuda::error::GpuError;
use crate::optimizer::{OptimizerKind, StepStats};
pub trait ShardProtocol: Send + 'static {
type Msg: Send + 'static;
fn make_step(
input_slice: Vec<f32>,
reply: oneshot::Sender<Result<ShardStepResult, GpuError>>,
) -> Self::Msg;
}
#[derive(Debug, Clone, Default)]
pub struct ShardStepResult {
pub partial_output: Vec<f32>,
pub loss: f32,
pub grad_norm: f32,
pub samples: usize,
}
#[derive(Debug, Clone)]
pub struct TensorParallelConfig {
pub shard_count: usize,
pub optimizer: OptimizerKind,
}
pub enum TensorParallelMsg<P: ShardProtocol> {
Step {
input: Vec<f32>,
reply: oneshot::Sender<Result<(Vec<f32>, StepStats), GpuError>>,
},
#[doc(hidden)]
_Phantom(std::marker::PhantomData<fn() -> P>),
}
pub struct TensorParallelTrainer<P: ShardProtocol> {
config: TensorParallelConfig,
shards: Vec<ActorRef<P::Msg>>,
}
impl<P: ShardProtocol> TensorParallelTrainer<P> {
pub fn props(config: TensorParallelConfig, shards: Vec<ActorRef<P::Msg>>) -> Props<Self> {
Props::create(move || TensorParallelTrainer {
config: config.clone(),
shards: shards.clone(),
})
}
}
#[async_trait]
impl<P: ShardProtocol> Actor for TensorParallelTrainer<P> {
type Msg = TensorParallelMsg<P>;
async fn handle(&mut self, _ctx: &mut Context<Self>, msg: TensorParallelMsg<P>) {
match msg {
TensorParallelMsg::_Phantom(_) => {}
TensorParallelMsg::Step { input, reply } => {
if self.shards.is_empty() {
let _ = reply.send(Err(GpuError::Unrecoverable(
"TensorParallelTrainer::Step: no shards".into(),
)));
return;
}
let _ = &self.config;
let started = Instant::now();
let n = self.shards.len();
let chunk_size = input.len().div_ceil(n);
let mut chunks: Vec<Vec<f32>> = Vec::with_capacity(n);
for i in 0..n {
let lo = (i * chunk_size).min(input.len());
let hi = ((i + 1) * chunk_size).min(input.len());
chunks.push(input[lo..hi].to_vec());
}
let mut rxs = Vec::with_capacity(n);
for (s, chunk) in self.shards.iter().zip(chunks) {
let (tx, rx) = oneshot::channel();
s.tell(P::make_step(chunk, tx));
rxs.push(rx);
}
tokio::spawn(async move {
let mut summed: Option<Vec<f32>> = None;
let mut total_loss = 0.0f32;
let mut total_grad = 0.0f32;
let mut total_samples = 0usize;
for rx in rxs {
match rx.await {
Ok(Ok(r)) => {
match summed.as_mut() {
None => summed = Some(r.partial_output),
Some(acc) => {
if acc.len() != r.partial_output.len() {
let m = acc.len().max(r.partial_output.len());
acc.resize(m, 0.0);
for (i, v) in r.partial_output.iter().enumerate() {
acc[i] += *v;
}
} else {
for (i, v) in r.partial_output.iter().enumerate() {
acc[i] += *v;
}
}
}
}
total_loss += r.loss * r.samples as f32;
total_grad += r.grad_norm * r.samples as f32;
total_samples += r.samples;
}
Ok(Err(e)) => {
let _ = reply.send(Err(e));
return;
}
Err(_) => {
let _ = reply.send(Err(GpuError::Unrecoverable(
"tensor-parallel: shard dropped reply".into(),
)));
return;
}
}
}
let out = summed.unwrap_or_default();
let stats = StepStats {
loss: if total_samples > 0 {
total_loss / total_samples as f32
} else {
0.0
},
grad_norm: if total_samples > 0 {
total_grad / total_samples as f32
} else {
0.0
},
step_micros: started.elapsed().as_micros() as u64,
};
let _ = reply.send(Ok((out, stats)));
});
}
}
}
}