use std::collections::HashMap;
use serde::{Deserialize, Serialize};
use super::transport::ChildTransportInfo;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(transparent)]
pub struct SlotId(uuid::Uuid);
impl SlotId {
pub fn new() -> Self {
Self(uuid::Uuid::new_v4())
}
pub fn as_uuid(&self) -> &uuid::Uuid {
&self.0
}
pub fn parse(s: &str) -> Result<Self, uuid::Error> {
let uuid = uuid::Uuid::parse_str(s)?;
Ok(Self(uuid))
}
}
impl Default for SlotId {
fn default() -> Self {
Self::new()
}
}
impl std::fmt::Display for SlotId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
pub const MAX_INLINE_IPC_SIZE: usize = 1024 * 1024 * 6;
const MAX_WORKER_LOG_SIZE: usize = 1024 * 1024 * 4; const WORKER_LOG_TRUNCATE_NOTICE: &str = "[**** LOG LINE TRUNCATED AT 4 MiB ****]";
pub fn truncate_worker_log(mut log_message: String) -> String {
if log_message.len() > MAX_WORKER_LOG_SIZE {
let boundary =
log_message.floor_char_boundary(MAX_WORKER_LOG_SIZE - WORKER_LOG_TRUNCATE_NOTICE.len());
log_message.truncate(boundary);
log_message.push_str(WORKER_LOG_TRUNCATE_NOTICE);
}
log_message
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ControlRequest {
Init {
predictor_ref: String,
num_slots: usize,
transport_info: ChildTransportInfo,
is_train: bool,
is_async: bool,
},
Cancel {
slot: SlotId,
},
Healthcheck {
id: String,
},
Shutdown,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ControlResponse {
Ready {
slots: Vec<SlotId>,
#[serde(skip_serializing_if = "Option::is_none")]
schema: Option<serde_json::Value>,
},
Log {
source: LogSource,
data: String,
},
WorkerLog {
target: String,
level: String,
message: String,
},
Idle {
slot: SlotId,
},
Cancelled {
slot: SlotId,
},
Failed {
slot: SlotId,
error: String,
},
Fatal {
reason: String,
},
DroppedLogs {
count: usize,
interval_millis: u64,
},
HealthcheckResult {
id: String,
status: HealthcheckStatus,
#[serde(skip_serializing_if = "Option::is_none")]
error: Option<String>,
},
ShuttingDown,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum HealthcheckStatus {
Healthy,
Unhealthy,
}
#[derive(Debug)]
pub enum SlotOutcome {
Idle(SlotId),
Poisoned { slot: SlotId, error: String },
}
impl SlotOutcome {
pub fn idle(slot: SlotId) -> Self {
Self::Idle(slot)
}
pub fn poisoned(slot: SlotId, error: impl Into<String>) -> Self {
Self::Poisoned {
slot,
error: error.into(),
}
}
pub fn slot_id(&self) -> SlotId {
match self {
Self::Idle(slot) => *slot,
Self::Poisoned { slot, .. } => *slot,
}
}
pub fn is_poisoned(&self) -> bool {
matches!(self, Self::Poisoned { .. })
}
pub fn into_control_response(self) -> ControlResponse {
match self {
Self::Idle(slot) => ControlResponse::Idle { slot },
Self::Poisoned { slot, error } => ControlResponse::Failed { slot, error },
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum SlotRequest {
Predict {
id: String,
#[serde(skip_serializing_if = "Option::is_none")]
input: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
input_file: Option<String>,
output_dir: String,
#[serde(default)]
context: HashMap<String, String>,
},
}
impl SlotRequest {
pub fn prediction_id(&self) -> &str {
match self {
SlotRequest::Predict { id, .. } => id,
}
}
pub fn rehydrate_input(
self,
) -> std::io::Result<(String, serde_json::Value, String, HashMap<String, String>)> {
match self {
SlotRequest::Predict {
id,
input: Some(value),
output_dir,
context,
..
} => Ok((id, value, output_dir, context)),
SlotRequest::Predict {
id,
input: None,
input_file: Some(path),
output_dir,
context,
} => {
let bytes = std::fs::read(&path)?;
if let Err(e) = std::fs::remove_file(&path) {
tracing::warn!(path = %path, error = %e, "Failed to remove input spill file");
}
let value: serde_json::Value = serde_json::from_slice(&bytes)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
Ok((id, value, output_dir, context))
}
SlotRequest::Predict { .. } => Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"SlotRequest::Predict has neither input nor input_file",
)),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum FileOutputKind {
FileType,
Oversized,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum MetricMode {
Replace,
Increment,
Append,
}
pub const SLOT_RESPONSE_PROTOCOL_VERSION: u32 = 1;
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum SlotResponse {
ProtocolVersion {
version: u32,
},
LogLine {
source: LogSource,
data: String,
},
FileOutput {
filename: String,
kind: FileOutputKind,
#[serde(skip_serializing_if = "Option::is_none")]
mime_type: Option<String>,
},
OutputChunk {
output: serde_json::Value,
index: u64,
},
Metric {
name: String,
value: serde_json::Value,
mode: MetricMode,
},
Done {
id: String,
#[serde(skip_serializing_if = "Option::is_none")]
output: Option<serde_json::Value>,
predict_time: f64,
#[serde(default, skip_serializing_if = "std::ops::Not::not")]
is_stream: bool,
},
Failed {
id: String,
error: String,
},
Cancelled {
id: String,
},
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum LogSource {
Stdout,
Stderr,
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
use std::path::PathBuf;
fn test_slot_id() -> SlotId {
SlotId(uuid::Uuid::parse_str("550e8400-e29b-41d4-a716-446655440000").unwrap())
}
#[test]
fn control_init_serializes() {
let req = ControlRequest::Init {
predictor_ref: "predict.py:Predictor".to_string(),
num_slots: 2,
transport_info: ChildTransportInfo::NamedSockets {
dir: PathBuf::from("/tmp/coglet-123"),
num_slots: 2,
},
is_train: false,
is_async: true,
};
insta::assert_json_snapshot!(req);
}
#[test]
fn control_cancel_serializes() {
let req = ControlRequest::Cancel {
slot: test_slot_id(),
};
insta::assert_json_snapshot!(req);
}
#[test]
fn control_shutdown_serializes() {
let req = ControlRequest::Shutdown;
insta::assert_json_snapshot!(req);
}
#[test]
fn control_healthcheck_serializes() {
let req = ControlRequest::Healthcheck {
id: "hc_123".to_string(),
};
insta::assert_json_snapshot!(req);
}
#[test]
fn control_healthcheck_result_healthy_serializes() {
let resp = ControlResponse::HealthcheckResult {
id: "hc_123".to_string(),
status: HealthcheckStatus::Healthy,
error: None,
};
insta::assert_json_snapshot!(resp);
}
#[test]
fn control_healthcheck_result_unhealthy_serializes() {
let resp = ControlResponse::HealthcheckResult {
id: "hc_123".to_string(),
status: HealthcheckStatus::Unhealthy,
error: Some("user healthcheck returned False".to_string()),
};
insta::assert_json_snapshot!(resp);
}
#[test]
fn control_ready_serializes() {
let resp = ControlResponse::Ready {
slots: vec![test_slot_id()],
schema: None,
};
insta::assert_json_snapshot!(resp);
}
#[test]
fn control_ready_with_schema_serializes() {
let resp = ControlResponse::Ready {
slots: vec![test_slot_id()],
schema: Some(json!({
"openapi": "3.0.2",
"info": {"title": "Cog", "version": "0.1.0"}
})),
};
insta::assert_json_snapshot!(resp);
}
#[test]
fn control_idle_serializes() {
let resp = ControlResponse::Idle {
slot: test_slot_id(),
};
insta::assert_json_snapshot!(resp);
}
#[test]
fn control_cancelled_serializes() {
let resp = ControlResponse::Cancelled {
slot: test_slot_id(),
};
insta::assert_json_snapshot!(resp);
}
#[test]
fn control_failed_serializes() {
let resp = ControlResponse::Failed {
slot: test_slot_id(),
error: "segfault".to_string(),
};
insta::assert_json_snapshot!(resp);
}
#[test]
fn slot_predict_serializes() {
let req = SlotRequest::Predict {
id: "pred_123".to_string(),
input: Some(json!({"text": "hello"})),
input_file: None,
output_dir: "/tmp/coglet/predictions/pred_123/outputs".to_string(),
context: Default::default(),
};
insta::assert_json_snapshot!(req);
}
#[test]
fn slot_predict_file_input_serializes() {
let req = SlotRequest::Predict {
id: "pred_456".to_string(),
input: None,
input_file: Some("/tmp/coglet/predictions/pred_456/inputs/spill_abc.json".to_string()),
output_dir: "/tmp/coglet/predictions/pred_456/outputs".to_string(),
context: Default::default(),
};
insta::assert_json_snapshot!(req);
}
#[test]
fn slot_log_line_serializes() {
let resp = SlotResponse::LogLine {
source: LogSource::Stdout,
data: "Processing...".to_string(),
};
assert_eq!(
serde_json::to_value(resp).unwrap(),
json!({
"type": "log_line",
"source": "stdout",
"data": "Processing..."
})
);
}
#[test]
fn slot_output_chunk_serializes() {
let resp = SlotResponse::OutputChunk {
output: json!("chunk 1"),
index: 7,
};
assert_eq!(
serde_json::to_value(resp).unwrap(),
json!({
"type": "output_chunk",
"output": "chunk 1",
"index": 7
})
);
}
#[test]
fn slot_protocol_version_serializes() {
let resp = SlotResponse::ProtocolVersion {
version: SLOT_RESPONSE_PROTOCOL_VERSION,
};
assert_eq!(
serde_json::to_value(resp).unwrap(),
json!({
"type": "protocol_version",
"version": 1
})
);
}
#[test]
fn slot_done_serializes() {
let resp = SlotResponse::Done {
id: "pred_123".to_string(),
output: Some(json!("final result")),
predict_time: 1.234,
is_stream: false,
};
insta::assert_json_snapshot!(resp);
}
#[test]
fn slot_failed_serializes() {
let resp = SlotResponse::Failed {
id: "pred_123".to_string(),
error: "ValueError: invalid input".to_string(),
};
insta::assert_json_snapshot!(resp);
}
#[test]
fn slot_cancelled_serializes() {
let resp = SlotResponse::Cancelled {
id: "pred_123".to_string(),
};
insta::assert_json_snapshot!(resp);
}
#[test]
fn slot_metric_replace_serializes() {
let resp = SlotResponse::Metric {
name: "temperature".to_string(),
value: json!(0.7),
mode: MetricMode::Replace,
};
insta::assert_json_snapshot!(resp);
}
#[test]
fn slot_metric_increment_serializes() {
let resp = SlotResponse::Metric {
name: "token_count".to_string(),
value: json!(1),
mode: MetricMode::Increment,
};
insta::assert_json_snapshot!(resp);
}
#[test]
fn slot_metric_append_serializes() {
let resp = SlotResponse::Metric {
name: "logprobs".to_string(),
value: json!(-1.2),
mode: MetricMode::Append,
};
insta::assert_json_snapshot!(resp);
}
#[test]
fn slot_metric_delete_serializes() {
let resp = SlotResponse::Metric {
name: "unwanted".to_string(),
value: json!(null),
mode: MetricMode::Replace,
};
insta::assert_json_snapshot!(resp);
}
#[test]
fn slot_metric_complex_value_serializes() {
let resp = SlotResponse::Metric {
name: "timing".to_string(),
value: json!({"preprocess": 0.1, "inference": 0.8}),
mode: MetricMode::Replace,
};
insta::assert_json_snapshot!(resp);
}
#[test]
fn rehydrate_input_inline() {
let req = SlotRequest::Predict {
id: "p1".to_string(),
input: Some(json!({"text": "hello"})),
input_file: None,
output_dir: "/tmp/out".to_string(),
context: Default::default(),
};
let (id, input, output_dir, _context) = req.rehydrate_input().unwrap();
assert_eq!(id, "p1");
assert_eq!(input, json!({"text": "hello"}));
assert_eq!(output_dir, "/tmp/out");
}
#[test]
fn rehydrate_input_from_file() {
let dir = tempfile::tempdir().unwrap();
let spill_path = dir.path().join("spill_test.json");
std::fs::write(&spill_path, r#"{"key":"value"}"#).unwrap();
let req = SlotRequest::Predict {
id: "p2".to_string(),
input: None,
input_file: Some(spill_path.to_str().unwrap().to_string()),
output_dir: "/tmp/out".to_string(),
context: Default::default(),
};
let (id, input, output_dir, _context) = req.rehydrate_input().unwrap();
assert_eq!(id, "p2");
assert_eq!(input, json!({"key": "value"}));
assert_eq!(output_dir, "/tmp/out");
assert!(!spill_path.exists());
}
#[test]
fn rehydrate_input_neither_errors() {
let req = SlotRequest::Predict {
id: "p3".to_string(),
input: None,
input_file: None,
output_dir: "/tmp/out".to_string(),
context: Default::default(),
};
let err = req.rehydrate_input().unwrap_err();
assert_eq!(err.kind(), std::io::ErrorKind::InvalidData);
}
#[test]
fn rehydrate_input_corrupt_file_errors() {
let dir = tempfile::tempdir().unwrap();
let spill_path = dir.path().join("corrupt.json");
std::fs::write(&spill_path, "not valid json!!!").unwrap();
let req = SlotRequest::Predict {
id: "p4".to_string(),
input: None,
input_file: Some(spill_path.to_str().unwrap().to_string()),
output_dir: "/tmp/out".to_string(),
context: Default::default(),
};
let err = req.rehydrate_input().unwrap_err();
assert_eq!(err.kind(), std::io::ErrorKind::InvalidData);
}
#[test]
fn truncate_worker_log_truncates_long_messages() {
let emoji = "🦀"; let count = 1024 * 1024 * 1024 * 4 / emoji.len() + 1;
let message: String = truncate_worker_log(emoji.repeat(count));
assert!(
message.ends_with(WORKER_LOG_TRUNCATE_NOTICE),
"log message didn't end with {}",
WORKER_LOG_TRUNCATE_NOTICE
);
}
#[test]
fn truncate_worker_log_does_not_truncate_short_messages() {
let emoji = "🦀"; let count = 10;
let message: String = truncate_worker_log(emoji.repeat(count));
assert!(
!message.ends_with(WORKER_LOG_TRUNCATE_NOTICE),
"short log message was truncated"
);
}
}