use crate::{WgpuResource, stream::WgpuStream};
use alloc::sync::Arc;
use cubecl_common::{bytes::Bytes, profile::TimingMethod};
use cubecl_core::{
CubeCount, MemoryConfiguration,
ir::StorageType,
server::{MetadataBinding, ScalarBinding},
};
use cubecl_ir::MemoryDeviceProperties;
use cubecl_runtime::{
logging::ServerLogger,
stream::{StreamFactory, scheduler::SchedulerStreamBackend},
};
use std::collections::BTreeMap;
#[derive(Debug)]
pub enum ScheduleTask {
Write {
data: Bytes,
buffer: WgpuResource,
},
Execute {
pipeline: Arc<wgpu::ComputePipeline>,
count: CubeCount,
resources: BindingsResource,
},
}
#[derive(Debug)]
pub struct BindingsResource {
pub resources: Vec<WgpuResource>,
pub metadata: MetadataBinding,
pub scalars: BTreeMap<StorageType, ScalarBinding>,
}
#[derive(Debug)]
pub struct ScheduledWgpuBackend {
factory: WgpuStreamFactory,
}
#[derive(Debug)]
pub struct WgpuStreamFactory {
device: wgpu::Device,
queue: wgpu::Queue,
memory_properties: MemoryDeviceProperties,
memory_config: MemoryConfiguration,
timing_method: TimingMethod,
tasks_max: usize,
logger: Arc<ServerLogger>,
}
impl StreamFactory for WgpuStreamFactory {
type Stream = WgpuStream;
fn create(&mut self) -> Self::Stream {
WgpuStream::new(
self.device.clone(),
self.queue.clone(),
self.memory_properties.clone(),
self.memory_config.clone(),
self.timing_method,
self.tasks_max,
self.logger.clone(),
)
}
}
impl ScheduledWgpuBackend {
pub fn new(
device: wgpu::Device,
queue: wgpu::Queue,
memory_properties: MemoryDeviceProperties,
memory_config: MemoryConfiguration,
timing_method: TimingMethod,
tasks_max: usize,
logger: Arc<ServerLogger>,
) -> Self {
Self {
factory: WgpuStreamFactory {
device,
queue,
memory_properties,
memory_config,
timing_method,
tasks_max,
logger,
},
}
}
}
impl BindingsResource {
pub fn into_resources(mut self, stream: &mut WgpuStream) -> Vec<WgpuResource> {
if !self.metadata.data.is_empty() {
let info = stream.create_uniform(bytemuck::cast_slice(&self.metadata.data));
self.resources.push(info);
}
self.resources.extend(
self.scalars
.values()
.map(|s| stream.create_uniform(s.data())),
);
self.resources
}
}
impl SchedulerStreamBackend for ScheduledWgpuBackend {
type Task = ScheduleTask;
type Stream = WgpuStream;
type Factory = WgpuStreamFactory;
fn enqueue(task: Self::Task, stream: &mut Self::Stream) {
stream.enqueue_task(task);
}
fn flush(stream: &mut Self::Stream) {
stream.flush();
}
fn factory(&mut self) -> &mut Self::Factory {
&mut self.factory
}
}