use std::collections::{HashMap, HashSet};
use std::process::Stdio;
use std::sync::Arc;
use std::sync::Mutex as StdMutex;
use std::time::Duration;
use async_trait::async_trait;
use futures::{SinkExt, StreamExt};
use tokio::process::{Child, Command};
use tokio::sync::mpsc;
use tokio_util::codec::{FramedRead, FramedWrite};
use crate::PredictionOutput;
use crate::bridge::codec::JsonCodec;
use crate::bridge::protocol::{
ControlRequest, ControlResponse, FileOutputKind, HealthcheckStatus, SlotId, SlotRequest,
SlotResponse,
};
use crate::bridge::transport::create_transport;
use crate::permit::{InactiveSlotIdleToken, PermitPool, SlotIdleToken};
use crate::prediction::Prediction;
const MAX_PENDING_CANCELLATIONS: usize = 1000;
async fn upload_file(
endpoint: &str,
filename: &str,
data: &[u8],
content_type: &str,
) -> Result<String, String> {
let url = format!("{endpoint}{filename}");
let client = reqwest::Client::new();
let resp = client
.put(&url)
.header("Content-Type", content_type)
.body(data.to_vec())
.timeout(std::time::Duration::from_secs(25))
.send()
.await
.map_err(|e| format!("upload request failed: {e}"))?;
if !resp.status().is_success() {
return Err(format!("upload returned status {}", resp.status()));
}
let final_url = resp
.headers()
.get("location")
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string())
.unwrap_or_else(|| resp.url().to_string());
match reqwest::Url::parse(&final_url) {
Ok(mut parsed) => {
parsed.set_query(None);
Ok(parsed.to_string())
}
Err(_) => Ok(final_url),
}
}
fn ensure_trailing_slash(s: &str) -> String {
if s.ends_with('/') {
s.to_string()
} else {
format!("{s}/")
}
}
fn try_lock_prediction(
pred: &Arc<StdMutex<Prediction>>,
) -> Option<std::sync::MutexGuard<'_, Prediction>> {
match pred.lock() {
Ok(guard) => Some(guard),
Err(poisoned) => {
tracing::error!("Prediction mutex poisoned - failing prediction");
let mut guard = poisoned.into_inner();
if !guard.is_terminal() {
guard.set_failed("Internal error: mutex poisoned".to_string());
}
None
}
}
}
fn wrap_outputs(
outputs: Vec<serde_json::Value>,
output_is_array: bool,
is_stream: bool,
) -> PredictionOutput {
let should_stream = output_is_array || is_stream;
match outputs.as_slice() {
[] => {
if should_stream {
PredictionOutput::Stream(vec![])
} else {
PredictionOutput::Single(serde_json::Value::Null)
}
}
_ if should_stream => PredictionOutput::Stream(outputs),
[single] => PredictionOutput::Single(single.clone()),
_ => PredictionOutput::Stream(outputs),
}
}
fn emit_worker_log(target: &str, level: &str, msg: &str) {
use std::collections::HashMap;
use std::sync::OnceLock;
use tracing::{
Level, Metadata,
callsite::{Callsite, Identifier},
field::FieldSet,
};
struct DummyCallsite;
impl Callsite for DummyCallsite {
fn set_interest(&self, _: tracing::subscriber::Interest) {}
fn metadata(&self) -> &Metadata<'static> {
unreachable!()
}
}
static DUMMY: DummyCallsite = DummyCallsite;
static CALLSITES: OnceLock<
std::sync::Mutex<HashMap<(&'static str, Level), Metadata<'static>>>,
> = OnceLock::new();
static FIELDS: &[&str] = &["message"];
let lvl = match level {
"error" => Level::ERROR,
"warn" => Level::WARN,
"info" => Level::INFO,
"debug" => Level::DEBUG,
"trace" => Level::TRACE,
_ => Level::INFO,
};
let target_static: &'static str = Box::leak(target.to_string().into_boxed_str());
let callsites = CALLSITES.get_or_init(|| std::sync::Mutex::new(HashMap::new()));
let mut map = match callsites.lock() {
Ok(guard) => guard,
Err(_poisoned) => {
tracing::error!("Worker log callsite cache poisoned");
return;
}
};
let meta = map.entry((target_static, lvl)).or_insert_with(|| {
Metadata::new(
"worker_log",
target_static,
lvl,
Some(file!()),
Some(line!()),
Some(module_path!()),
FieldSet::new(FIELDS, Identifier(&DUMMY)),
tracing::metadata::Kind::EVENT,
)
});
let meta_ref = meta as *const Metadata<'static>;
drop(map);
let meta = unsafe { &*meta_ref };
tracing::dispatcher::get_default(|dispatch| {
if dispatch.enabled(meta) {
let fields = meta.fields();
if let Some(field) = fields.field("message") {
let value_array = [(&field, Some(&msg as &dyn tracing::Value))];
let values = fields.value_set(&value_array);
dispatch.event(&tracing::Event::new(meta, &values));
}
}
});
}
#[derive(Debug, Clone)]
pub struct HealthcheckResult {
pub status: HealthcheckStatus,
pub error: Option<String>,
}
impl HealthcheckResult {
pub fn healthy() -> Self {
Self {
status: HealthcheckStatus::Healthy,
error: None,
}
}
pub fn unhealthy(error: impl Into<String>) -> Self {
Self {
status: HealthcheckStatus::Unhealthy,
error: Some(error.into()),
}
}
pub fn is_healthy(&self) -> bool {
self.status == HealthcheckStatus::Healthy
}
}
#[async_trait]
pub trait Orchestrator: Send + Sync {
async fn register_prediction(
&self,
slot_id: SlotId,
prediction: Arc<StdMutex<Prediction>>,
idle_sender: tokio::sync::oneshot::Sender<SlotIdleToken>,
);
async fn cancel_by_prediction_id(&self, prediction_id: &str) -> Result<(), OrchestratorError>;
async fn healthcheck(&self) -> Result<HealthcheckResult, OrchestratorError>;
async fn shutdown(&self) -> Result<(), OrchestratorError>;
}
#[derive(Debug, Clone)]
pub struct WorkerSpawnConfig {
pub num_slots: usize,
}
#[derive(Debug, thiserror::Error)]
pub enum SpawnError {
#[error("failed to spawn process: {0}")]
Spawn(#[from] std::io::Error),
#[error("spawn failed: {0}")]
Other(String),
}
pub trait WorkerSpawner: Send + Sync {
fn spawn(&self, config: &WorkerSpawnConfig) -> Result<Child, SpawnError>;
}
pub struct SimpleSpawner;
impl WorkerSpawner for SimpleSpawner {
fn spawn(&self, _config: &WorkerSpawnConfig) -> Result<Child, SpawnError> {
let child = Command::new("python")
.args(["-c", "import coglet; coglet.server._run_worker()"])
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::inherit())
.spawn()?;
Ok(child)
}
}
pub struct OrchestratorConfig {
pub predictor_ref: String,
pub num_slots: usize,
pub is_train: bool,
pub is_async: bool,
pub setup_timeout: Option<Duration>,
pub spawner: Arc<dyn WorkerSpawner>,
pub upload_url: Option<String>,
}
impl OrchestratorConfig {
pub fn new(predictor_ref: impl Into<String>) -> Self {
Self {
predictor_ref: predictor_ref.into(),
num_slots: 1,
is_train: false,
is_async: false,
setup_timeout: None,
spawner: Arc::new(SimpleSpawner),
upload_url: None,
}
}
pub fn with_upload_url(mut self, upload_url: Option<String>) -> Self {
self.upload_url = upload_url;
self
}
pub fn with_num_slots(mut self, n: usize) -> Self {
self.num_slots = n;
self
}
pub fn with_train(mut self, is_train: bool) -> Self {
self.is_train = is_train;
self
}
pub fn with_async(mut self, is_async: bool) -> Self {
self.is_async = is_async;
self
}
pub fn with_setup_timeout(mut self, timeout: Option<Duration>) -> Self {
self.setup_timeout = timeout;
self
}
pub fn with_spawner(mut self, spawner: Arc<dyn WorkerSpawner>) -> Self {
self.spawner = spawner;
self
}
}
pub struct OrchestratorReady {
pub pool: Arc<PermitPool>,
pub schema: Option<serde_json::Value>,
pub handle: OrchestratorHandle,
pub setup_logs: String,
}
struct RegisterPredictionMessage {
slot_id: SlotId,
prediction: Arc<StdMutex<Prediction>>,
idle_sender: tokio::sync::oneshot::Sender<SlotIdleToken>,
registered_ack: tokio::sync::oneshot::Sender<()>,
}
pub struct OrchestratorHandle {
child: Child,
ctrl_writer:
Arc<tokio::sync::Mutex<FramedWrite<tokio::process::ChildStdin, JsonCodec<ControlRequest>>>>,
register_tx: mpsc::Sender<RegisterPredictionMessage>,
healthcheck_tx: mpsc::Sender<tokio::sync::oneshot::Sender<HealthcheckResult>>,
cancel_tx: mpsc::Sender<String>,
slot_ids: Vec<SlotId>,
}
#[async_trait]
impl Orchestrator for OrchestratorHandle {
async fn register_prediction(
&self,
slot_id: SlotId,
prediction: Arc<StdMutex<Prediction>>,
idle_sender: tokio::sync::oneshot::Sender<SlotIdleToken>,
) {
let (ack_tx, ack_rx) = tokio::sync::oneshot::channel();
let _ = self
.register_tx
.send(RegisterPredictionMessage {
slot_id,
prediction,
idle_sender,
registered_ack: ack_tx,
})
.await;
let _ = ack_rx.await;
}
async fn cancel_by_prediction_id(&self, prediction_id: &str) -> Result<(), OrchestratorError> {
self.cancel_tx
.send(prediction_id.to_string())
.await
.map_err(|_| OrchestratorError::Protocol("cancel channel closed".to_string()))
}
async fn healthcheck(&self) -> Result<HealthcheckResult, OrchestratorError> {
tracing::trace!("Healthcheck requested via orchestrator handle");
let (response_tx, response_rx) = tokio::sync::oneshot::channel();
self.healthcheck_tx
.send(response_tx)
.await
.map_err(|_| OrchestratorError::Protocol("healthcheck channel closed".to_string()))?;
match tokio::time::timeout(Duration::from_secs(10), response_rx).await {
Ok(Ok(result)) => {
tracing::trace!(healthy = result.is_healthy(), "Healthcheck completed");
Ok(result)
}
Ok(Err(_)) => {
tracing::debug!("Healthcheck response channel dropped");
Err(OrchestratorError::Protocol(
"healthcheck response channel dropped".to_string(),
))
}
Err(_) => {
tracing::debug!("Healthcheck timed out after 10s");
Ok(HealthcheckResult::unhealthy("healthcheck timed out"))
}
}
}
async fn shutdown(&self) -> Result<(), OrchestratorError> {
let mut writer = self.ctrl_writer.lock().await;
writer
.send(ControlRequest::Shutdown)
.await
.map_err(|e| OrchestratorError::Protocol(format!("failed to send shutdown: {}", e)))
}
}
impl OrchestratorHandle {
pub async fn cancel(&self, slot_id: SlotId) -> Result<(), OrchestratorError> {
let mut writer = self.ctrl_writer.lock().await;
writer
.send(ControlRequest::Cancel { slot: slot_id })
.await
.map_err(|e| OrchestratorError::Protocol(format!("failed to send cancel: {}", e)))
}
pub fn slot_ids(&self) -> &[SlotId] {
&self.slot_ids
}
pub async fn wait(&mut self) -> Result<(), OrchestratorError> {
self.child.wait().await.map_err(|e| {
OrchestratorError::Protocol(format!("failed to wait for worker: {}", e))
})?;
Ok(())
}
}
#[derive(Debug, thiserror::Error)]
pub enum OrchestratorError {
#[error("failed to spawn worker: {0}")]
Spawn(String),
#[error("worker setup failed: {0}")]
Setup(String),
#[error("worker setup timed out")]
SetupTimeout,
#[error("protocol error: {0}")]
Protocol(String),
#[error("worker crashed")]
WorkerCrashed,
}
pub async fn spawn_worker(
config: OrchestratorConfig,
setup_log_rx: &mut tokio::sync::mpsc::UnboundedReceiver<String>,
) -> Result<OrchestratorReady, OrchestratorError> {
let num_slots = config.num_slots;
tracing::info!(num_slots, "Creating slot transport");
let (mut transport, child_transport_info) = create_transport(num_slots)
.await
.map_err(|e| OrchestratorError::Spawn(format!("failed to create transport: {}", e)))?;
tracing::info!("Spawning worker subprocess");
let spawn_config = WorkerSpawnConfig { num_slots };
let mut child = config
.spawner
.spawn(&spawn_config)
.map_err(|e| OrchestratorError::Spawn(format!("spawner failed: {}", e)))?;
let stdin = child
.stdin
.take()
.ok_or_else(|| OrchestratorError::Spawn("stdin not captured".to_string()))?;
let stdout = child
.stdout
.take()
.ok_or_else(|| OrchestratorError::Spawn("stdout not captured".to_string()))?;
let mut ctrl_writer = FramedWrite::new(stdin, JsonCodec::<ControlRequest>::new());
let mut ctrl_reader = FramedRead::new(stdout, JsonCodec::<ControlResponse>::new());
tracing::debug!("Sending Init to worker");
ctrl_writer
.send(ControlRequest::Init {
predictor_ref: config.predictor_ref.clone(),
num_slots,
transport_info: child_transport_info,
is_train: config.is_train,
is_async: config.is_async,
})
.await
.map_err(|e| OrchestratorError::Protocol(format!("failed to send Init: {}", e)))?;
tracing::debug!("Waiting for worker to connect to slot sockets");
transport
.accept_connections(num_slots)
.await
.map_err(|e| OrchestratorError::Spawn(format!("failed to accept connections: {}", e)))?;
tracing::debug!("Waiting for Ready from worker");
let setup_fut = async {
loop {
match ctrl_reader.next().await {
Some(Ok(ControlResponse::Ready { slots, schema })) => {
return Ok((slots, schema));
}
Some(Ok(ControlResponse::Log { source, data })) => {
for line in data.lines() {
tracing::info!(target: "coglet::setup", source = ?source, "{}", line);
}
}
Some(Ok(ControlResponse::WorkerLog {
target,
level,
message,
})) => {
emit_worker_log(&target, &level, &message);
}
Some(Ok(ControlResponse::DroppedLogs {
count,
interval_millis,
})) => {
tracing::trace!(count, interval_millis, "Received DroppedLogs during setup");
let interval_secs = interval_millis as f64 / 1000.0;
tracing::warn!(
"Log production exceeds consumption rate during setup. {} logs dropped in last {:.1}s",
count,
interval_secs
);
}
Some(Ok(ControlResponse::Failed { slot, error })) => {
return Err(OrchestratorError::Setup(format!(
"worker setup failed (slot {}): {}",
slot, error
)));
}
Some(Ok(ControlResponse::Fatal { reason })) => {
return Err(OrchestratorError::Setup(format!(
"worker fatal: {}",
reason
)));
}
Some(Ok(other)) => {
tracing::warn!(?other, "Unexpected message during setup");
}
Some(Err(e)) => {
return Err(OrchestratorError::Protocol(format!(
"control channel error: {}",
e
)));
}
None => {
return Err(OrchestratorError::WorkerCrashed);
}
}
}
};
let (slot_ids, schema) = match config.setup_timeout {
Some(timeout) => {
tracing::debug!(
timeout_secs = timeout.as_secs(),
"Waiting for setup with timeout"
);
match tokio::time::timeout(timeout, setup_fut).await {
Ok(Ok((slots, schema))) => {
tracing::debug!(num_slots = slots.len(), "Setup completed within timeout");
(slots, schema)
}
Ok(Err(e)) => {
tracing::debug!(error = %e, "Setup failed");
return Err(e);
}
Err(_) => {
tracing::debug!(timeout_secs = timeout.as_secs(), "Setup timed out");
return Err(OrchestratorError::SetupTimeout);
}
}
}
None => {
tracing::debug!("Waiting for setup with no timeout");
setup_fut.await?
}
};
let setup_logs = crate::setup_log_accumulator::drain_accumulated_logs(setup_log_rx);
tracing::debug!(
setup_logs_len = setup_logs.len(),
"Drained accumulated setup logs"
);
tracing::debug!(num_slots = slot_ids.len(), "Worker ready");
if let Some(ref s) = schema
&& let Ok(json) = serde_json::to_string_pretty(s)
{
tracing::trace!(target: "coglet::schema", schema = %json, "OpenAPI schema");
}
let output_is_array = schema
.as_ref()
.and_then(|s| s.get("components"))
.and_then(|c| c.get("schemas"))
.and_then(|schemas| {
let key = if config.is_train {
"TrainingOutput"
} else {
"Output"
};
schemas.get(key)
})
.and_then(|output| output.get("type"))
.and_then(|t| t.as_str())
.is_some_and(|t| t == "array");
let pool = Arc::new(PermitPool::new(num_slots));
let sockets = transport.drain_sockets();
let mut slot_readers = Vec::with_capacity(num_slots);
for (slot_id, socket) in slot_ids.iter().zip(sockets) {
let (read_half, write_half) = socket.into_split();
let writer = FramedWrite::new(write_half, JsonCodec::<SlotRequest>::new());
pool.add_permit(*slot_id, writer);
let reader = FramedRead::new(read_half, JsonCodec::<SlotResponse>::new());
slot_readers.push((*slot_id, reader));
}
let (register_tx, register_rx) = mpsc::channel(num_slots);
let (healthcheck_tx, healthcheck_rx) = mpsc::channel(1);
let (cancel_tx, cancel_rx) = mpsc::channel(16);
let ctrl_writer = Arc::new(tokio::sync::Mutex::new(ctrl_writer));
let handle = OrchestratorHandle {
child,
ctrl_writer: Arc::clone(&ctrl_writer),
register_tx,
healthcheck_tx,
cancel_tx,
slot_ids: slot_ids.clone(),
};
let pool_for_loop = Arc::clone(&pool);
let ctrl_writer_for_loop = Arc::clone(&ctrl_writer);
let upload_url = config.upload_url.clone();
tokio::spawn(async move {
run_event_loop(
ctrl_reader,
ctrl_writer_for_loop,
slot_readers,
register_rx,
healthcheck_rx,
cancel_rx,
pool_for_loop,
upload_url,
output_is_array,
)
.await;
});
Ok(OrchestratorReady {
pool,
schema,
handle,
setup_logs,
})
}
fn record_pending_cancellation(pending_cancellations: &mut HashSet<String>, prediction_id: String) {
if pending_cancellations.len() >= MAX_PENDING_CANCELLATIONS {
tracing::warn!(
prediction_id = %prediction_id,
cap = MAX_PENDING_CANCELLATIONS,
"Dropping pending cancellation because the pending cancellation buffer is full"
);
return;
}
pending_cancellations.insert(prediction_id);
}
#[allow(clippy::too_many_arguments)]
async fn run_event_loop(
mut ctrl_reader: FramedRead<tokio::process::ChildStdout, JsonCodec<ControlResponse>>,
ctrl_writer: Arc<
tokio::sync::Mutex<FramedWrite<tokio::process::ChildStdin, JsonCodec<ControlRequest>>>,
>,
slot_readers: Vec<(
SlotId,
FramedRead<tokio::net::unix::OwnedReadHalf, JsonCodec<SlotResponse>>,
)>,
mut register_rx: mpsc::Receiver<RegisterPredictionMessage>,
mut healthcheck_rx: mpsc::Receiver<tokio::sync::oneshot::Sender<HealthcheckResult>>,
mut cancel_rx: mpsc::Receiver<String>,
pool: Arc<PermitPool>,
upload_url: Option<String>,
output_is_array: bool,
) {
let mut predictions: HashMap<SlotId, Arc<StdMutex<Prediction>>> = HashMap::new();
let mut idle_senders: HashMap<SlotId, tokio::sync::oneshot::Sender<SlotIdleToken>> =
HashMap::new();
let mut pending_healthchecks: Vec<tokio::sync::oneshot::Sender<HealthcheckResult>> = Vec::new();
let mut healthcheck_counter: u64 = 0;
let mut pending_uploads: HashMap<SlotId, Vec<tokio::task::JoinHandle<()>>> = HashMap::new();
let mut pending_cancellations: HashSet<String> = HashSet::new();
let (slot_msg_tx, mut slot_msg_rx) =
mpsc::channel::<(SlotId, Result<SlotResponse, std::io::Error>)>(100);
for (slot_id, mut reader) in slot_readers {
let tx = slot_msg_tx.clone();
tokio::spawn(async move {
loop {
let msg = reader.next().await;
match msg {
Some(Ok(response)) => {
if tx.send((slot_id, Ok(response))).await.is_err() {
break;
}
}
Some(Err(e)) => {
let _ = tx.send((slot_id, Err(e))).await;
break;
}
None => {
break;
}
}
}
tracing::debug!(%slot_id, "Slot reader task exiting");
});
}
drop(slot_msg_tx);
loop {
tokio::select! {
biased;
ctrl_msg = ctrl_reader.next() => {
match ctrl_msg {
Some(Ok(ControlResponse::Idle { slot })) => {
tracing::debug!(%slot, "Slot idle notification received (control channel)");
match idle_senders.remove(&slot) {
Some(sender) => {
let token = InactiveSlotIdleToken::new(slot);
if sender.send(token.activate()).is_err() {
tracing::warn!(%slot, "Idle token receiver dropped before idle confirmation");
}
}
None => {
tracing::warn!(%slot, "Received Idle for slot with no pending idle confirmation");
}
}
}
Some(Ok(ControlResponse::Cancelled { slot })) => {
tracing::debug!(%slot, "Slot cancelled (control channel)");
}
Some(Ok(ControlResponse::Failed { slot, error })) => {
tracing::warn!(%slot, %error, "Slot poisoned");
pool.poison(slot);
if let Some(pred) = predictions.remove(&slot)
&& let Some(mut p) = try_lock_prediction(&pred)
&& !p.is_terminal()
{
p.set_failed(error);
}
}
Some(Ok(ControlResponse::Fatal { reason })) => {
tracing::error!(%reason, "Worker fatal");
for (slot, pred) in predictions.drain() {
tracing::warn!(%slot, "Failing prediction due to worker fatal error");
pool.poison(slot);
if let Some(mut p) = try_lock_prediction(&pred)
&& !p.is_terminal()
{
p.set_failed(reason.clone());
}
}
let result = HealthcheckResult::unhealthy(&reason);
for tx in pending_healthchecks.drain(..) {
let _ = tx.send(result.clone());
}
break;
}
Some(Ok(ControlResponse::Ready { .. })) => {
tracing::warn!("Unexpected Ready in event loop");
}
Some(Ok(ControlResponse::Log { source: _, data })) => {
for line in data.lines() {
tracing::info!(target: "coglet::user", "{}", line);
}
}
Some(Ok(ControlResponse::WorkerLog { target, level, message })) => {
emit_worker_log(&target, &level, &message);
}
Some(Ok(ControlResponse::DroppedLogs { count, interval_millis })) => {
tracing::trace!(count, interval_millis, "Received DroppedLogs message");
let interval_secs = interval_millis as f64 / 1000.0;
tracing::warn!(
"Log production exceeds consumption rate. {} logs dropped in last {:.1}s",
count, interval_secs
);
}
Some(Ok(ControlResponse::HealthcheckResult { id: _, status, error })) => {
tracing::trace!(
?status,
?error,
pending_count = pending_healthchecks.len(),
"Received healthcheck result from worker"
);
if pending_healthchecks.is_empty() {
tracing::warn!("Received healthcheck result but no pending requests");
} else {
let result = match status {
HealthcheckStatus::Healthy => HealthcheckResult::healthy(),
HealthcheckStatus::Unhealthy => {
HealthcheckResult::unhealthy(error.unwrap_or_else(|| "unhealthy".to_string()))
}
};
tracing::trace!(
pending_count = pending_healthchecks.len(),
"Distributing healthcheck result to pending callers"
);
for tx in pending_healthchecks.drain(..) {
let _ = tx.send(result.clone());
}
}
}
Some(Ok(ControlResponse::ShuttingDown)) => {
tracing::info!("Worker shutting down");
break;
}
Some(Err(e)) => {
tracing::error!(error = %e, "Control channel error");
break;
}
None => {
tracing::warn!("Control channel closed (worker crashed?)");
for (slot, pred) in predictions.drain() {
tracing::warn!(%slot, "Failing prediction due to worker crash");
if let Some(mut p) = try_lock_prediction(&pred) {
p.set_failed("Worker crashed".to_string());
}
}
for tx in pending_healthchecks.drain(..) {
let _ = tx.send(HealthcheckResult::unhealthy("Worker crashed"));
}
break;
}
}
}
Some(response_tx) = healthcheck_rx.recv() => {
let in_flight = !pending_healthchecks.is_empty();
pending_healthchecks.push(response_tx);
if !in_flight {
healthcheck_counter += 1;
let hc_id = format!("hc_{}", healthcheck_counter);
tracing::trace!(%hc_id, "Sending healthcheck request to worker");
let mut writer = ctrl_writer.lock().await;
if let Err(e) = writer.send(ControlRequest::Healthcheck { id: hc_id }).await {
tracing::error!(error = %e, "Failed to send healthcheck request");
let result = HealthcheckResult::unhealthy(format!("Failed to send: {}", e));
for tx in pending_healthchecks.drain(..) {
let _ = tx.send(result.clone());
}
}
} else {
tracing::trace!(
pending_count = pending_healthchecks.len(),
"Healthcheck already in-flight, coalescing request"
);
}
}
Some(prediction_id) = cancel_rx.recv() => {
let slot = predictions.iter().find_map(|(sid, pred)| {
try_lock_prediction(pred)
.filter(|p| p.id() == prediction_id)
.map(|_| *sid)
});
match slot {
Some(slot_id) => {
tracing::info!(
target: "coglet::prediction",
%prediction_id,
%slot_id,
"Cancelling prediction"
);
let mut writer = ctrl_writer.lock().await;
if let Err(e) = writer.send(ControlRequest::Cancel { slot: slot_id }).await {
tracing::error!(
%slot_id,
error = %e,
"Failed to send cancel request to worker"
);
}
if let Some(handles) = pending_uploads.remove(&slot_id) {
for h in handles { h.abort(); }
}
}
None => {
tracing::debug!(%prediction_id, "Cancel requested for unknown prediction; storing pending cancellation");
record_pending_cancellation(&mut pending_cancellations, prediction_id);
}
}
}
Some(RegisterPredictionMessage { slot_id, prediction, idle_sender, registered_ack }) = register_rx.recv() => {
let prediction_id = match try_lock_prediction(&prediction) {
Some(p) => p.id().to_string(),
None => {
tracing::error!(%slot_id, "Prediction mutex poisoned during registration");
let _ = registered_ack.send(());
continue;
}
};
idle_senders.insert(slot_id, idle_sender);
tracing::info!(
target: "coglet::prediction",
%prediction_id,
"Starting prediction"
);
tracing::debug!(%slot_id, %prediction_id, "Registered prediction");
predictions.insert(slot_id, prediction);
let pending_cancel = pending_cancellations.remove(&prediction_id);
let _ = registered_ack.send(());
if pending_cancel {
tracing::info!(
target: "coglet::prediction",
%prediction_id,
%slot_id,
"Applying pending cancellation"
);
let mut writer = ctrl_writer.lock().await;
if let Err(e) = writer.send(ControlRequest::Cancel { slot: slot_id }).await {
tracing::error!(
%slot_id,
error = %e,
"Failed to send pending cancel request to worker"
);
}
}
}
Some((slot_id, result)) = slot_msg_rx.recv() => {
match result {
Ok(SlotResponse::ProtocolVersion { version }) => {
if version != crate::bridge::protocol::SLOT_RESPONSE_PROTOCOL_VERSION {
tracing::warn!(
%slot_id,
version,
expected = crate::bridge::protocol::SLOT_RESPONSE_PROTOCOL_VERSION,
"Worker reported unexpected slot response protocol version"
);
}
}
Ok(SlotResponse::LogLine { source, data }) => {
let (prediction_id, poisoned) = if let Some(pred) = predictions.get(&slot_id) {
if let Some(mut p) = try_lock_prediction(pred) {
p.append_log_source(source, &data);
(Some(p.id().to_string()), false)
} else {
(None, true)
}
} else {
(None, false)
};
if poisoned {
predictions.remove(&slot_id);
}
let trimmed = data.trim();
if !trimmed.is_empty() {
if let Some(id) = prediction_id {
tracing::info!(
target: "coglet::prediction",
prediction_id = %id,
source = ?source,
"{}",
trimmed
);
} else {
tracing::warn!(
target: "coglet::prediction",
prediction_id = "NO_ACTIVE_PREDICTION",
source = ?source,
"{}",
trimmed
);
}
}
}
Ok(SlotResponse::Metric { name, value, mode }) => {
let poisoned = if let Some(pred) = predictions.get(&slot_id) {
if let Some(mut p) = try_lock_prediction(pred) {
p.set_metric(name, value, mode);
false
} else {
true
}
} else {
false
};
if poisoned {
predictions.remove(&slot_id);
}
}
Ok(SlotResponse::OutputChunk { output, index }) => {
let poisoned = if let Some(pred) = predictions.get(&slot_id) {
if let Some(mut p) = try_lock_prediction(pred) {
p.append_output_chunk(output, index);
false
} else {
true
}
} else {
false
};
if poisoned {
predictions.remove(&slot_id);
}
}
Ok(SlotResponse::FileOutput { filename, kind, mime_type }) => {
tracing::debug!(%slot_id, %filename, ?kind, "FileOutput received");
let bytes = match std::fs::read(&filename) {
Ok(b) => b,
Err(e) => {
tracing::error!(%slot_id, %filename, error = %e, "Failed to read FileOutput");
continue;
}
};
match kind {
FileOutputKind::Oversized => {
let output: serde_json::Value = match serde_json::from_slice(&bytes) {
Ok(val) => val,
Err(e) => {
tracing::error!(%slot_id, %filename, error = %e, "Failed to parse oversized JSON");
continue;
}
};
let poisoned = if let Some(pred) = predictions.get(&slot_id) {
if let Some(mut p) = try_lock_prediction(pred) {
p.append_output(output);
false
} else {
true
}
} else {
false
};
if poisoned {
predictions.remove(&slot_id);
}
}
FileOutputKind::FileType => {
let mime = mime_type.unwrap_or_else(|| {
mime_guess::from_path(&filename)
.first_or_octet_stream()
.to_string()
});
if let Some(ref url) = upload_url {
let pred = predictions.get(&slot_id).cloned();
let endpoint = ensure_trailing_slash(url);
let basename = std::path::Path::new(&filename)
.file_name()
.and_then(|n| n.to_str())
.unwrap_or("output")
.to_string();
let handle = tokio::spawn(async move {
match upload_file(&endpoint, &basename, &bytes, &mime).await {
Ok(url) => {
if let Some(pred) = pred
&& let Some(mut p) = try_lock_prediction(&pred)
{
p.append_output(serde_json::Value::String(url));
}
}
Err(e) => {
tracing::error!(error = %e, "Failed to upload file output");
}
}
});
pending_uploads.entry(slot_id).or_default().push(handle);
} else {
use base64::Engine;
let encoded = base64::engine::general_purpose::STANDARD
.encode(&bytes);
let output = serde_json::Value::String(format!(
"data:{mime};base64,{encoded}"
));
let poisoned = if let Some(pred) = predictions.get(&slot_id) {
if let Some(mut p) = try_lock_prediction(pred) {
p.append_output(output);
false
} else {
true
}
} else {
false
};
if poisoned {
predictions.remove(&slot_id);
}
}
}
}
}
Ok(SlotResponse::Done { id, output: _, predict_time, is_stream }) => {
tracing::info!(
target: "coglet::prediction",
prediction_id = %id,
predict_time,
is_stream,
output_is_array,
"Prediction succeeded"
);
let uploads = pending_uploads.remove(&slot_id).unwrap_or_default();
if let Some(pred) = predictions.remove(&slot_id) {
if uploads.is_empty() {
if let Some(mut p) = try_lock_prediction(&pred) {
let pred_output = wrap_outputs(
p.take_outputs(),
output_is_array,
is_stream,
);
p.set_succeeded(pred_output);
}
} else {
let (cancel_token, upload_pred_id) = match try_lock_prediction(&pred) {
Some(p) => (Some(p.cancel_token()), p.id().to_string()),
None => (None, id.clone()),
};
tokio::spawn(async move {
if let Some(token) = cancel_token {
let upload_fut = futures::future::join_all(uploads);
tokio::pin!(upload_fut);
tokio::select! {
_ = &mut upload_fut => {}
_ = token.cancelled() => {
tracing::info!(
target: "coglet::prediction",
prediction_id = %upload_pred_id,
"Aborting in-flight uploads due to cancellation"
);
if let Some(mut p) = try_lock_prediction(&pred) {
p.set_canceled();
}
return;
}
}
} else {
for h in uploads {
let _ = h.await;
}
}
if let Some(mut p) = try_lock_prediction(&pred) {
let pred_output = wrap_outputs(
p.take_outputs(),
output_is_array,
is_stream,
);
p.set_succeeded(pred_output);
}
});
}
} else {
tracing::warn!(%slot_id, %id, "Prediction not found for Done message");
}
}
Ok(SlotResponse::Failed { id, error }) => {
tracing::info!(
target: "coglet::prediction",
prediction_id = %id,
%error,
"Prediction failed"
);
if let Some(handles) = pending_uploads.remove(&slot_id) {
for h in handles { h.abort(); }
}
if let Some(pred) = predictions.remove(&slot_id)
&& let Some(mut p) = try_lock_prediction(&pred)
{
p.set_failed(error);
}
}
Ok(SlotResponse::Cancelled { id }) => {
tracing::info!(
target: "coglet::prediction",
prediction_id = %id,
"Prediction cancelled"
);
if let Some(handles) = pending_uploads.remove(&slot_id) {
for h in handles { h.abort(); }
}
if let Some(pred) = predictions.remove(&slot_id)
&& let Some(mut p) = try_lock_prediction(&pred)
{
p.set_canceled();
}
}
Err(e) => {
tracing::error!(%slot_id, error = %e, "Slot socket error");
if let Some(handles) = pending_uploads.remove(&slot_id) {
for h in handles { h.abort(); }
}
if let Some(pred) = predictions.remove(&slot_id)
&& let Some(mut p) = try_lock_prediction(&pred)
{
p.set_failed(format!("Slot socket error: {}", e));
}
}
}
}
}
}
tracing::info!("Event loop exiting");
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn wrap_outputs_schema_array_empty() {
let result = wrap_outputs(vec![], true, true);
assert!(result.is_stream());
assert_eq!(result.into_values(), Vec::<serde_json::Value>::new());
}
#[test]
fn record_pending_cancellation_caps_stored_ids() {
let mut pending = HashSet::new();
for index in 0..MAX_PENDING_CANCELLATIONS {
record_pending_cancellation(&mut pending, format!("pred-{index}"));
}
record_pending_cancellation(&mut pending, "overflow".to_string());
assert_eq!(pending.len(), MAX_PENDING_CANCELLATIONS);
assert!(!pending.contains("overflow"));
}
#[test]
fn wrap_outputs_schema_array_single_item() {
let result = wrap_outputs(vec![json!("https://example.com/img.png")], true, true);
assert!(result.is_stream());
assert_eq!(
result.into_values(),
vec![json!("https://example.com/img.png")]
);
}
#[test]
fn wrap_outputs_schema_array_multiple_items() {
let items = vec![
json!("https://example.com/1.png"),
json!("https://example.com/2.png"),
json!("https://example.com/3.png"),
json!("https://example.com/4.png"),
];
let result = wrap_outputs(items.clone(), true, true);
assert!(result.is_stream());
assert_eq!(result.into_values(), items);
}
#[test]
fn wrap_outputs_schema_array_overrides_is_stream_false() {
let result = wrap_outputs(vec![json!("url")], true, false);
assert!(result.is_stream());
}
#[test]
fn wrap_outputs_predictor_stream_empty() {
let result = wrap_outputs(vec![], false, true);
assert!(result.is_stream());
assert_eq!(result.into_values(), Vec::<serde_json::Value>::new());
}
#[test]
fn wrap_outputs_predictor_stream_single_item() {
let result = wrap_outputs(vec![json!("only_item")], false, true);
assert!(result.is_stream());
assert_eq!(result.into_values(), vec![json!("only_item")]);
}
#[test]
fn wrap_outputs_predictor_stream_multiple_items() {
let items = vec![json!("a"), json!("b"), json!("c")];
let result = wrap_outputs(items.clone(), false, true);
assert!(result.is_stream());
assert_eq!(result.into_values(), items);
}
#[test]
fn wrap_outputs_scalar_empty() {
let result = wrap_outputs(vec![], false, false);
assert!(!result.is_stream());
assert_eq!(result.final_value(), &json!(null));
}
#[test]
fn wrap_outputs_scalar_single() {
let result = wrap_outputs(vec![json!("https://example.com/output.png")], false, false);
assert!(!result.is_stream());
assert_eq!(
result.final_value(),
&json!("https://example.com/output.png")
);
}
#[test]
fn wrap_outputs_scalar_multiple_falls_back_to_stream() {
let items = vec![json!("a"), json!("b")];
let result = wrap_outputs(items.clone(), false, false);
assert!(result.is_stream());
assert_eq!(result.into_values(), items);
}
#[test]
fn done_is_stream_false_omitted_from_json() {
let msg = SlotResponse::Done {
id: "p1".into(),
output: None,
predict_time: 1.0,
is_stream: false,
};
let json = serde_json::to_value(&msg).unwrap();
assert!(
json.get("is_stream").is_none(),
"is_stream=false should be omitted"
);
}
#[test]
fn done_is_stream_true_present_in_json() {
let msg = SlotResponse::Done {
id: "p1".into(),
output: None,
predict_time: 1.0,
is_stream: true,
};
let json = serde_json::to_value(&msg).unwrap();
assert_eq!(json.get("is_stream"), Some(&json!(true)));
}
#[test]
fn done_without_is_stream_deserializes_as_false() {
let json = json!({
"type": "done",
"id": "p1",
"predict_time": 1.0
});
let msg: SlotResponse = serde_json::from_value(json).unwrap();
match msg {
SlotResponse::Done { is_stream, .. } => assert!(!is_stream),
_ => panic!("wrong variant"),
}
}
}