use cubecl::{prelude::barrier::Barrier, prelude::*};
use crate::{
components::{
global::{SharedGlobalMatmulConfig, read::SyncStrategy},
stage::StageConfig,
},
definition::MatmulTypes,
};
pub struct AsyncBarrier {}
#[cube]
impl SyncStrategy for AsyncBarrier {
type Barrier = Shared<Barrier>;
fn create_barrier() -> Self::Barrier {
Barrier::shared(CUBE_DIM, UNIT_POS == 0)
}
fn sync<MP: MatmulTypes, S: StageConfig>(
barrier: &mut Self::Barrier,
#[comptime] _config: SharedGlobalMatmulConfig<S>,
) {
barrier.arrive_and_wait();
}
}
pub struct AsyncCopy {}
#[cube]
impl SyncStrategy for AsyncCopy {
type Barrier = Shared<Barrier>;
fn create_barrier() -> Self::Barrier {
Barrier::shared(CUBE_DIM, UNIT_POS == 0)
}
fn sync<MP: MatmulTypes, S: StageConfig>(
barrier: &mut Self::Barrier,
#[comptime] _config: SharedGlobalMatmulConfig<S>,
) {
barrier.commit_copy_async();
barrier.arrive_and_wait();
}
}