#![doc = r"
Provider-agnostic cloud node provisioning for `ZLayer` autoscaling.
This crate defines the [`CloudProvisioner`] trait and the value types used to
request and describe worker nodes that join a `ZLayer` cluster. The trait plus the
core value types depend only on `async-trait`, `serde`, and `thiserror`, so a
downstream consumer (for example a `ZataCloudDeploy` backend) can implement
the trait against its own cloud SDK without pulling in the reference
implementation's runtime dependencies.
The built-in [`CloudInitProvisioner`] (behind the default `cloud-init` feature)
is a provider-agnostic implementation that shells out to operator-supplied
commands and feeds each node a cloud-init `#cloud-config` that runs
`zlayer node join` on boot. It requires no cloud SDK.
# Identifiers
[`ProviderNodeId`] is the provider-scoped identifier for a node (for example an
`EC2` instance id or an opaque token printed by a provisioning script). It is
deliberately distinct from the raft layer's numeric node id.
"]
use std::collections::BTreeMap;
pub type ProviderNodeId = String;
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub enum CapacityType {
#[default]
OnDemand,
Spot,
}
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct NodeShape {
pub cpu: f64,
pub memory_bytes: u64,
pub gpu: u32,
pub gpu_vendor: Option<String>,
pub labels: BTreeMap<String, String>,
pub zone: Option<String>,
pub capacity_type: CapacityType,
}
impl Default for NodeShape {
fn default() -> Self {
Self {
cpu: 1.0,
memory_bytes: 1024 * 1024 * 1024,
gpu: 0,
gpu_vendor: None,
labels: BTreeMap::new(),
zone: None,
capacity_type: CapacityType::OnDemand,
}
}
}
impl NodeShape {
#[must_use]
pub fn new(cpu: f64, memory_bytes: u64) -> Self {
Self {
cpu,
memory_bytes,
..Self::default()
}
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub enum JoinState {
Provisioning,
Booting,
Joining,
Joined,
Failed,
}
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct NodeHandle {
pub provider_id: ProviderNodeId,
pub address: Option<String>,
pub zone: Option<String>,
pub capacity_type: CapacityType,
pub join_state: JoinState,
}
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct PriceHint {
pub hourly_usd: f64,
pub capacity_type: CapacityType,
}
#[derive(thiserror::Error, Debug)]
pub enum ProvisionerError {
#[error("unsupported: {0}")]
Unsupported(String),
#[error("capacity unavailable: {0}")]
Capacity(String),
#[error("auth: {0}")]
Auth(String),
#[error("transport: {0}")]
Transport(String),
#[error("{0}")]
Other(String),
}
pub type Result<T> = std::result::Result<T, ProvisionerError>;
#[async_trait::async_trait]
pub trait CloudProvisioner: Send + Sync {
async fn provision(&self, shape: &NodeShape) -> Result<NodeHandle>;
#[allow(clippy::ptr_arg)]
async fn terminate(&self, id: &ProviderNodeId) -> Result<()>;
async fn describe(&self) -> Result<Vec<NodeHandle>>;
fn capacity_types(&self) -> &[CapacityType];
fn price_hint(&self, shape: &NodeShape) -> Option<PriceHint>;
fn name(&self) -> &str;
}
#[cfg(feature = "cloud-init")]
pub mod cloud_init;
#[cfg(feature = "cloud-init")]
pub use cloud_init::{CloudInitConfig, CloudInitProvisioner};
#[cfg(test)]
mod tests {
use super::{
CapacityType, JoinState, NodeHandle, NodeShape, PriceHint, ProviderNodeId, ProvisionerError,
};
use serde::{Deserialize, Serialize};
fn assert_serde<T: Serialize + for<'de> Deserialize<'de>>() {}
#[test]
fn value_types_implement_serde() {
assert_serde::<NodeShape>();
assert_serde::<NodeHandle>();
assert_serde::<PriceHint>();
assert_serde::<CapacityType>();
assert_serde::<JoinState>();
}
#[test]
fn node_shape_default_is_one_cpu_one_gib() {
let shape = NodeShape::default();
assert!((shape.cpu - 1.0).abs() < f64::EPSILON);
assert_eq!(shape.memory_bytes, 1024 * 1024 * 1024);
assert_eq!(shape.gpu, 0);
assert!(shape.gpu_vendor.is_none());
assert!(shape.labels.is_empty());
assert!(shape.zone.is_none());
assert_eq!(shape.capacity_type, CapacityType::OnDemand);
}
#[test]
fn node_shape_new_sets_cpu_and_memory() {
let shape = NodeShape::new(4.0, 8 * 1024 * 1024 * 1024);
assert!((shape.cpu - 4.0).abs() < f64::EPSILON);
assert_eq!(shape.memory_bytes, 8 * 1024 * 1024 * 1024);
assert_eq!(shape.capacity_type, CapacityType::OnDemand);
}
#[test]
fn capacity_type_default_is_on_demand() {
assert_eq!(CapacityType::default(), CapacityType::OnDemand);
}
#[test]
fn node_shape_clone_preserves_fields() {
let mut shape = NodeShape::new(2.0, 4 * 1024 * 1024 * 1024);
shape.gpu = 1;
shape.gpu_vendor = Some("nvidia".to_string());
shape.zone = Some("us-east-1a".to_string());
shape.capacity_type = CapacityType::Spot;
shape
.labels
.insert("role".to_string(), "worker".to_string());
let back = shape.clone();
assert!((back.cpu - shape.cpu).abs() < f64::EPSILON);
assert_eq!(back.memory_bytes, shape.memory_bytes);
assert_eq!(back.gpu, shape.gpu);
assert_eq!(back.gpu_vendor, shape.gpu_vendor);
assert_eq!(back.zone, shape.zone);
assert_eq!(back.capacity_type, shape.capacity_type);
assert_eq!(back.labels, shape.labels);
}
#[test]
fn node_handle_carries_state() {
let handle = NodeHandle {
provider_id: "i-0123".to_string(),
address: Some("10.0.0.5".to_string()),
zone: Some("us-east-1a".to_string()),
capacity_type: CapacityType::Spot,
join_state: JoinState::Joining,
};
assert_eq!(handle.provider_id, "i-0123");
assert_eq!(handle.join_state, JoinState::Joining);
assert_eq!(handle.capacity_type, CapacityType::Spot);
}
#[test]
fn price_hint_carries_fields() {
let hint = PriceHint {
hourly_usd: 0.42,
capacity_type: CapacityType::OnDemand,
};
assert!((hint.hourly_usd - 0.42).abs() < f64::EPSILON);
assert_eq!(hint.capacity_type, CapacityType::OnDemand);
}
#[test]
fn provider_node_id_is_string() {
let id: ProviderNodeId = "node-7".to_string();
assert_eq!(id, "node-7");
}
#[test]
fn provisioner_error_display() {
assert_eq!(
ProvisionerError::Capacity("none left".to_string()).to_string(),
"capacity unavailable: none left"
);
assert_eq!(
ProvisionerError::Unsupported("gpu".to_string()).to_string(),
"unsupported: gpu"
);
assert_eq!(
ProvisionerError::Other("boom".to_string()).to_string(),
"boom"
);
}
}