cubecl-cpu 0.10.0-pre.3

CPU runtime for CubeCL
use cubecl_core::server::ExecutionMode;

use crate::compiler::{mlir_data::MlirData, mlir_engine::MlirEngine};
use std::sync::{
    atomic::{AtomicI32, Ordering},
    mpsc,
};

pub static BARRIER_COUNTER: AtomicI32 = AtomicI32::new(0);
pub static STOPPED_COUNTER: AtomicI32 = AtomicI32::new(0);
pub static BARRIER_TARGET: AtomicI32 = AtomicI32::new(0);
pub static CURRENT_CUBE_DIM: AtomicI32 = AtomicI32::new(-1);

pub fn sync_cube() {
    let barrier_target = BARRIER_TARGET.load(Ordering::Acquire);
    if barrier_target <= 1 {
        return;
    }

    while STOPPED_COUNTER.load(Ordering::Acquire) != 0 {
        std::hint::spin_loop();
    }

    std::sync::atomic::fence(Ordering::Release);
    let arrived = BARRIER_COUNTER.fetch_add(1, Ordering::AcqRel) + 1;

    if arrived < barrier_target {
        while BARRIER_COUNTER.load(Ordering::Acquire) < barrier_target {
            std::hint::spin_loop();
        }
    }

    std::sync::atomic::fence(Ordering::Acquire);

    let stopped = STOPPED_COUNTER.fetch_add(1, Ordering::AcqRel) + 1;
    if stopped == barrier_target {
        BARRIER_COUNTER.store(0, Ordering::Release);
        STOPPED_COUNTER.store(0, Ordering::Release);
    }
}

pub enum Message {
    ComputeTask(ComputeTask),
    EndTask(mpsc::Sender<()>),
}

pub struct ComputeTask {
    pub mlir_engine: MlirEngine,
    pub mlir_data: MlirData,
    pub unit_pos: [u32; 3],
    pub kind: ExecutionMode,
}

impl ComputeTask {
    pub fn compute(mut self) {
        self.mlir_data.push_builtin();
        self.mlir_data.builtin.set_unit_pos(self.unit_pos);
        unsafe {
            self.mlir_engine.run_kernel(&mut self.mlir_data);
        }
        CURRENT_CUBE_DIM.fetch_sub(1, Ordering::AcqRel);
    }
}