use crate::error::{KernelError, Result};
use crate::kernel::KernelMetadata;
use async_trait::async_trait;
use ringkernel_core::{RingContext, RingMessage};
use serde::{Deserialize, Serialize};
use std::fmt::{self, Debug};
use std::marker::PhantomData;
use std::time::Duration;
use uuid::Uuid;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
pub enum HealthStatus {
#[default]
Healthy,
Degraded,
Unhealthy,
Unknown,
}
impl std::fmt::Display for HealthStatus {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Healthy => write!(f, "healthy"),
Self::Degraded => write!(f, "degraded"),
Self::Unhealthy => write!(f, "unhealthy"),
Self::Unknown => write!(f, "unknown"),
}
}
}
#[derive(Debug, Clone, Default)]
pub struct ExecutionContext {
pub request_id: Option<Uuid>,
pub trace_id: Option<String>,
pub span_id: Option<String>,
pub user_id: Option<String>,
pub tenant_id: Option<String>,
pub timeout: Option<Duration>,
pub metadata: std::collections::HashMap<String, String>,
}
impl ExecutionContext {
pub fn new() -> Self {
Self {
request_id: Some(Uuid::new_v4()),
..Default::default()
}
}
pub fn with_request_id(mut self, id: Uuid) -> Self {
self.request_id = Some(id);
self
}
pub fn with_trace(mut self, trace_id: impl Into<String>, span_id: impl Into<String>) -> Self {
self.trace_id = Some(trace_id.into());
self.span_id = Some(span_id.into());
self
}
pub fn with_user(mut self, user_id: impl Into<String>) -> Self {
self.user_id = Some(user_id.into());
self
}
pub fn with_tenant(mut self, tenant_id: impl Into<String>) -> Self {
self.tenant_id = Some(tenant_id.into());
self
}
pub fn with_timeout(mut self, timeout: Duration) -> Self {
self.timeout = Some(timeout);
self
}
pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.metadata.insert(key.into(), value.into());
self
}
}
pub struct SecureRingContext<'ctx, 'ring> {
pub ring_ctx: &'ctx mut RingContext<'ring>,
pub exec_ctx: &'ctx ExecutionContext,
}
impl<'ctx, 'ring> SecureRingContext<'ctx, 'ring> {
pub fn new(ring_ctx: &'ctx mut RingContext<'ring>, exec_ctx: &'ctx ExecutionContext) -> Self {
Self { ring_ctx, exec_ctx }
}
pub fn user_id(&self) -> Option<&str> {
self.exec_ctx.user_id.as_deref()
}
pub fn tenant_id(&self) -> Option<&str> {
self.exec_ctx.tenant_id.as_deref()
}
pub fn is_authenticated(&self) -> bool {
self.exec_ctx.user_id.is_some()
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct KernelConfig {
pub max_queue_depth: Option<usize>,
pub timeout: Option<Duration>,
pub tracing_enabled: bool,
pub metrics_enabled: bool,
pub custom: std::collections::HashMap<String, serde_json::Value>,
}
impl KernelConfig {
pub fn new() -> Self {
Self::default()
}
pub fn with_queue_depth(mut self, depth: usize) -> Self {
self.max_queue_depth = Some(depth);
self
}
pub fn with_timeout(mut self, timeout: Duration) -> Self {
self.timeout = Some(timeout);
self
}
pub fn with_tracing(mut self, enabled: bool) -> Self {
self.tracing_enabled = enabled;
self
}
pub fn with_metrics(mut self, enabled: bool) -> Self {
self.metrics_enabled = enabled;
self
}
pub fn with_custom(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
self.custom.insert(key.into(), value);
self
}
}
pub trait GpuKernel: Send + Sync + Debug {
fn metadata(&self) -> &KernelMetadata;
fn validate(&self) -> Result<()> {
Ok(())
}
fn id(&self) -> &str {
&self.metadata().id
}
fn requires_gpu_native(&self) -> bool {
self.metadata().requires_gpu_native
}
fn health_check(&self) -> HealthStatus {
HealthStatus::Healthy
}
fn shutdown(&self) -> Result<()> {
Ok(())
}
fn refresh_config(&mut self, _config: &KernelConfig) -> Result<()> {
Ok(())
}
}
#[async_trait]
pub trait BatchKernel<I, O>: GpuKernel
where
I: Send + Sync,
O: Send + Sync,
{
async fn execute(&self, input: I) -> Result<O>;
fn validate_input(&self, _input: &I) -> Result<()> {
Ok(())
}
async fn execute_with_context(&self, ctx: &ExecutionContext, input: I) -> Result<O>
where
I: 'async_trait,
{
let _ = ctx;
self.execute(input).await
}
async fn execute_with_timeout(&self, input: I, timeout: Duration) -> Result<O>
where
I: 'async_trait,
{
match tokio::time::timeout(timeout, self.execute(input)).await {
Ok(result) => result,
Err(_elapsed) => Err(crate::error::KernelError::Timeout(timeout)),
}
}
}
#[async_trait]
pub trait RingKernelHandler<M, R>: GpuKernel
where
M: RingMessage + Send + Sync,
R: RingMessage + Send + Sync,
{
async fn handle(&self, ctx: &mut RingContext, msg: M) -> Result<R>;
async fn initialize(&self, _ctx: &mut RingContext) -> Result<()> {
Ok(())
}
async fn ring_shutdown(&self, _ctx: &mut RingContext) -> Result<()> {
Ok(())
}
async fn handle_secure(&self, ctx: &mut SecureRingContext<'_, '_>, msg: M) -> Result<R>
where
M: 'async_trait,
R: 'async_trait,
{
self.handle(ctx.ring_ctx, msg).await
}
}
#[async_trait]
pub trait IterativeKernel<S, I, O>: GpuKernel
where
S: Send + Sync + 'static,
I: Send + Sync + 'static,
O: Send + Sync + 'static,
{
fn initial_state(&self, input: &I) -> S;
async fn iterate(&self, state: &mut S, input: &I) -> Result<IterationResult<O>>;
fn converged(&self, state: &S, threshold: f64) -> bool;
fn max_iterations(&self) -> usize {
100
}
fn default_threshold(&self) -> f64 {
1e-6
}
async fn run_to_convergence(&self, input: I) -> Result<O> {
self.run_to_convergence_with_threshold(input, self.default_threshold())
.await
}
async fn run_to_convergence_with_threshold(&self, input: I, threshold: f64) -> Result<O> {
let mut state = self.initial_state(&input);
let max_iter = self.max_iterations();
for _ in 0..max_iter {
let result = self.iterate(&mut state, &input).await?;
if let IterationResult::Converged(output) = result {
return Ok(output);
}
if self.converged(&state, threshold) {
if let IterationResult::Continue(output) = result {
return Ok(output);
}
}
}
match self.iterate(&mut state, &input).await? {
IterationResult::Converged(output) | IterationResult::Continue(output) => Ok(output),
}
}
}
#[derive(Debug, Clone)]
pub enum IterationResult<O> {
Converged(O),
Continue(O),
}
impl<O> IterationResult<O> {
#[must_use]
pub fn is_converged(&self) -> bool {
matches!(self, IterationResult::Converged(_))
}
#[must_use]
pub fn into_output(self) -> O {
match self {
IterationResult::Converged(o) | IterationResult::Continue(o) => o,
}
}
}
#[async_trait]
pub trait BatchKernelDyn: GpuKernel {
async fn execute_dyn(&self, input: &[u8]) -> Result<Vec<u8>>;
}
#[async_trait]
pub trait RingKernelDyn: GpuKernel {
async fn handle_dyn(&self, ctx: &mut RingContext, msg: &[u8]) -> Result<Vec<u8>>;
}
pub struct TypeErasedBatchKernel<K, I, O> {
inner: K,
_phantom: PhantomData<fn(I) -> O>,
}
impl<K: Debug, I, O> Debug for TypeErasedBatchKernel<K, I, O> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("TypeErasedBatchKernel")
.field("inner", &self.inner)
.finish()
}
}
impl<K, I, O> TypeErasedBatchKernel<K, I, O> {
pub fn new(kernel: K) -> Self {
Self {
inner: kernel,
_phantom: PhantomData,
}
}
pub fn inner(&self) -> &K {
&self.inner
}
}
impl<K, I, O> GpuKernel for TypeErasedBatchKernel<K, I, O>
where
K: GpuKernel,
I: Send + Sync + 'static,
O: Send + Sync + 'static,
{
fn metadata(&self) -> &KernelMetadata {
self.inner.metadata()
}
fn validate(&self) -> Result<()> {
self.inner.validate()
}
fn health_check(&self) -> HealthStatus {
self.inner.health_check()
}
fn shutdown(&self) -> Result<()> {
self.inner.shutdown()
}
fn refresh_config(&mut self, config: &KernelConfig) -> Result<()> {
self.inner.refresh_config(config)
}
}
#[async_trait]
impl<K, I, O> BatchKernelDyn for TypeErasedBatchKernel<K, I, O>
where
K: BatchKernel<I, O> + 'static,
I: serde::de::DeserializeOwned + Send + Sync + 'static,
O: serde::Serialize + Send + Sync + 'static,
{
async fn execute_dyn(&self, input: &[u8]) -> Result<Vec<u8>> {
let typed_input: I = serde_json::from_slice(input)
.map_err(|e| KernelError::DeserializationError(e.to_string()))?;
let output = self.inner.execute(typed_input).await?;
serde_json::to_vec(&output).map_err(|e| KernelError::SerializationError(e.to_string()))
}
}
pub struct TypeErasedRingKernel<K, M, R> {
inner: K,
_phantom: PhantomData<fn(M) -> R>,
}
impl<K: Debug, M, R> Debug for TypeErasedRingKernel<K, M, R> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("TypeErasedRingKernel")
.field("inner", &self.inner)
.finish()
}
}
impl<K, M, R> TypeErasedRingKernel<K, M, R> {
pub fn new(kernel: K) -> Self {
Self {
inner: kernel,
_phantom: PhantomData,
}
}
}
impl<K, M, R> GpuKernel for TypeErasedRingKernel<K, M, R>
where
K: GpuKernel,
M: Send + Sync + 'static,
R: Send + Sync + 'static,
{
fn metadata(&self) -> &KernelMetadata {
self.inner.metadata()
}
fn validate(&self) -> Result<()> {
self.inner.validate()
}
fn health_check(&self) -> HealthStatus {
self.inner.health_check()
}
fn shutdown(&self) -> Result<()> {
self.inner.shutdown()
}
fn refresh_config(&mut self, config: &KernelConfig) -> Result<()> {
self.inner.refresh_config(config)
}
}
#[async_trait]
impl<K, M, R> RingKernelDyn for TypeErasedRingKernel<K, M, R>
where
K: RingKernelHandler<M, R> + 'static,
M: RingMessage + serde::de::DeserializeOwned + Send + Sync + 'static,
R: RingMessage + serde::Serialize + Send + Sync + 'static,
{
async fn handle_dyn(&self, ctx: &mut RingContext, msg: &[u8]) -> Result<Vec<u8>> {
let typed_msg: M = serde_json::from_slice(msg)
.map_err(|e| KernelError::DeserializationError(e.to_string()))?;
let response = self.inner.handle(ctx, typed_msg).await?;
serde_json::to_vec(&response).map_err(|e| KernelError::SerializationError(e.to_string()))
}
}
#[async_trait]
pub trait CheckpointableKernel: GpuKernel {
type Checkpoint: Serialize + serde::de::DeserializeOwned + Send + Sync;
async fn checkpoint(&self) -> Result<Self::Checkpoint>;
async fn restore(&mut self, checkpoint: Self::Checkpoint) -> Result<()>;
fn can_checkpoint(&self) -> bool {
true
}
fn checkpoint_size_estimate(&self) -> usize {
0
}
}
pub trait DegradableKernel: GpuKernel {
fn enter_degraded_mode(&mut self) -> Result<()>;
fn exit_degraded_mode(&mut self) -> Result<()>;
fn is_degraded(&self) -> bool;
fn degradation_info(&self) -> Option<String> {
None
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_iteration_result() {
let converged: IterationResult<i32> = IterationResult::Converged(42);
assert!(converged.is_converged());
assert_eq!(converged.into_output(), 42);
let continuing: IterationResult<i32> = IterationResult::Continue(0);
assert!(!continuing.is_converged());
}
#[test]
fn test_health_status() {
assert_eq!(HealthStatus::default(), HealthStatus::Healthy);
assert_eq!(format!("{}", HealthStatus::Healthy), "healthy");
assert_eq!(format!("{}", HealthStatus::Degraded), "degraded");
}
#[test]
fn test_execution_context() {
let ctx = ExecutionContext::new()
.with_user("user123")
.with_tenant("tenant456")
.with_timeout(Duration::from_secs(30));
assert!(ctx.request_id.is_some());
assert_eq!(ctx.user_id.as_deref(), Some("user123"));
assert_eq!(ctx.tenant_id.as_deref(), Some("tenant456"));
assert_eq!(ctx.timeout, Some(Duration::from_secs(30)));
}
#[test]
fn test_kernel_config() {
let config = KernelConfig::new()
.with_queue_depth(1000)
.with_timeout(Duration::from_secs(60))
.with_tracing(true)
.with_metrics(true);
assert_eq!(config.max_queue_depth, Some(1000));
assert_eq!(config.timeout, Some(Duration::from_secs(60)));
assert!(config.tracing_enabled);
assert!(config.metrics_enabled);
}
}