use crate::{data::Role, GpuValue, WARP_SIZE};
use std::marker::PhantomData;
pub struct SharedRegion<T: GpuValue, const OWNER: u8> {
data: [T; WARP_SIZE as usize],
owner: Role,
_phantom: PhantomData<()>,
}
impl<T: GpuValue + Default, const OWNER: u8> SharedRegion<T, OWNER> {
pub fn new(owner: Role) -> Self {
SharedRegion {
data: [T::default(); WARP_SIZE as usize],
owner,
_phantom: PhantomData,
}
}
}
impl<T: GpuValue, const OWNER: u8> SharedRegion<T, OWNER> {
pub fn write(&mut self, index: usize, value: T) {
assert!(index < 32, "Index out of bounds");
self.data[index] = value;
}
pub fn read(&self, index: usize) -> T {
assert!(index < 32, "Index out of bounds");
self.data[index]
}
pub fn grant_read(&self) -> SharedView<'_, T, OWNER> {
SharedView {
region: self,
_phantom: PhantomData,
}
}
pub fn owner(&self) -> Role {
self.owner
}
}
pub struct SharedView<'a, T: GpuValue, const OWNER: u8> {
region: &'a SharedRegion<T, OWNER>,
_phantom: PhantomData<()>,
}
impl<'a, T: GpuValue, const OWNER: u8> SharedView<'a, T, OWNER> {
pub fn read(&self, index: usize) -> T {
self.region.read(index)
}
}
pub struct WorkQueue<T: GpuValue, const PRODUCER: u8, const CONSUMER: u8> {
tasks: SharedRegion<T, PRODUCER>,
head: usize,
tail: usize,
_phantom: PhantomData<()>,
}
impl<T: GpuValue + Default, const PRODUCER: u8, const CONSUMER: u8>
WorkQueue<T, PRODUCER, CONSUMER>
{
pub fn new(producer_role: Role, _consumer_role: Role) -> Self {
WorkQueue {
tasks: SharedRegion::new(producer_role),
head: 0,
tail: 0,
_phantom: PhantomData,
}
}
pub fn push(&mut self, task: T) -> Result<(), QueueFull> {
let next = (self.head + 1) % WARP_SIZE as usize;
if next == self.tail {
return Err(QueueFull);
}
self.tasks.write(self.head, task);
self.head = next;
Ok(())
}
pub fn pop(&mut self) -> Option<T> {
if self.tail == self.head {
return None;
}
let task = self.tasks.read(self.tail);
self.tail = (self.tail + 1) % WARP_SIZE as usize;
Some(task)
}
pub fn is_empty(&self) -> bool {
self.tail == self.head
}
pub fn is_full(&self) -> bool {
(self.head + 1) % WARP_SIZE as usize == self.tail
}
}
#[derive(Debug, Clone, Copy)]
pub struct QueueFull;
#[cfg(test)]
mod tests {
use super::*;
use crate::data::Role;
const COORDINATOR: u8 = 0;
const WORKER: u8 = 1;
#[test]
fn test_shared_region_ownership() {
let coordinator = Role::lanes(0, 4, "coordinator");
let mut region: SharedRegion<i32, COORDINATOR> = SharedRegion::new(coordinator);
region.write(0, 42);
assert_eq!(region.read(0), 42);
let view = region.grant_read();
assert_eq!(view.read(0), 42);
}
#[test]
fn test_work_queue() {
let coordinator = Role::lanes(0, 4, "coordinator");
let worker = Role::lanes(4, 32, "worker");
let mut queue: WorkQueue<i32, COORDINATOR, WORKER> = WorkQueue::new(coordinator, worker);
assert!(queue.is_empty());
queue.push(1).unwrap();
queue.push(2).unwrap();
queue.push(3).unwrap();
assert!(!queue.is_empty());
assert_eq!(queue.pop(), Some(1));
assert_eq!(queue.pop(), Some(2));
assert_eq!(queue.pop(), Some(3));
assert_eq!(queue.pop(), None);
assert!(queue.is_empty());
}
}