cubecl-common 0.10.0-pre.4

Common crate for CubeCL
Documentation
use crate::{
    device::{
        DeviceId, DeviceService,
        handle::{CallError, DeviceHandleSpec, ServerUtilitiesHandle, ServiceCreationError},
    },
    stream_id::StreamId,
    stub::{Arc, Mutex, RwLock},
};
use alloc::boxed::Box;
use core::{
    any::{Any, TypeId},
    marker::PhantomData,
};
use hashbrown::HashMap;

/// A handle to a specific device context (no-std version).
pub struct MutexDeviceHandle<S: DeviceService> {
    state: MutexDeviceState,
    device_id: DeviceId,
    // fn(S) makes this Send+Sync regardless of S, since the handle
    // never holds an S — it only accesses it through the Mutex.
    _phantom: PhantomData<fn(S)>,
}

#[derive(Clone)]
struct MutexDeviceState {
    service: Arc<Mutex<Box<dyn Any + Send>>>,
    device_id: DeviceId,
    utilities: ServerUtilitiesHandle,
}

/// The global storage for all device services.
/// In no-std, we use a global registry protected by a Mutex.
static DEVICE_REGISTRY: spin::Mutex<Option<HashMap<DeviceId, DeviceRegistry>>> =
    spin::Mutex::new(None);

/// Maps `TypeId` to the actual Service instance.
type DeviceRegistry = HashMap<TypeId, MutexDeviceState>;

impl<S: DeviceService + 'static> DeviceHandleSpec<S> for MutexDeviceHandle<S> {
    const BLOCKING: bool = true;

    fn new(device_id: DeviceId) -> Self {
        let mut guard = DEVICE_REGISTRY.lock();
        if guard.is_none() {
            *guard = Some(HashMap::new());
        };
        let device_map: &mut HashMap<_, _> = match guard.as_mut() {
            Some(val) => val.entry(device_id).or_insert_with(HashMap::new),
            None => unreachable!(),
        };

        let type_id = TypeId::of::<S>();

        let state = device_map
            .entry(type_id)
            .or_insert_with(|| {
                let state = S::init(device_id);
                let utilities = state.utilities();
                MutexDeviceState {
                    service: Arc::new(Mutex::new(Box::new(state))),
                    device_id,
                    utilities,
                }
            })
            .clone();

        Self {
            state,
            device_id,
            _phantom: PhantomData,
        }
    }

    fn device_id(&self) -> DeviceId {
        self.device_id
    }

    fn utilities(&self) -> ServerUtilitiesHandle {
        self.state.utilities.clone()
    }

    fn flush_queue(&self) {}

    fn submit_blocking<'a, R: Send, T: FnOnce(&mut S) -> R + Send + 'a>(
        &self,
        task: T,
    ) -> Result<R, CallError> {
        let mut guard = self.state.service.lock().unwrap();
        let state = guard.downcast_mut::<S>().expect("State type mismatch");

        Ok(task(state))
    }

    fn submit<T: FnOnce(&mut S) + Send + 'static>(&self, task: T) {
        let mut guard = self.state.service.lock().unwrap();
        let state = guard.downcast_mut::<S>().expect("State type mismatch");

        task(state);
    }

    fn insert(device_id: DeviceId, service: S) -> Result<Self, ServiceCreationError> {
        let mut guard = DEVICE_REGISTRY.lock();
        if guard.is_none() {
            *guard = Some(HashMap::new());
        };
        let device_map: &mut HashMap<_, _> = match guard.as_mut() {
            Some(val) => val.entry(device_id).or_insert_with(HashMap::new),
            None => unreachable!(),
        };

        let type_id = TypeId::of::<S>();

        if device_map.contains_key(&type_id) {
            return Err(ServiceCreationError::new("Service already created".into()));
        }

        let state = device_map
            .entry(type_id)
            .or_insert_with(|| {
                let utilities = service.utilities();
                MutexDeviceState {
                    service: Arc::new(Mutex::new(Box::new(service))),
                    device_id,
                    utilities,
                }
            })
            .clone();

        Ok(Self {
            state,
            device_id,
            _phantom: PhantomData,
        })
    }

    fn exclusive<R: Send, T: FnOnce() -> R + Send>(&self, task: T) -> Result<R, CallError> {
        let lock = self.device_lock();
        let guard = lock.lock();
        let result = Ok(task());
        core::mem::drop(guard);
        result
    }
}

static DEVICE_LOCK: spin::Mutex<Option<HashMap<DeviceId, Arc<DeviceLock>>>> =
    spin::Mutex::new(None);

struct DeviceLock {
    lock: RwLock<Option<StreamId>>,
    main: spin::Mutex<()>,
}

enum Guard<'a> {
    Reentrant,
    Main(spin::MutexGuard<'a, ()>, &'a DeviceLock),
}

impl<'a> Drop for Guard<'a> {
    fn drop(&mut self) {
        match self {
            Guard::Reentrant => {}
            Guard::Main(_mutex_guard, thread_mutex) => {
                let mut state = thread_mutex.lock.write().unwrap();
                *state = None;
            }
        }
    }
}

impl DeviceLock {
    pub fn lock(&self) -> Guard<'_> {
        // TODO: Use thread id when we can.
        let stream_id = StreamId::current();

        loop {
            let mut state = self.lock.write().unwrap();

            let is_ok = match state.as_ref() {
                Some(value) => *value == stream_id,
                None => {
                    *state = Some(stream_id);
                    let guard = self.main.lock();
                    return Guard::Main(guard, self);
                }
            };

            match is_ok {
                true => {
                    core::mem::drop(state);
                    return Guard::Reentrant;
                }
                false => {
                    // spin.
                }
            };
        }
    }
}

impl<S: DeviceService> MutexDeviceHandle<S> {
    fn device_lock(&self) -> Arc<DeviceLock> {
        let mut guard = DEVICE_LOCK.lock();
        if guard.is_none() {
            *guard = Some(HashMap::new());
        };

        let device_map = match guard.as_mut() {
            Some(val) => val.entry(self.device_id),
            None => unreachable!(),
        };

        device_map
            .or_insert_with(|| {
                Arc::new(DeviceLock {
                    lock: RwLock::new(None),
                    main: spin::Mutex::new(()),
                })
            })
            .clone()
    }
}
impl<S: DeviceService> Clone for MutexDeviceHandle<S> {
    fn clone(&self) -> Self {
        Self {
            state: self.state.clone(),
            device_id: self.device_id,
            _phantom: PhantomData,
        }
    }
}