use std::sync::{Arc, Mutex as StdMutex};
use dashmap::DashMap;
use tokio::sync::{RwLock, watch};
use crate::bridge::protocol::{MAX_INLINE_IPC_SIZE, SlotRequest};
use crate::health::{Health, SetupResult};
use crate::input_validation::InputValidator;
use crate::orchestrator::{HealthcheckResult, Orchestrator};
use crate::permit::{PermitPool, PredictionSlot, UnregisteredPredictionSlot};
use crate::prediction::{
CancellationToken, Prediction, PredictionStatus, STREAM_CHANNEL_CAPACITY,
SharedPredictionStreamEvent,
};
use crate::predictor::{PredictionError, PredictionOutput, PredictionResult};
use crate::version::VersionInfo;
use crate::webhook::WebhookSender;
fn try_lock_prediction(
pred: &Arc<StdMutex<Prediction>>,
) -> Option<std::sync::MutexGuard<'_, Prediction>> {
match pred.lock() {
Ok(guard) => Some(guard),
Err(poisoned) => {
tracing::error!("Prediction mutex poisoned - failing prediction");
let mut guard = poisoned.into_inner();
if !guard.is_terminal() {
guard.set_failed("Internal error: mutex poisoned".to_string());
}
None
}
}
}
#[derive(Debug, thiserror::Error)]
pub enum CreatePredictionError {
#[error("Service not ready")]
NotReady,
#[error("At capacity (no slots available)")]
AtCapacity,
}
const MAX_STREAM_SUBSCRIBERS: usize = STREAM_CHANNEL_CAPACITY;
#[derive(Debug, thiserror::Error)]
pub enum SubscribePredictionStreamError {
#[error("Prediction not found")]
NotFound,
#[error("Too many stream subscribers")]
TooManySubscribers,
#[error("Prediction stream unavailable")]
Unavailable,
}
#[derive(Debug, Clone)]
pub struct HealthSnapshot {
pub state: Health,
pub available_slots: usize,
pub total_slots: usize,
pub setup_result: Option<SetupResult>,
pub version: VersionInfo,
}
impl HealthSnapshot {
pub fn is_ready(&self) -> bool {
self.state == Health::Ready
}
pub fn is_busy(&self) -> bool {
self.state == Health::Ready && self.available_slots == 0
}
}
struct PredictionEntry {
prediction: Arc<StdMutex<Prediction>>,
cancel_token: CancellationToken,
input: serde_json::Value,
cancel_on_stream_drop: bool,
}
pub struct PredictionHandle {
id: String,
cancel_token: CancellationToken,
}
impl PredictionHandle {
pub fn id(&self) -> &str {
&self.id
}
pub fn cancel_token(&self) -> CancellationToken {
self.cancel_token.clone()
}
pub fn sync_guard(&self, service: Arc<PredictionService>) -> SyncPredictionGuard {
SyncPredictionGuard::new(self.id.clone(), service)
}
}
pub struct PredictionStreamSubscription {
id: String,
replay: std::collections::VecDeque<SharedPredictionStreamEvent>,
skipped: u64,
receiver: tokio::sync::broadcast::Receiver<SharedPredictionStreamEvent>,
guard: PredictionStreamGuard,
}
impl PredictionStreamSubscription {
pub fn prediction_id(&self) -> &str {
&self.id
}
pub fn into_parts(
self,
) -> (
std::collections::VecDeque<SharedPredictionStreamEvent>,
u64,
tokio::sync::broadcast::Receiver<SharedPredictionStreamEvent>,
PredictionStreamGuard,
) {
(self.replay, self.skipped, self.receiver, self.guard)
}
}
pub struct PredictionStreamGuard {
id: String,
service: Arc<PredictionService>,
cancel_on_stream_drop: bool,
}
impl Drop for PredictionStreamGuard {
fn drop(&mut self) {
if !self.cancel_on_stream_drop {
return;
}
if self.service.stream_receiver_count(&self.id) == 0
&& !self.service.prediction_is_terminal(&self.id)
{
self.service.cancel(&self.id);
}
}
}
pub struct SyncPredictionGuard {
prediction_id: Option<String>,
service: Arc<PredictionService>,
}
impl SyncPredictionGuard {
pub fn new(prediction_id: String, service: Arc<PredictionService>) -> Self {
Self {
prediction_id: Some(prediction_id),
service,
}
}
pub fn disarm(&mut self) {
self.prediction_id = None;
}
}
impl Drop for SyncPredictionGuard {
fn drop(&mut self) {
if let Some(ref id) = self.prediction_id {
self.service.cancel(id);
}
}
}
pub struct OrchestratorState {
pub pool: Arc<PermitPool>,
pub orchestrator: Arc<dyn Orchestrator>,
}
impl Clone for OrchestratorState {
fn clone(&self) -> Self {
Self {
pool: Arc::clone(&self.pool),
orchestrator: Arc::clone(&self.orchestrator),
}
}
}
pub struct PredictionService {
orchestrator: RwLock<Option<OrchestratorState>>,
health: RwLock<Health>,
setup_result: RwLock<Option<SetupResult>>,
predictions: DashMap<String, PredictionEntry>,
shutdown_tx: watch::Sender<bool>,
shutdown_rx: watch::Receiver<bool>,
version: VersionInfo,
schema: RwLock<Option<serde_json::Value>>,
input_validator: RwLock<Option<InputValidator>>,
train_validator: RwLock<Option<InputValidator>>,
supports_prediction_streaming: RwLock<bool>,
}
impl PredictionService {
pub fn new_no_pool() -> Self {
let (shutdown_tx, shutdown_rx) = watch::channel(false);
Self {
orchestrator: RwLock::new(None),
health: RwLock::new(Health::Unknown),
setup_result: RwLock::new(None),
predictions: DashMap::new(),
shutdown_tx,
shutdown_rx,
version: VersionInfo::new(),
schema: RwLock::new(None),
input_validator: RwLock::new(None),
train_validator: RwLock::new(None),
supports_prediction_streaming: RwLock::new(false),
}
}
pub async fn set_orchestrator(
&self,
pool: Arc<PermitPool>,
orchestrator: Arc<dyn Orchestrator>,
) {
*self.orchestrator.write().await = Some(OrchestratorState { pool, orchestrator });
}
pub async fn has_orchestrator(&self) -> bool {
self.orchestrator.read().await.is_some()
}
pub async fn shutdown(&self) {
if let Some(ref state) = *self.orchestrator.read().await
&& let Err(e) = state.orchestrator.shutdown().await
{
tracing::warn!(error = %e, "Error during orchestrator shutdown");
}
}
pub fn with_health(mut self, health: Health) -> Self {
if health != Health::Ready {
self.health = RwLock::new(health);
}
self
}
pub fn with_version(mut self, version: VersionInfo) -> Self {
self.version = version;
self
}
pub fn version(&self) -> &VersionInfo {
&self.version
}
pub async fn supports_training(&self) -> bool {
self.train_validator.read().await.is_some()
}
pub async fn supports_prediction_streaming(&self) -> bool {
*self.supports_prediction_streaming.read().await
}
pub async fn pool(&self) -> Option<Arc<PermitPool>> {
if let Some(ref state) = *self.orchestrator.read().await {
Some(Arc::clone(&state.pool))
} else {
None
}
}
pub async fn health(&self) -> HealthSnapshot {
let state = *self.health.read().await;
let setup_result = self.setup_result.read().await.clone();
let pool = self.pool().await;
let (available_slots, total_slots) = match pool.as_ref() {
Some(p) => (p.available(), p.num_slots()),
None => (0, 0),
};
tracing::trace!(
?state,
available_slots,
total_slots,
setup_status = ?setup_result.as_ref().map(|r| r.status),
"Building health snapshot"
);
HealthSnapshot {
state,
available_slots,
total_slots,
setup_result,
version: self.version.clone(),
}
}
pub async fn set_health(&self, health: Health) {
if health == Health::Ready && self.orchestrator.read().await.is_none() {
tracing::warn!("Attempted to set READY without orchestrator, ignoring");
return;
}
let previous = *self.health.read().await;
tracing::debug!(from = ?previous, to = ?health, "Health state transition");
*self.health.write().await = health;
}
pub async fn set_setup_result(&self, result: SetupResult) {
tracing::debug!(
status = ?result.status,
started_at = %result.started_at,
completed_at = ?result.completed_at,
logs_len = result.logs.len(),
"Setting setup result"
);
*self.setup_result.write().await = Some(result);
}
pub async fn set_schema(&self, schema: serde_json::Value) {
let supports_prediction_streaming = Self::schema_supports_prediction_streaming(&schema);
*self.supports_prediction_streaming.write().await = supports_prediction_streaming;
let validator = InputValidator::from_openapi_schema(&schema);
if let Some(v) = &validator {
tracing::info!(
"Input validation enabled ({} required fields)",
v.required_count()
);
}
*self.input_validator.write().await = validator;
let train_val = InputValidator::from_openapi_schema_key(&schema, "TrainingInput");
if let Some(v) = &train_val {
tracing::info!(
"Training input validation enabled ({} required fields)",
v.required_count()
);
}
*self.train_validator.write().await = train_val;
*self.schema.write().await = Some(schema);
}
fn schema_supports_prediction_streaming(schema: &serde_json::Value) -> bool {
schema
.get("paths")
.and_then(|paths| paths.get("/predictions"))
.and_then(|path| path.get("post"))
.and_then(|operation| operation.get("x-cog-streaming"))
.and_then(serde_json::Value::as_bool)
.unwrap_or(false)
}
pub async fn schema(&self) -> Option<serde_json::Value> {
self.schema.read().await.clone()
}
pub async fn strip_and_validate_input(
&self,
input: &mut serde_json::Value,
) -> (
Vec<String>,
Result<(), Vec<crate::input_validation::ValidationError>>,
) {
let guard = self.input_validator.read().await;
if let Some(ref validator) = *guard {
Self::strip_validate_inject(validator, input)
} else {
(Vec::new(), Ok(()))
}
}
fn strip_validate_inject(
validator: &InputValidator,
input: &mut serde_json::Value,
) -> (
Vec<String>,
Result<(), Vec<crate::input_validation::ValidationError>>,
) {
let stripped = validator.strip_unknown(input);
let result = validator.validate(input);
if result.is_ok() {
validator.inject_missing_optionals(input);
}
(stripped, result)
}
pub async fn strip_and_validate_train_input(
&self,
input: &mut serde_json::Value,
) -> (
Vec<String>,
Result<(), Vec<crate::input_validation::ValidationError>>,
) {
let train_guard = self.train_validator.read().await;
if let Some(ref validator) = *train_guard {
return Self::strip_validate_inject(validator, input);
}
drop(train_guard);
let predict_guard = self.input_validator.read().await;
if let Some(ref validator) = *predict_guard {
return Self::strip_validate_inject(validator, input);
}
(Vec::new(), Ok(()))
}
pub async fn healthcheck(
&self,
) -> Result<HealthcheckResult, crate::orchestrator::OrchestratorError> {
if let Some(ref state) = *self.orchestrator.read().await {
tracing::trace!("Dispatching healthcheck to orchestrator");
let result = state.orchestrator.healthcheck().await;
tracing::trace!(
healthy = result.as_ref().map(|r| r.is_healthy()).unwrap_or(false),
error = ?result.as_ref().ok().and_then(|r| r.error.as_ref()),
"Healthcheck result from orchestrator"
);
result
} else {
tracing::debug!("No orchestrator configured, returning default healthy");
Ok(HealthcheckResult::healthy())
}
}
pub async fn submit_prediction(
&self,
id: String,
input: serde_json::Value,
webhook: Option<WebhookSender>,
cancel_on_stream_drop: bool,
) -> Result<(PredictionHandle, UnregisteredPredictionSlot), CreatePredictionError> {
let health = *self.health.read().await;
if health != Health::Ready {
return Err(CreatePredictionError::NotReady);
}
let pool = self.pool().await;
let pool = pool.as_ref().ok_or(CreatePredictionError::NotReady)?;
let permit = pool
.try_acquire()
.ok_or(CreatePredictionError::AtCapacity)?;
let prediction = Prediction::new(id.clone(), webhook);
let cancel_token = prediction.cancel_token();
let (idle_tx, idle_rx) = tokio::sync::oneshot::channel();
let slot = PredictionSlot::new(prediction, permit, idle_rx);
let prediction_arc = slot.prediction();
self.predictions.insert(
id.clone(),
PredictionEntry {
prediction: prediction_arc,
cancel_token: cancel_token.clone(),
input,
cancel_on_stream_drop,
},
);
let handle = PredictionHandle { id, cancel_token };
Ok((handle, UnregisteredPredictionSlot::new(slot, idle_tx)))
}
pub fn prediction_exists(&self, id: &str) -> bool {
self.predictions.contains_key(id)
}
pub fn get_prediction_response(&self, id: &str) -> Option<serde_json::Value> {
let entry = self.predictions.get(id)?;
let pred = entry.prediction.lock().ok()?;
let mut response = pred.build_state_snapshot();
response["input"] = entry.input.clone();
Some(response)
}
pub fn subscribe_prediction_stream(
self: &Arc<Self>,
id: &str,
) -> Result<PredictionStreamSubscription, SubscribePredictionStreamError> {
let entry = self
.predictions
.get(id)
.ok_or(SubscribePredictionStreamError::NotFound)?;
let stream = {
let Some(prediction) = try_lock_prediction(&entry.prediction) else {
return Err(SubscribePredictionStreamError::Unavailable);
};
if prediction.stream_receiver_count() >= MAX_STREAM_SUBSCRIBERS {
return Err(SubscribePredictionStreamError::TooManySubscribers);
}
prediction.subscribe_stream_replay()
};
let cancel_on_stream_drop = entry.cancel_on_stream_drop;
let id = id.to_string();
Ok(PredictionStreamSubscription {
id: id.clone(),
replay: stream.replay,
skipped: stream.skipped,
receiver: stream.receiver,
guard: PredictionStreamGuard {
id,
service: Arc::clone(self),
cancel_on_stream_drop,
},
})
}
fn stream_receiver_count(&self, id: &str) -> usize {
self.predictions
.get(id)
.and_then(|entry| {
entry
.prediction
.lock()
.ok()
.map(|p| p.stream_receiver_count())
})
.unwrap_or(0)
}
fn prediction_is_terminal(&self, id: &str) -> bool {
self.predictions
.get(id)
.and_then(|entry| entry.prediction.lock().ok().map(|p| p.is_terminal()))
.unwrap_or(true)
}
pub async fn predict(
&self,
unregistered_slot: UnregisteredPredictionSlot,
input: serde_json::Value,
context: std::collections::HashMap<String, String>,
) -> Result<PredictionResult, PredictionError> {
let state = self.orchestrator.read().await.clone();
let state = state
.ok_or_else(|| PredictionError::Failed("No orchestrator configured".to_string()))?;
let (idle_tx, mut slot) = unregistered_slot.into_parts();
let prediction_id = slot.id();
let slot_id = slot.slot_id();
{
let prediction = slot.prediction();
let Some(mut pred) = try_lock_prediction(&prediction) else {
return Err(PredictionError::Failed(
"Prediction mutex poisoned".to_string(),
));
};
pred.set_processing();
}
let prediction_arc = slot.prediction();
state
.orchestrator
.register_prediction(slot_id, Arc::clone(&prediction_arc), idle_tx)
.await;
let prediction_dir =
std::path::PathBuf::from("/tmp/coglet/predictions").join(&prediction_id);
let output_dir = prediction_dir.join("outputs");
let input_dir = prediction_dir.join("inputs");
std::fs::create_dir_all(&output_dir)
.map_err(|e| PredictionError::Failed(format!("Failed to create output dir: {}", e)))?;
std::fs::create_dir_all(&input_dir)
.map_err(|e| PredictionError::Failed(format!("Failed to create input dir: {}", e)))?;
let request = build_slot_request(
prediction_id.clone(),
input,
output_dir
.to_str()
.expect("output dir path is valid UTF-8")
.to_string(),
&input_dir,
context,
)
.map_err(|e| PredictionError::Failed(format!("Failed to build slot request: {}", e)))?;
let permit = slot
.permit_mut()
.ok_or_else(|| PredictionError::Failed("Permit not in use".to_string()))?;
if let Err(e) = permit.send(request).await {
tracing::error!(%slot_id, error = %e, "Failed to send prediction request");
state.pool.poison(slot_id);
if let Some(mut pred) = try_lock_prediction(&prediction_arc) {
pred.set_failed(format!("Failed to send request: {}", e));
}
return Err(PredictionError::Failed(format!(
"Failed to send request: {}",
e
)));
}
let was_cancelled_before_send = try_lock_prediction(&prediction_arc)
.map(|p| p.is_canceled())
.unwrap_or(false);
if was_cancelled_before_send
&& let Err(e) = state
.orchestrator
.cancel_by_prediction_id(&prediction_id)
.await
{
tracing::error!(
prediction_id = %prediction_id,
error = %e,
"Failed to forward pending cancellation after registration"
);
}
let (already_terminal, completion) = {
let Some(pred) = try_lock_prediction(&prediction_arc) else {
return Err(PredictionError::Failed(
"Prediction mutex poisoned".to_string(),
));
};
(pred.is_terminal(), pred.completion())
};
if !already_terminal {
completion.notified().await;
}
let (status, output, error, logs, predict_time, metrics) = {
let Some(pred) = try_lock_prediction(&prediction_arc) else {
return Err(PredictionError::Failed(
"Prediction mutex poisoned".to_string(),
));
};
(
pred.status(),
pred.output().cloned(),
pred.error().map(|s| s.to_string()),
pred.logs().to_string(),
pred.elapsed(),
pred.metrics().clone(),
)
};
tokio::spawn(async move {
if let Err(e) = slot.into_idle().await {
tracing::error!(%slot_id, error = %e, "Failed to transition slot to idle, poisoning slot");
state.pool.poison(slot_id);
}
});
match status {
PredictionStatus::Succeeded => Ok(PredictionResult {
output: output.unwrap_or(PredictionOutput::Single(serde_json::Value::Null)),
predict_time: Some(predict_time),
logs,
metrics,
}),
PredictionStatus::Failed => Err(PredictionError::Failed(
error.unwrap_or_else(|| "Unknown error".to_string()),
)),
PredictionStatus::Canceled => Err(PredictionError::Cancelled),
_ => Err(PredictionError::Failed(format!(
"Prediction ended in unexpected state: {:?}",
status
))),
}
}
pub fn cancel(&self, id: &str) -> bool {
if let Some(entry) = self.predictions.get(id) {
entry.cancel_token.cancel();
let id_owned = id.to_string();
let orchestrator = match self.orchestrator.try_read() {
Ok(guard) => guard.as_ref().map(|s| Arc::clone(&s.orchestrator)),
Err(_) => {
tracing::warn!(prediction_id = %id, "Skipped worker cancel: orchestrator lock unavailable");
None
}
};
if let Some(orch) = orchestrator {
spawn_orchestrator_cancel(orch, id_owned);
}
true
} else {
false
}
}
pub fn remove_prediction(&self, id: &str) {
self.predictions.remove(id);
}
pub fn trigger_shutdown(&self) {
let _ = self.shutdown_tx.send(true);
}
pub fn shutdown_rx(&self) -> watch::Receiver<bool> {
self.shutdown_rx.clone()
}
}
fn spawn_orchestrator_cancel(orch: Arc<dyn Orchestrator>, id: String) {
let Ok(handle) = tokio::runtime::Handle::try_current() else {
tracing::warn!(prediction_id = %id, "No tokio runtime available to cancel prediction");
return;
};
handle.spawn(async move {
if let Err(e) = orch.cancel_by_prediction_id(&id).await {
tracing::error!(
prediction_id = %id,
error = %e,
"Failed to send cancel to orchestrator"
);
}
});
}
fn build_slot_request(
id: String,
input: serde_json::Value,
output_dir: String,
input_dir: &std::path::Path,
context: std::collections::HashMap<String, String>,
) -> std::io::Result<SlotRequest> {
let serialized = serde_json::to_vec(&input)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
if serialized.len() > MAX_INLINE_IPC_SIZE {
let path = input_dir.join(format!("spill_{}.json", uuid::Uuid::new_v4()));
std::fs::write(&path, &serialized)?;
let input_file = path
.to_str()
.ok_or_else(|| std::io::Error::new(std::io::ErrorKind::InvalidData, "non-UTF-8 path"))?
.to_string();
Ok(SlotRequest::Predict {
id,
input: None,
input_file: Some(input_file),
output_dir,
context,
})
} else {
Ok(SlotRequest::Predict {
id,
input: Some(input),
input_file: None,
output_dir,
context,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::bridge::protocol::SlotId;
use crate::permit::{InactiveSlotIdleToken, SlotIdleToken};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::Duration;
struct MockOrchestrator {
register_count: AtomicUsize,
complete_immediately: bool,
send_idle_ack: bool,
}
impl MockOrchestrator {
fn new() -> Self {
Self {
register_count: AtomicUsize::new(0),
complete_immediately: true,
send_idle_ack: false,
}
}
fn register_count(&self) -> usize {
self.register_count.load(Ordering::SeqCst)
}
fn with_idle_ack(mut self) -> Self {
self.send_idle_ack = true;
self
}
}
#[async_trait::async_trait]
impl Orchestrator for MockOrchestrator {
async fn register_prediction(
&self,
slot_id: SlotId,
prediction: Arc<std::sync::Mutex<crate::prediction::Prediction>>,
idle_sender: tokio::sync::oneshot::Sender<SlotIdleToken>,
) {
self.register_count.fetch_add(1, Ordering::SeqCst);
if self.complete_immediately {
let mut pred = prediction.lock().unwrap();
pred.set_succeeded(crate::PredictionOutput::Single(serde_json::json!(
"mock result"
)));
}
if self.send_idle_ack {
let _ = idle_sender.send(InactiveSlotIdleToken::new(slot_id).activate());
}
}
async fn cancel_by_prediction_id(
&self,
_prediction_id: &str,
) -> Result<(), crate::orchestrator::OrchestratorError> {
Ok(())
}
async fn healthcheck(
&self,
) -> Result<HealthcheckResult, crate::orchestrator::OrchestratorError> {
Ok(HealthcheckResult::healthy())
}
async fn shutdown(&self) -> Result<(), crate::orchestrator::OrchestratorError> {
Ok(())
}
}
struct CountingCancelOrchestrator {
cancel_count: AtomicUsize,
}
impl CountingCancelOrchestrator {
fn new() -> Self {
Self {
cancel_count: AtomicUsize::new(0),
}
}
fn cancel_count(&self) -> usize {
self.cancel_count.load(Ordering::SeqCst)
}
}
#[async_trait::async_trait]
impl Orchestrator for CountingCancelOrchestrator {
async fn register_prediction(
&self,
_slot_id: SlotId,
_prediction: Arc<std::sync::Mutex<crate::prediction::Prediction>>,
_idle_sender: tokio::sync::oneshot::Sender<SlotIdleToken>,
) {
}
async fn cancel_by_prediction_id(
&self,
_prediction_id: &str,
) -> Result<(), crate::orchestrator::OrchestratorError> {
self.cancel_count.fetch_add(1, Ordering::SeqCst);
Ok(())
}
async fn healthcheck(
&self,
) -> Result<HealthcheckResult, crate::orchestrator::OrchestratorError> {
Ok(HealthcheckResult::healthy())
}
async fn shutdown(&self) -> Result<(), crate::orchestrator::OrchestratorError> {
Ok(())
}
}
struct CancelRecordingOrchestrator {
cancel_count: AtomicUsize,
prediction: std::sync::Mutex<Option<Arc<std::sync::Mutex<crate::prediction::Prediction>>>>,
}
impl CancelRecordingOrchestrator {
fn new() -> Self {
Self {
cancel_count: AtomicUsize::new(0),
prediction: std::sync::Mutex::new(None),
}
}
fn cancel_count(&self) -> usize {
self.cancel_count.load(Ordering::SeqCst)
}
}
#[async_trait::async_trait]
impl Orchestrator for CancelRecordingOrchestrator {
async fn register_prediction(
&self,
slot_id: SlotId,
prediction: Arc<std::sync::Mutex<crate::prediction::Prediction>>,
idle_sender: tokio::sync::oneshot::Sender<SlotIdleToken>,
) {
*self.prediction.lock().unwrap() = Some(prediction);
let _ = idle_sender.send(InactiveSlotIdleToken::new(slot_id).activate());
}
async fn cancel_by_prediction_id(
&self,
_prediction_id: &str,
) -> Result<(), crate::orchestrator::OrchestratorError> {
self.cancel_count.fetch_add(1, Ordering::SeqCst);
if let Some(prediction) = self.prediction.lock().unwrap().as_ref() {
prediction.lock().unwrap().set_canceled();
}
Ok(())
}
async fn healthcheck(
&self,
) -> Result<HealthcheckResult, crate::orchestrator::OrchestratorError> {
Ok(HealthcheckResult::healthy())
}
async fn shutdown(&self) -> Result<(), crate::orchestrator::OrchestratorError> {
Ok(())
}
}
async fn create_test_pool(num_slots: usize) -> Arc<PermitPool> {
use crate::bridge::codec::JsonCodec;
use crate::bridge::protocol::SlotRequest;
use futures::StreamExt;
use tokio::net::UnixStream;
let pool = Arc::new(PermitPool::new(num_slots));
for _ in 0..num_slots {
let (a, b) = UnixStream::pair().unwrap();
let (_read_a, write_a) = a.into_split();
let (read_b, _write_b) = b.into_split();
let mut reader =
tokio_util::codec::FramedRead::new(read_b, JsonCodec::<SlotRequest>::new());
tokio::spawn(async move { while reader.next().await.is_some() {} });
let writer =
tokio_util::codec::FramedWrite::new(write_a, JsonCodec::<SlotRequest>::new());
pool.add_permit(SlotId::new(), writer);
}
pool
}
async fn create_test_pool_with_slots(num_slots: usize) -> (Arc<PermitPool>, Vec<SlotId>) {
use crate::bridge::codec::JsonCodec;
use crate::bridge::protocol::SlotRequest;
use futures::StreamExt;
use tokio::net::UnixStream;
let pool = Arc::new(PermitPool::new(num_slots));
let mut slot_ids = Vec::with_capacity(num_slots);
for _ in 0..num_slots {
let (a, b) = UnixStream::pair().unwrap();
let (_read_a, write_a) = a.into_split();
let (read_b, _write_b) = b.into_split();
let mut reader =
tokio_util::codec::FramedRead::new(read_b, JsonCodec::<SlotRequest>::new());
tokio::spawn(async move { while reader.next().await.is_some() {} });
let writer =
tokio_util::codec::FramedWrite::new(write_a, JsonCodec::<SlotRequest>::new());
let slot_id = SlotId::new();
pool.add_permit(slot_id, writer);
slot_ids.push(slot_id);
}
(pool, slot_ids)
}
async fn create_broken_test_pool() -> (Arc<PermitPool>, SlotId) {
use crate::bridge::codec::JsonCodec;
use crate::bridge::protocol::SlotRequest;
use tokio::net::UnixStream;
let pool = Arc::new(PermitPool::new(1));
let (a, b) = UnixStream::pair().unwrap();
let (_read_a, write_a) = a.into_split();
drop(b);
let writer = tokio_util::codec::FramedWrite::new(write_a, JsonCodec::<SlotRequest>::new());
let slot_id = SlotId::new();
pool.add_permit(slot_id, writer);
(pool, slot_id)
}
#[tokio::test]
async fn service_new_no_pool_works() {
let svc = PredictionService::new_no_pool();
let health = svc.health().await;
assert_eq!(health.state, Health::Unknown);
assert_eq!(health.total_slots, 0);
assert_eq!(health.available_slots, 0);
assert!(svc.pool().await.is_none());
}
#[tokio::test]
async fn service_no_pool_initially() {
let svc = PredictionService::new_no_pool();
assert!(svc.pool().await.is_none());
assert!(!svc.has_orchestrator().await);
}
#[tokio::test]
async fn shutdown_signal_works() {
let svc = PredictionService::new_no_pool();
let mut rx = svc.shutdown_rx();
assert!(!*rx.borrow());
svc.trigger_shutdown();
rx.changed().await.unwrap();
assert!(*rx.borrow());
}
#[tokio::test]
async fn submit_fails_when_not_ready() {
let svc = PredictionService::new_no_pool();
let result = svc
.submit_prediction("test".to_string(), serde_json::json!({}), None, false)
.await;
assert!(matches!(result, Err(CreatePredictionError::NotReady)));
}
#[tokio::test]
async fn cannot_set_ready_without_orchestrator() {
let svc = PredictionService::new_no_pool();
let svc2 = PredictionService::new_no_pool().with_health(Health::Ready);
assert_eq!(svc2.health().await.state, Health::Unknown);
svc.set_health(Health::Ready).await;
assert_eq!(svc.health().await.state, Health::Unknown);
}
#[tokio::test]
async fn set_orchestrator_enables_ready_health() {
let svc = PredictionService::new_no_pool();
let pool = create_test_pool(2).await;
let orchestrator = Arc::new(MockOrchestrator::new());
svc.set_orchestrator(pool, orchestrator).await;
assert!(svc.has_orchestrator().await);
svc.set_health(Health::Ready).await;
let health = svc.health().await;
assert_eq!(health.state, Health::Ready);
assert_eq!(health.total_slots, 2);
assert_eq!(health.available_slots, 2);
}
#[tokio::test]
async fn submit_prediction_succeeds_when_ready() {
let svc = PredictionService::new_no_pool();
let pool = create_test_pool(1).await;
let orchestrator = Arc::new(MockOrchestrator::new());
svc.set_orchestrator(pool, orchestrator).await;
svc.set_health(Health::Ready).await;
let (handle, _slot) = svc
.submit_prediction("test-1".to_string(), serde_json::json!({}), None, false)
.await
.unwrap();
assert_eq!(handle.id(), "test-1");
assert!(svc.prediction_exists("test-1"));
}
#[tokio::test]
async fn subscribe_prediction_stream_returns_receiver_for_existing_prediction() {
let svc = Arc::new(PredictionService::new_no_pool());
let pool = create_test_pool(1).await;
let orchestrator = Arc::new(MockOrchestrator::new());
svc.set_orchestrator(pool, orchestrator).await;
svc.set_health(Health::Ready).await;
let (_handle, _slot) = svc
.submit_prediction("stream-test".to_string(), serde_json::json!({}), None, true)
.await
.unwrap();
let subscription = svc.subscribe_prediction_stream("stream-test").unwrap();
assert_eq!(subscription.prediction_id(), "stream-test");
}
#[tokio::test]
async fn dropping_only_sync_stream_subscription_cancels_prediction() {
let svc = Arc::new(PredictionService::new_no_pool());
let pool = create_test_pool(1).await;
let orchestrator = Arc::new(CountingCancelOrchestrator::new());
let orchestrator_ref = Arc::clone(&orchestrator);
svc.set_orchestrator(pool, orchestrator).await;
svc.set_health(Health::Ready).await;
let (_handle, _slot) = svc
.submit_prediction("sync-stream".to_string(), serde_json::json!({}), None, true)
.await
.unwrap();
let subscription = svc.subscribe_prediction_stream("sync-stream").unwrap();
drop(subscription);
tokio::time::sleep(Duration::from_millis(25)).await;
assert_eq!(orchestrator_ref.cancel_count(), 1);
}
#[tokio::test]
async fn dropping_async_json_stream_subscription_does_not_cancel_prediction() {
let svc = Arc::new(PredictionService::new_no_pool());
let pool = create_test_pool(1).await;
let orchestrator = Arc::new(CountingCancelOrchestrator::new());
let orchestrator_ref = Arc::clone(&orchestrator);
svc.set_orchestrator(pool, orchestrator).await;
svc.set_health(Health::Ready).await;
let (_handle, _slot) = svc
.submit_prediction(
"async-json-stream".to_string(),
serde_json::json!({}),
None,
false,
)
.await
.unwrap();
let subscription = svc
.subscribe_prediction_stream("async-json-stream")
.unwrap();
drop(subscription);
tokio::time::sleep(Duration::from_millis(25)).await;
assert_eq!(orchestrator_ref.cancel_count(), 0);
}
#[tokio::test]
async fn dropping_live_sse_stream_subscription_cancels_prediction() {
let svc = Arc::new(PredictionService::new_no_pool());
let pool = create_test_pool(1).await;
let orchestrator = Arc::new(CountingCancelOrchestrator::new());
let orchestrator_ref = Arc::clone(&orchestrator);
svc.set_orchestrator(pool, orchestrator).await;
svc.set_health(Health::Ready).await;
let (_handle, _slot) = svc
.submit_prediction(
"live-sse-stream".to_string(),
serde_json::json!({}),
None,
true,
)
.await
.unwrap();
let subscription = svc.subscribe_prediction_stream("live-sse-stream").unwrap();
drop(subscription);
tokio::time::sleep(Duration::from_millis(25)).await;
assert_eq!(orchestrator_ref.cancel_count(), 1);
}
#[tokio::test]
async fn dropping_one_of_two_sync_stream_subscriptions_does_not_cancel_prediction() {
let svc = Arc::new(PredictionService::new_no_pool());
let pool = create_test_pool(1).await;
let orchestrator = Arc::new(CountingCancelOrchestrator::new());
let orchestrator_ref = Arc::clone(&orchestrator);
svc.set_orchestrator(pool, orchestrator).await;
svc.set_health(Health::Ready).await;
let (_handle, _slot) = svc
.submit_prediction(
"multi-sse-stream".to_string(),
serde_json::json!({}),
None,
true,
)
.await
.unwrap();
let first = svc.subscribe_prediction_stream("multi-sse-stream").unwrap();
let second = svc.subscribe_prediction_stream("multi-sse-stream").unwrap();
drop(first);
tokio::time::sleep(Duration::from_millis(25)).await;
assert_eq!(orchestrator_ref.cancel_count(), 0);
drop(second);
tokio::time::sleep(Duration::from_millis(25)).await;
assert_eq!(orchestrator_ref.cancel_count(), 1);
}
#[tokio::test]
async fn subscribe_prediction_stream_rejects_too_many_subscribers() {
let svc = Arc::new(PredictionService::new_no_pool());
let pool = create_test_pool(1).await;
let orchestrator = Arc::new(MockOrchestrator::new());
svc.set_orchestrator(pool, orchestrator).await;
svc.set_health(Health::Ready).await;
let (_handle, _slot) = svc
.submit_prediction(
"subscriber-cap".to_string(),
serde_json::json!({}),
None,
true,
)
.await
.unwrap();
let mut subscriptions = Vec::new();
for _ in 0..MAX_STREAM_SUBSCRIBERS {
subscriptions.push(svc.subscribe_prediction_stream("subscriber-cap").unwrap());
}
assert!(matches!(
svc.subscribe_prediction_stream("subscriber-cap"),
Err(SubscribePredictionStreamError::TooManySubscribers)
));
}
#[tokio::test]
async fn dropping_completed_sync_stream_subscription_does_not_cancel_prediction() {
let svc = Arc::new(PredictionService::new_no_pool());
let pool = create_test_pool(1).await;
let orchestrator = Arc::new(CountingCancelOrchestrator::new());
let orchestrator_ref = Arc::clone(&orchestrator);
svc.set_orchestrator(pool, orchestrator).await;
svc.set_health(Health::Ready).await;
let (_handle, _slot) = svc
.submit_prediction(
"completed-sync-stream".to_string(),
serde_json::json!({}),
None,
true,
)
.await
.unwrap();
{
let entry = svc.predictions.get("completed-sync-stream").unwrap();
let mut prediction = entry.prediction.lock().unwrap();
prediction.set_succeeded(crate::PredictionOutput::Single(serde_json::json!("done")));
}
let subscription = svc
.subscribe_prediction_stream("completed-sync-stream")
.unwrap();
drop(subscription);
tokio::time::sleep(Duration::from_millis(25)).await;
assert_eq!(orchestrator_ref.cancel_count(), 0);
}
#[tokio::test]
async fn submit_returns_at_capacity_when_no_slots() {
let svc = PredictionService::new_no_pool();
let pool = create_test_pool(1).await;
let orchestrator = Arc::new(MockOrchestrator::new());
svc.set_orchestrator(pool, orchestrator).await;
svc.set_health(Health::Ready).await;
let (_handle1, _slot1) = svc
.submit_prediction("test-1".to_string(), serde_json::json!({}), None, false)
.await
.unwrap();
let result = svc
.submit_prediction("test-2".to_string(), serde_json::json!({}), None, false)
.await;
assert!(matches!(result, Err(CreatePredictionError::AtCapacity)));
}
#[tokio::test]
async fn predict_calls_orchestrator_register() {
let svc = PredictionService::new_no_pool();
let pool = create_test_pool(1).await;
let orchestrator = Arc::new(MockOrchestrator::new());
let orch_ref = Arc::clone(&orchestrator);
svc.set_orchestrator(pool, orchestrator).await;
svc.set_health(Health::Ready).await;
let (_handle, slot) = svc
.submit_prediction(
"test-1".to_string(),
serde_json::json!({"prompt": "hello"}),
None,
false,
)
.await
.unwrap();
let result = svc
.predict(
slot,
serde_json::json!({"prompt": "hello"}),
Default::default(),
)
.await;
assert!(result.is_ok(), "predict failed: {:?}", result.err());
assert_eq!(orch_ref.register_count(), 1);
}
#[tokio::test]
async fn predict_forwards_cancel_token_set_before_registration() {
let svc = PredictionService::new_no_pool();
let pool = create_test_pool(1).await;
let orchestrator = Arc::new(CancelRecordingOrchestrator::new());
let orchestrator_ref = Arc::clone(&orchestrator);
svc.set_orchestrator(pool, orchestrator).await;
svc.set_health(Health::Ready).await;
let (handle, slot) = svc
.submit_prediction(
"pre-register-cancel".to_string(),
serde_json::json!({}),
None,
true,
)
.await
.unwrap();
handle.cancel_token().cancel();
let result = tokio::time::timeout(
Duration::from_millis(100),
svc.predict(
slot,
serde_json::json!({}),
std::collections::HashMap::new(),
),
)
.await
.expect("prediction should observe cancellation after registration");
assert!(matches!(result, Err(PredictionError::Cancelled)));
assert_eq!(orchestrator_ref.cancel_count(), 1);
}
#[tokio::test]
async fn health_shows_busy_when_all_slots_used() {
let svc = PredictionService::new_no_pool();
let pool = create_test_pool(1).await;
let orchestrator = Arc::new(MockOrchestrator::new());
svc.set_orchestrator(pool, orchestrator).await;
svc.set_health(Health::Ready).await;
let health = svc.health().await;
assert!(!health.is_busy());
assert_eq!(health.available_slots, 1);
let (_handle, _slot) = svc
.submit_prediction("test-1".to_string(), serde_json::json!({}), None, false)
.await
.unwrap();
let health = svc.health().await;
assert!(health.is_busy());
assert_eq!(health.available_slots, 0);
}
#[tokio::test]
async fn predict_idle_channel_closed_poison_slot_async() {
let svc = PredictionService::new_no_pool();
let (pool, slot_ids) = create_test_pool_with_slots(1).await;
let orchestrator = Arc::new(MockOrchestrator::new());
let slot_id = slot_ids[0];
svc.set_orchestrator(Arc::clone(&pool), orchestrator).await;
svc.set_health(Health::Ready).await;
let (_handle, slot) = svc
.submit_prediction(
"test-1".to_string(),
serde_json::json!({"prompt": "hello"}),
None,
false,
)
.await
.unwrap();
let result = svc
.predict(
slot,
serde_json::json!({"prompt": "hello"}),
Default::default(),
)
.await;
assert!(result.is_ok(), "predict failed: {:?}", result.err());
tokio::time::timeout(Duration::from_secs(1), async {
loop {
if pool.is_poisoned(slot_id) {
break;
}
tokio::time::sleep(Duration::from_millis(10)).await;
}
})
.await
.expect("slot was not poisoned after idle token channel closed");
}
#[tokio::test]
async fn predict_idle_ack_returns_capacity_async() {
let svc = PredictionService::new_no_pool();
let pool = create_test_pool(1).await;
let orchestrator = Arc::new(MockOrchestrator::new().with_idle_ack());
svc.set_orchestrator(Arc::clone(&pool), orchestrator).await;
svc.set_health(Health::Ready).await;
let (_handle, slot) = svc
.submit_prediction(
"test-1".to_string(),
serde_json::json!({"prompt": "hello"}),
None,
false,
)
.await
.unwrap();
let result = svc
.predict(
slot,
serde_json::json!({"prompt": "hello"}),
Default::default(),
)
.await;
assert!(result.is_ok(), "predict failed: {:?}", result.err());
tokio::time::timeout(Duration::from_secs(1), async {
loop {
if pool.available() == 1 {
break;
}
tokio::time::sleep(Duration::from_millis(10)).await;
}
})
.await
.expect("slot capacity was not returned after idle acknowledgement");
}
#[tokio::test]
async fn predict_send_failure_poison_slot() {
let svc = PredictionService::new_no_pool();
let (pool, slot_id) = create_broken_test_pool().await;
let orchestrator = Arc::new(MockOrchestrator::new());
svc.set_orchestrator(Arc::clone(&pool), orchestrator).await;
svc.set_health(Health::Ready).await;
let (_handle, slot) = svc
.submit_prediction(
"test-1".to_string(),
serde_json::json!({"prompt": "hello"}),
None,
false,
)
.await
.unwrap();
let result = svc
.predict(
slot,
serde_json::json!({"prompt": "hello"}),
Default::default(),
)
.await;
assert!(matches!(result, Err(PredictionError::Failed(_))));
assert!(pool.is_poisoned(slot_id));
assert!(pool.try_acquire().is_none());
}
#[tokio::test]
async fn cancel_prediction_works() {
let svc = PredictionService::new_no_pool();
let pool = create_test_pool(1).await;
let orchestrator = Arc::new(MockOrchestrator::new());
svc.set_orchestrator(pool, orchestrator).await;
svc.set_health(Health::Ready).await;
let (handle, _slot) = svc
.submit_prediction(
"test-cancel".to_string(),
serde_json::json!({}),
None,
false,
)
.await
.unwrap();
let cancel_token = handle.cancel_token();
let cancelled = svc.cancel("test-cancel");
assert!(cancelled);
assert!(cancel_token.is_cancelled());
}
#[tokio::test]
async fn cancel_nonexistent_returns_false() {
let svc = PredictionService::new_no_pool();
assert!(!svc.cancel("nonexistent"));
}
#[tokio::test]
async fn sync_guard_cancels_on_drop() {
let svc = Arc::new(PredictionService::new_no_pool());
let pool = create_test_pool(1).await;
let orchestrator = Arc::new(MockOrchestrator::new());
svc.set_orchestrator(pool, orchestrator).await;
svc.set_health(Health::Ready).await;
let (handle, _slot) = svc
.submit_prediction("test-guard".to_string(), serde_json::json!({}), None, false)
.await
.unwrap();
let cancel_token = handle.cancel_token();
{
let _guard = handle.sync_guard(Arc::clone(&svc));
assert!(!cancel_token.is_cancelled());
}
assert!(cancel_token.is_cancelled());
}
#[tokio::test]
async fn sync_guard_disarm_prevents_cancel() {
let svc = Arc::new(PredictionService::new_no_pool());
let pool = create_test_pool(1).await;
let orchestrator = Arc::new(MockOrchestrator::new());
svc.set_orchestrator(pool, orchestrator).await;
svc.set_health(Health::Ready).await;
let (handle, _slot) = svc
.submit_prediction(
"test-disarm".to_string(),
serde_json::json!({}),
None,
false,
)
.await
.unwrap();
let cancel_token = handle.cancel_token();
{
let mut guard = handle.sync_guard(Arc::clone(&svc));
guard.disarm();
}
assert!(!cancel_token.is_cancelled());
}
#[tokio::test]
async fn remove_prediction_cleans_up() {
let svc = PredictionService::new_no_pool();
let pool = create_test_pool(1).await;
let orchestrator = Arc::new(MockOrchestrator::new());
svc.set_orchestrator(pool, orchestrator).await;
svc.set_health(Health::Ready).await;
let (_handle, _slot) = svc
.submit_prediction(
"test-remove".to_string(),
serde_json::json!({}),
None,
false,
)
.await
.unwrap();
assert!(svc.prediction_exists("test-remove"));
svc.remove_prediction("test-remove");
assert!(!svc.prediction_exists("test-remove"));
}
#[test]
fn build_slot_request_small_input_inline() {
let dir = tempfile::tempdir().unwrap();
let input = serde_json::json!({"text": "hello"});
let req = build_slot_request(
"p1".into(),
input.clone(),
"/tmp/out".into(),
dir.path(),
Default::default(),
)
.unwrap();
match req {
SlotRequest::Predict {
id,
input: Some(v),
input_file: None,
output_dir,
..
} => {
assert_eq!(id, "p1");
assert_eq!(v, input);
assert_eq!(output_dir, "/tmp/out");
}
_ => panic!("expected inline input"),
}
}
#[test]
fn build_slot_request_large_input_spills() {
let dir = tempfile::tempdir().unwrap();
let big = "x".repeat(7 * 1024 * 1024);
let input = serde_json::json!({"data": big});
let req = build_slot_request(
"p2".into(),
input.clone(),
"/tmp/out".into(),
dir.path(),
Default::default(),
)
.unwrap();
match req {
SlotRequest::Predict {
id,
input: None,
input_file: Some(ref path),
output_dir,
..
} => {
assert_eq!(id, "p2");
assert_eq!(output_dir, "/tmp/out");
assert!(std::path::Path::new(path).exists());
let bytes = std::fs::read(path).unwrap();
let roundtrip: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
assert_eq!(roundtrip, input);
}
_ => panic!("expected file-backed input"),
}
}
#[test]
fn build_slot_request_roundtrip() {
let dir = tempfile::tempdir().unwrap();
let big = "y".repeat(7 * 1024 * 1024);
let input = serde_json::json!({"payload": big});
let req = build_slot_request(
"p3".into(),
input.clone(),
"/tmp/out".into(),
dir.path(),
Default::default(),
)
.unwrap();
let (id, rehydrated, output_dir, _context) = req.rehydrate_input().unwrap();
assert_eq!(id, "p3");
assert_eq!(rehydrated, input);
assert_eq!(output_dir, "/tmp/out");
}
fn optional_schema(key: &str) -> serde_json::Value {
serde_json::json!({
"components": {
"schemas": {
key: {
"type": "object",
"properties": {
"s": {"type": "string", "title": "S"},
"value": {"type": "string", "title": "Value", "nullable": true}
},
"required": ["s"]
}
}
}
})
}
#[tokio::test]
async fn strip_and_validate_input_injects_missing_optional() {
let svc = PredictionService::new_no_pool();
svc.set_schema(optional_schema("Input")).await;
let mut input = serde_json::json!({"s": "hello"});
let (_stripped, result) = svc.strip_and_validate_input(&mut input).await;
assert!(result.is_ok(), "valid input should pass validation");
assert_eq!(
input,
serde_json::json!({"s": "hello", "value": null}),
"omitted optional-with-no-default should be injected with null"
);
}
#[tokio::test]
async fn strip_and_validate_input_does_not_inject_when_required_missing() {
let svc = PredictionService::new_no_pool();
svc.set_schema(optional_schema("Input")).await;
let mut input = serde_json::json!({});
let (_stripped, result) = svc.strip_and_validate_input(&mut input).await;
assert!(result.is_err(), "missing required field should fail");
assert_eq!(
input,
serde_json::json!({}),
"injection must not run when validation fails"
);
}
#[tokio::test]
async fn strip_and_validate_train_input_injects_missing_optional() {
let svc = PredictionService::new_no_pool();
svc.set_schema(optional_schema("TrainingInput")).await;
let mut input = serde_json::json!({"s": "hello"});
let (_stripped, result) = svc.strip_and_validate_train_input(&mut input).await;
assert!(
result.is_ok(),
"valid training input should pass validation"
);
assert_eq!(
input,
serde_json::json!({"s": "hello", "value": null}),
"train path should inject omitted optional-with-no-default"
);
}
}