use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use somatize_compiler::ExecutionPlan;
use somatize_core::event::Event;
use somatize_core::store::{DataRef, DataStore};
use somatize_core::value::Value;
pub type WorkerId = String;
pub type PlanId = String;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Capabilities {
pub cpu_cores: usize,
pub ram_bytes: u64,
pub gpus: Vec<GpuInfo>,
pub python_envs: Vec<String>,
pub tags: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GpuInfo {
pub name: String,
pub memory_bytes: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LoadMetrics {
pub cpu_usage: f32,
pub memory_usage: f32,
pub gpu_usage: Vec<f32>,
pub active_plans: usize,
pub queue_depth: usize,
pub timestamp: DateTime<Utc>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "source")]
#[non_exhaustive]
pub enum InputSource {
Inline { value: Value },
Reference { data_ref: DataRef },
}
impl InputSource {
pub fn resolve(
&self,
data_store: Option<&dyn somatize_core::store::DataStore>,
temp_store: &somatize_core::store::LocalDataStore,
) -> Value {
match self {
InputSource::Inline { value } => value.clone(),
InputSource::Reference { data_ref } => {
if let Some(store) = data_store
&& let Ok(val) = store.get(data_ref)
{
return val;
}
temp_store.get(data_ref).unwrap_or_else(|e| {
tracing::warn!("Failed to resolve DataRef: {e}");
Value::Empty
})
}
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SerializedFilter {
pub node_id: String,
#[serde(with = "base64_bytes")]
pub pickled_filter: Vec<u8>,
pub state: Option<Value>,
#[serde(default)]
pub requirements: Vec<String>,
#[serde(default)]
pub trainable: bool,
}
mod base64_bytes {
use base64::engine::{Engine, general_purpose::STANDARD};
use serde::{Deserialize, Deserializer, Serialize, Serializer};
pub fn serialize<S: Serializer>(bytes: &Vec<u8>, s: S) -> Result<S::Ok, S::Error> {
STANDARD.encode(bytes).serialize(s)
}
pub fn deserialize<'de, D: Deserializer<'de>>(d: D) -> Result<Vec<u8>, D::Error> {
let s = String::deserialize(d)?;
STANDARD.decode(s).map_err(serde::de::Error::custom)
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
#[non_exhaustive]
pub enum ExecutionMode {
Fit {
y: Option<Value>,
},
#[default]
Forward,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SerializedPlan {
pub plan_id: PlanId,
pub plan: ExecutionPlan,
pub input: Option<InputSource>,
#[serde(default)]
pub filters: Vec<SerializedFilter>,
#[serde(default)]
pub mode: ExecutionMode,
pub metadata: serde_json::Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type")]
pub enum WorkerToCoordinator {
Register {
worker_id: WorkerId,
capabilities: Capabilities,
},
Heartbeat {
worker_id: WorkerId,
load: LoadMetrics,
},
Event {
worker_id: WorkerId,
plan_id: PlanId,
event: Event,
},
PlanResult {
worker_id: WorkerId,
plan_id: PlanId,
result: PlanResult,
},
JobProgress {
worker_id: WorkerId,
job_id: String,
phase: String,
step: u32,
total: u32,
metrics: serde_json::Value,
},
JobResult {
worker_id: WorkerId,
job_id: String,
success: bool,
metrics: serde_json::Value,
output: String,
duration_ms: u64,
},
StateResult {
worker_id: WorkerId,
plan_id: PlanId,
states: std::collections::HashMap<String, Value>,
},
GradientsResult {
worker_id: WorkerId,
plan_id: PlanId,
gradients: std::collections::HashMap<String, Value>,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PythonPipelineJob {
pub job_id: String,
pub pipeline_id: String,
pub investigation_id: String,
pub files: Vec<PipelineFile>,
pub requirements: String,
pub entry_point: String,
pub input_data: Option<serde_json::Value>,
pub params: serde_json::Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PipelineFile {
pub path: String,
pub content: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type")]
pub enum CoordinatorToWorker {
Registered { worker_id: WorkerId },
AssignPlan { plan: SerializedPlan },
AssignPythonJob { job: PythonPipelineJob },
CancelPlan { plan_id: PlanId },
StatusRequest,
Ping,
Shutdown { reason: String },
GetState {
plan_id: PlanId,
node_ids: Vec<String>,
},
SetState {
plan_id: PlanId,
states: std::collections::HashMap<String, Value>,
},
GetGradients {
plan_id: PlanId,
node_ids: Vec<String>,
},
ApplyGradients {
plan_id: PlanId,
gradients: std::collections::HashMap<String, Value>,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "delivery")]
#[non_exhaustive]
pub enum OutputDelivery {
Inline { value: Value },
Reference {
data_ref: somatize_core::store::DataRef,
},
}
impl OutputDelivery {
pub fn resolve(&self, addr: &str, token: &Option<String>) -> Value {
match self {
OutputDelivery::Inline { value } => value.clone(),
OutputDelivery::Reference { data_ref } => {
let http_addr = addr
.replace("ws://", "http://")
.replace("wss://", "https://");
let url = format!("{http_addr}/download");
let ref_json = serde_json::to_string(data_ref).unwrap_or_default();
let token = token.clone();
std::thread::spawn(move || {
let client = reqwest::blocking::Client::new();
let mut req = client.get(&url).query(&[("ref", &ref_json)]);
if let Some(t) = &token {
req = req.query(&[("token", t.as_str())]);
}
let resp = req.send().ok()?;
let bytes = resp.bytes().ok()?;
serde_json::from_slice(&bytes).ok()
})
.join()
.ok()
.flatten()
.unwrap_or(Value::Empty)
}
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "status")]
pub enum PlanResult {
Success {
output: OutputDelivery,
duration_ms: u64,
#[serde(default)]
states: std::collections::HashMap<String, Value>,
},
Failed {
error: String,
duration_ms: u64,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type")]
#[non_exhaustive]
pub enum StreamMessage {
StreamBegin {
stream_id: String,
plan_id: PlanId,
total_chunks: Option<usize>,
plan: Box<SerializedPlan>,
},
ChunkData {
stream_id: String,
chunk_index: usize,
value: Value,
},
StreamEnd { stream_id: String },
ChunkResult {
stream_id: String,
chunk_index: usize,
value: Value,
},
StreamComplete {
stream_id: String,
result: PlanResult,
},
}
#[cfg(test)]
mod tests {
use super::*;
use somatize_core::event::PlanSummary;
#[test]
fn capabilities_serde() {
let caps = Capabilities {
cpu_cores: 8,
ram_bytes: 32 * 1024 * 1024 * 1024,
gpus: vec![GpuInfo {
name: "A100".into(),
memory_bytes: 80 * 1024 * 1024 * 1024,
}],
python_envs: vec!["py310".into(), "py311".into()],
tags: vec!["gpu".into(), "training".into()],
};
let json = serde_json::to_string(&caps).unwrap();
let deserialized: Capabilities = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.cpu_cores, 8);
assert_eq!(deserialized.gpus.len(), 1);
assert_eq!(deserialized.tags, vec!["gpu", "training"]);
}
#[test]
fn worker_message_serde() {
let msg = WorkerToCoordinator::Register {
worker_id: "worker_01".into(),
capabilities: Capabilities {
cpu_cores: 4,
ram_bytes: 16_000_000_000,
gpus: vec![],
python_envs: vec![],
tags: vec!["cpu".into()],
},
};
let json = serde_json::to_string(&msg).unwrap();
assert!(json.contains("Register"));
let deserialized: WorkerToCoordinator = serde_json::from_str(&json).unwrap();
if let WorkerToCoordinator::Register { worker_id, .. } = deserialized {
assert_eq!(worker_id, "worker_01");
} else {
panic!("wrong variant");
}
}
#[test]
fn coordinator_message_serde() {
let msg = CoordinatorToWorker::AssignPlan {
plan: SerializedPlan {
plan_id: "plan_001".into(),
plan: ExecutionPlan::Execute {
node_id: "train".into(),
},
input: Some(InputSource::Inline {
value: Value::tensor(vec![1.0, 2.0], vec![2]),
}),
filters: vec![],
mode: ExecutionMode::default(),
metadata: serde_json::json!({"experiment": "test"}),
},
};
let json = serde_json::to_string(&msg).unwrap();
let deserialized: CoordinatorToWorker = serde_json::from_str(&json).unwrap();
assert!(matches!(
deserialized,
CoordinatorToWorker::AssignPlan { .. }
));
}
#[test]
fn plan_result_serde() {
let success = PlanResult::Success {
output: OutputDelivery::Inline {
value: Value::tensor(vec![0.95], vec![1]),
},
duration_ms: 1234,
states: std::collections::HashMap::new(),
};
let json = serde_json::to_string(&success).unwrap();
let deserialized: PlanResult = serde_json::from_str(&json).unwrap();
assert!(matches!(deserialized, PlanResult::Success { .. }));
let failed = PlanResult::Failed {
error: "OOM".into(),
duration_ms: 500,
};
let json = serde_json::to_string(&failed).unwrap();
let deserialized: PlanResult = serde_json::from_str(&json).unwrap();
assert!(matches!(deserialized, PlanResult::Failed { .. }));
}
#[test]
fn event_message_serde() {
let msg = WorkerToCoordinator::Event {
worker_id: "w1".into(),
plan_id: "p1".into(),
event: Event::RunStarted {
run_id: "r1".into(),
plan_summary: PlanSummary {
total_nodes: 3,
cached_nodes: 1,
parallel_branches: 0,
},
},
};
let json = serde_json::to_string(&msg).unwrap();
let deserialized: WorkerToCoordinator = serde_json::from_str(&json).unwrap();
assert!(matches!(deserialized, WorkerToCoordinator::Event { .. }));
}
#[test]
fn heartbeat_serde() {
let msg = WorkerToCoordinator::Heartbeat {
worker_id: "w1".into(),
load: LoadMetrics {
cpu_usage: 0.45,
memory_usage: 0.72,
gpu_usage: vec![0.88],
active_plans: 2,
queue_depth: 5,
timestamp: Utc::now(),
},
};
let json = serde_json::to_string(&msg).unwrap();
let deserialized: WorkerToCoordinator = serde_json::from_str(&json).unwrap();
if let WorkerToCoordinator::Heartbeat { load, .. } = deserialized {
assert!(load.cpu_usage > 0.0);
assert_eq!(load.active_plans, 2);
}
}
}