use std::convert::Infallible;
use std::sync::Arc;
use std::time::Duration;
use axum::{
Router,
extract::{DefaultBodyLimit, Path, State},
http::{HeaderMap, StatusCode},
response::{
IntoResponse, Json, Response,
sse::{Event, KeepAlive, Sse},
},
routing::{get, post, put},
};
use serde::{Deserialize, Serialize};
#[cfg(test)]
use crate::health::Health;
use crate::health::{HealthResponse, SetupResult};
use crate::prediction::SharedPredictionStreamEvent;
use crate::predictor::PredictionError;
use crate::service::{
CreatePredictionError, HealthSnapshot, PredictionService, PredictionStreamSubscription,
SubscribePredictionStreamError,
};
use crate::version::VersionInfo;
use crate::webhook::{TraceContext, WebhookConfig, WebhookEventType, WebhookSender};
#[derive(Debug, Serialize)]
pub struct HealthCheckResponse {
pub status: HealthResponse,
#[serde(skip_serializing_if = "Option::is_none")]
pub setup: Option<SetupResult>,
pub version: VersionInfo,
#[serde(skip_serializing_if = "Option::is_none")]
pub user_healthcheck_error: Option<String>,
}
impl HealthCheckResponse {
pub fn from_snapshot(snapshot: HealthSnapshot, user_healthcheck_error: Option<String>) -> Self {
let status = if user_healthcheck_error.is_some() {
HealthResponse::Unhealthy
} else if snapshot.is_busy() {
HealthResponse::Busy
} else {
snapshot.state.into()
};
Self {
status,
setup: snapshot.setup_result,
version: snapshot.version,
user_healthcheck_error,
}
}
}
#[derive(Debug, Deserialize)]
pub struct PredictionRequest {
pub id: Option<String>,
#[serde(
default = "default_empty_input",
deserialize_with = "deserialize_input"
)]
pub input: serde_json::Value,
#[serde(default)]
pub context: std::collections::HashMap<String, String>,
pub webhook: Option<String>,
#[serde(default = "default_webhook_events_filter")]
pub webhook_events_filter: Vec<WebhookEventType>,
}
fn default_empty_input() -> serde_json::Value {
serde_json::json!({})
}
fn deserialize_input<'de, D>(deserializer: D) -> Result<serde_json::Value, D::Error>
where
D: serde::Deserializer<'de>,
{
let value = serde_json::Value::deserialize(deserializer)?;
Ok(if value.is_null() {
serde_json::json!({})
} else {
value
})
}
fn default_webhook_events_filter() -> Vec<WebhookEventType> {
vec![
WebhookEventType::Start,
WebhookEventType::Output,
WebhookEventType::Logs,
WebhookEventType::Completed,
]
}
fn generate_prediction_id() -> String {
use std::time::{SystemTime, UNIX_EPOCH};
let timestamp = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("system clock is before 1970")
.as_nanos();
format!("pred_{:x}", timestamp)
}
async fn root(State(service): State<Arc<PredictionService>>) -> Json<serde_json::Value> {
let version = service.version();
let cog_version = version.python_sdk.as_deref().unwrap_or(version.coglet);
let mut doc = serde_json::json!({
"cog_version": cog_version,
"docs_url": "/docs",
"openapi_url": "/openapi.json",
"shutdown_url": "/shutdown",
"healthcheck_url": "/health-check",
"predictions_url": "/predictions",
"predictions_idempotent_url": "/predictions/{prediction_id}",
"predictions_cancel_url": "/predictions/{prediction_id}/cancel",
});
if service.supports_training().await {
let obj = doc.as_object_mut().expect("doc is an object");
obj.insert("trainings_url".to_string(), serde_json::json!("/trainings"));
obj.insert(
"trainings_idempotent_url".to_string(),
serde_json::json!("/trainings/{training_id}"),
);
obj.insert(
"trainings_cancel_url".to_string(),
serde_json::json!("/trainings/{training_id}/cancel"),
);
}
Json(doc)
}
async fn health_check(State(service): State<Arc<PredictionService>>) -> Json<HealthCheckResponse> {
tracing::trace!("Health check endpoint called");
let snapshot = service.health().await;
tracing::trace!(
state = ?snapshot.state,
available_slots = snapshot.available_slots,
total_slots = snapshot.total_slots,
has_setup_result = snapshot.setup_result.is_some(),
"Health snapshot retrieved"
);
let user_healthcheck_error = if snapshot.is_ready() {
write_readiness_file();
tracing::trace!("Running user-defined healthcheck");
match service.healthcheck().await {
Ok(result) if result.is_healthy() => {
tracing::trace!("User healthcheck passed");
None
}
Ok(result) => {
tracing::debug!(error = ?result.error, "User healthcheck reported unhealthy");
result.error
}
Err(e) => {
tracing::debug!(error = %e, "User healthcheck returned error");
Some(format!("Healthcheck error: {}", e))
}
}
} else {
tracing::trace!(state = ?snapshot.state, "Skipping user healthcheck (not ready)");
None
};
let response = HealthCheckResponse::from_snapshot(snapshot, user_healthcheck_error);
tracing::trace!(status = ?response.status, "Health check response");
Json(response)
}
fn write_readiness_file() {
if std::env::var("KUBERNETES_SERVICE_HOST").is_err() {
return;
}
let dir = std::path::Path::new("/var/run/cog");
let file = dir.join("ready");
if file.exists() {
return;
}
if let Err(e) = std::fs::create_dir_all(dir) {
tracing::warn!(error = %e, "Failed to create /var/run/cog directory");
return;
}
if let Err(e) = std::fs::write(&file, b"") {
tracing::warn!(error = %e, "Failed to write readiness file");
}
}
fn should_respond_async(headers: &HeaderMap) -> bool {
headers
.get("prefer")
.and_then(|v| v.to_str().ok())
.map(|v| v == "respond-async")
.unwrap_or(false)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum PredictionResponseMode {
SyncJson,
AsyncJson,
AsyncSse,
}
fn wants_sse(headers: &HeaderMap) -> bool {
headers
.get(axum::http::header::ACCEPT)
.and_then(|value| value.to_str().ok())
.map(|accept| {
accept
.split(',')
.any(|part| part.trim().split(';').next() == Some("text/event-stream"))
})
.unwrap_or(false)
}
fn prediction_response_mode(headers: &HeaderMap) -> PredictionResponseMode {
if wants_sse(headers) {
PredictionResponseMode::AsyncSse
} else if should_respond_async(headers) {
PredictionResponseMode::AsyncJson
} else {
PredictionResponseMode::SyncJson
}
}
fn streaming_not_supported_response() -> Response {
(
StatusCode::NOT_ACCEPTABLE,
Json(serde_json::json!({
"error": "This model does not support streaming responses. Add @cog.streaming to predict() to enable SSE."
})),
)
.into_response()
}
fn training_streaming_not_supported_response() -> Response {
(
StatusCode::NOT_ACCEPTABLE,
Json(serde_json::json!({
"error": "Training endpoints do not support streaming responses."
})),
)
.into_response()
}
fn json_response_mode(headers: &HeaderMap) -> PredictionResponseMode {
if should_respond_async(headers) {
PredictionResponseMode::AsyncJson
} else {
PredictionResponseMode::SyncJson
}
}
fn extract_trace_context(headers: &HeaderMap) -> TraceContext {
TraceContext {
traceparent: headers
.get("traceparent")
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string()),
tracestate: headers
.get("tracestate")
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string()),
}
}
async fn create_prediction(
State(service): State<Arc<PredictionService>>,
headers: HeaderMap,
body: Option<Json<PredictionRequest>>,
) -> Response {
let request = body.map(|Json(r)| r).unwrap_or_else(|| PredictionRequest {
id: None,
input: serde_json::json!({}),
context: Default::default(),
webhook: None,
webhook_events_filter: default_webhook_events_filter(),
});
let prediction_id = request.id.unwrap_or_else(generate_prediction_id);
let response_mode = prediction_response_mode(&headers);
let trace_context = extract_trace_context(&headers);
create_prediction_with_id(
service,
prediction_id,
request.input,
request.context,
request.webhook,
request.webhook_events_filter,
response_mode,
trace_context,
false,
)
.await
}
async fn create_prediction_idempotent(
State(service): State<Arc<PredictionService>>,
Path(prediction_id): Path<String>,
headers: HeaderMap,
body: Option<Json<PredictionRequest>>,
) -> Response {
let request = body.map(|Json(r)| r).unwrap_or_else(|| PredictionRequest {
id: None,
input: serde_json::json!({}),
context: Default::default(),
webhook: None,
webhook_events_filter: default_webhook_events_filter(),
});
if let Some(ref req_id) = request.id
&& req_id != &prediction_id
{
return (
StatusCode::UNPROCESSABLE_ENTITY,
Json(serde_json::json!({
"detail": [{
"loc": ["body", "id"],
"msg": "prediction ID must match the ID supplied in the URL",
"type": "value_error"
}]
})),
)
.into_response();
}
let response_mode = prediction_response_mode(&headers);
if let Some(response) = service.get_prediction_response(&prediction_id) {
if response_mode == PredictionResponseMode::AsyncSse {
if !service.supports_prediction_streaming().await {
return streaming_not_supported_response();
}
return stream_prediction_response(service, &prediction_id);
}
return (StatusCode::ACCEPTED, Json(response)).into_response();
}
let trace_context = extract_trace_context(&headers);
create_prediction_with_id(
service,
prediction_id,
request.input,
request.context,
request.webhook,
request.webhook_events_filter,
response_mode,
trace_context,
false,
)
.await
}
fn build_webhook_sender(
webhook: Option<String>,
events_filter: Vec<WebhookEventType>,
trace_context: TraceContext,
) -> Option<WebhookSender> {
let webhook_url = webhook?;
let events: std::collections::HashSet<_> = events_filter.into_iter().collect();
match WebhookSender::with_trace_context(
webhook_url.clone(),
WebhookConfig {
events_filter: events,
..Default::default()
},
trace_context,
) {
Ok(sender) => Some(sender),
Err(e) => {
tracing::error!(url = %webhook_url, error = %e, "Failed to create webhook sender");
None
}
}
}
#[allow(clippy::too_many_arguments)]
async fn create_prediction_with_id(
service: Arc<PredictionService>,
prediction_id: String,
mut input: serde_json::Value,
context: std::collections::HashMap<String, String>,
webhook: Option<String>,
webhook_events_filter: Vec<WebhookEventType>,
response_mode: PredictionResponseMode,
trace_context: TraceContext,
is_training: bool,
) -> Response {
if !is_training
&& response_mode == PredictionResponseMode::AsyncSse
&& !service.supports_prediction_streaming().await
{
return streaming_not_supported_response();
}
let (stripped, validation_result) = if is_training {
service.strip_and_validate_train_input(&mut input).await
} else {
service.strip_and_validate_input(&mut input).await
};
if !stripped.is_empty() {
tracing::warn!(
prediction_id = %prediction_id,
fields = ?stripped,
"Stripped unknown input fields"
);
}
if let Err(errors) = validation_result {
let detail: Vec<serde_json::Value> = errors
.into_iter()
.map(|e| {
serde_json::json!({
"loc": ["body", "input", e.field],
"msg": e.msg,
"type": e.error_type
})
})
.collect();
return (
StatusCode::UNPROCESSABLE_ENTITY,
Json(serde_json::json!({ "detail": detail })),
)
.into_response();
}
let webhook_sender = build_webhook_sender(
webhook.clone(),
webhook_events_filter.clone(),
trace_context.clone(),
);
let (handle, unregistered_slot) = match service
.submit_prediction(
prediction_id.clone(),
input.clone(),
webhook_sender,
response_mode == PredictionResponseMode::AsyncSse,
)
.await
{
Ok(r) => r,
Err(CreatePredictionError::NotReady) => {
let msg = PredictionError::NotReady.to_string();
return (
StatusCode::SERVICE_UNAVAILABLE,
Json(serde_json::json!({
"error": msg,
"status": "failed"
})),
)
.into_response();
}
Err(CreatePredictionError::AtCapacity) => {
return (
StatusCode::CONFLICT,
Json(serde_json::json!({
"error": "At capacity - all prediction slots busy",
"status": "failed"
})),
)
.into_response();
}
};
let prediction = unregistered_slot.prediction();
if response_mode != PredictionResponseMode::SyncJson {
let sse_subscription = if response_mode == PredictionResponseMode::AsyncSse {
match service.subscribe_prediction_stream(&prediction_id) {
Ok(subscription) => Some(subscription),
Err(error) => {
service.remove_prediction(&prediction_id);
return stream_subscription_error_response(error);
}
}
} else {
None
};
let service_clone = Arc::clone(&service);
let id_for_cleanup = prediction_id.clone();
let context_async = context.clone();
tokio::spawn(async move {
let _result = service_clone
.predict(unregistered_slot, input, context_async)
.await;
service_clone.remove_prediction(&id_for_cleanup);
});
if response_mode == PredictionResponseMode::AsyncSse {
let subscription = sse_subscription.expect("SSE subscription requested");
return stream_prediction_subscription_response(subscription);
}
return (
StatusCode::ACCEPTED,
Json(serde_json::json!({
"id": prediction_id,
"status": "starting"
})),
)
.into_response();
}
let mut sync_guard = handle.sync_guard(Arc::clone(&service));
let service_bg = Arc::clone(&service);
let id_bg = prediction_id.clone();
let result_rx = {
let (tx, rx) = tokio::sync::oneshot::channel();
tokio::spawn(async move {
let result = service_bg.predict(unregistered_slot, input, context).await;
service_bg.remove_prediction(&id_bg);
let _ = tx.send(result);
});
rx
};
let result = match result_rx.await {
Ok(r) => r,
Err(_) => {
Err(PredictionError::Failed("prediction task lost".to_string()))
}
};
let (predict_time, user_metrics) = prediction
.try_lock()
.map(|p| (p.elapsed().as_secs_f64(), p.metrics().clone()))
.unwrap_or_default();
sync_guard.disarm();
let build_metrics = |user_metrics: &std::collections::HashMap<String, serde_json::Value>| {
let mut m = serde_json::Map::new();
for (k, v) in user_metrics {
m.insert(k.clone(), v.clone());
}
m.insert("predict_time".to_string(), serde_json::json!(predict_time));
serde_json::Value::Object(m)
};
match result {
Ok(r) => {
let metrics = build_metrics(&r.metrics);
(
StatusCode::OK,
Json(serde_json::json!({
"id": prediction_id,
"output": r.output,
"logs": r.logs,
"status": "succeeded",
"metrics": metrics
})),
)
.into_response()
}
Err(PredictionError::InvalidInput(msg)) => {
let metrics = build_metrics(&user_metrics);
(
StatusCode::UNPROCESSABLE_ENTITY,
Json(serde_json::json!({
"id": prediction_id,
"error": msg,
"logs": "",
"status": "failed",
"metrics": metrics
})),
)
.into_response()
}
Err(PredictionError::NotReady) => {
let msg = PredictionError::NotReady.to_string();
(
StatusCode::SERVICE_UNAVAILABLE,
Json(serde_json::json!({
"id": prediction_id,
"error": msg,
"logs": "",
"status": "failed"
})),
)
.into_response()
}
Err(PredictionError::Failed(msg)) => {
let metrics = build_metrics(&user_metrics);
(
StatusCode::OK,
Json(serde_json::json!({
"id": prediction_id,
"error": msg,
"logs": "",
"status": "failed",
"metrics": metrics
})),
)
.into_response()
}
Err(PredictionError::Cancelled) => {
let metrics = build_metrics(&user_metrics);
(
StatusCode::OK,
Json(serde_json::json!({
"id": prediction_id,
"logs": "",
"status": "canceled",
"metrics": metrics
})),
)
.into_response()
}
}
}
async fn cancel_prediction(
State(service): State<Arc<PredictionService>>,
Path(prediction_id): Path<String>,
) -> impl IntoResponse {
let cancelled = service.cancel(&prediction_id);
if cancelled {
(StatusCode::OK, Json(serde_json::json!({})))
} else {
(StatusCode::NOT_FOUND, Json(serde_json::json!({})))
}
}
fn stream_event_to_sse(event: SharedPredictionStreamEvent) -> Event {
Event::default()
.event(event.event_name())
.json_data(event.json_data())
.expect("prediction stream events serialize to JSON")
}
fn prediction_sse_stream(
subscription: PredictionStreamSubscription,
) -> impl futures::Stream<Item = Result<Event, Infallible>> {
let (replay, replay_skipped, receiver, guard) = subscription.into_parts();
struct StreamState {
replay: std::collections::VecDeque<SharedPredictionStreamEvent>,
replay_skipped: u64,
receiver: tokio::sync::broadcast::Receiver<SharedPredictionStreamEvent>,
_guard: crate::service::PredictionStreamGuard,
done: bool,
}
futures::stream::unfold(
StreamState {
replay,
replay_skipped,
receiver,
_guard: guard,
done: false,
},
|mut state| async move {
if state.done {
return None;
}
if state.replay_skipped > 0 {
let skipped = state.replay_skipped;
state.replay_skipped = 0;
state.done = true;
let event = Event::default()
.event("error")
.json_data(serde_json::json!({
"error": "SSE stream replay truncated; events were dropped",
"skipped": skipped,
}))
.expect("SSE replay truncation error serializes to JSON");
return Some((Ok(event), state));
}
if let Some(event) = state.replay.pop_front() {
state.done = event.event_name() == "completed";
return Some((Ok(stream_event_to_sse(event)), state));
}
match state.receiver.recv().await {
Ok(event) => {
state.done = event.event_name() == "completed";
Some((Ok(stream_event_to_sse(event)), state))
}
Err(tokio::sync::broadcast::error::RecvError::Lagged(skipped)) => {
tracing::warn!(skipped, "SSE prediction stream receiver lagged");
state.done = true;
let event = Event::default()
.event("error")
.json_data(serde_json::json!({
"error": "SSE stream lagged; events were dropped",
"skipped": skipped,
}))
.expect("SSE lag error serializes to JSON");
Some((Ok(event), state))
}
Err(tokio::sync::broadcast::error::RecvError::Closed) => None,
}
},
)
}
fn stream_prediction_response(service: Arc<PredictionService>, prediction_id: &str) -> Response {
let subscription = match service.subscribe_prediction_stream(prediction_id) {
Ok(subscription) => subscription,
Err(error) => return stream_subscription_error_response(error),
};
stream_prediction_subscription_response(subscription)
}
fn stream_subscription_error_response(error: SubscribePredictionStreamError) -> Response {
match error {
SubscribePredictionStreamError::NotFound => (
StatusCode::NOT_FOUND,
Json(serde_json::json!({"error": "Prediction not found"})),
)
.into_response(),
SubscribePredictionStreamError::TooManySubscribers => (
StatusCode::TOO_MANY_REQUESTS,
Json(serde_json::json!({"error": "Too many stream subscribers"})),
)
.into_response(),
SubscribePredictionStreamError::Unavailable => (
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({"error": "Prediction stream unavailable"})),
)
.into_response(),
}
}
fn stream_prediction_subscription_response(subscription: PredictionStreamSubscription) -> Response {
Sse::new(prediction_sse_stream(subscription))
.keep_alive(
KeepAlive::new()
.interval(Duration::from_secs(15))
.text("keep-alive"),
)
.into_response()
}
async fn shutdown(State(service): State<Arc<PredictionService>>) -> impl IntoResponse {
tracing::info!("Shutdown requested via HTTP");
service.trigger_shutdown();
(StatusCode::OK, Json(serde_json::json!({})))
}
async fn openapi_schema(State(service): State<Arc<PredictionService>>) -> impl IntoResponse {
match service.schema().await {
Some(schema) => (StatusCode::OK, Json(schema)),
None => (
StatusCode::SERVICE_UNAVAILABLE,
Json(serde_json::json!({
"error": "OpenAPI schema not available"
})),
),
}
}
async fn create_training(
State(service): State<Arc<PredictionService>>,
headers: HeaderMap,
body: Option<Json<PredictionRequest>>,
) -> Response {
if wants_sse(&headers) {
return training_streaming_not_supported_response();
}
let request = body.map(|Json(r)| r).unwrap_or_else(|| PredictionRequest {
id: None,
input: serde_json::json!({}),
context: Default::default(),
webhook: None,
webhook_events_filter: default_webhook_events_filter(),
});
let prediction_id = request.id.unwrap_or_else(generate_prediction_id);
let response_mode = json_response_mode(&headers);
let trace_context = extract_trace_context(&headers);
create_prediction_with_id(
service,
prediction_id,
request.input,
request.context,
request.webhook,
request.webhook_events_filter,
response_mode,
trace_context,
true,
)
.await
}
async fn create_training_idempotent(
State(service): State<Arc<PredictionService>>,
Path(training_id): Path<String>,
headers: HeaderMap,
body: Option<Json<PredictionRequest>>,
) -> Response {
if wants_sse(&headers) {
return training_streaming_not_supported_response();
}
let request = body.map(|Json(r)| r).unwrap_or_else(|| PredictionRequest {
id: None,
input: serde_json::json!({}),
context: Default::default(),
webhook: None,
webhook_events_filter: default_webhook_events_filter(),
});
if let Some(ref req_id) = request.id
&& req_id != &training_id
{
return (
StatusCode::UNPROCESSABLE_ENTITY,
Json(serde_json::json!({
"detail": [{
"loc": ["body", "id"],
"msg": "training ID must match the ID supplied in the URL",
"type": "value_error"
}]
})),
)
.into_response();
}
if let Some(response) = service.get_prediction_response(&training_id) {
return (StatusCode::ACCEPTED, Json(response)).into_response();
}
let response_mode = json_response_mode(&headers);
let trace_context = extract_trace_context(&headers);
create_prediction_with_id(
service,
training_id,
request.input,
request.context,
request.webhook,
request.webhook_events_filter,
response_mode,
trace_context,
true,
)
.await
}
async fn cancel_training(
State(service): State<Arc<PredictionService>>,
Path(training_id): Path<String>,
) -> impl IntoResponse {
cancel_prediction(State(service), Path(training_id)).await
}
const MAX_HTTP_BODY_SIZE: usize = 100 * 1024 * 1024;
pub fn routes(service: Arc<PredictionService>) -> Router {
Router::new()
.route("/", get(root))
.route("/health-check", get(health_check))
.route("/openapi.json", get(openapi_schema))
.route("/shutdown", post(shutdown))
.route("/predictions", post(create_prediction))
.route("/predictions/{id}", put(create_prediction_idempotent))
.route("/predictions/{id}/cancel", post(cancel_prediction))
.route("/trainings", post(create_training))
.route("/trainings/{id}", put(create_training_idempotent))
.route("/trainings/{id}/cancel", post(cancel_training))
.layer(DefaultBodyLimit::max(MAX_HTTP_BODY_SIZE))
.with_state(service)
}
#[cfg(test)]
mod tests {
use super::*;
use axum::body::Body;
use axum::http::{Request, StatusCode};
use http_body_util::BodyExt;
use tower::ServiceExt;
async fn response_json(response: axum::response::Response) -> serde_json::Value {
let body = response.into_body();
let bytes = body.collect().await.unwrap().to_bytes();
serde_json::from_slice(&bytes).unwrap()
}
#[tokio::test]
async fn health_check_returns_status_and_version() {
let service = Arc::new(PredictionService::new_no_pool().with_health(Health::Starting));
let app = routes(service);
let response = app
.oneshot(Request::get("/health-check").body(Body::empty()).unwrap())
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let json = response_json(response).await;
assert_eq!(json["status"], "STARTING");
assert!(json["version"]["coglet"].is_string());
}
#[tokio::test]
async fn health_check_unknown_when_no_predictor() {
let service = Arc::new(PredictionService::new_no_pool());
let app = routes(service);
let response = app
.oneshot(Request::get("/health-check").body(Body::empty()).unwrap())
.await
.unwrap();
let json = response_json(response).await;
assert_eq!(json["status"], "UNKNOWN");
}
#[tokio::test]
async fn predictions_returns_503_when_not_ready() {
let service = Arc::new(PredictionService::new_no_pool());
let app = routes(service);
let response = app
.oneshot(
Request::post("/predictions")
.header("content-type", "application/json")
.body(Body::from(r#"{"input":{}}"#))
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::SERVICE_UNAVAILABLE);
let json = response_json(response).await;
assert_eq!(json["status"], "failed");
assert!(
json["error"]
.as_str()
.unwrap()
.contains("Setup has not finished yet")
);
}
#[tokio::test]
async fn openapi_returns_503_when_schema_not_available() {
let service = Arc::new(PredictionService::new_no_pool());
let app = routes(service);
let response = app
.oneshot(Request::get("/openapi.json").body(Body::empty()).unwrap())
.await
.unwrap();
assert_eq!(response.status(), StatusCode::SERVICE_UNAVAILABLE);
let json = response_json(response).await;
assert!(json["error"].as_str().unwrap().contains("not available"));
}
#[tokio::test]
async fn openapi_returns_schema_when_available() {
let service = Arc::new(PredictionService::new_no_pool());
service
.set_schema(serde_json::json!({
"openapi": "3.0.2",
"info": {"title": "Cog", "version": "0.1.0"}
}))
.await;
let app = routes(service);
let response = app
.oneshot(Request::get("/openapi.json").body(Body::empty()).unwrap())
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let json = response_json(response).await;
assert_eq!(json["openapi"], "3.0.2");
assert_eq!(json["info"]["title"], "Cog");
}
use crate::PredictionOutput;
use crate::bridge::protocol::SlotId;
use crate::orchestrator::Orchestrator;
use crate::permit::PermitPool;
use std::sync::Mutex as StdMutex;
use std::sync::atomic::{AtomicUsize, Ordering};
struct MockOrchestrator {
register_count: AtomicUsize,
complete_immediately: bool,
}
impl MockOrchestrator {
fn new() -> Self {
Self {
register_count: AtomicUsize::new(0),
complete_immediately: true,
}
}
fn never_complete() -> Self {
Self {
register_count: AtomicUsize::new(0),
complete_immediately: false,
}
}
}
#[async_trait::async_trait]
impl Orchestrator for MockOrchestrator {
async fn register_prediction(
&self,
_slot_id: SlotId,
prediction: Arc<StdMutex<crate::prediction::Prediction>>,
_idle_sender: tokio::sync::oneshot::Sender<crate::permit::SlotIdleToken>,
) {
self.register_count.fetch_add(1, Ordering::SeqCst);
if self.complete_immediately {
let mut pred = prediction.lock().unwrap();
pred.set_succeeded(PredictionOutput::Single(serde_json::json!("mock output")));
}
}
async fn cancel_by_prediction_id(
&self,
_prediction_id: &str,
) -> Result<(), crate::orchestrator::OrchestratorError> {
Ok(())
}
async fn healthcheck(
&self,
) -> Result<crate::orchestrator::HealthcheckResult, crate::orchestrator::OrchestratorError>
{
Ok(crate::orchestrator::HealthcheckResult::healthy())
}
async fn shutdown(&self) -> Result<(), crate::orchestrator::OrchestratorError> {
Ok(())
}
}
async fn create_test_pool(num_slots: usize) -> Arc<PermitPool> {
use crate::bridge::codec::JsonCodec;
use crate::bridge::protocol::SlotRequest;
use futures::StreamExt;
use tokio::net::UnixStream;
let pool = Arc::new(PermitPool::new(num_slots));
for _ in 0..num_slots {
let (a, b) = UnixStream::pair().unwrap();
let (_read_a, write_a) = a.into_split();
let (read_b, _write_b) = b.into_split();
let mut reader =
tokio_util::codec::FramedRead::new(read_b, JsonCodec::<SlotRequest>::new());
tokio::spawn(async move { while reader.next().await.is_some() {} });
let writer =
tokio_util::codec::FramedWrite::new(write_a, JsonCodec::<SlotRequest>::new());
pool.add_permit(SlotId::new(), writer);
}
pool
}
async fn create_ready_service() -> Arc<PredictionService> {
let service = Arc::new(PredictionService::new_no_pool());
let pool = create_test_pool(2).await;
let orchestrator = Arc::new(MockOrchestrator::new());
service.set_orchestrator(pool, orchestrator).await;
service.set_health(Health::Ready).await;
service
}
async fn enable_prediction_streaming(service: &PredictionService) {
service
.set_schema(serde_json::json!({
"paths": {"/predictions": {"post": {"x-cog-streaming": true}}},
"components": {"schemas": {"Input": {"type": "object", "properties": {}}}}
}))
.await;
}
#[tokio::test]
async fn health_check_ready_with_orchestrator() {
let service = create_ready_service().await;
let app = routes(service);
let response = app
.oneshot(Request::get("/health-check").body(Body::empty()).unwrap())
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let json = response_json(response).await;
assert_eq!(json["status"], "READY");
}
#[tokio::test]
async fn prediction_sync_success() {
let service = create_ready_service().await;
let app = routes(service);
let response = app
.oneshot(
Request::post("/predictions")
.header("content-type", "application/json")
.body(Body::from(r#"{"input":{"prompt":"hello"}}"#))
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let json = response_json(response).await;
assert_eq!(json["status"], "succeeded");
assert_eq!(json["output"], "mock output");
assert!(json["id"].is_string());
}
#[tokio::test]
async fn prediction_async_returns_accepted() {
let service = create_ready_service().await;
let app = routes(service);
let response = app
.oneshot(
Request::post("/predictions")
.header("content-type", "application/json")
.header("prefer", "respond-async")
.body(Body::from(r#"{"input":{}}"#))
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::ACCEPTED);
let json = response_json(response).await;
assert_eq!(json["status"], "starting");
}
#[tokio::test]
async fn prediction_post_with_sse_accept_returns_sse() {
let service = create_ready_service().await;
enable_prediction_streaming(&service).await;
let app = routes(service);
let response = app
.oneshot(
Request::post("/predictions")
.header("content-type", "application/json")
.header("accept", "text/event-stream")
.body(Body::from(r#"{"input":{}}"#))
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let content_type = response.headers().get("content-type").unwrap();
assert!(
content_type
.to_str()
.unwrap()
.starts_with("text/event-stream"),
"unexpected content-type: {:?}",
content_type
);
let body = response.into_body();
let bytes = body.collect().await.unwrap().to_bytes();
let sse = String::from_utf8(bytes.to_vec()).unwrap();
assert!(sse.contains("event: completed"), "SSE body: {sse}");
assert!(sse.contains(r#""status":"succeeded""#), "SSE body: {sse}");
}
#[tokio::test]
async fn prediction_post_with_sse_accept_rejects_when_not_opted_in() {
let service = create_ready_service().await;
let app = routes(service);
let response = app
.oneshot(
Request::post("/predictions")
.header("content-type", "application/json")
.header("accept", "text/event-stream")
.body(Body::from(r#"{"input":{}}"#))
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::NOT_ACCEPTABLE);
let json = response_json(response).await;
assert_eq!(
json["error"],
"This model does not support streaming responses. Add @cog.streaming to predict() to enable SSE."
);
}
#[tokio::test]
async fn lagged_prediction_sse_stream_emits_error_and_closes() {
let service = Arc::new(PredictionService::new_no_pool());
let pool = create_test_pool(1).await;
let orchestrator = Arc::new(MockOrchestrator::never_complete());
service.set_orchestrator(pool, orchestrator).await;
service.set_health(Health::Ready).await;
let (_handle, slot) = service
.submit_prediction(
"lagged-stream".to_string(),
serde_json::json!({}),
None,
true,
)
.await
.unwrap();
let subscription = service
.subscribe_prediction_stream("lagged-stream")
.unwrap();
{
let prediction = slot.prediction();
let mut prediction = prediction.lock().unwrap();
for index in 0..1030 {
prediction.append_output_chunk(serde_json::json!(index), index);
}
}
let response = Sse::new(prediction_sse_stream(subscription)).into_response();
let collected =
tokio::time::timeout(Duration::from_millis(100), response.into_body().collect())
.await
.expect("lagged SSE stream should close after emitting an error")
.unwrap();
let sse = String::from_utf8(collected.to_bytes().to_vec()).unwrap();
assert!(sse.contains("event: error"), "SSE body: {sse}");
assert!(sse.contains("SSE stream lagged"), "SSE body: {sse}");
assert!(sse.contains("skipped"), "SSE body: {sse}");
}
#[tokio::test]
async fn truncated_replay_prediction_sse_stream_emits_error_and_closes() {
let service = Arc::new(PredictionService::new_no_pool());
let pool = create_test_pool(1).await;
let orchestrator = Arc::new(MockOrchestrator::never_complete());
service.set_orchestrator(pool, orchestrator).await;
service.set_health(Health::Ready).await;
let (_handle, slot) = service
.submit_prediction(
"truncated-replay".to_string(),
serde_json::json!({}),
None,
true,
)
.await
.unwrap();
{
let prediction = slot.prediction();
let mut prediction = prediction.lock().unwrap();
for index in 0..1030 {
prediction.append_output_chunk(serde_json::json!(index), index);
}
}
let subscription = service
.subscribe_prediction_stream("truncated-replay")
.unwrap();
let response = Sse::new(prediction_sse_stream(subscription)).into_response();
let collected =
tokio::time::timeout(Duration::from_millis(100), response.into_body().collect())
.await
.expect("truncated replay SSE stream should close after emitting an error")
.unwrap();
let sse = String::from_utf8(collected.to_bytes().to_vec()).unwrap();
assert!(sse.contains("event: error"), "SSE body: {sse}");
assert!(
sse.contains("SSE stream replay truncated"),
"SSE body: {sse}"
);
assert!(sse.contains("skipped"), "SSE body: {sse}");
}
#[tokio::test]
async fn failed_prediction_sse_stream_emits_completed_event() {
let service = Arc::new(PredictionService::new_no_pool());
let pool = create_test_pool(1).await;
let orchestrator = Arc::new(MockOrchestrator::never_complete());
service.set_orchestrator(pool, orchestrator).await;
service.set_health(Health::Ready).await;
let (_handle, slot) = service
.submit_prediction(
"failed-stream".to_string(),
serde_json::json!({}),
None,
true,
)
.await
.unwrap();
let subscription = service
.subscribe_prediction_stream("failed-stream")
.unwrap();
{
let prediction = slot.prediction();
let mut prediction = prediction.lock().unwrap();
prediction.set_processing();
prediction.set_failed("boom".to_string());
}
let response = Sse::new(prediction_sse_stream(subscription)).into_response();
let collected = response.into_body().collect().await.unwrap();
let sse = String::from_utf8(collected.to_bytes().to_vec()).unwrap();
assert!(sse.contains("event: completed"), "SSE body: {sse}");
assert!(sse.contains(r#""status":"failed""#), "SSE body: {sse}");
assert!(sse.contains(r#""error":"boom""#), "SSE body: {sse}");
}
#[tokio::test]
async fn canceled_prediction_sse_stream_emits_completed_event() {
let service = Arc::new(PredictionService::new_no_pool());
let pool = create_test_pool(1).await;
let orchestrator = Arc::new(MockOrchestrator::never_complete());
service.set_orchestrator(pool, orchestrator).await;
service.set_health(Health::Ready).await;
let (_handle, slot) = service
.submit_prediction(
"canceled-stream".to_string(),
serde_json::json!({}),
None,
true,
)
.await
.unwrap();
let subscription = service
.subscribe_prediction_stream("canceled-stream")
.unwrap();
{
let prediction = slot.prediction();
let mut prediction = prediction.lock().unwrap();
prediction.set_processing();
prediction.set_canceled();
}
let response = Sse::new(prediction_sse_stream(subscription)).into_response();
let collected = response.into_body().collect().await.unwrap();
let sse = String::from_utf8(collected.to_bytes().to_vec()).unwrap();
assert!(sse.contains("event: completed"), "SSE body: {sse}");
assert!(sse.contains(r#""status":"canceled""#), "SSE body: {sse}");
}
#[tokio::test]
async fn prediction_put_with_sse_accept_returns_sse() {
let service = create_ready_service().await;
enable_prediction_streaming(&service).await;
let app = routes(service);
let response = app
.oneshot(
Request::put("/predictions/sse-put")
.header("content-type", "application/json")
.header("accept", "text/event-stream")
.body(Body::from(r#"{"input":{}}"#))
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let content_type = response.headers().get("content-type").unwrap();
assert!(
content_type
.to_str()
.unwrap()
.starts_with("text/event-stream"),
"unexpected content-type: {:?}",
content_type
);
}
#[tokio::test]
async fn prediction_put_existing_with_sse_accept_returns_sse() {
let service = create_ready_service().await;
enable_prediction_streaming(&service).await;
let (_handle, _slot) = service
.submit_prediction(
"existing-sse-put".to_string(),
serde_json::json!({}),
None,
true,
)
.await
.unwrap();
let app = routes(service);
let response = app
.oneshot(
Request::put("/predictions/existing-sse-put")
.header("content-type", "application/json")
.header("accept", "text/event-stream")
.body(Body::from(r#"{"input":{}}"#))
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let content_type = response.headers().get("content-type").unwrap();
assert!(
content_type
.to_str()
.unwrap()
.starts_with("text/event-stream"),
"unexpected content-type: {:?}",
content_type
);
}
#[tokio::test]
async fn stream_prediction_route_is_removed() {
let service = create_ready_service().await;
let (_handle, _slot) = service
.submit_prediction(
"removed-stream-route".to_string(),
serde_json::json!({}),
None,
true,
)
.await
.unwrap();
let app = routes(service);
let response = app
.oneshot(
Request::get("/predictions/removed-stream-route/stream")
.header("accept", "text/event-stream")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::NOT_FOUND);
}
#[tokio::test]
async fn prediction_with_custom_id() {
let service = create_ready_service().await;
let app = routes(service);
let response = app
.oneshot(
Request::post("/predictions")
.header("content-type", "application/json")
.body(Body::from(r#"{"id":"my-pred-123","input":{}}"#))
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let json = response_json(response).await;
assert_eq!(json["id"], "my-pred-123");
assert_eq!(json["status"], "succeeded");
}
#[tokio::test]
async fn prediction_idempotent_put() {
let service = create_ready_service().await;
let app = routes(service);
let response = app
.oneshot(
Request::put("/predictions/idempotent-123")
.header("content-type", "application/json")
.body(Body::from(r#"{"input":{}}"#))
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let json = response_json(response).await;
assert_eq!(json["id"], "idempotent-123");
assert_eq!(json["status"], "succeeded");
}
#[tokio::test]
async fn prediction_idempotent_id_mismatch() {
let service = create_ready_service().await;
let app = routes(service);
let response = app
.oneshot(
Request::put("/predictions/url-id")
.header("content-type", "application/json")
.body(Body::from(r#"{"id":"body-id","input":{}}"#))
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::UNPROCESSABLE_ENTITY);
let json = response_json(response).await;
assert!(
json["detail"][0]["msg"]
.as_str()
.unwrap()
.contains("must match")
);
}
#[tokio::test]
async fn prediction_at_capacity() {
let service = Arc::new(PredictionService::new_no_pool());
let pool = create_test_pool(1).await; let orchestrator = Arc::new(MockOrchestrator::never_complete());
service.set_orchestrator(pool, orchestrator).await;
service.set_health(Health::Ready).await;
let app = routes(Arc::clone(&service));
let _resp1 = app
.oneshot(
Request::post("/predictions")
.header("content-type", "application/json")
.header("prefer", "respond-async")
.body(Body::from(r#"{"input":{}}"#))
.unwrap(),
)
.await
.unwrap();
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
let app2 = routes(service);
let response = app2
.oneshot(
Request::post("/predictions")
.header("content-type", "application/json")
.body(Body::from(r#"{"input":{}}"#))
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::CONFLICT);
let json = response_json(response).await;
assert!(json["error"].as_str().unwrap().contains("capacity"));
}
#[tokio::test]
async fn health_check_busy_when_at_capacity() {
let service = Arc::new(PredictionService::new_no_pool());
let pool = create_test_pool(1).await;
let orchestrator = Arc::new(MockOrchestrator::never_complete());
service.set_orchestrator(pool, orchestrator).await;
service.set_health(Health::Ready).await;
let app = routes(Arc::clone(&service));
let _resp = app
.oneshot(
Request::post("/predictions")
.header("content-type", "application/json")
.header("prefer", "respond-async")
.body(Body::from(r#"{"input":{}}"#))
.unwrap(),
)
.await
.unwrap();
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
let app2 = routes(service);
let response = app2
.oneshot(Request::get("/health-check").body(Body::empty()).unwrap())
.await
.unwrap();
let json = response_json(response).await;
assert_eq!(json["status"], "BUSY");
}
#[tokio::test]
async fn training_routes_work() {
let service = create_ready_service().await;
let app = routes(service);
let response = app
.oneshot(
Request::post("/trainings")
.header("content-type", "application/json")
.body(Body::from(r#"{"input":{}}"#))
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let json = response_json(response).await;
assert_eq!(json["status"], "succeeded");
}
#[tokio::test]
async fn training_post_with_sse_accept_rejects() {
let service = create_ready_service().await;
let app = routes(service);
let response = app
.oneshot(
Request::post("/trainings")
.header("content-type", "application/json")
.header("accept", "text/event-stream")
.body(Body::from(r#"{"input":{}}"#))
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::NOT_ACCEPTABLE);
let json = response_json(response).await;
assert_eq!(
json["error"],
"Training endpoints do not support streaming responses."
);
}
#[tokio::test]
async fn training_put_with_sse_accept_rejects() {
let service = create_ready_service().await;
let app = routes(service);
let response = app
.oneshot(
Request::put("/trainings/train-sse")
.header("content-type", "application/json")
.header("accept", "text/event-stream")
.body(Body::from(r#"{"input":{}}"#))
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::NOT_ACCEPTABLE);
let json = response_json(response).await;
assert_eq!(
json["error"],
"Training endpoints do not support streaming responses."
);
}
#[tokio::test]
async fn training_idempotent_put() {
let service = create_ready_service().await;
let app = routes(service);
let response = app
.oneshot(
Request::put("/trainings/train-123")
.header("content-type", "application/json")
.body(Body::from(r#"{"input":{}}"#))
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let json = response_json(response).await;
assert_eq!(json["id"], "train-123");
assert_eq!(json["status"], "succeeded");
}
#[tokio::test]
async fn training_idempotent_id_mismatch() {
let service = create_ready_service().await;
let app = routes(service);
let response = app
.oneshot(
Request::put("/trainings/url-id")
.header("content-type", "application/json")
.body(Body::from(r#"{"id":"body-id","input":{}}"#))
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::UNPROCESSABLE_ENTITY);
let json = response_json(response).await;
assert!(
json["detail"][0]["msg"]
.as_str()
.unwrap()
.contains("must match")
);
}
#[tokio::test]
async fn shutdown_triggers_service_shutdown() {
let service = create_ready_service().await;
let mut rx = service.shutdown_rx();
let app = routes(service);
assert!(!*rx.borrow());
let response = app
.oneshot(Request::post("/shutdown").body(Body::empty()).unwrap())
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
rx.changed().await.unwrap();
assert!(*rx.borrow());
}
#[tokio::test]
async fn root_returns_discovery_document() {
let service = Arc::new(PredictionService::new_no_pool());
let app = routes(service);
let response = app
.oneshot(Request::get("/").body(Body::empty()).unwrap())
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
assert_eq!(
response.headers().get("content-type").unwrap(),
"application/json"
);
let json = response_json(response).await;
assert_eq!(json["cog_version"], crate::version::COGLET_VERSION);
assert_eq!(json["docs_url"], "/docs");
assert_eq!(json["openapi_url"], "/openapi.json");
assert_eq!(json["shutdown_url"], "/shutdown");
assert_eq!(json["healthcheck_url"], "/health-check");
assert_eq!(json["predictions_url"], "/predictions");
assert_eq!(
json["predictions_idempotent_url"],
"/predictions/{prediction_id}"
);
assert_eq!(
json["predictions_cancel_url"],
"/predictions/{prediction_id}/cancel"
);
assert!(json.get("trainings_url").is_none());
assert!(json.get("trainings_idempotent_url").is_none());
assert!(json.get("trainings_cancel_url").is_none());
}
#[tokio::test]
async fn root_includes_training_urls_when_schema_has_training() {
let service = Arc::new(PredictionService::new_no_pool());
service
.set_schema(serde_json::json!({
"openapi": "3.0.2",
"info": {"title": "Cog", "version": "0.1.0"},
"components": {
"schemas": {
"TrainingInput": {
"type": "object",
"properties": {
"data": {"type": "string"}
}
}
}
}
}))
.await;
let app = routes(service);
let response = app
.oneshot(Request::get("/").body(Body::empty()).unwrap())
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let json = response_json(response).await;
assert_eq!(json["predictions_url"], "/predictions");
assert_eq!(json["trainings_url"], "/trainings");
assert_eq!(json["trainings_idempotent_url"], "/trainings/{training_id}");
assert_eq!(
json["trainings_cancel_url"],
"/trainings/{training_id}/cancel"
);
}
#[tokio::test]
async fn root_cog_version_prefers_python_sdk() {
let version = VersionInfo::new().with_python_sdk("0.14.0".to_string());
let service = Arc::new(PredictionService::new_no_pool().with_version(version));
let app = routes(service);
let response = app
.oneshot(Request::get("/").body(Body::empty()).unwrap())
.await
.unwrap();
let json = response_json(response).await;
assert_eq!(json["cog_version"], "0.14.0");
}
}