use crate::{EcosystemError, KernelRequest, KernelResponse, ResponseMetadata};
use rustkernel_core::registry::KernelRegistry;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::Instant;
pub struct KernelService {
registry: Arc<KernelRegistry>,
default_timeout: std::time::Duration,
}
impl KernelService {
pub fn new(registry: Arc<KernelRegistry>) -> Self {
Self {
registry,
default_timeout: std::time::Duration::from_secs(30),
}
}
pub fn with_timeout(mut self, timeout: std::time::Duration) -> Self {
self.default_timeout = timeout;
self
}
pub async fn execute(&self, request: KernelRequest) -> Result<KernelResponse, EcosystemError> {
let start = Instant::now();
let request_id = uuid::Uuid::new_v4().to_string();
if let Some(entry) = self.registry.get_batch(&request.kernel_id) {
let kernel = entry.create();
let input_bytes = serde_json::to_vec(&request.input)
.map_err(|e| EcosystemError::InvalidRequest(format!("Invalid input: {}", e)))?;
let timeout_ms = request
.metadata
.timeout_ms
.unwrap_or(self.default_timeout.as_millis() as u64);
let timeout = std::time::Duration::from_millis(timeout_ms);
let result = tokio::time::timeout(timeout, kernel.execute_dyn(&input_bytes)).await;
match result {
Ok(Ok(output_bytes)) => {
let output: serde_json::Value =
serde_json::from_slice(&output_bytes).map_err(|e| {
EcosystemError::InternalError(format!("Output deserialization: {}", e))
})?;
let duration_us = start.elapsed().as_micros() as u64;
Ok(KernelResponse {
request_id,
kernel_id: request.kernel_id,
output,
metadata: ResponseMetadata {
duration_us,
backend: entry.metadata.mode.as_str().to_uppercase(),
gpu_memory_bytes: None,
trace_id: request.metadata.trace_id,
},
})
}
Ok(Err(e)) => Err(EcosystemError::ExecutionFailed(e.to_string())),
Err(_) => Err(EcosystemError::ServiceUnavailable(format!(
"Kernel execution timed out after {}ms",
timeout_ms
))),
}
} else if self.registry.get(&request.kernel_id).is_some() {
Err(EcosystemError::InvalidRequest(format!(
"Kernel '{}' is a Ring kernel and cannot be executed via this interface. \
Use the Ring protocol or gRPC streaming API.",
request.kernel_id
)))
} else {
Err(EcosystemError::KernelNotFound(request.kernel_id))
}
}
}
impl Clone for KernelService {
fn clone(&self) -> Self {
Self {
registry: self.registry.clone(),
default_timeout: self.default_timeout,
}
}
}
impl tower::Service<KernelRequest> for KernelService {
type Response = KernelResponse;
type Error = EcosystemError;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: KernelRequest) -> Self::Future {
let service = self.clone();
Box::pin(async move { service.execute(req).await })
}
}
pub struct KernelLayer {
registry: Arc<KernelRegistry>,
}
impl KernelLayer {
pub fn new(registry: Arc<KernelRegistry>) -> Self {
Self { registry }
}
}
impl<S> tower::Layer<S> for KernelLayer {
type Service = KernelMiddleware<S>;
fn layer(&self, inner: S) -> Self::Service {
KernelMiddleware {
inner,
registry: self.registry.clone(),
}
}
}
pub struct KernelMiddleware<S> {
inner: S,
registry: Arc<KernelRegistry>,
}
impl<S> KernelMiddleware<S> {
pub fn registry(&self) -> &Arc<KernelRegistry> {
&self.registry
}
}
impl<S: Clone> Clone for KernelMiddleware<S> {
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
registry: self.registry.clone(),
}
}
}
pub struct TimeoutLayer {
timeout: std::time::Duration,
}
impl TimeoutLayer {
pub fn new(timeout: std::time::Duration) -> Self {
Self { timeout }
}
}
impl<S> tower::Layer<S> for TimeoutLayer {
type Service = TimeoutService<S>;
fn layer(&self, inner: S) -> Self::Service {
TimeoutService {
inner,
timeout: self.timeout,
}
}
}
pub struct TimeoutService<S> {
inner: S,
timeout: std::time::Duration,
}
impl<S: Clone> Clone for TimeoutService<S> {
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
timeout: self.timeout,
}
}
}
impl<S, Request> tower::Service<Request> for TimeoutService<S>
where
S: tower::Service<Request> + Clone + Send + 'static,
S::Future: Send,
Request: Send + 'static,
{
type Response = S::Response;
type Error = TimeoutError<S::Error>;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx).map_err(TimeoutError::Inner)
}
fn call(&mut self, req: Request) -> Self::Future {
let mut inner = self.inner.clone();
let timeout = self.timeout;
Box::pin(async move {
tokio::time::timeout(timeout, inner.call(req))
.await
.map_err(|_| TimeoutError::Timeout)?
.map_err(TimeoutError::Inner)
})
}
}
#[derive(Debug, thiserror::Error)]
pub enum TimeoutError<E> {
#[error("Request timed out")]
Timeout,
#[error("Service error: {0}")]
Inner(E),
}
pub struct RateLimiterLayer {
requests_per_second: u32,
burst_size: u32,
}
impl RateLimiterLayer {
pub fn new(requests_per_second: u32, burst_size: u32) -> Self {
Self {
requests_per_second,
burst_size,
}
}
}
impl<S> tower::Layer<S> for RateLimiterLayer {
type Service = RateLimiterService<S>;
fn layer(&self, inner: S) -> Self::Service {
RateLimiterService {
inner,
requests_per_second: self.requests_per_second as f64,
burst_size: self.burst_size as f64,
state: Arc::new(tokio::sync::Mutex::new(TokenBucketState {
tokens: self.burst_size as f64,
last_refill: std::time::Instant::now(),
})),
}
}
}
struct TokenBucketState {
tokens: f64,
last_refill: std::time::Instant,
}
pub struct RateLimiterService<S> {
inner: S,
requests_per_second: f64,
burst_size: f64,
state: Arc<tokio::sync::Mutex<TokenBucketState>>,
}
impl<S: Clone> Clone for RateLimiterService<S> {
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
requests_per_second: self.requests_per_second,
burst_size: self.burst_size,
state: self.state.clone(),
}
}
}
impl<S, Request> tower::Service<Request> for RateLimiterService<S>
where
S: tower::Service<Request> + Clone + Send + 'static,
S::Future: Send,
Request: Send + 'static,
{
type Response = S::Response;
type Error = RateLimitError<S::Error>;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx).map_err(RateLimitError::Inner)
}
fn call(&mut self, req: Request) -> Self::Future {
let mut inner = self.inner.clone();
let rps = self.requests_per_second;
let burst = self.burst_size;
let state = self.state.clone();
Box::pin(async move {
let mut bucket = state.lock().await;
let elapsed = bucket.last_refill.elapsed().as_secs_f64();
bucket.tokens = (bucket.tokens + elapsed * rps).min(burst);
bucket.last_refill = std::time::Instant::now();
if bucket.tokens < 1.0 {
return Err(RateLimitError::RateLimitExceeded);
}
bucket.tokens -= 1.0;
drop(bucket);
inner.call(req).await.map_err(RateLimitError::Inner)
})
}
}
#[derive(Debug, thiserror::Error)]
pub enum RateLimitError<E> {
#[error("Rate limit exceeded")]
RateLimitExceeded,
#[error("Service error: {0}")]
Inner(E),
}
#[cfg(test)]
mod tests {
use super::*;
use crate::RequestMetadata;
#[tokio::test]
async fn test_kernel_service() {
let registry = Arc::new(KernelRegistry::new());
let service = KernelService::new(registry);
let request = KernelRequest {
kernel_id: "nonexistent".to_string(),
input: serde_json::json!({}),
metadata: RequestMetadata::default(),
};
let result = service.execute(request).await;
assert!(matches!(result, Err(EcosystemError::KernelNotFound(_))));
}
#[tokio::test]
async fn test_timeout_layer() {
use std::time::Duration;
let layer = TimeoutLayer::new(Duration::from_millis(100));
assert_eq!(layer.timeout, Duration::from_millis(100));
}
#[test]
fn test_rate_limiter_layer() {
let layer = RateLimiterLayer::new(100, 200);
assert_eq!(layer.requests_per_second, 100);
}
}