use crate::EcosystemError;
use rustkernel_core::registry::KernelRegistry;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use std::time::Instant;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GrpcConfig {
pub address: String,
pub reflection: bool,
pub health_service: bool,
pub max_message_size: usize,
pub connect_timeout_ms: u64,
pub request_timeout_ms: u64,
}
impl Default for GrpcConfig {
fn default() -> Self {
Self {
address: "[::1]:50051".to_string(),
reflection: true,
health_service: true,
max_message_size: 4 * 1024 * 1024, connect_timeout_ms: 5000,
request_timeout_ms: 30000,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GrpcKernelRequest {
pub kernel_id: String,
pub input_json: String,
pub trace_id: Option<String>,
pub tenant_id: Option<String>,
pub priority: Option<i32>,
pub timeout_ms: Option<u64>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GrpcKernelResponse {
pub request_id: String,
pub kernel_id: String,
pub output_json: String,
pub duration_us: u64,
pub backend: String,
pub gpu_memory_bytes: Option<u64>,
pub trace_id: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GetKernelRequest {
pub kernel_id: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KernelInfo {
pub id: String,
pub domain: String,
pub mode: String,
pub description: String,
pub expected_throughput: u64,
pub target_latency_us: f64,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ListKernelsRequest {
pub domain: Option<String>,
pub mode: Option<String>,
pub page_size: Option<i32>,
pub page_token: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ListKernelsResponse {
pub kernels: Vec<KernelInfo>,
pub next_page_token: Option<String>,
pub total_count: i32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GrpcError {
pub code: i32,
pub message: String,
pub details: Option<String>,
}
impl From<EcosystemError> for GrpcError {
fn from(err: EcosystemError) -> Self {
let (code, message) = match &err {
EcosystemError::KernelNotFound(_) => (5, err.to_string()), EcosystemError::InvalidRequest(_) => (3, err.to_string()), EcosystemError::AuthenticationRequired => (16, err.to_string()), EcosystemError::PermissionDenied(_) => (7, err.to_string()), EcosystemError::RateLimitExceeded => (8, err.to_string()), EcosystemError::ServiceUnavailable(_) => (14, err.to_string()), _ => (13, err.to_string()), };
Self {
code,
message,
details: None,
}
}
}
pub struct KernelGrpcServer {
registry: Arc<KernelRegistry>,
config: GrpcConfig,
}
impl KernelGrpcServer {
pub fn new(registry: Arc<KernelRegistry>) -> Self {
Self {
registry,
config: GrpcConfig::default(),
}
}
pub fn with_config(mut self, config: GrpcConfig) -> Self {
self.config = config;
self
}
pub async fn execute_kernel(
&self,
request: GrpcKernelRequest,
) -> Result<GrpcKernelResponse, GrpcError> {
let start = Instant::now();
let request_id = request
.trace_id
.as_deref()
.map(|s| s.to_string())
.unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
if let Some(entry) = self.registry.get_batch(&request.kernel_id) {
let kernel = entry.create();
let input_bytes = request.input_json.as_bytes();
if serde_json::from_slice::<serde_json::Value>(input_bytes).is_err() {
return Err(GrpcError::from(EcosystemError::InvalidRequest(
"Input must be valid JSON".to_string(),
)));
}
let timeout_ms = request.timeout_ms.unwrap_or(self.config.request_timeout_ms);
let timeout = std::time::Duration::from_millis(timeout_ms);
let result = tokio::time::timeout(timeout, kernel.execute_dyn(input_bytes)).await;
match result {
Ok(Ok(output_bytes)) => {
let duration_us = start.elapsed().as_micros() as u64;
let output_json =
String::from_utf8(output_bytes).unwrap_or_else(|_| "{}".to_string());
Ok(GrpcKernelResponse {
request_id,
kernel_id: request.kernel_id,
output_json,
duration_us,
backend: entry.metadata.mode.as_str().to_uppercase(),
gpu_memory_bytes: None,
trace_id: request.trace_id,
})
}
Ok(Err(e)) => Err(GrpcError::from(EcosystemError::ExecutionFailed(
e.to_string(),
))),
Err(_) => Err(GrpcError {
code: 4, message: format!("Kernel execution timed out after {}ms", timeout_ms),
details: None,
}),
}
} else if self.registry.get(&request.kernel_id).is_some() {
Err(GrpcError::from(EcosystemError::InvalidRequest(format!(
"Kernel '{}' is a Ring kernel. Use bidirectional streaming RPC for Ring kernel dispatch.",
request.kernel_id
))))
} else {
Err(GrpcError::from(EcosystemError::KernelNotFound(
request.kernel_id,
)))
}
}
pub async fn get_kernel(&self, request: GetKernelRequest) -> Result<KernelInfo, GrpcError> {
let kernel_meta = self.registry.get(&request.kernel_id).ok_or_else(|| {
GrpcError::from(EcosystemError::KernelNotFound(request.kernel_id.clone()))
})?;
Ok(KernelInfo {
id: kernel_meta.id.clone(),
domain: format!("{:?}", kernel_meta.domain),
mode: format!("{:?}", kernel_meta.mode),
description: kernel_meta.description.clone(),
expected_throughput: kernel_meta.expected_throughput,
target_latency_us: kernel_meta.target_latency_us,
})
}
pub async fn list_kernels(
&self,
request: ListKernelsRequest,
) -> Result<ListKernelsResponse, GrpcError> {
let page_size = request.page_size.unwrap_or(100).max(1) as usize;
let all_metadata = self.registry.all_metadata();
let filtered: Vec<_> = all_metadata
.iter()
.filter(|k| {
if let Some(ref domain) = request.domain {
format!("{:?}", k.domain).to_lowercase() == domain.to_lowercase()
} else {
true
}
})
.filter(|k| {
if let Some(ref mode) = request.mode {
format!("{:?}", k.mode).to_lowercase() == mode.to_lowercase()
} else {
true
}
})
.collect();
let total_count = filtered.len() as i32;
let start_idx = if let Some(ref token) = request.page_token {
filtered
.iter()
.position(|k| k.id == *token)
.map(|pos| pos + 1)
.unwrap_or(0)
} else {
0
};
let page: Vec<KernelInfo> = filtered
.iter()
.skip(start_idx)
.take(page_size)
.map(|k| KernelInfo {
id: k.id.clone(),
domain: format!("{:?}", k.domain),
mode: format!("{:?}", k.mode),
description: k.description.clone(),
expected_throughput: k.expected_throughput,
target_latency_us: k.target_latency_us,
})
.collect();
let next_page_token = if start_idx + page_size < filtered.len() {
page.last().map(|k| k.id.clone())
} else {
None
};
Ok(ListKernelsResponse {
total_count,
kernels: page,
next_page_token,
})
}
pub fn config(&self) -> &GrpcConfig {
&self.config
}
}
impl Clone for KernelGrpcServer {
fn clone(&self) -> Self {
Self {
registry: self.registry.clone(),
config: self.config.clone(),
}
}
}
pub struct HealthService {
registry: Arc<KernelRegistry>,
}
impl HealthService {
pub fn new(registry: Arc<KernelRegistry>) -> Self {
Self { registry }
}
pub fn check(&self) -> HealthStatus {
if self.registry.stats().total > 0 {
HealthStatus::Serving
} else {
HealthStatus::NotServing
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum HealthStatus {
Unknown,
Serving,
NotServing,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_grpc_config() {
let config = GrpcConfig::default();
assert_eq!(config.address, "[::1]:50051");
assert!(config.reflection);
}
#[tokio::test]
async fn test_kernel_grpc_server() {
let registry = Arc::new(KernelRegistry::new());
let server = KernelGrpcServer::new(registry);
let request = GrpcKernelRequest {
kernel_id: "nonexistent".to_string(),
input_json: "{}".to_string(),
trace_id: None,
tenant_id: None,
priority: None,
timeout_ms: None,
};
let result = server.execute_kernel(request).await;
assert!(result.is_err());
assert_eq!(result.unwrap_err().code, 5); }
#[tokio::test]
async fn test_list_kernels() {
let registry = Arc::new(KernelRegistry::new());
let server = KernelGrpcServer::new(registry);
let request = ListKernelsRequest::default();
let response = server.list_kernels(request).await.unwrap();
assert_eq!(response.total_count, 0);
}
#[test]
fn test_health_service() {
let registry = Arc::new(KernelRegistry::new());
let health = HealthService::new(registry);
assert_eq!(health.check(), HealthStatus::NotServing);
}
}