pub struct RowParallelLinear {
pub weight: Parameter,
pub bias: Option<Parameter>,
pub rank: usize,
pub world_size: usize,
pub barrier: Arc<CollectiveBarrier>,
}Expand description
Linear layer with weight sharded along rows (K dimension).
Forward: Y_t = X_t @ W_t (partial sum), then Y = AllReduce(Y_t).
Backward: grad_X_t = grad_Y @ W_t^T (no collective).
Fields§
§weight: Parameter§bias: Option<Parameter>§rank: usize§world_size: usize§barrier: Arc<CollectiveBarrier>Implementations§
Auto Trait Implementations§
impl Freeze for RowParallelLinear
impl !RefUnwindSafe for RowParallelLinear
impl Send for RowParallelLinear
impl Sync for RowParallelLinear
impl Unpin for RowParallelLinear
impl UnsafeUnpin for RowParallelLinear
impl !UnwindSafe for RowParallelLinear
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Mutably borrows from an owned value. Read more