use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::net::SocketAddr;
use std::time::Duration;
#[allow(clippy::large_enum_variant)]
#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq, Eq)]
#[serde(tag = "mode", rename_all = "kebab-case", deny_unknown_fields)]
pub enum ClusterMode {
#[default]
#[serde(rename = "single-node")]
SingleNode,
Raft {
node_id: u64,
peers: Vec<RaftPeer>,
},
Static {
node_id: u64,
peers: Vec<StaticPeer>,
#[serde(default = "default_heartbeat_interval", with = "duration_secs")]
heartbeat_interval: Duration,
#[serde(default = "default_failure_threshold", with = "duration_secs")]
failure_threshold: Duration,
},
#[serde(rename = "worker-tier")]
WorkerTier {
role: WorkerTierRole,
#[serde(default, skip_serializing_if = "Option::is_none")]
node_id: Option<u64>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
peers: Vec<RaftPeer>,
#[serde(
default = "default_worker_grpc_addr",
skip_serializing_if = "is_default_worker_grpc_addr"
)]
worker_grpc_addr: SocketAddr,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
servers: Vec<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
token_file: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
identity_dir: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
worker_ca_dir: Option<String>,
#[serde(default = "default_min_ttl", with = "duration_secs")]
heartbeat_min_ttl: Duration,
#[serde(default = "default_max_ttl", with = "duration_secs")]
heartbeat_max_ttl: Duration,
#[serde(default = "default_grace", with = "duration_secs")]
heartbeat_grace: Duration,
#[serde(default = "default_max_hb")]
max_heartbeats_per_second: u32,
#[serde(default = "default_failover_ttl", with = "duration_secs")]
failover_heartbeat_ttl: Duration,
#[serde(default, skip_serializing_if = "HashMap::is_empty")]
labels: HashMap<String, String>,
},
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum WorkerTierRole {
Server,
Worker,
}
impl ClusterMode {
#[must_use]
pub fn adaptive_ttl_config(&self) -> Option<AdaptiveTtlConfig> {
if let ClusterMode::WorkerTier {
heartbeat_min_ttl,
heartbeat_max_ttl,
heartbeat_grace,
max_heartbeats_per_second,
failover_heartbeat_ttl,
..
} = self
{
Some(AdaptiveTtlConfig {
min_ttl_secs: u32::try_from(heartbeat_min_ttl.as_secs()).unwrap_or(u32::MAX),
max_ttl_secs: u32::try_from(heartbeat_max_ttl.as_secs()).unwrap_or(u32::MAX),
grace_secs: u32::try_from(heartbeat_grace.as_secs()).unwrap_or(u32::MAX),
max_heartbeats_per_second: *max_heartbeats_per_second,
failover_ttl_secs: u32::try_from(failover_heartbeat_ttl.as_secs())
.unwrap_or(u32::MAX),
})
} else {
None
}
}
#[must_use]
pub fn is_worker_tier_server(&self) -> bool {
matches!(
self,
ClusterMode::WorkerTier {
role: WorkerTierRole::Server,
..
}
)
}
#[must_use]
pub fn is_worker_tier_worker(&self) -> bool {
matches!(
self,
ClusterMode::WorkerTier {
role: WorkerTierRole::Worker,
..
}
)
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(deny_unknown_fields)]
pub struct RaftPeer {
pub id: u64,
pub raft_addr: SocketAddr,
pub api_addr: SocketAddr,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct ClusterContainerSummary {
pub node_id: u64,
pub id: String,
pub service: String,
pub replica: u32,
pub image: String,
pub state: String,
pub pid: Option<u32>,
pub overlay_ip: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NodeServiceState {
pub node_id: u64,
pub running: u32,
pub healthy: bool,
pub containers: Vec<ClusterContainerSummary>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(deny_unknown_fields)]
pub struct StaticPeer {
pub id: u64,
pub api_addr: SocketAddr,
#[serde(default = "default_os")]
pub os: String,
#[serde(default, skip_serializing_if = "HashMap::is_empty")]
pub labels: HashMap<String, String>,
}
fn default_heartbeat_interval() -> Duration {
Duration::from_secs(5)
}
fn default_failure_threshold() -> Duration {
Duration::from_secs(15)
}
fn default_os() -> String {
"linux".to_string()
}
fn default_worker_grpc_addr() -> SocketAddr {
"0.0.0.0:3670"
.parse()
.expect("hardcoded SocketAddr literal")
}
fn is_default_worker_grpc_addr(addr: &SocketAddr) -> bool {
*addr == default_worker_grpc_addr()
}
fn default_min_ttl() -> Duration {
Duration::from_secs(10)
}
fn default_max_ttl() -> Duration {
Duration::from_secs(600)
}
fn default_grace() -> Duration {
Duration::from_secs(10)
}
fn default_max_hb() -> u32 {
50
}
fn default_failover_ttl() -> Duration {
Duration::from_secs(300)
}
mod duration_secs {
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use std::time::Duration;
pub fn serialize<S>(dur: &Duration, ser: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
dur.as_secs().serialize(ser)
}
pub fn deserialize<'de, D>(de: D) -> Result<Duration, D::Error>
where
D: Deserializer<'de>,
{
let secs = u64::deserialize(de)?;
Ok(Duration::from_secs(secs))
}
}
#[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema)]
pub struct InternalScaleRequest {
pub service: String,
#[serde(default)]
pub replicas: u32,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub assignments: Vec<ScaleAssignment>,
#[serde(default, skip_serializing_if = "Option::is_none")]
#[schema(value_type = Option<Object>)]
pub spec: Option<Box<crate::spec::types::ServiceSpec>>,
}
#[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema)]
pub struct ScaleAssignment {
pub role: String,
pub indices: Vec<u32>,
}
impl InternalScaleRequest {
#[must_use]
pub fn new(service: impl Into<String>, replicas: u32) -> Self {
Self {
service: service.into(),
replicas,
assignments: Vec::new(),
spec: None,
}
}
#[must_use]
pub fn with_spec(mut self, spec: crate::spec::types::ServiceSpec) -> Self {
self.spec = Some(Box::new(spec));
self
}
#[must_use]
pub fn with_assignments(service: impl Into<String>, assignments: Vec<ScaleAssignment>) -> Self {
let replicas: u32 = assignments
.iter()
.map(|a| u32::try_from(a.indices.len()).unwrap_or(u32::MAX))
.sum();
Self {
service: service.into(),
replicas,
assignments,
spec: None,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn default_is_single_node() {
let cfg = ClusterMode::default();
assert_eq!(cfg, ClusterMode::SingleNode);
}
#[test]
fn scale_request_legacy_shape_has_no_spec() {
let req: InternalScaleRequest =
serde_json::from_str(r#"{"service":"web","replicas":3}"#).unwrap();
assert_eq!(req.service, "web");
assert_eq!(req.replicas, 3);
assert!(req.spec.is_none());
assert!(req.assignments.is_empty());
}
#[test]
fn scale_request_with_spec_roundtrips() {
let spec = crate::spec::types::ServiceSpec::default();
let req = InternalScaleRequest::new("web", 3).with_spec(spec);
assert!(req.spec.is_some());
let json = serde_json::to_string(&req).unwrap();
let back: InternalScaleRequest = serde_json::from_str(&json).unwrap();
assert_eq!(back.service, "web");
assert_eq!(back.replicas, 3);
assert!(back.spec.is_some(), "spec must survive the round-trip");
}
#[test]
fn yaml_static_roundtrip() {
let yaml = r"
mode: static
node_id: 2
peers:
- id: 1
api_addr: 10.0.0.10:3669
- id: 2
api_addr: 10.0.0.11:3669
heartbeat_interval: 5
failure_threshold: 15
";
let parsed: ClusterMode = serde_yaml::from_str(yaml).unwrap();
match parsed {
ClusterMode::Static {
node_id,
peers,
heartbeat_interval,
failure_threshold,
} => {
assert_eq!(node_id, 2);
assert_eq!(peers.len(), 2);
assert_eq!(heartbeat_interval, Duration::from_secs(5));
assert_eq!(failure_threshold, Duration::from_secs(15));
}
_ => panic!("expected Static variant"),
}
}
#[test]
fn yaml_single_node_roundtrip() {
let yaml = "mode: single-node";
let parsed: ClusterMode = serde_yaml::from_str(yaml).unwrap();
assert_eq!(parsed, ClusterMode::SingleNode);
}
#[test]
fn internal_scale_request_legacy_shape() {
let json = r#"{"service":"web","replicas":3}"#;
let req: InternalScaleRequest = serde_json::from_str(json).unwrap();
assert_eq!(req.service, "web");
assert_eq!(req.replicas, 3);
assert!(req.assignments.is_empty());
let out = serde_json::to_string(&req).unwrap();
assert!(!out.contains("assignments"), "got: {out}");
assert!(out.contains(r#""service":"web""#));
assert!(out.contains(r#""replicas":3"#));
}
#[test]
fn internal_scale_request_with_assignments_roundtrip() {
let req = InternalScaleRequest::with_assignments(
"db",
vec![
ScaleAssignment {
role: "primary".to_string(),
indices: vec![0],
},
ScaleAssignment {
role: "read".to_string(),
indices: vec![1, 2],
},
],
);
assert_eq!(req.replicas, 3);
let json = serde_json::to_string(&req).unwrap();
let parsed: InternalScaleRequest = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.service, "db");
assert_eq!(parsed.replicas, 3);
assert_eq!(parsed.assignments.len(), 2);
assert_eq!(parsed.assignments[0].role, "primary");
assert_eq!(parsed.assignments[0].indices, vec![0]);
assert_eq!(parsed.assignments[1].role, "read");
assert_eq!(parsed.assignments[1].indices, vec![1, 2]);
}
#[test]
fn internal_scale_request_new_constructs_legacy_shape() {
let req = InternalScaleRequest::new("api", 5);
assert_eq!(req.service, "api");
assert_eq!(req.replicas, 5);
assert!(req.assignments.is_empty());
}
#[test]
fn worker_tier_server_yaml_round_trips() {
let yaml = r"
mode: worker-tier
role: server
node_id: 1
peers:
- id: 1
raft_addr: 10.0.0.1:9001
api_addr: 10.0.0.1:3669
- id: 2
raft_addr: 10.0.0.2:9001
api_addr: 10.0.0.2:3669
- id: 3
raft_addr: 10.0.0.3:9001
api_addr: 10.0.0.3:3669
worker_grpc_addr: 0.0.0.0:3670
worker_ca_dir: /var/lib/zlayer/cluster
heartbeat_min_ttl: 15
heartbeat_max_ttl: 600
heartbeat_grace: 10
max_heartbeats_per_second: 100
failover_heartbeat_ttl: 300
";
let parsed: ClusterMode = serde_yaml::from_str(yaml).unwrap();
assert!(parsed.is_worker_tier_server());
assert!(!parsed.is_worker_tier_worker());
let ttl = parsed.adaptive_ttl_config().expect("ttl");
assert_eq!(ttl.max_heartbeats_per_second, 100);
}
#[test]
fn worker_tier_worker_yaml_round_trips() {
let yaml = r"
mode: worker-tier
role: worker
servers:
- http://10.0.0.1:3670
- http://10.0.0.2:3670
token_file: /etc/zlayer/worker.token
identity_dir: /var/lib/zlayer/worker
";
let parsed: ClusterMode = serde_yaml::from_str(yaml).unwrap();
assert!(parsed.is_worker_tier_worker());
assert!(!parsed.is_worker_tier_server());
let ttl = parsed.adaptive_ttl_config().expect("ttl");
assert_eq!(ttl.min_ttl_secs, 10);
assert_eq!(ttl.max_heartbeats_per_second, 50);
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WorkerRegisterRequest {
pub token: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub desired_node_id: Option<u64>,
pub profile: WorkerProfile,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WorkerProfile {
pub api_addr: SocketAddr,
pub os: String,
pub arch: String,
#[serde(default, skip_serializing_if = "std::collections::HashMap::is_empty")]
pub labels: std::collections::HashMap<String, String>,
#[serde(default)]
pub cpu_total: u32,
#[serde(default)]
pub memory_total_bytes: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WorkerRegisterResponse {
pub node_id: u64,
pub cluster_id: String,
pub heartbeat_ttl_secs: u32,
pub heartbeat_grace_secs: u32,
pub internal_token: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WorkerPollRequest {
pub node_id: u64,
#[serde(default)]
pub last_revision: u64,
#[serde(default = "default_poll_wait_secs")]
pub max_wait_secs: u32,
}
fn default_poll_wait_secs() -> u32 {
30
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WorkerPollResponse {
pub revision: u64,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub events: Vec<WorkerAssignmentEvent>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "kind", rename_all = "snake_case")]
pub enum WorkerAssignmentEvent {
Set {
service: String,
assignments: Vec<ScaleAssignment>,
revision: u64,
},
Delete { service: String, revision: u64 },
Drain { revision: u64 },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WorkerStatusReport {
pub node_id: u64,
pub ts_ns: u64,
#[serde(default)]
pub containers: Vec<WorkerContainerStatus>,
pub resources: WorkerResourceUsage,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WorkerContainerStatus {
pub service: String,
pub role: String,
pub replica: u32,
pub state: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub overlay_ip: Option<std::net::IpAddr>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WorkerResourceUsage {
pub cpu_used: f64,
pub memory_used_bytes: u64,
pub gpu_used: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WorkerStatusAck {
pub next_ttl_secs: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WorkerLease {
pub node_id: u64,
pub profile: WorkerProfile,
pub acquired_unix_secs: i64,
pub renewed_unix_secs: i64,
pub ttl_secs: u32,
pub grace_secs: u32,
}
impl WorkerLease {
#[must_use]
pub fn is_expired(&self, now_unix_secs: i64, grace_secs: u32) -> bool {
let elapsed = now_unix_secs.saturating_sub(self.renewed_unix_secs).max(0);
let elapsed_secs = u64::try_from(elapsed).unwrap_or(0);
elapsed_secs > u64::from(self.ttl_secs).saturating_add(u64::from(grace_secs))
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AdaptiveTtlConfig {
pub min_ttl_secs: u32,
pub max_ttl_secs: u32,
pub grace_secs: u32,
pub max_heartbeats_per_second: u32,
pub failover_ttl_secs: u32,
}
impl Default for AdaptiveTtlConfig {
fn default() -> Self {
Self {
min_ttl_secs: 10,
max_ttl_secs: 600,
grace_secs: 10,
max_heartbeats_per_second: 50,
failover_ttl_secs: 300,
}
}
}
impl AdaptiveTtlConfig {
#[must_use]
pub fn compute_ttl(&self, n_workers: u32) -> u32 {
if self.max_heartbeats_per_second == 0 {
return self.max_ttl_secs;
}
let raw = n_workers.saturating_add(self.max_heartbeats_per_second - 1)
/ self.max_heartbeats_per_second;
raw.clamp(self.min_ttl_secs, self.max_ttl_secs)
}
}
#[cfg(test)]
mod worker_tier_tests {
use super::*;
#[test]
fn adaptive_ttl_scales_with_cluster() {
let cfg = AdaptiveTtlConfig::default();
assert_eq!(cfg.compute_ttl(10), 10);
assert_eq!(cfg.compute_ttl(100), 10);
assert_eq!(cfg.compute_ttl(500), 10);
assert_eq!(cfg.compute_ttl(1000), 20);
assert_eq!(cfg.compute_ttl(10000), 200);
assert_eq!(cfg.compute_ttl(100_000), 600);
}
#[test]
fn worker_lease_expiration() {
let lease = WorkerLease {
node_id: 1,
profile: WorkerProfile {
api_addr: "127.0.0.1:3669".parse().unwrap(),
os: "linux".to_string(),
arch: "x86_64".to_string(),
labels: HashMap::default(),
cpu_total: 4,
memory_total_bytes: 8_000_000_000,
},
acquired_unix_secs: 1000,
renewed_unix_secs: 1000,
ttl_secs: 30,
grace_secs: 10,
};
assert!(!lease.is_expired(1025, 10));
assert!(!lease.is_expired(1040, 10));
assert!(lease.is_expired(1041, 10));
}
}