use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::time::Duration;
use async_trait::async_trait;
use crate::error::Result;
use crate::message::{MessageEnvelope, RingMessage};
use crate::telemetry::KernelMetrics;
use crate::types::KernelMode;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct KernelId(pub String);
impl KernelId {
pub fn new(id: impl Into<String>) -> Self {
Self(id.into())
}
pub fn as_str(&self) -> &str {
&self.0
}
}
impl std::fmt::Display for KernelId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
impl From<&str> for KernelId {
fn from(s: &str) -> Self {
Self(s.to_string())
}
}
impl From<String> for KernelId {
fn from(s: String) -> Self {
Self(s)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum KernelState {
Created,
Launched,
Active,
Deactivated,
Terminating,
Terminated,
}
impl KernelState {
pub fn can_activate(&self) -> bool {
matches!(self, Self::Launched | Self::Deactivated)
}
pub fn can_deactivate(&self) -> bool {
matches!(self, Self::Active)
}
pub fn can_terminate(&self) -> bool {
matches!(self, Self::Active | Self::Deactivated | Self::Launched)
}
pub fn is_running(&self) -> bool {
matches!(self, Self::Active)
}
pub fn is_finished(&self) -> bool {
matches!(self, Self::Terminated)
}
}
#[derive(Debug, Clone)]
pub struct KernelStatus {
pub id: KernelId,
pub state: KernelState,
pub mode: KernelMode,
pub input_queue_depth: usize,
pub output_queue_depth: usize,
pub messages_processed: u64,
pub uptime: Duration,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
pub enum Backend {
Cpu,
Cuda,
Metal,
Wgpu,
#[default]
Auto,
}
impl Backend {
pub fn name(&self) -> &'static str {
match self {
Backend::Cpu => "CPU",
Backend::Cuda => "CUDA",
Backend::Metal => "Metal",
Backend::Wgpu => "WebGPU",
Backend::Auto => "Auto",
}
}
}
impl std::fmt::Display for Backend {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.name())
}
}
#[derive(Debug, Clone)]
pub struct LaunchOptions {
pub mode: KernelMode,
pub grid_size: u32,
pub block_size: u32,
pub input_queue_capacity: usize,
pub output_queue_capacity: usize,
pub shared_memory_size: usize,
pub auto_activate: bool,
pub cooperative: bool,
pub enable_k2k: bool,
}
impl Default for LaunchOptions {
fn default() -> Self {
Self {
mode: KernelMode::Persistent,
grid_size: 1,
block_size: 256,
input_queue_capacity: 1024,
output_queue_capacity: 1024,
shared_memory_size: 0,
auto_activate: true,
cooperative: false,
enable_k2k: false,
}
}
}
impl LaunchOptions {
pub fn single_block(block_size: u32) -> Self {
Self {
block_size,
..Default::default()
}
}
pub fn multi_block(grid_size: u32, block_size: u32) -> Self {
Self {
grid_size,
block_size,
..Default::default()
}
}
pub fn with_mode(mut self, mode: KernelMode) -> Self {
self.mode = mode;
self
}
pub fn with_queue_capacity(mut self, capacity: usize) -> Self {
self.input_queue_capacity = capacity;
self.output_queue_capacity = capacity;
self
}
pub fn with_shared_memory(mut self, size: usize) -> Self {
self.shared_memory_size = size;
self
}
pub fn without_auto_activate(mut self) -> Self {
self.auto_activate = false;
self
}
pub fn with_grid_size(mut self, grid_size: u32) -> Self {
self.grid_size = grid_size;
self
}
pub fn with_block_size(mut self, block_size: u32) -> Self {
self.block_size = block_size;
self
}
pub fn with_cooperative(mut self, cooperative: bool) -> Self {
self.cooperative = cooperative;
self
}
pub fn with_k2k(mut self, enable: bool) -> Self {
self.enable_k2k = enable;
self
}
pub fn with_priority(self, _priority: u8) -> Self {
self
}
pub fn with_input_queue_capacity(mut self, capacity: usize) -> Self {
self.input_queue_capacity = capacity;
self
}
pub fn with_output_queue_capacity(mut self, capacity: usize) -> Self {
self.output_queue_capacity = capacity;
self
}
}
pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
#[async_trait]
pub trait RingKernelRuntime: Send + Sync {
fn backend(&self) -> Backend;
fn is_backend_available(&self, backend: Backend) -> bool;
async fn launch(&self, kernel_id: &str, options: LaunchOptions) -> Result<KernelHandle>;
fn get_kernel(&self, kernel_id: &KernelId) -> Option<KernelHandle>;
fn list_kernels(&self) -> Vec<KernelId>;
fn metrics(&self) -> RuntimeMetrics;
async fn shutdown(&self) -> Result<()>;
}
#[derive(Clone)]
pub struct KernelHandle {
id: KernelId,
inner: Arc<dyn KernelHandleInner>,
}
impl KernelHandle {
pub fn new(id: KernelId, inner: Arc<dyn KernelHandleInner>) -> Self {
Self { id, inner }
}
pub fn id(&self) -> &KernelId {
&self.id
}
pub async fn activate(&self) -> Result<()> {
self.inner.activate().await
}
pub async fn deactivate(&self) -> Result<()> {
self.inner.deactivate().await
}
pub async fn terminate(&self) -> Result<()> {
self.inner.terminate().await
}
pub async fn send<M: RingMessage>(&self, message: M) -> Result<()> {
let envelope = MessageEnvelope::new(
&message,
0, self.inner.kernel_id_num(),
self.inner.current_timestamp(),
);
self.inner.send_envelope(envelope).await
}
pub async fn send_envelope(&self, envelope: MessageEnvelope) -> Result<()> {
self.inner.send_envelope(envelope).await
}
pub async fn receive(&self) -> Result<MessageEnvelope> {
self.inner.receive().await
}
pub async fn receive_timeout(&self, timeout: Duration) -> Result<MessageEnvelope> {
self.inner.receive_timeout(timeout).await
}
pub fn try_receive(&self) -> Result<MessageEnvelope> {
self.inner.try_receive()
}
pub async fn call<M: RingMessage>(
&self,
message: M,
timeout: Duration,
) -> Result<MessageEnvelope> {
let correlation = crate::message::CorrelationId::generate();
let mut envelope = MessageEnvelope::new(
&message,
0,
self.inner.kernel_id_num(),
self.inner.current_timestamp(),
);
envelope.header.correlation_id = correlation;
self.inner.send_envelope(envelope).await?;
self.inner.receive_correlated(correlation, timeout).await
}
pub fn status(&self) -> KernelStatus {
self.inner.status()
}
pub fn metrics(&self) -> KernelMetrics {
self.inner.metrics()
}
pub async fn wait(&self) -> Result<()> {
self.inner.wait().await
}
pub fn state(&self) -> KernelState {
self.status().state
}
pub async fn suspend(&self) -> Result<()> {
self.deactivate().await
}
pub async fn resume(&self) -> Result<()> {
self.activate().await
}
pub fn is_active(&self) -> bool {
self.state() == KernelState::Active
}
pub fn is_terminated(&self) -> bool {
self.state() == KernelState::Terminated
}
}
impl std::fmt::Debug for KernelHandle {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("KernelHandle")
.field("id", &self.id)
.finish()
}
}
#[async_trait]
pub trait KernelHandleInner: Send + Sync {
fn kernel_id_num(&self) -> u64;
fn current_timestamp(&self) -> crate::hlc::HlcTimestamp;
async fn activate(&self) -> Result<()>;
async fn deactivate(&self) -> Result<()>;
async fn terminate(&self) -> Result<()>;
async fn send_envelope(&self, envelope: MessageEnvelope) -> Result<()>;
async fn receive(&self) -> Result<MessageEnvelope>;
async fn receive_timeout(&self, timeout: Duration) -> Result<MessageEnvelope>;
fn try_receive(&self) -> Result<MessageEnvelope>;
async fn receive_correlated(
&self,
correlation: crate::message::CorrelationId,
timeout: Duration,
) -> Result<MessageEnvelope>;
fn status(&self) -> KernelStatus;
fn metrics(&self) -> KernelMetrics;
async fn wait(&self) -> Result<()>;
}
#[derive(Debug, Clone, Default)]
pub struct RuntimeMetrics {
pub active_kernels: usize,
pub total_launched: u64,
pub messages_sent: u64,
pub messages_received: u64,
pub gpu_memory_used: u64,
pub host_memory_used: u64,
}
#[derive(Debug, Clone)]
pub struct RuntimeBuilder {
pub backend: Backend,
pub device_index: usize,
pub debug: bool,
pub profiling: bool,
}
impl Default for RuntimeBuilder {
fn default() -> Self {
Self {
backend: Backend::Auto,
device_index: 0,
debug: false,
profiling: false,
}
}
}
impl RuntimeBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn backend(mut self, backend: Backend) -> Self {
self.backend = backend;
self
}
pub fn device(mut self, index: usize) -> Self {
self.device_index = index;
self
}
pub fn debug(mut self, enable: bool) -> Self {
self.debug = enable;
self
}
pub fn profiling(mut self, enable: bool) -> Self {
self.profiling = enable;
self
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_kernel_state_transitions() {
assert!(KernelState::Launched.can_activate());
assert!(KernelState::Deactivated.can_activate());
assert!(!KernelState::Active.can_activate());
assert!(!KernelState::Terminated.can_activate());
assert!(KernelState::Active.can_deactivate());
assert!(!KernelState::Launched.can_deactivate());
assert!(KernelState::Active.can_terminate());
assert!(KernelState::Deactivated.can_terminate());
assert!(!KernelState::Terminated.can_terminate());
}
#[test]
fn test_launch_options_builder() {
let opts = LaunchOptions::multi_block(4, 128)
.with_mode(KernelMode::EventDriven)
.with_queue_capacity(2048)
.with_shared_memory(4096)
.without_auto_activate();
assert_eq!(opts.grid_size, 4);
assert_eq!(opts.block_size, 128);
assert_eq!(opts.mode, KernelMode::EventDriven);
assert_eq!(opts.input_queue_capacity, 2048);
assert_eq!(opts.shared_memory_size, 4096);
assert!(!opts.auto_activate);
}
#[test]
fn test_kernel_id() {
let id1 = KernelId::new("test_kernel");
let id2: KernelId = "test_kernel".into();
assert_eq!(id1, id2);
assert_eq!(id1.as_str(), "test_kernel");
}
#[test]
fn test_backend_name() {
assert_eq!(Backend::Cpu.name(), "CPU");
assert_eq!(Backend::Cuda.name(), "CUDA");
assert_eq!(Backend::Metal.name(), "Metal");
assert_eq!(Backend::Wgpu.name(), "WebGPU");
}
}