use std::sync::{Arc, Mutex as StdMutex};
use dashmap::DashMap;
use tokio::sync::{RwLock, watch};
use crate::bridge::protocol::{MAX_INLINE_IPC_SIZE, SlotRequest};
use crate::health::{Health, SetupResult};
use crate::input_validation::InputValidator;
use crate::orchestrator::{HealthcheckResult, Orchestrator};
use crate::permit::{PermitPool, PredictionSlot, UnregisteredPredictionSlot};
use crate::prediction::{CancellationToken, Prediction, PredictionStatus};
use crate::predictor::{PredictionError, PredictionOutput, PredictionResult};
use crate::version::VersionInfo;
use crate::webhook::WebhookSender;
fn try_lock_prediction(
pred: &Arc<StdMutex<Prediction>>,
) -> Option<std::sync::MutexGuard<'_, Prediction>> {
match pred.lock() {
Ok(guard) => Some(guard),
Err(poisoned) => {
tracing::error!("Prediction mutex poisoned - failing prediction");
let mut guard = poisoned.into_inner();
if !guard.is_terminal() {
guard.set_failed("Internal error: mutex poisoned".to_string());
}
None
}
}
}
#[derive(Debug, thiserror::Error)]
pub enum CreatePredictionError {
#[error("Service not ready")]
NotReady,
#[error("At capacity (no slots available)")]
AtCapacity,
}
#[derive(Debug, Clone)]
pub struct HealthSnapshot {
pub state: Health,
pub available_slots: usize,
pub total_slots: usize,
pub setup_result: Option<SetupResult>,
pub version: VersionInfo,
}
impl HealthSnapshot {
pub fn is_ready(&self) -> bool {
self.state == Health::Ready
}
pub fn is_busy(&self) -> bool {
self.state == Health::Ready && self.available_slots == 0
}
}
struct PredictionEntry {
prediction: Arc<StdMutex<Prediction>>,
cancel_token: CancellationToken,
input: serde_json::Value,
}
pub struct PredictionHandle {
id: String,
cancel_token: CancellationToken,
}
impl PredictionHandle {
pub fn id(&self) -> &str {
&self.id
}
pub fn cancel_token(&self) -> CancellationToken {
self.cancel_token.clone()
}
pub fn sync_guard(&self, service: Arc<PredictionService>) -> SyncPredictionGuard {
SyncPredictionGuard::new(self.id.clone(), service)
}
}
pub struct SyncPredictionGuard {
prediction_id: Option<String>,
service: Arc<PredictionService>,
}
impl SyncPredictionGuard {
pub fn new(prediction_id: String, service: Arc<PredictionService>) -> Self {
Self {
prediction_id: Some(prediction_id),
service,
}
}
pub fn disarm(&mut self) {
self.prediction_id = None;
}
}
impl Drop for SyncPredictionGuard {
fn drop(&mut self) {
if let Some(ref id) = self.prediction_id {
self.service.cancel(id);
}
}
}
pub struct OrchestratorState {
pub pool: Arc<PermitPool>,
pub orchestrator: Arc<dyn Orchestrator>,
}
impl Clone for OrchestratorState {
fn clone(&self) -> Self {
Self {
pool: Arc::clone(&self.pool),
orchestrator: Arc::clone(&self.orchestrator),
}
}
}
pub struct PredictionService {
orchestrator: RwLock<Option<OrchestratorState>>,
health: RwLock<Health>,
setup_result: RwLock<Option<SetupResult>>,
predictions: DashMap<String, PredictionEntry>,
shutdown_tx: watch::Sender<bool>,
shutdown_rx: watch::Receiver<bool>,
version: VersionInfo,
schema: RwLock<Option<serde_json::Value>>,
input_validator: RwLock<Option<InputValidator>>,
train_validator: RwLock<Option<InputValidator>>,
}
impl PredictionService {
pub fn new_no_pool() -> Self {
let (shutdown_tx, shutdown_rx) = watch::channel(false);
Self {
orchestrator: RwLock::new(None),
health: RwLock::new(Health::Unknown),
setup_result: RwLock::new(None),
predictions: DashMap::new(),
shutdown_tx,
shutdown_rx,
version: VersionInfo::new(),
schema: RwLock::new(None),
input_validator: RwLock::new(None),
train_validator: RwLock::new(None),
}
}
pub async fn set_orchestrator(
&self,
pool: Arc<PermitPool>,
orchestrator: Arc<dyn Orchestrator>,
) {
*self.orchestrator.write().await = Some(OrchestratorState { pool, orchestrator });
}
pub async fn has_orchestrator(&self) -> bool {
self.orchestrator.read().await.is_some()
}
pub async fn shutdown(&self) {
if let Some(ref state) = *self.orchestrator.read().await
&& let Err(e) = state.orchestrator.shutdown().await
{
tracing::warn!(error = %e, "Error during orchestrator shutdown");
}
}
pub fn with_health(mut self, health: Health) -> Self {
if health != Health::Ready {
self.health = RwLock::new(health);
}
self
}
pub fn with_version(mut self, version: VersionInfo) -> Self {
self.version = version;
self
}
pub fn version(&self) -> &VersionInfo {
&self.version
}
pub async fn supports_training(&self) -> bool {
self.train_validator.read().await.is_some()
}
pub async fn pool(&self) -> Option<Arc<PermitPool>> {
if let Some(ref state) = *self.orchestrator.read().await {
Some(Arc::clone(&state.pool))
} else {
None
}
}
pub async fn health(&self) -> HealthSnapshot {
let state = *self.health.read().await;
let setup_result = self.setup_result.read().await.clone();
let pool = self.pool().await;
let (available_slots, total_slots) = match pool.as_ref() {
Some(p) => (p.available(), p.num_slots()),
None => (0, 0),
};
tracing::trace!(
?state,
available_slots,
total_slots,
setup_status = ?setup_result.as_ref().map(|r| r.status),
"Building health snapshot"
);
HealthSnapshot {
state,
available_slots,
total_slots,
setup_result,
version: self.version.clone(),
}
}
pub async fn set_health(&self, health: Health) {
if health == Health::Ready && self.orchestrator.read().await.is_none() {
tracing::warn!("Attempted to set READY without orchestrator, ignoring");
return;
}
let previous = *self.health.read().await;
tracing::debug!(from = ?previous, to = ?health, "Health state transition");
*self.health.write().await = health;
}
pub async fn set_setup_result(&self, result: SetupResult) {
tracing::debug!(
status = ?result.status,
started_at = %result.started_at,
completed_at = ?result.completed_at,
logs_len = result.logs.len(),
"Setting setup result"
);
*self.setup_result.write().await = Some(result);
}
pub async fn set_schema(&self, schema: serde_json::Value) {
let validator = InputValidator::from_openapi_schema(&schema);
if let Some(v) = &validator {
tracing::info!(
"Input validation enabled ({} required fields)",
v.required_count()
);
}
*self.input_validator.write().await = validator;
let train_val = InputValidator::from_openapi_schema_key(&schema, "TrainingInput");
if let Some(v) = &train_val {
tracing::info!(
"Training input validation enabled ({} required fields)",
v.required_count()
);
}
*self.train_validator.write().await = train_val;
*self.schema.write().await = Some(schema);
}
pub async fn schema(&self) -> Option<serde_json::Value> {
self.schema.read().await.clone()
}
pub async fn strip_and_validate_input(
&self,
input: &mut serde_json::Value,
) -> (
Vec<String>,
Result<(), Vec<crate::input_validation::ValidationError>>,
) {
let guard = self.input_validator.read().await;
if let Some(ref validator) = *guard {
let stripped = validator.strip_unknown(input);
let result = validator.validate(input);
(stripped, result)
} else {
(Vec::new(), Ok(()))
}
}
pub async fn strip_and_validate_train_input(
&self,
input: &mut serde_json::Value,
) -> (
Vec<String>,
Result<(), Vec<crate::input_validation::ValidationError>>,
) {
let train_guard = self.train_validator.read().await;
if let Some(ref validator) = *train_guard {
let stripped = validator.strip_unknown(input);
let result = validator.validate(input);
return (stripped, result);
}
drop(train_guard);
let predict_guard = self.input_validator.read().await;
if let Some(ref validator) = *predict_guard {
let stripped = validator.strip_unknown(input);
let result = validator.validate(input);
return (stripped, result);
}
(Vec::new(), Ok(()))
}
pub async fn healthcheck(
&self,
) -> Result<HealthcheckResult, crate::orchestrator::OrchestratorError> {
if let Some(ref state) = *self.orchestrator.read().await {
tracing::trace!("Dispatching healthcheck to orchestrator");
let result = state.orchestrator.healthcheck().await;
tracing::trace!(
healthy = result.as_ref().map(|r| r.is_healthy()).unwrap_or(false),
error = ?result.as_ref().ok().and_then(|r| r.error.as_ref()),
"Healthcheck result from orchestrator"
);
result
} else {
tracing::debug!("No orchestrator configured, returning default healthy");
Ok(HealthcheckResult::healthy())
}
}
pub async fn submit_prediction(
&self,
id: String,
input: serde_json::Value,
webhook: Option<WebhookSender>,
) -> Result<(PredictionHandle, UnregisteredPredictionSlot), CreatePredictionError> {
let health = *self.health.read().await;
if health != Health::Ready {
return Err(CreatePredictionError::NotReady);
}
let pool = self.pool().await;
let pool = pool.as_ref().ok_or(CreatePredictionError::NotReady)?;
let permit = pool
.try_acquire()
.ok_or(CreatePredictionError::AtCapacity)?;
let prediction = Prediction::new(id.clone(), webhook);
let cancel_token = prediction.cancel_token();
let (idle_tx, idle_rx) = tokio::sync::oneshot::channel();
let slot = PredictionSlot::new(prediction, permit, idle_rx);
let prediction_arc = slot.prediction();
self.predictions.insert(
id.clone(),
PredictionEntry {
prediction: prediction_arc,
cancel_token: cancel_token.clone(),
input,
},
);
let handle = PredictionHandle { id, cancel_token };
Ok((handle, UnregisteredPredictionSlot::new(slot, idle_tx)))
}
pub fn prediction_exists(&self, id: &str) -> bool {
self.predictions.contains_key(id)
}
pub fn get_prediction_response(&self, id: &str) -> Option<serde_json::Value> {
let entry = self.predictions.get(id)?;
let pred = entry.prediction.lock().ok()?;
let mut response = pred.build_state_snapshot();
response["input"] = entry.input.clone();
Some(response)
}
pub async fn predict(
&self,
unregistered_slot: UnregisteredPredictionSlot,
input: serde_json::Value,
context: std::collections::HashMap<String, String>,
) -> Result<PredictionResult, PredictionError> {
let state = self.orchestrator.read().await.clone();
let state = state
.ok_or_else(|| PredictionError::Failed("No orchestrator configured".to_string()))?;
let (idle_tx, mut slot) = unregistered_slot.into_parts();
let prediction_id = slot.id();
let slot_id = slot.slot_id();
{
let prediction = slot.prediction();
let Some(mut pred) = try_lock_prediction(&prediction) else {
return Err(PredictionError::Failed(
"Prediction mutex poisoned".to_string(),
));
};
pred.set_processing();
}
let prediction_arc = slot.prediction();
state
.orchestrator
.register_prediction(slot_id, Arc::clone(&prediction_arc), idle_tx)
.await;
let prediction_dir =
std::path::PathBuf::from("/tmp/coglet/predictions").join(&prediction_id);
let output_dir = prediction_dir.join("outputs");
let input_dir = prediction_dir.join("inputs");
std::fs::create_dir_all(&output_dir)
.map_err(|e| PredictionError::Failed(format!("Failed to create output dir: {}", e)))?;
std::fs::create_dir_all(&input_dir)
.map_err(|e| PredictionError::Failed(format!("Failed to create input dir: {}", e)))?;
let request = build_slot_request(
prediction_id.clone(),
input,
output_dir
.to_str()
.expect("output dir path is valid UTF-8")
.to_string(),
&input_dir,
context,
)
.map_err(|e| PredictionError::Failed(format!("Failed to build slot request: {}", e)))?;
let permit = slot
.permit_mut()
.ok_or_else(|| PredictionError::Failed("Permit not in use".to_string()))?;
if let Err(e) = permit.send(request).await {
tracing::error!(%slot_id, error = %e, "Failed to send prediction request");
state.pool.poison(slot_id);
if let Some(mut pred) = try_lock_prediction(&prediction_arc) {
pred.set_failed(format!("Failed to send request: {}", e));
}
return Err(PredictionError::Failed(format!(
"Failed to send request: {}",
e
)));
}
let (already_terminal, completion) = {
let Some(pred) = try_lock_prediction(&prediction_arc) else {
return Err(PredictionError::Failed(
"Prediction mutex poisoned".to_string(),
));
};
(pred.is_terminal(), pred.completion())
};
if !already_terminal {
completion.notified().await;
}
let (status, output, error, logs, predict_time, metrics) = {
let Some(pred) = try_lock_prediction(&prediction_arc) else {
return Err(PredictionError::Failed(
"Prediction mutex poisoned".to_string(),
));
};
(
pred.status(),
pred.output().cloned(),
pred.error().map(|s| s.to_string()),
pred.logs().to_string(),
pred.elapsed(),
pred.metrics().clone(),
)
};
tokio::spawn(async move {
if let Err(e) = slot.into_idle().await {
tracing::error!(%slot_id, error = %e, "Failed to transition slot to idle, poisoning slot");
state.pool.poison(slot_id);
}
});
match status {
PredictionStatus::Succeeded => Ok(PredictionResult {
output: output.unwrap_or(PredictionOutput::Single(serde_json::Value::Null)),
predict_time: Some(predict_time),
logs,
metrics,
}),
PredictionStatus::Failed => Err(PredictionError::Failed(
error.unwrap_or_else(|| "Unknown error".to_string()),
)),
PredictionStatus::Canceled => Err(PredictionError::Cancelled),
_ => Err(PredictionError::Failed(format!(
"Prediction ended in unexpected state: {:?}",
status
))),
}
}
pub fn cancel(&self, id: &str) -> bool {
if let Some(entry) = self.predictions.get(id) {
entry.cancel_token.cancel();
let id_owned = id.to_string();
let orchestrator = self
.orchestrator
.try_read()
.ok()
.and_then(|guard| guard.as_ref().map(|s| Arc::clone(&s.orchestrator)));
if let Some(orch) = orchestrator {
tokio::spawn(async move {
if let Err(e) = orch.cancel_by_prediction_id(&id_owned).await {
tracing::error!(
prediction_id = %id_owned,
error = %e,
"Failed to send cancel to orchestrator"
);
}
});
}
true
} else {
false
}
}
pub fn remove_prediction(&self, id: &str) {
self.predictions.remove(id);
}
pub fn trigger_shutdown(&self) {
let _ = self.shutdown_tx.send(true);
}
pub fn shutdown_rx(&self) -> watch::Receiver<bool> {
self.shutdown_rx.clone()
}
}
fn build_slot_request(
id: String,
input: serde_json::Value,
output_dir: String,
input_dir: &std::path::Path,
context: std::collections::HashMap<String, String>,
) -> std::io::Result<SlotRequest> {
let serialized = serde_json::to_vec(&input)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
if serialized.len() > MAX_INLINE_IPC_SIZE {
let path = input_dir.join(format!("spill_{}.json", uuid::Uuid::new_v4()));
std::fs::write(&path, &serialized)?;
let input_file = path
.to_str()
.ok_or_else(|| std::io::Error::new(std::io::ErrorKind::InvalidData, "non-UTF-8 path"))?
.to_string();
Ok(SlotRequest::Predict {
id,
input: None,
input_file: Some(input_file),
output_dir,
context,
})
} else {
Ok(SlotRequest::Predict {
id,
input: Some(input),
input_file: None,
output_dir,
context,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::bridge::protocol::SlotId;
use crate::permit::{InactiveSlotIdleToken, SlotIdleToken};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::Duration;
struct MockOrchestrator {
register_count: AtomicUsize,
complete_immediately: bool,
send_idle_ack: bool,
}
impl MockOrchestrator {
fn new() -> Self {
Self {
register_count: AtomicUsize::new(0),
complete_immediately: true,
send_idle_ack: false,
}
}
fn register_count(&self) -> usize {
self.register_count.load(Ordering::SeqCst)
}
fn with_idle_ack(mut self) -> Self {
self.send_idle_ack = true;
self
}
}
#[async_trait::async_trait]
impl Orchestrator for MockOrchestrator {
async fn register_prediction(
&self,
slot_id: SlotId,
prediction: Arc<std::sync::Mutex<crate::prediction::Prediction>>,
idle_sender: tokio::sync::oneshot::Sender<SlotIdleToken>,
) {
self.register_count.fetch_add(1, Ordering::SeqCst);
if self.complete_immediately {
let mut pred = prediction.lock().unwrap();
pred.set_succeeded(crate::PredictionOutput::Single(serde_json::json!(
"mock result"
)));
}
if self.send_idle_ack {
let _ = idle_sender.send(InactiveSlotIdleToken::new(slot_id).activate());
}
}
async fn cancel_by_prediction_id(
&self,
_prediction_id: &str,
) -> Result<(), crate::orchestrator::OrchestratorError> {
Ok(())
}
async fn healthcheck(
&self,
) -> Result<HealthcheckResult, crate::orchestrator::OrchestratorError> {
Ok(HealthcheckResult::healthy())
}
async fn shutdown(&self) -> Result<(), crate::orchestrator::OrchestratorError> {
Ok(())
}
}
async fn create_test_pool(num_slots: usize) -> Arc<PermitPool> {
use crate::bridge::codec::JsonCodec;
use crate::bridge::protocol::SlotRequest;
use futures::StreamExt;
use tokio::net::UnixStream;
let pool = Arc::new(PermitPool::new(num_slots));
for _ in 0..num_slots {
let (a, b) = UnixStream::pair().unwrap();
let (_read_a, write_a) = a.into_split();
let (read_b, _write_b) = b.into_split();
let mut reader =
tokio_util::codec::FramedRead::new(read_b, JsonCodec::<SlotRequest>::new());
tokio::spawn(async move { while reader.next().await.is_some() {} });
let writer =
tokio_util::codec::FramedWrite::new(write_a, JsonCodec::<SlotRequest>::new());
pool.add_permit(SlotId::new(), writer);
}
pool
}
async fn create_test_pool_with_slots(num_slots: usize) -> (Arc<PermitPool>, Vec<SlotId>) {
use crate::bridge::codec::JsonCodec;
use crate::bridge::protocol::SlotRequest;
use futures::StreamExt;
use tokio::net::UnixStream;
let pool = Arc::new(PermitPool::new(num_slots));
let mut slot_ids = Vec::with_capacity(num_slots);
for _ in 0..num_slots {
let (a, b) = UnixStream::pair().unwrap();
let (_read_a, write_a) = a.into_split();
let (read_b, _write_b) = b.into_split();
let mut reader =
tokio_util::codec::FramedRead::new(read_b, JsonCodec::<SlotRequest>::new());
tokio::spawn(async move { while reader.next().await.is_some() {} });
let writer =
tokio_util::codec::FramedWrite::new(write_a, JsonCodec::<SlotRequest>::new());
let slot_id = SlotId::new();
pool.add_permit(slot_id, writer);
slot_ids.push(slot_id);
}
(pool, slot_ids)
}
async fn create_broken_test_pool() -> (Arc<PermitPool>, SlotId) {
use crate::bridge::codec::JsonCodec;
use crate::bridge::protocol::SlotRequest;
use tokio::net::UnixStream;
let pool = Arc::new(PermitPool::new(1));
let (a, b) = UnixStream::pair().unwrap();
let (_read_a, write_a) = a.into_split();
drop(b);
let writer = tokio_util::codec::FramedWrite::new(write_a, JsonCodec::<SlotRequest>::new());
let slot_id = SlotId::new();
pool.add_permit(slot_id, writer);
(pool, slot_id)
}
#[tokio::test]
async fn service_new_no_pool_works() {
let svc = PredictionService::new_no_pool();
let health = svc.health().await;
assert_eq!(health.state, Health::Unknown);
assert_eq!(health.total_slots, 0);
assert_eq!(health.available_slots, 0);
assert!(svc.pool().await.is_none());
}
#[tokio::test]
async fn service_no_pool_initially() {
let svc = PredictionService::new_no_pool();
assert!(svc.pool().await.is_none());
assert!(!svc.has_orchestrator().await);
}
#[tokio::test]
async fn shutdown_signal_works() {
let svc = PredictionService::new_no_pool();
let mut rx = svc.shutdown_rx();
assert!(!*rx.borrow());
svc.trigger_shutdown();
rx.changed().await.unwrap();
assert!(*rx.borrow());
}
#[tokio::test]
async fn submit_fails_when_not_ready() {
let svc = PredictionService::new_no_pool();
let result = svc
.submit_prediction("test".to_string(), serde_json::json!({}), None)
.await;
assert!(matches!(result, Err(CreatePredictionError::NotReady)));
}
#[tokio::test]
async fn cannot_set_ready_without_orchestrator() {
let svc = PredictionService::new_no_pool();
let svc2 = PredictionService::new_no_pool().with_health(Health::Ready);
assert_eq!(svc2.health().await.state, Health::Unknown);
svc.set_health(Health::Ready).await;
assert_eq!(svc.health().await.state, Health::Unknown);
}
#[tokio::test]
async fn set_orchestrator_enables_ready_health() {
let svc = PredictionService::new_no_pool();
let pool = create_test_pool(2).await;
let orchestrator = Arc::new(MockOrchestrator::new());
svc.set_orchestrator(pool, orchestrator).await;
assert!(svc.has_orchestrator().await);
svc.set_health(Health::Ready).await;
let health = svc.health().await;
assert_eq!(health.state, Health::Ready);
assert_eq!(health.total_slots, 2);
assert_eq!(health.available_slots, 2);
}
#[tokio::test]
async fn submit_prediction_succeeds_when_ready() {
let svc = PredictionService::new_no_pool();
let pool = create_test_pool(1).await;
let orchestrator = Arc::new(MockOrchestrator::new());
svc.set_orchestrator(pool, orchestrator).await;
svc.set_health(Health::Ready).await;
let (handle, _slot) = svc
.submit_prediction("test-1".to_string(), serde_json::json!({}), None)
.await
.unwrap();
assert_eq!(handle.id(), "test-1");
assert!(svc.prediction_exists("test-1"));
}
#[tokio::test]
async fn submit_returns_at_capacity_when_no_slots() {
let svc = PredictionService::new_no_pool();
let pool = create_test_pool(1).await;
let orchestrator = Arc::new(MockOrchestrator::new());
svc.set_orchestrator(pool, orchestrator).await;
svc.set_health(Health::Ready).await;
let (_handle1, _slot1) = svc
.submit_prediction("test-1".to_string(), serde_json::json!({}), None)
.await
.unwrap();
let result = svc
.submit_prediction("test-2".to_string(), serde_json::json!({}), None)
.await;
assert!(matches!(result, Err(CreatePredictionError::AtCapacity)));
}
#[tokio::test]
async fn predict_calls_orchestrator_register() {
let svc = PredictionService::new_no_pool();
let pool = create_test_pool(1).await;
let orchestrator = Arc::new(MockOrchestrator::new());
let orch_ref = Arc::clone(&orchestrator);
svc.set_orchestrator(pool, orchestrator).await;
svc.set_health(Health::Ready).await;
let (_handle, slot) = svc
.submit_prediction(
"test-1".to_string(),
serde_json::json!({"prompt": "hello"}),
None,
)
.await
.unwrap();
let result = svc
.predict(
slot,
serde_json::json!({"prompt": "hello"}),
Default::default(),
)
.await;
assert!(result.is_ok(), "predict failed: {:?}", result.err());
assert_eq!(orch_ref.register_count(), 1);
}
#[tokio::test]
async fn health_shows_busy_when_all_slots_used() {
let svc = PredictionService::new_no_pool();
let pool = create_test_pool(1).await;
let orchestrator = Arc::new(MockOrchestrator::new());
svc.set_orchestrator(pool, orchestrator).await;
svc.set_health(Health::Ready).await;
let health = svc.health().await;
assert!(!health.is_busy());
assert_eq!(health.available_slots, 1);
let (_handle, _slot) = svc
.submit_prediction("test-1".to_string(), serde_json::json!({}), None)
.await
.unwrap();
let health = svc.health().await;
assert!(health.is_busy());
assert_eq!(health.available_slots, 0);
}
#[tokio::test]
async fn predict_idle_channel_closed_poison_slot_async() {
let svc = PredictionService::new_no_pool();
let (pool, slot_ids) = create_test_pool_with_slots(1).await;
let orchestrator = Arc::new(MockOrchestrator::new());
let slot_id = slot_ids[0];
svc.set_orchestrator(Arc::clone(&pool), orchestrator).await;
svc.set_health(Health::Ready).await;
let (_handle, slot) = svc
.submit_prediction(
"test-1".to_string(),
serde_json::json!({"prompt": "hello"}),
None,
)
.await
.unwrap();
let result = svc
.predict(
slot,
serde_json::json!({"prompt": "hello"}),
Default::default(),
)
.await;
assert!(result.is_ok(), "predict failed: {:?}", result.err());
tokio::time::timeout(Duration::from_secs(1), async {
loop {
if pool.is_poisoned(slot_id) {
break;
}
tokio::time::sleep(Duration::from_millis(10)).await;
}
})
.await
.expect("slot was not poisoned after idle token channel closed");
}
#[tokio::test]
async fn predict_idle_ack_returns_capacity_async() {
let svc = PredictionService::new_no_pool();
let pool = create_test_pool(1).await;
let orchestrator = Arc::new(MockOrchestrator::new().with_idle_ack());
svc.set_orchestrator(Arc::clone(&pool), orchestrator).await;
svc.set_health(Health::Ready).await;
let (_handle, slot) = svc
.submit_prediction(
"test-1".to_string(),
serde_json::json!({"prompt": "hello"}),
None,
)
.await
.unwrap();
let result = svc
.predict(
slot,
serde_json::json!({"prompt": "hello"}),
Default::default(),
)
.await;
assert!(result.is_ok(), "predict failed: {:?}", result.err());
tokio::time::timeout(Duration::from_secs(1), async {
loop {
if pool.available() == 1 {
break;
}
tokio::time::sleep(Duration::from_millis(10)).await;
}
})
.await
.expect("slot capacity was not returned after idle acknowledgement");
}
#[tokio::test]
async fn predict_send_failure_poison_slot() {
let svc = PredictionService::new_no_pool();
let (pool, slot_id) = create_broken_test_pool().await;
let orchestrator = Arc::new(MockOrchestrator::new());
svc.set_orchestrator(Arc::clone(&pool), orchestrator).await;
svc.set_health(Health::Ready).await;
let (_handle, slot) = svc
.submit_prediction(
"test-1".to_string(),
serde_json::json!({"prompt": "hello"}),
None,
)
.await
.unwrap();
let result = svc
.predict(
slot,
serde_json::json!({"prompt": "hello"}),
Default::default(),
)
.await;
assert!(matches!(result, Err(PredictionError::Failed(_))));
assert!(pool.is_poisoned(slot_id));
assert!(pool.try_acquire().is_none());
}
#[tokio::test]
async fn cancel_prediction_works() {
let svc = PredictionService::new_no_pool();
let pool = create_test_pool(1).await;
let orchestrator = Arc::new(MockOrchestrator::new());
svc.set_orchestrator(pool, orchestrator).await;
svc.set_health(Health::Ready).await;
let (handle, _slot) = svc
.submit_prediction("test-cancel".to_string(), serde_json::json!({}), None)
.await
.unwrap();
let cancel_token = handle.cancel_token();
let cancelled = svc.cancel("test-cancel");
assert!(cancelled);
assert!(cancel_token.is_cancelled());
}
#[tokio::test]
async fn cancel_nonexistent_returns_false() {
let svc = PredictionService::new_no_pool();
assert!(!svc.cancel("nonexistent"));
}
#[tokio::test]
async fn sync_guard_cancels_on_drop() {
let svc = Arc::new(PredictionService::new_no_pool());
let pool = create_test_pool(1).await;
let orchestrator = Arc::new(MockOrchestrator::new());
svc.set_orchestrator(pool, orchestrator).await;
svc.set_health(Health::Ready).await;
let (handle, _slot) = svc
.submit_prediction("test-guard".to_string(), serde_json::json!({}), None)
.await
.unwrap();
let cancel_token = handle.cancel_token();
{
let _guard = handle.sync_guard(Arc::clone(&svc));
assert!(!cancel_token.is_cancelled());
}
assert!(cancel_token.is_cancelled());
}
#[tokio::test]
async fn sync_guard_disarm_prevents_cancel() {
let svc = Arc::new(PredictionService::new_no_pool());
let pool = create_test_pool(1).await;
let orchestrator = Arc::new(MockOrchestrator::new());
svc.set_orchestrator(pool, orchestrator).await;
svc.set_health(Health::Ready).await;
let (handle, _slot) = svc
.submit_prediction("test-disarm".to_string(), serde_json::json!({}), None)
.await
.unwrap();
let cancel_token = handle.cancel_token();
{
let mut guard = handle.sync_guard(Arc::clone(&svc));
guard.disarm();
}
assert!(!cancel_token.is_cancelled());
}
#[tokio::test]
async fn remove_prediction_cleans_up() {
let svc = PredictionService::new_no_pool();
let pool = create_test_pool(1).await;
let orchestrator = Arc::new(MockOrchestrator::new());
svc.set_orchestrator(pool, orchestrator).await;
svc.set_health(Health::Ready).await;
let (_handle, _slot) = svc
.submit_prediction("test-remove".to_string(), serde_json::json!({}), None)
.await
.unwrap();
assert!(svc.prediction_exists("test-remove"));
svc.remove_prediction("test-remove");
assert!(!svc.prediction_exists("test-remove"));
}
#[test]
fn build_slot_request_small_input_inline() {
let dir = tempfile::tempdir().unwrap();
let input = serde_json::json!({"text": "hello"});
let req = build_slot_request(
"p1".into(),
input.clone(),
"/tmp/out".into(),
dir.path(),
Default::default(),
)
.unwrap();
match req {
SlotRequest::Predict {
id,
input: Some(v),
input_file: None,
output_dir,
..
} => {
assert_eq!(id, "p1");
assert_eq!(v, input);
assert_eq!(output_dir, "/tmp/out");
}
_ => panic!("expected inline input"),
}
}
#[test]
fn build_slot_request_large_input_spills() {
let dir = tempfile::tempdir().unwrap();
let big = "x".repeat(7 * 1024 * 1024);
let input = serde_json::json!({"data": big});
let req = build_slot_request(
"p2".into(),
input.clone(),
"/tmp/out".into(),
dir.path(),
Default::default(),
)
.unwrap();
match req {
SlotRequest::Predict {
id,
input: None,
input_file: Some(ref path),
output_dir,
..
} => {
assert_eq!(id, "p2");
assert_eq!(output_dir, "/tmp/out");
assert!(std::path::Path::new(path).exists());
let bytes = std::fs::read(path).unwrap();
let roundtrip: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
assert_eq!(roundtrip, input);
}
_ => panic!("expected file-backed input"),
}
}
#[test]
fn build_slot_request_roundtrip() {
let dir = tempfile::tempdir().unwrap();
let big = "y".repeat(7 * 1024 * 1024);
let input = serde_json::json!({"payload": big});
let req = build_slot_request(
"p3".into(),
input.clone(),
"/tmp/out".into(),
dir.path(),
Default::default(),
)
.unwrap();
let (id, rehydrated, output_dir, _context) = req.rehydrate_input().unwrap();
assert_eq!(id, "p3");
assert_eq!(rehydrated, input);
assert_eq!(output_dir, "/tmp/out");
}
}