use crate::{WgpuResource, stream::WgpuStream};
use alloc::sync::Arc;
use cubecl_common::{bytes::Bytes, profile::TimingMethod};
use cubecl_core::{
CubeCount, MemoryConfiguration,
server::{MetadataBindingInfo, StreamErrorMode},
};
use cubecl_ir::MemoryDeviceProperties;
use cubecl_runtime::{
logging::ServerLogger,
stream::{StreamFactory, scheduler::SchedulerStreamBackend},
};
pub enum ScheduleTask {
Write {
data: Bytes,
buffer: WgpuResource,
},
Execute {
pipeline: Arc<wgpu::ComputePipeline>,
count: CubeCount,
resources: BindingsResource,
},
}
impl core::fmt::Debug for ScheduleTask {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Write { data, .. } => f.write_fmt(format_args!("Write(bytes={})", data.len())),
Self::Execute {
count, resources, ..
} => f.write_fmt(format_args!(
"Execute(resources={}, cube_count={count:?})",
resources.resources.len()
)),
}
}
}
#[derive(Debug)]
pub struct BindingsResource {
pub resources: Vec<WgpuResource>,
pub info: MetadataBindingInfo,
}
#[derive(Debug)]
pub struct ScheduledWgpuBackend {
factory: WgpuStreamFactory,
}
#[derive(Debug)]
pub struct WgpuStreamFactory {
device: wgpu::Device,
queue: wgpu::Queue,
memory_properties: MemoryDeviceProperties,
memory_config: MemoryConfiguration,
timing_method: TimingMethod,
tasks_max: usize,
logger: Arc<ServerLogger>,
count: u64,
}
impl StreamFactory for WgpuStreamFactory {
type Stream = WgpuStream;
fn create(&mut self) -> Self::Stream {
self.count += 1;
WgpuStream::new(
self.device.clone(),
self.queue.clone(),
self.memory_properties.clone(),
self.memory_config.clone(),
self.timing_method,
self.tasks_max,
self.logger.clone(),
)
}
}
impl ScheduledWgpuBackend {
pub fn new(
device: wgpu::Device,
queue: wgpu::Queue,
memory_properties: MemoryDeviceProperties,
memory_config: MemoryConfiguration,
timing_method: TimingMethod,
tasks_max: usize,
logger: Arc<ServerLogger>,
) -> Self {
Self {
factory: WgpuStreamFactory {
device,
queue,
memory_properties,
memory_config,
timing_method,
tasks_max,
logger,
count: 0,
},
}
}
}
impl BindingsResource {
pub fn into_resources(mut self, stream: &mut WgpuStream) -> Vec<WgpuResource> {
if !self.info.data.is_empty() {
let info = stream.create_uniform(bytemuck::cast_slice(&self.info.data));
self.resources.push(info);
}
self.resources
}
}
impl SchedulerStreamBackend for ScheduledWgpuBackend {
type Task = ScheduleTask;
type Stream = WgpuStream;
type Factory = WgpuStreamFactory;
fn enqueue(task: Self::Task, stream: &mut Self::Stream) {
stream.enqueue_task(task);
}
fn flush(stream: &mut Self::Stream) {
let _ = stream
.flush(StreamErrorMode {
ignore: true,
flush: false,
})
.ok();
}
fn factory(&mut self) -> &mut Self::Factory {
&mut self.factory
}
}