use std::collections::HashMap;
use std::io;
use std::path::PathBuf;
use std::sync::Arc;
use std::sync::OnceLock;
use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
use futures::{SinkExt, StreamExt};
use tokio::runtime::Handle;
use tokio::sync::mpsc;
use tokio_util::codec::{FramedRead, FramedWrite};
use crate::bridge::protocol::truncate_worker_log;
static DROPPED_SETUP_LOG_COUNT: AtomicUsize = AtomicUsize::new(0);
pub fn increment_dropped_log_count() {
DROPPED_SETUP_LOG_COUNT.fetch_add(1, Ordering::Relaxed);
}
fn report_dropped_logs(tx: &mpsc::Sender<ControlResponse>, interval_millis: u64) {
let dropped = DROPPED_SETUP_LOG_COUNT.swap(0, Ordering::Relaxed);
if dropped > 0 {
let _ = tx.try_send(ControlResponse::DroppedLogs {
count: dropped,
interval_millis,
});
}
}
struct FatalContext {
tx: mpsc::Sender<ControlResponse>,
}
static FATAL_CONTEXT: OnceLock<FatalContext> = OnceLock::new();
fn init_fatal_context(tx: mpsc::Sender<ControlResponse>) {
let _ = FATAL_CONTEXT.set(FatalContext { tx });
}
fn install_panic_hook() {
let prev = std::panic::take_hook();
std::panic::set_hook(Box::new(move |info| {
prev(info);
let msg = if let Some(s) = info.payload().downcast_ref::<&str>() {
(*s).to_string()
} else if let Some(s) = info.payload().downcast_ref::<String>() {
s.clone()
} else {
"<unknown>".to_string()
};
let reason = match info.location() {
Some(loc) => format!("panic at {}:{}: {}", loc.file(), loc.line(), msg),
None => format!("panic: {}", msg),
};
if let Some(ctx) = FATAL_CONTEXT.get() {
let _ = ctx.tx.try_send(ControlResponse::Fatal { reason });
}
std::process::abort();
}));
}
fn init_worker_tracing(tx: mpsc::Sender<ControlResponse>) {
use tracing_subscriber::{EnvFilter, layer::SubscriberExt, util::SubscriberInitExt};
let filter = if std::env::var("RUST_LOG").is_ok() {
EnvFilter::from_default_env()
} else {
let base_level = match std::env::var("COG_LOG_LEVEL").as_deref() {
Ok("debug") => "debug",
Ok("warn") | Ok("warning") => "warn",
Ok("error") => "error",
_ => "info",
};
let filter_str = format!(
"coglet={level},coglet::setup=info,coglet::user=info,coglet_worker={level},coglet_worker::schema=off,coglet_worker::protocol=off",
level = base_level
);
EnvFilter::new(filter_str)
};
let worker_layer = WorkerTracingLayer::new(tx);
let subscriber = tracing_subscriber::registry()
.with(filter)
.with(worker_layer);
let _ = subscriber.try_init();
}
use crate::bridge::codec::JsonCodec;
use crate::bridge::protocol::{
ControlRequest, ControlResponse, FileOutputKind, LogSource, MAX_INLINE_IPC_SIZE, MetricMode,
SLOT_RESPONSE_PROTOCOL_VERSION, SlotId, SlotOutcome, SlotRequest, SlotResponse,
};
use crate::bridge::transport::{ChildTransportInfo, connect_transport};
use crate::orchestrator::HealthcheckResult;
use crate::worker_tracing_layer::WorkerTracingLayer;
type SlotWriter =
Arc<tokio::sync::Mutex<FramedWrite<tokio::net::unix::OwnedWriteHalf, JsonCodec<SlotResponse>>>>;
#[derive(Clone)]
pub struct SlotSender {
tx: mpsc::UnboundedSender<SlotResponse>,
output_dir: PathBuf,
file_counter: Arc<AtomicUsize>,
output_counter: Arc<AtomicU64>,
}
impl SlotSender {
pub fn new(tx: mpsc::UnboundedSender<SlotResponse>, output_dir: PathBuf) -> Self {
Self {
tx,
output_dir,
file_counter: Arc::new(AtomicUsize::new(0)),
output_counter: Arc::new(AtomicU64::new(0)),
}
}
fn next_output_index(&self) -> u64 {
self.output_counter.fetch_add(1, Ordering::Relaxed)
}
fn next_output_path(&self, extension: &str) -> PathBuf {
let n = self.file_counter.fetch_add(1, Ordering::Relaxed);
self.output_dir.join(format!("{n}.{extension}"))
}
pub fn send_log(&self, source: LogSource, data: &str) -> io::Result<()> {
if data.is_empty() {
return Ok(());
}
let msg = SlotResponse::LogLine {
source,
data: truncate_worker_log(data.to_string()),
};
self.tx
.send(msg)
.map_err(|_| io::Error::new(io::ErrorKind::BrokenPipe, "slot channel closed"))
}
pub fn write_file_output(
&self,
data: &[u8],
extension: &str,
mime_type: Option<String>,
) -> io::Result<()> {
let path = self.next_output_path(extension);
std::fs::write(&path, data)?;
self.send_file_output(path, mime_type)
}
pub fn send_file_output(&self, path: PathBuf, mime_type: Option<String>) -> io::Result<()> {
let filename = path
.to_str()
.ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "non-UTF-8 path"))?
.to_string();
let msg = SlotResponse::FileOutput {
filename,
kind: FileOutputKind::FileType,
mime_type,
};
self.tx
.send(msg)
.map_err(|_| io::Error::new(io::ErrorKind::BrokenPipe, "slot channel closed"))
}
pub fn send_metric(
&self,
name: String,
value: serde_json::Value,
mode: MetricMode,
) -> io::Result<()> {
let msg = SlotResponse::Metric { name, value, mode };
self.tx
.send(msg)
.map_err(|_| io::Error::new(io::ErrorKind::BrokenPipe, "slot channel closed"))
}
pub fn send_output(&self, output: serde_json::Value) -> io::Result<()> {
let msg = build_output_message(&self.output_dir, output, self.next_output_index())?;
self.tx
.send(msg)
.map_err(|_| io::Error::new(io::ErrorKind::BrokenPipe, "slot channel closed"))
}
}
fn build_output_message(
output_dir: &std::path::Path,
output: serde_json::Value,
index: u64,
) -> io::Result<SlotResponse> {
let serialized =
serde_json::to_vec(&output).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
if serialized.len() > MAX_INLINE_IPC_SIZE {
let path = output_dir.join(format!("spill_{}.json", uuid::Uuid::new_v4()));
std::fs::write(&path, &serialized)?;
let filename = path
.to_str()
.ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "non-UTF-8 path"))?
.to_string();
Ok(SlotResponse::FileOutput {
filename,
kind: FileOutputKind::Oversized,
mime_type: None,
})
} else {
Ok(SlotResponse::OutputChunk { output, index })
}
}
#[derive(Debug, thiserror::Error)]
pub enum SetupError {
#[error("failed to load predictor: {message}")]
Load { message: String },
#[error("setup failed: {message}")]
Setup { message: String },
#[error("internal error: {message}")]
Internal { message: String },
}
impl SetupError {
pub fn load(message: impl Into<String>) -> Self {
Self::Load {
message: message.into(),
}
}
pub fn setup(message: impl Into<String>) -> Self {
Self::Setup {
message: message.into(),
}
}
pub fn internal(message: impl Into<String>) -> Self {
Self::Internal {
message: message.into(),
}
}
}
#[async_trait::async_trait]
pub trait PredictHandler: Send + Sync + 'static {
async fn setup(&self) -> Result<(), SetupError>;
async fn predict(
&self,
slot: SlotId,
id: String,
input: serde_json::Value,
slot_sender: Arc<SlotSender>,
context: std::collections::HashMap<String, String>,
) -> PredictResult;
fn cancel(&self, slot: SlotId);
async fn healthcheck(&self) -> HealthcheckResult {
HealthcheckResult::healthy()
}
}
const BUNDLED_SCHEMA_PATH: &str = ".cog/openapi_schema.json";
fn load_bundled_schema() -> Option<serde_json::Value> {
let path = std::path::Path::new(BUNDLED_SCHEMA_PATH);
match std::fs::read_to_string(path) {
Ok(contents) => match serde_json::from_str(&contents) {
Ok(schema) => {
tracing::info!("Loaded OpenAPI schema from {}", BUNDLED_SCHEMA_PATH);
Some(schema)
}
Err(e) => {
tracing::warn!(
"Failed to parse {}: {}. Running without schema — all input types accepted.",
BUNDLED_SCHEMA_PATH,
e,
);
None
}
},
Err(_) => {
tracing::warn!(
"No schema file at {}. Running without schema — all input types accepted. \
Rebuild with a recent version of cog to generate the schema.",
BUNDLED_SCHEMA_PATH,
);
None
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum PredictionOutcome {
Success {
output: serde_json::Value,
predict_time: f64,
is_stream: bool,
},
Failed { error: String, predict_time: f64 },
Cancelled { predict_time: f64 },
}
#[derive(Debug)]
pub struct PredictResult {
pub outcome: PredictionOutcome,
}
impl PredictResult {
pub fn success(output: serde_json::Value, predict_time: f64, is_stream: bool) -> Self {
Self {
outcome: PredictionOutcome::Success {
output,
predict_time,
is_stream,
},
}
}
pub fn failed(error: String, predict_time: f64) -> Self {
Self {
outcome: PredictionOutcome::Failed {
error,
predict_time,
},
}
}
pub fn cancelled(predict_time: f64) -> Self {
Self {
outcome: PredictionOutcome::Cancelled { predict_time },
}
}
}
pub type SetupLogHook =
Box<dyn FnOnce(mpsc::Sender<ControlResponse>) -> Box<dyn FnOnce() + Send> + Send>;
pub struct WorkerConfig {
pub num_slots: usize,
pub setup_log_hook: Option<SetupLogHook>,
}
impl Default for WorkerConfig {
fn default() -> Self {
Self {
num_slots: 1,
setup_log_hook: None,
}
}
}
struct SlotCompletion {
outcome: SlotOutcome,
}
impl SlotCompletion {
fn idle(slot: SlotId) -> Self {
Self {
outcome: SlotOutcome::idle(slot),
}
}
fn poisoned(slot: SlotId, error: impl Into<String>) -> Self {
Self {
outcome: SlotOutcome::poisoned(slot, error),
}
}
}
pub async fn run_worker<H: PredictHandler>(
handler: Arc<H>,
config: WorkerConfig,
transport_info: ChildTransportInfo,
) -> io::Result<()> {
let num_slots = config.num_slots;
let (setup_log_tx, mut setup_log_rx) = mpsc::channel::<ControlResponse>(5000);
init_worker_tracing(setup_log_tx.clone());
let control_fds =
crate::fd_redirect::redirect_fds_for_subprocess_isolation(setup_log_tx.clone())?;
tracing::trace!(?transport_info, "Connecting to slot transport");
let mut transport = connect_transport(transport_info).await?;
tracing::info!(num_slots, "Connected to slot transport");
let ctrl_stdin = tokio::fs::File::from_std(control_fds.stdin_fd.into());
let ctrl_stdout = tokio::fs::File::from_std(control_fds.stdout_fd.into());
let mut ctrl_reader = FramedRead::new(ctrl_stdin, JsonCodec::<ControlRequest>::new());
let ctrl_writer = Arc::new(tokio::sync::Mutex::new(FramedWrite::new(
ctrl_stdout,
JsonCodec::<ControlResponse>::new(),
)));
let slot_ids: Vec<SlotId> = (0..num_slots).map(|_| SlotId::new()).collect();
init_fatal_context(setup_log_tx.clone());
install_panic_hook();
let setup_cleanup = config.setup_log_hook.map(|hook| hook(setup_log_tx.clone()));
let ctrl_writer_for_logs = Arc::clone(&ctrl_writer);
let _log_forwarder = tokio::spawn(async move {
let mut log_count = 0;
let mut total_bytes = 0;
while let Some(msg) = setup_log_rx.recv().await {
if let ControlResponse::Log { ref data, .. } = msg {
let msg_size = data.len();
log_count += 1;
total_bytes += msg_size;
tracing::trace!(
log_number = log_count,
msg_size_bytes = msg_size,
total_bytes,
"Forwarding log"
);
}
let mut w = ctrl_writer_for_logs.lock().await;
if let Err(e) = w.send(msg).await {
tracing::warn!(
error = %e,
log_count,
total_bytes,
"Failed to forward log"
);
break;
}
}
tracing::debug!(
total_logs = log_count,
total_bytes,
total_kb = total_bytes / 1024,
"Log forwarder exiting"
);
});
let dropped_log_tx = setup_log_tx.clone();
let _dropped_log_reporter = tokio::spawn(async move {
let mut interval = tokio::time::interval(std::time::Duration::from_millis(5000));
loop {
interval.tick().await;
report_dropped_logs(&dropped_log_tx, 5000);
}
});
tracing::info!("Worker starting setup");
let setup_start = std::time::Instant::now();
let setup_result = handler.setup().await;
let setup_elapsed = setup_start.elapsed();
tracing::debug!(
elapsed_ms = setup_elapsed.as_millis() as u64,
success = setup_result.is_ok(),
"Setup handler returned"
);
if let Some(cleanup) = setup_cleanup {
tracing::debug!("Running cleanup (unregistering Python setup sender)");
cleanup();
}
if let Err(e) = setup_result {
tracing::error!(
error = %e,
elapsed_ms = setup_elapsed.as_millis() as u64,
"Setup failed"
);
let slot = slot_ids.first().copied().unwrap_or_else(SlotId::new);
let mut w = ctrl_writer.lock().await;
let _ = w
.send(ControlResponse::Failed {
slot,
error: format!("Setup failed: {}", e),
})
.await;
return Ok(());
}
let schema = load_bundled_schema();
if let Some(ref s) = schema {
let schema_json = serde_json::to_string(s).unwrap_or_else(|_| "{}".to_string());
let schema_size = schema_json.len();
tracing::info!(
schema_size_bytes = schema_size,
schema_size_kb = schema_size / 1024,
"Schema loaded"
);
if schema_size > 1024 * 1024 {
tracing::warn!(
schema_preview = &schema_json[..500.min(schema_json.len())],
"Large schema detected"
);
}
}
tracing::trace!(num_slots, ?slot_ids, "Sending Ready to parent");
{
let mut w = ctrl_writer.lock().await;
w.send(ControlResponse::Ready {
slots: slot_ids.clone(),
schema,
})
.await?;
}
let (completion_tx, mut completion_rx) = mpsc::channel::<SlotCompletion>(num_slots);
let mut slot_busy: HashMap<SlotId, bool> = slot_ids.iter().map(|id| (*id, false)).collect();
let mut slot_poisoned: HashMap<SlotId, bool> = slot_ids.iter().map(|id| (*id, false)).collect();
let sockets = transport.drain_sockets();
let mut slot_readers: HashMap<
SlotId,
FramedRead<tokio::net::unix::OwnedReadHalf, JsonCodec<SlotRequest>>,
> = HashMap::new();
let mut slot_writers: HashMap<
SlotId,
FramedWrite<tokio::net::unix::OwnedWriteHalf, JsonCodec<SlotResponse>>,
> = HashMap::new();
for (slot_id, socket) in slot_ids.iter().zip(sockets) {
let (read_half, write_half) = socket.into_split();
slot_readers.insert(*slot_id, FramedRead::new(read_half, JsonCodec::new()));
slot_writers.insert(*slot_id, FramedWrite::new(write_half, JsonCodec::new()));
}
let (request_tx, mut request_rx) = mpsc::channel::<(SlotId, SlotRequest)>(num_slots);
for (slot_id, reader) in slot_readers {
let tx = request_tx.clone();
tokio::spawn(async move {
slot_reader_task(slot_id, reader, tx).await;
});
}
drop(request_tx);
let slot_writers: HashMap<SlotId, SlotWriter> = slot_writers
.into_iter()
.map(|(id, w)| (id, Arc::new(tokio::sync::Mutex::new(w))))
.collect();
for (slot_id, writer) in &slot_writers {
let mut w = writer.lock().await;
if let Err(e) = w
.send(SlotResponse::ProtocolVersion {
version: SLOT_RESPONSE_PROTOCOL_VERSION,
})
.await
{
tracing::warn!(%slot_id, error = %e, "Failed to send protocol version");
}
}
loop {
tokio::select! {
biased;
ctrl_msg = ctrl_reader.next() => {
match ctrl_msg {
Some(Ok(ControlRequest::Init { .. })) => {
tracing::warn!("Received Init in event loop (should be at startup)");
}
Some(Ok(ControlRequest::Cancel { slot })) => {
tracing::trace!(%slot, "Cancel requested");
handler.cancel(slot);
}
Some(Ok(ControlRequest::Shutdown)) => {
tracing::info!("Shutdown requested");
let mut w = ctrl_writer.lock().await;
let _ = w.send(ControlResponse::ShuttingDown).await;
break;
}
Some(Ok(ControlRequest::Healthcheck { id })) => {
tracing::trace!(%id, "Healthcheck requested, invoking handler");
let result = handler.healthcheck().await;
tracing::trace!(
%id,
status = ?result.status,
error = ?result.error,
"Healthcheck handler returned"
);
let mut w = ctrl_writer.lock().await;
let _ = w.send(ControlResponse::HealthcheckResult {
id,
status: result.status,
error: result.error,
}).await;
}
Some(Err(e)) => {
tracing::error!(error = %e, "Control channel error");
break;
}
None => {
tracing::error!("Control channel closed (parent died?), exiting");
break;
}
}
}
Some(completion) = completion_rx.recv() => {
let slot = completion.outcome.slot_id();
slot_busy.insert(slot, false);
if completion.outcome.is_poisoned() {
slot_poisoned.insert(slot, true);
}
{
let mut w = ctrl_writer.lock().await;
let _ = w.send(completion.outcome.into_control_response()).await;
}
if slot_poisoned.values().all(|&p| p) {
tracing::error!("All slots poisoned, exiting");
break;
}
}
Some((slot_id, request)) = request_rx.recv() => {
if slot_busy.get(&slot_id).copied().unwrap_or(false) {
tracing::warn!(%slot_id, "Request received for busy slot, ignoring");
continue;
}
if slot_poisoned.get(&slot_id).copied().unwrap_or(false) {
tracing::warn!(%slot_id, "Request received for poisoned slot, ignoring");
continue;
}
let prediction_id = request.prediction_id().to_string();
match request.rehydrate_input() {
Ok((id, input, output_dir, context)) => {
tracing::trace!(%slot_id, %id, "Prediction request received");
slot_busy.insert(slot_id, true);
let writer = match slot_writers.get(&slot_id) {
Some(w) => Arc::clone(w),
None => {
tracing::error!(%slot_id, "No writer for slot");
continue;
}
};
let handler = Arc::clone(&handler);
let completion_tx = completion_tx.clone();
tokio::spawn(async move {
let completion = run_prediction(
slot_id,
id,
input,
PathBuf::from(output_dir),
handler,
writer,
context,
).await;
let _ = completion_tx.send(completion).await;
});
}
Err(e) => {
tracing::error!(%slot_id, %prediction_id, error = %e, "Failed to rehydrate input");
if let Some(writer) = slot_writers.get(&slot_id) {
let mut w = writer.lock().await;
let fail_msg = SlotResponse::Failed {
id: prediction_id,
error: format!("Failed to rehydrate input: {}", e),
};
if let Err(send_err) = w.send(fail_msg).await {
tracing::error!(%slot_id, error = %send_err, "Failed to send rehydrate error response");
}
}
let _ = completion_tx.send(SlotCompletion::idle(slot_id)).await;
}
}
}
}
}
tracing::info!("Worker exiting");
Ok(())
}
async fn slot_reader_task(
slot_id: SlotId,
mut reader: FramedRead<tokio::net::unix::OwnedReadHalf, JsonCodec<SlotRequest>>,
tx: mpsc::Sender<(SlotId, SlotRequest)>,
) {
loop {
match reader.next().await {
Some(Ok(request)) => {
if tx.send((slot_id, request)).await.is_err() {
break;
}
}
Some(Err(e)) => {
tracing::error!(%slot_id, error = %e, "Slot reader error");
break;
}
None => {
tracing::trace!(%slot_id, "Slot socket closed");
break;
}
}
}
}
async fn run_prediction<H: PredictHandler>(
slot_id: SlotId,
prediction_id: String,
input: serde_json::Value,
output_dir: PathBuf,
handler: Arc<H>,
writer: SlotWriter,
context: std::collections::HashMap<String, String>,
) -> SlotCompletion {
tracing::trace!(%slot_id, %prediction_id, "run_prediction starting");
let (log_tx, mut log_rx) = mpsc::unbounded_channel::<SlotResponse>();
let slot_sender = Arc::new(SlotSender::new(log_tx, output_dir.clone()));
let writer_for_logs = Arc::clone(&writer);
let log_forwarder = tokio::spawn(async move {
while let Some(msg) = log_rx.recv().await {
let mut w = writer_for_logs.lock().await;
if let Err(e) = w.send(msg).await {
tracing::warn!(error = %e, "Failed to forward log");
break;
}
}
tracing::trace!("Prediction log forwarder exiting");
});
let result = tokio::task::block_in_place(|| {
Handle::current().block_on(handler.predict(
slot_id,
prediction_id.clone(),
input,
slot_sender,
context,
))
});
tracing::trace!(%slot_id, %prediction_id, "handler.predict returned");
tracing::trace!(%slot_id, %prediction_id, "Waiting for log forwarder");
let _ = log_forwarder.await;
tracing::trace!(%slot_id, %prediction_id, "Log forwarder done");
let mut w = writer.lock().await;
let response = match result.outcome {
PredictionOutcome::Success {
output,
predict_time,
is_stream,
} => {
if !output.is_null() && output != serde_json::Value::Array(vec![]) {
let output_msg = match build_output_message(&output_dir, output, 0) {
Ok(msg) => msg,
Err(e) => {
tracing::error!(error = %e, "Failed to build output message");
return SlotCompletion::poisoned(
slot_id,
format!("Output spill error: {}", e),
);
}
};
if let Err(e) = w.send(output_msg).await {
tracing::error!(error = %e, "Failed to send prediction output");
return SlotCompletion::poisoned(slot_id, format!("Socket write error: {}", e));
}
}
SlotResponse::Done {
id: prediction_id.clone(),
output: None,
predict_time,
is_stream,
}
}
PredictionOutcome::Cancelled { .. } => SlotResponse::Cancelled {
id: prediction_id.clone(),
},
PredictionOutcome::Failed { error, .. } => SlotResponse::Failed {
id: prediction_id.clone(),
error,
},
};
if let Err(e) = w.send(response).await {
tracing::error!(error = %e, "Failed to send prediction response");
return SlotCompletion::poisoned(slot_id, format!("Socket write error: {}", e));
}
SlotCompletion::idle(slot_id)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn predict_result_success() {
let r = PredictResult::success(serde_json::json!("hello"), 0.5, false);
assert!(matches!(r.outcome, PredictionOutcome::Success { .. }));
}
#[test]
fn predict_result_success_stream() {
let r = PredictResult::success(serde_json::json!([]), 0.5, true);
assert!(matches!(
r.outcome,
PredictionOutcome::Success {
is_stream: true,
..
}
));
}
#[test]
fn predict_result_failed() {
let r = PredictResult::failed("oops".into(), 0.5);
assert!(matches!(
r.outcome,
PredictionOutcome::Failed { ref error, .. } if error == "oops"
));
}
#[test]
fn predict_result_cancelled() {
let r = PredictResult::cancelled(0.5);
assert!(matches!(r.outcome, PredictionOutcome::Cancelled { .. }));
}
#[test]
fn worker_config_default() {
let config = WorkerConfig::default();
assert_eq!(config.num_slots, 1);
}
}