use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
use async_trait::async_trait;
use parking_lot::RwLock;
use tokio::sync::Notify;
use ringkernel_core::error::{Result, RingKernelError};
use ringkernel_core::hlc::{HlcClock, HlcTimestamp};
use ringkernel_core::message::{CorrelationId, MessageEnvelope};
use ringkernel_core::runtime::{
KernelHandleInner, KernelId, KernelState, KernelStatus, LaunchOptions,
};
use ringkernel_core::telemetry::KernelMetrics;
use ringkernel_core::types::KernelMode;
use crate::adapter::WgpuAdapter;
use crate::memory::{WgpuControlBlock, WgpuMessageQueue};
use crate::shader::{create_bind_group, ComputePipeline};
pub struct WgpuKernel {
id: KernelId,
id_num: u64,
state: RwLock<KernelState>,
options: LaunchOptions,
adapter: Arc<WgpuAdapter>,
control_block: RwLock<WgpuControlBlock>,
input_queue: WgpuMessageQueue,
#[allow(dead_code)]
output_queue: WgpuMessageQueue,
pipeline: Option<ComputePipeline>,
bind_group: Option<wgpu::BindGroup>,
clock: HlcClock,
metrics: RwLock<KernelMetrics>,
message_counter: AtomicU64,
created_at: Instant,
terminate_notify: Notify,
}
impl WgpuKernel {
pub fn new(
id: &str,
id_num: u64,
adapter: Arc<WgpuAdapter>,
options: LaunchOptions,
) -> Result<Self> {
let input_capacity = options.input_queue_capacity;
let output_capacity = options.output_queue_capacity;
let control_block = WgpuControlBlock::new(&adapter);
let input_queue = WgpuMessageQueue::new(&adapter, input_capacity, 4096);
let output_queue = WgpuMessageQueue::new(&adapter, output_capacity, 4096);
Ok(Self {
id: KernelId::new(id),
id_num,
state: RwLock::new(KernelState::Created),
options,
adapter,
control_block: RwLock::new(control_block),
input_queue,
output_queue,
pipeline: None,
bind_group: None,
clock: HlcClock::new(id_num),
metrics: RwLock::new(KernelMetrics::default()),
message_counter: AtomicU64::new(0),
created_at: Instant::now(),
terminate_notify: Notify::new(),
})
}
#[allow(dead_code)]
pub fn kernel_id(&self) -> &KernelId {
&self.id
}
pub fn load_shader(&mut self, wgsl_source: &str) -> Result<()> {
let workgroup_size = self.options.block_size;
let pipeline =
ComputePipeline::new(&self.adapter, wgsl_source, "main", (workgroup_size, 1, 1))?;
let bind_group = create_bind_group(
self.adapter.device(),
pipeline.bind_group_layout(),
self.control_block.read().as_binding(),
self.input_queue.headers_binding(),
self.input_queue.headers_binding(), );
self.pipeline = Some(pipeline);
self.bind_group = Some(bind_group);
*self.state.write() = KernelState::Launched;
Ok(())
}
#[allow(dead_code)]
pub fn dispatch(&self, workgroups: u32) -> Result<()> {
let pipeline = self
.pipeline
.as_ref()
.ok_or_else(|| RingKernelError::LaunchFailed("Shader not loaded".to_string()))?;
let bind_group = self
.bind_group
.as_ref()
.ok_or_else(|| RingKernelError::LaunchFailed("Bind group not created".to_string()))?;
let mut encoder =
self.adapter
.device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("RingKernel Dispatch"),
});
{
let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("RingKernel Compute Pass"),
timestamp_writes: None,
});
compute_pass.set_pipeline(pipeline.pipeline());
compute_pass.set_bind_group(0, bind_group, &[]);
compute_pass.dispatch_workgroups(workgroups, 1, 1);
}
self.adapter.queue().submit(Some(encoder.finish()));
Ok(())
}
#[allow(dead_code)]
pub fn gpu_wait(&self) {
self.adapter.poll(wgpu::Maintain::Wait);
}
}
#[async_trait]
impl KernelHandleInner for WgpuKernel {
fn kernel_id_num(&self) -> u64 {
self.id_num
}
fn current_timestamp(&self) -> HlcTimestamp {
self.clock.tick()
}
fn status(&self) -> KernelStatus {
let state = *self.state.read();
let cb = self.control_block.read().read().unwrap_or_default();
KernelStatus {
id: self.id.clone(),
state,
mode: KernelMode::EventDriven, input_queue_depth: cb.input_queue_size() as usize,
output_queue_depth: cb.output_queue_size() as usize,
messages_processed: self.message_counter.load(Ordering::Relaxed),
uptime: self.created_at.elapsed(),
}
}
fn metrics(&self) -> KernelMetrics {
self.metrics.read().clone()
}
async fn activate(&self) -> Result<()> {
let current_state = *self.state.read();
if current_state != KernelState::Launched && current_state != KernelState::Deactivated {
return Err(RingKernelError::InvalidStateTransition {
from: format!("{:?}", current_state),
to: "Active".to_string(),
});
}
{
let cb_lock = self.control_block.write();
let mut cb = cb_lock.read()?;
cb.is_active = 1;
cb_lock.write(&cb)?;
}
*self.state.write() = KernelState::Active;
tracing::info!(kernel_id = %self.id, "WebGPU kernel activated");
Ok(())
}
async fn deactivate(&self) -> Result<()> {
let current_state = *self.state.read();
if current_state != KernelState::Active {
return Err(RingKernelError::InvalidStateTransition {
from: format!("{:?}", current_state),
to: "Deactivated".to_string(),
});
}
{
let cb_lock = self.control_block.write();
let mut cb = cb_lock.read()?;
cb.is_active = 0;
cb_lock.write(&cb)?;
}
*self.state.write() = KernelState::Deactivated;
tracing::info!(kernel_id = %self.id, "WebGPU kernel deactivated");
Ok(())
}
async fn terminate(&self) -> Result<()> {
{
let cb_lock = self.control_block.write();
let mut cb = cb_lock.read()?;
cb.should_terminate = 1;
cb_lock.write(&cb)?;
}
*self.state.write() = KernelState::Terminating;
self.adapter.poll(wgpu::Maintain::Wait);
{
let cb_lock = self.control_block.write();
let mut cb = cb_lock.read()?;
cb.has_terminated = 1;
cb_lock.write(&cb)?;
}
self.terminate_notify.notify_waiters();
tracing::info!(kernel_id = %self.id, "WebGPU kernel terminated");
Ok(())
}
async fn send_envelope(&self, envelope: MessageEnvelope) -> Result<()> {
let state = *self.state.read();
if state != KernelState::Active {
return Err(RingKernelError::KernelNotActive(self.id.to_string()));
}
self.input_queue.enqueue(&envelope)?;
self.message_counter.fetch_add(1, Ordering::Relaxed);
if self.options.mode == KernelMode::EventDriven {
if let (Some(_pipeline), Some(_bind_group)) = (&self.pipeline, &self.bind_group) {
self.dispatch(1)?;
}
}
Ok(())
}
async fn receive(&self) -> Result<MessageEnvelope> {
loop {
if let Some(envelope) = self.output_queue.try_dequeue() {
return Ok(envelope);
}
if *self.state.read() == KernelState::Terminated {
return Err(RingKernelError::QueueEmpty);
}
self.adapter.poll(wgpu::Maintain::Poll);
tokio::task::yield_now().await;
}
}
async fn receive_timeout(&self, timeout: Duration) -> Result<MessageEnvelope> {
match tokio::time::timeout(timeout, self.receive()).await {
Ok(result) => result,
Err(_) => Err(RingKernelError::Timeout(timeout)),
}
}
fn try_receive(&self) -> Result<MessageEnvelope> {
self.output_queue
.try_dequeue()
.ok_or(RingKernelError::QueueEmpty)
}
async fn receive_correlated(
&self,
correlation: CorrelationId,
timeout: Duration,
) -> Result<MessageEnvelope> {
let start = Instant::now();
loop {
match self.try_receive() {
Ok(envelope) => {
if envelope.header.correlation_id == correlation {
return Ok(envelope);
}
}
Err(RingKernelError::QueueEmpty) => {
if start.elapsed() >= timeout {
return Err(RingKernelError::Timeout(timeout));
}
self.adapter.poll(wgpu::Maintain::Poll);
tokio::task::yield_now().await;
}
Err(e) => return Err(e),
}
}
}
async fn wait(&self) -> Result<()> {
self.terminate_notify.notified().await;
Ok(())
}
}