use std::sync::Arc;
use morok_dtype::DeviceSpec;
use crate::buffer::Buffer;
use crate::error::Result;
use crate::sync::TimelineSignal;
#[derive(Debug, Clone)]
pub struct ExecParams {
pub global_size: [usize; 3],
pub local_size: [usize; 3],
}
impl ExecParams {
pub fn new_1d(global: usize, local: usize) -> Self {
Self { global_size: [global, 1, 1], local_size: [local, 1, 1] }
}
pub fn new_2d(global: [usize; 2], local: [usize; 2]) -> Self {
Self { global_size: [global[0], global[1], 1], local_size: [local[0], local[1], 1] }
}
pub fn new_3d(global: [usize; 3], local: [usize; 3]) -> Self {
Self { global_size: global, local_size: local }
}
}
impl Default for ExecParams {
fn default() -> Self {
Self { global_size: [1, 1, 1], local_size: [1, 1, 1] }
}
}
pub trait Program: Send + Sync + std::fmt::Debug {
fn device(&self) -> &DeviceSpec;
fn name(&self) -> &str;
}
pub trait HardwareQueue: Send + std::fmt::Debug {
type Signal: TimelineSignal;
fn wait(&mut self, signal: &Self::Signal, value: u64) -> &mut Self;
fn signal(&mut self, signal: &Self::Signal, value: u64) -> &mut Self;
fn exec(&mut self, program: &dyn Program, buffers: &[&Buffer], params: &ExecParams) -> &mut Self;
fn copy(&mut self, dst: &Buffer, src: &Buffer) -> &mut Self;
fn memory_barrier(&mut self) -> &mut Self;
fn submit(&mut self) -> Result<()>;
fn device(&self) -> &DeviceSpec;
}
pub trait QueueFactory: Send + Sync + std::fmt::Debug {
type Queue: HardwareQueue;
type Signal: TimelineSignal;
fn create_compute_queue(&self) -> Result<Self::Queue>;
fn create_copy_queue(&self) -> Result<Option<Self::Queue>>;
fn create_signal(&self) -> Result<Arc<Self::Signal>>;
fn device(&self) -> &DeviceSpec;
}
pub struct DynQueue {
inner: Box<dyn DynQueueInner>,
}
impl std::fmt::Debug for DynQueue {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("DynQueue").field("device", &self.inner.device()).finish()
}
}
impl DynQueue {
pub fn new<Q: HardwareQueue + 'static>(queue: Q) -> Self
where
Q::Signal: 'static,
{
Self { inner: Box::new(DynQueueWrapper { queue, _phantom: std::marker::PhantomData }) }
}
pub fn wait(&mut self, signal: &dyn TimelineSignal, value: u64) -> &mut Self {
self.inner.wait_dyn(signal, value);
self
}
pub fn signal(&mut self, signal: &dyn TimelineSignal, value: u64) -> &mut Self {
self.inner.signal_dyn(signal, value);
self
}
pub fn exec(&mut self, program: &dyn Program, buffers: &[&Buffer], params: &ExecParams) -> &mut Self {
self.inner.exec_dyn(program, buffers, params);
self
}
pub fn copy(&mut self, dst: &Buffer, src: &Buffer) -> &mut Self {
self.inner.copy_dyn(dst, src);
self
}
pub fn memory_barrier(&mut self) -> &mut Self {
self.inner.memory_barrier_dyn();
self
}
pub fn submit(&mut self) -> Result<()> {
self.inner.submit_dyn()
}
pub fn device(&self) -> &DeviceSpec {
self.inner.device()
}
}
trait DynQueueInner: Send + std::fmt::Debug {
fn wait_dyn(&mut self, signal: &dyn TimelineSignal, value: u64);
fn signal_dyn(&mut self, signal: &dyn TimelineSignal, value: u64);
fn exec_dyn(&mut self, program: &dyn Program, buffers: &[&Buffer], params: &ExecParams);
fn copy_dyn(&mut self, dst: &Buffer, src: &Buffer);
fn memory_barrier_dyn(&mut self);
fn submit_dyn(&mut self) -> Result<()>;
fn device(&self) -> &DeviceSpec;
}
struct DynQueueWrapper<Q: HardwareQueue> {
queue: Q,
_phantom: std::marker::PhantomData<Q::Signal>,
}
impl<Q: HardwareQueue> std::fmt::Debug for DynQueueWrapper<Q> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("DynQueueWrapper").field("queue", &self.queue).finish()
}
}
impl<Q: HardwareQueue + 'static> DynQueueInner for DynQueueWrapper<Q>
where
Q::Signal: 'static,
{
fn wait_dyn(&mut self, _signal: &dyn TimelineSignal, _value: u64) {
}
fn signal_dyn(&mut self, _signal: &dyn TimelineSignal, _value: u64) {
}
fn exec_dyn(&mut self, program: &dyn Program, buffers: &[&Buffer], params: &ExecParams) {
self.queue.exec(program, buffers, params);
}
fn copy_dyn(&mut self, dst: &Buffer, src: &Buffer) {
self.queue.copy(dst, src);
}
fn memory_barrier_dyn(&mut self) {
self.queue.memory_barrier();
}
fn submit_dyn(&mut self) -> Result<()> {
self.queue.submit()
}
fn device(&self) -> &DeviceSpec {
self.queue.device()
}
}