use std::sync::Arc;
use rumus::nn::Parameter;
use rumus::tensor::Tensor;
use crate::collective::CollectiveBarrier;
pub struct ColumnParallelLinear {
pub weight: Parameter,
pub bias: Option<Parameter>,
pub rank: usize,
pub world_size: usize,
pub barrier: Arc<CollectiveBarrier>,
}
impl ColumnParallelLinear {
pub fn forward(&self, x: &Tensor) -> Tensor {
let y = x.matmul(&self.weight.tensor);
match &self.bias {
Some(b) => y.add_bias(&b.tensor),
None => y,
}
}
pub fn allreduce_grad_x(&self, grad_x: &Tensor) -> Tensor {
let data = {
let g = grad_x.data();
g.to_vec()
};
let reduced = self.barrier.reduce(data);
let t = Tensor::new(reduced, grad_x.shape().to_vec());
t.to_gpu();
t
}
}
pub struct RowParallelLinear {
pub weight: Parameter,
pub bias: Option<Parameter>,
pub rank: usize,
pub world_size: usize,
pub barrier: Arc<CollectiveBarrier>,
}
impl RowParallelLinear {
pub fn forward(&self, x_t: &Tensor) -> Tensor {
let y_partial = x_t.matmul(&self.weight.tensor);
let data = {
let g = y_partial.data();
g.to_vec()
};
let reduced = self.barrier.reduce(data);
let y = Tensor::new(reduced, y_partial.shape().to_vec());
y.to_gpu();
match &self.bias {
Some(b) if self.rank == 0 => y.add_bias(&b.tensor),
_ => y,
}
}
}