cubecl-cpu 0.10.0-pre.3

CPU runtime for CubeCL
use crate::{
    compiler::{MlirCompiler, register_supported_types},
    compute::server::CpuServer,
    device::CpuDevice,
};
use cubecl_common::{device::DeviceService, profile::TimingMethod};
use cubecl_core::{
    MemoryConfiguration, Runtime,
    client::ComputeClient,
    device::{DeviceId, ServerUtilitiesHandle},
    ir::{
        DeviceProperties, HardwareProperties, MemoryDeviceProperties, TargetProperties, VectorSize,
        features::Features,
    },
    server::ServerUtilities,
    zspace::{Shape, Strides},
};
use cubecl_runtime::{allocator::ContiguousMemoryLayoutPolicy, logging::ServerLogger};
use cubecl_std::tensor::is_contiguous;
use std::sync::Arc;
use sysinfo::System;

#[derive(Default)]
pub struct RuntimeOptions {
    /// Configures the memory management.
    pub memory_config: MemoryConfiguration,
}

#[derive(Debug, Clone)]
pub struct CpuRuntime;

pub type CpuCompiler = MlirCompiler;

impl DeviceService for CpuServer {
    fn init(_device_id: cubecl_common::device::DeviceId) -> Self {
        let options = RuntimeOptions::default();
        let max_cube_dim = (u32::MAX, u32::MAX, u32::MAX);
        let max_cube_count = (u32::MAX, u32::MAX, u32::MAX);
        let system = System::new_all();
        let max_shared_memory_size = system
            .cgroup_limits()
            .map(|g| g.total_memory)
            .unwrap_or(system.total_memory()) as usize;
        let logger = cubecl_common::stub::Arc::new(ServerLogger::default());

        let available_parallelism = std::thread::available_parallelism()
            .expect("Can't get available parallelism on this platform")
            .get();

        let topology = HardwareProperties {
            load_width: 512,
            plane_size_min: 1,
            plane_size_max: 1,
            max_bindings: u32::MAX,
            max_shared_memory_size,
            max_cube_count,
            num_cpu_cores: Some(available_parallelism as u32),
            max_units_per_cube: u32::MAX,
            max_cube_dim,
            num_streaming_multiprocessors: None,
            num_tensor_cores: None,
            min_tensor_cores_dim: None,
            max_vector_size: VectorSize::MAX,
        };

        const ALIGNMENT: u64 = 8;

        let mem_properties = MemoryDeviceProperties {
            max_page_size: max_shared_memory_size as u64,
            alignment: ALIGNMENT,
        };

        let mut device_props = DeviceProperties::new(
            Features {
                unaligned_io: true,
                ..Default::default()
            },
            mem_properties.clone(),
            topology,
            TimingMethod::Device,
        );
        register_supported_types(&mut device_props);

        let utilities = ServerUtilities::new(
            device_props,
            logger,
            (),
            ContiguousMemoryLayoutPolicy::new(ALIGNMENT as usize),
        );
        CpuServer::new(mem_properties, options.memory_config, Arc::new(utilities))
    }

    fn utilities(&self) -> ServerUtilitiesHandle {
        self.utilities() as ServerUtilitiesHandle
    }
}

impl Runtime for CpuRuntime {
    type Compiler = CpuCompiler;
    type Server = CpuServer;
    type Device = CpuDevice;

    fn client(device: &Self::Device) -> ComputeClient<Self> {
        ComputeClient::load(device)
    }

    fn name(_client: &ComputeClient<Self>) -> &'static str {
        "cpu"
    }

    fn max_cube_count() -> (u32, u32, u32) {
        (u32::MAX, u32::MAX, u32::MAX)
    }

    fn can_read_tensor(shape: &Shape, strides: &Strides) -> bool {
        is_contiguous(shape, strides)
    }

    fn target_properties() -> TargetProperties {
        TargetProperties {
            // Values are irrelevant, since no wgsl backends currently support manual mma
            mma: Default::default(),
        }
    }

    fn enumerate_devices(
        _: u16,
        _: &<Self::Server as cubecl_core::server::ComputeServer>::Info,
    ) -> Vec<DeviceId> {
        vec![DeviceId {
            type_id: 0,
            index_id: 0,
        }]
    }
}