use crate::{compiler::mlir_engine::MlirEngine, compute::stream::CpuStream};
use cubecl_common::bytes::Bytes;
use cubecl_core::{
CubeDim, ExecutionMode, MemoryConfiguration, ir::MemoryDeviceProperties,
server::MetadataBindingInfo,
};
use cubecl_runtime::{
logging::ServerLogger,
storage::BytesResource,
stream::{StreamFactory, scheduler::SchedulerStreamBackend},
};
use std::sync::Arc;
pub enum ScheduleTask {
Write { data: Bytes, buffer: BytesResource },
Execute {
mlir_engine: MlirEngine,
bindings: BindingsResource,
kind: ExecutionMode,
cube_dim: CubeDim,
cube_count: [u32; 3],
},
}
impl core::fmt::Debug for ScheduleTask {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Write { data, buffer } => f
.debug_struct("Write")
.field("data", data)
.field("buffer", buffer)
.finish(),
Self::Execute {
mlir_engine: _,
bindings: _,
kind,
cube_dim,
cube_count,
} => f
.debug_struct("Execute")
.field("kind", kind)
.field("cube_dim", cube_dim)
.field("cube_count", cube_count)
.finish(),
}
}
}
#[derive(Debug)]
pub struct BindingsResource {
pub resources: Vec<BytesResource>,
pub info: MetadataBindingInfo,
}
#[derive(Debug)]
pub struct ScheduledCpuBackend {
factory: CpuStreamFactory,
}
#[derive(Debug)]
pub struct CpuStreamFactory {
memory_properties: MemoryDeviceProperties,
memory_config: MemoryConfiguration,
logger: Arc<ServerLogger>,
}
impl StreamFactory for CpuStreamFactory {
type Stream = CpuStream;
fn create(&mut self) -> Self::Stream {
CpuStream::new(
self.memory_properties.clone(),
self.memory_config.clone(),
self.logger.clone(),
)
}
}
impl ScheduledCpuBackend {
pub fn new(
memory_properties: MemoryDeviceProperties,
memory_config: MemoryConfiguration,
logger: Arc<ServerLogger>,
) -> Self {
Self {
factory: CpuStreamFactory {
memory_properties,
memory_config,
logger,
},
}
}
}
impl SchedulerStreamBackend for ScheduledCpuBackend {
type Task = ScheduleTask;
type Stream = CpuStream;
type Factory = CpuStreamFactory;
fn enqueue(task: Self::Task, stream: &mut Self::Stream) {
stream.enqueue_task(task);
}
fn flush(stream: &mut Self::Stream) {
let _ = stream
.flush(cubecl_core::server::StreamErrorMode {
ignore: true,
flush: false,
})
.ok();
}
fn factory(&mut self) -> &mut Self::Factory {
&mut self.factory
}
}