use std::process::Output;
use crate::{
CapacityType, CloudProvisioner, JoinState, NodeHandle, NodeShape, PriceHint, ProviderNodeId,
ProvisionerError, Result,
};
static SUPPORTED_CAPACITY: &[CapacityType] = &[CapacityType::OnDemand];
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct CloudInitConfig {
pub leader_addr: String,
pub join_token: String,
pub provision_cmd: String,
pub terminate_cmd: String,
pub describe_cmd: Option<String>,
pub user_data_template: String,
pub hourly_usd: f64,
}
#[derive(Clone, Debug)]
pub struct CloudInitProvisioner {
config: CloudInitConfig,
}
fn format_labels(labels: &std::collections::BTreeMap<String, String>) -> String {
labels
.iter()
.map(|(k, v)| format!("{k}={v}"))
.collect::<Vec<_>>()
.join(",")
}
fn capacity_token(capacity: CapacityType) -> &'static str {
match capacity {
CapacityType::OnDemand => "on-demand",
CapacityType::Spot => "spot",
}
}
fn substitute_user_data(
template: &str,
leader_addr: &str,
join_token: &str,
labels: &str,
) -> String {
template
.replace("{leader_addr}", leader_addr)
.replace("{join_token}", join_token)
.replace("{labels}", labels)
}
fn substitute_provision_cmd(template: &str, user_data: &str, shape: &NodeShape) -> String {
let memory_mb = shape.memory_bytes / (1024 * 1024);
template
.replace("{user_data}", user_data)
.replace("{cpu}", &shape.cpu.to_string())
.replace("{memory_mb}", &memory_mb.to_string())
.replace("{labels}", &format_labels(&shape.labels))
.replace("{zone}", shape.zone.as_deref().unwrap_or(""))
.replace("{capacity}", capacity_token(shape.capacity_type))
}
fn substitute_provider_id(template: &str, provider_id: &str) -> String {
template.replace("{provider_id}", provider_id)
}
async fn run_shell(cmd: &str) -> Result<Output> {
tokio::process::Command::new("sh")
.arg("-c")
.arg(cmd)
.output()
.await
.map_err(|e| ProvisionerError::Transport(e.to_string()))
}
fn parse_describe_output(stdout: &str) -> Vec<NodeHandle> {
stdout
.lines()
.map(str::trim)
.filter(|line| !line.is_empty())
.map(|line| {
let mut parts = line.splitn(2, ',');
let provider_id = parts.next().unwrap_or("").trim().to_string();
let address = parts
.next()
.map(str::trim)
.filter(|a| !a.is_empty())
.map(ToString::to_string);
NodeHandle {
provider_id,
address,
zone: None,
capacity_type: CapacityType::OnDemand,
join_state: JoinState::Joined,
}
})
.collect()
}
#[must_use]
pub fn default_user_data_template() -> String {
let advertise = "$(curl -s http://169.254.169.254/latest/meta-data/public-ipv4)";
format!(
"#cloud-config\n\
runcmd:\n \
- zlayer node join {{leader_addr}} --token {{join_token}} --advertise-addr {advertise} --mode full --labels {{labels}} --no-ingress\n"
)
}
impl CloudInitProvisioner {
#[must_use]
pub fn new(config: CloudInitConfig) -> Self {
Self { config }
}
#[must_use]
pub fn render_user_data(&self, shape: &NodeShape) -> String {
substitute_user_data(
&self.config.user_data_template,
&self.config.leader_addr,
&self.config.join_token,
&format_labels(&shape.labels),
)
}
}
#[async_trait::async_trait]
impl CloudProvisioner for CloudInitProvisioner {
async fn provision(&self, shape: &NodeShape) -> Result<NodeHandle> {
let user_data = self.render_user_data(shape);
let cmd = substitute_provision_cmd(&self.config.provision_cmd, &user_data, shape);
tracing::info!(
provisioner = "cloud-init",
cpu = shape.cpu,
memory_bytes = shape.memory_bytes,
zone = shape.zone.as_deref().unwrap_or(""),
"provisioning node"
);
let output = run_shell(&cmd).await?;
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr).trim().to_string();
return Err(ProvisionerError::Capacity(stderr));
}
let provider_id = String::from_utf8_lossy(&output.stdout).trim().to_string();
tracing::info!(provisioner = "cloud-init", %provider_id, "provisioned node");
Ok(NodeHandle {
provider_id,
address: None,
zone: shape.zone.clone(),
capacity_type: shape.capacity_type,
join_state: JoinState::Provisioning,
})
}
#[allow(clippy::ptr_arg)]
async fn terminate(&self, id: &ProviderNodeId) -> Result<()> {
let cmd = substitute_provider_id(&self.config.terminate_cmd, id);
tracing::info!(provisioner = "cloud-init", provider_id = %id, "terminating node");
let output = run_shell(&cmd).await?;
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr).trim().to_string();
return Err(ProvisionerError::Transport(stderr));
}
Ok(())
}
async fn describe(&self) -> Result<Vec<NodeHandle>> {
let Some(cmd) = self.config.describe_cmd.as_deref() else {
return Ok(Vec::new());
};
let output = run_shell(cmd).await?;
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr).trim().to_string();
return Err(ProvisionerError::Transport(stderr));
}
let stdout = String::from_utf8_lossy(&output.stdout);
Ok(parse_describe_output(&stdout))
}
fn capacity_types(&self) -> &[CapacityType] {
SUPPORTED_CAPACITY
}
fn price_hint(&self, shape: &NodeShape) -> Option<PriceHint> {
Some(PriceHint {
hourly_usd: self.config.hourly_usd,
capacity_type: shape.capacity_type,
})
}
fn name(&self) -> &'static str {
"cloud-init"
}
}
#[cfg(test)]
mod tests {
use super::{
capacity_token, default_user_data_template, format_labels, parse_describe_output,
substitute_provider_id, substitute_provision_cmd, substitute_user_data, CloudInitConfig,
CloudInitProvisioner,
};
use crate::{CapacityType, CloudProvisioner, JoinState, NodeShape};
fn sample_config() -> CloudInitConfig {
CloudInitConfig {
leader_addr: "10.0.0.1:3669".to_string(),
join_token: "tok-abc".to_string(),
provision_cmd: "echo i-123".to_string(),
terminate_cmd: "echo gone {provider_id}".to_string(),
describe_cmd: None,
user_data_template: default_user_data_template(),
hourly_usd: 0.10,
}
}
#[test]
fn default_template_contains_join_line() {
let tpl = default_user_data_template();
assert!(tpl.contains("zlayer node join"));
assert!(tpl.contains("#cloud-config"));
assert!(tpl.contains("--mode full"));
assert!(tpl.contains("--no-ingress"));
assert!(tpl.contains("169.254.169.254"));
}
#[test]
fn render_user_data_substitutes_placeholders() {
let provisioner = CloudInitProvisioner::new(sample_config());
let mut shape = NodeShape::new(2.0, 4 * 1024 * 1024 * 1024);
shape
.labels
.insert("role".to_string(), "worker".to_string());
let rendered = provisioner.render_user_data(&shape);
assert!(rendered.contains("10.0.0.1:3669"));
assert!(rendered.contains("tok-abc"));
assert!(rendered.contains("role=worker"));
assert!(!rendered.contains("{leader_addr}"));
assert!(!rendered.contains("{join_token}"));
assert!(!rendered.contains("{labels}"));
}
#[test]
fn substitute_user_data_replaces_all() {
let out = substitute_user_data(
"join {leader_addr} tok {join_token} lbl {labels}",
"host:1",
"secret",
"a=b",
);
assert_eq!(out, "join host:1 tok secret lbl a=b");
}
#[test]
fn format_labels_is_sorted_and_joined() {
let mut labels = std::collections::BTreeMap::new();
labels.insert("z".to_string(), "1".to_string());
labels.insert("a".to_string(), "2".to_string());
assert_eq!(format_labels(&labels), "a=2,z=1");
assert_eq!(format_labels(&std::collections::BTreeMap::new()), "");
}
#[test]
fn capacity_token_maps_variants() {
assert_eq!(capacity_token(CapacityType::OnDemand), "on-demand");
assert_eq!(capacity_token(CapacityType::Spot), "spot");
}
#[test]
fn substitute_provision_cmd_fills_shape_fields() {
let mut shape = NodeShape::new(2.0, 2048 * 1024 * 1024);
shape.zone = Some("us-east-1a".to_string());
shape.capacity_type = CapacityType::Spot;
shape.labels.insert("k".to_string(), "v".to_string());
let cmd = substitute_provision_cmd(
"run --ud '{user_data}' --cpu {cpu} --mem {memory_mb} --labels {labels} --zone {zone} --cap {capacity}",
"USERDATA",
&shape,
);
assert!(cmd.contains("--ud 'USERDATA'"));
assert!(cmd.contains("--cpu 2"));
assert!(cmd.contains("--mem 2048"));
assert!(cmd.contains("--labels k=v"));
assert!(cmd.contains("--zone us-east-1a"));
assert!(cmd.contains("--cap spot"));
assert!(!cmd.contains('{'));
}
#[test]
fn substitute_provision_cmd_empty_zone() {
let shape = NodeShape::new(1.0, 1024 * 1024 * 1024);
let cmd = substitute_provision_cmd("z=[{zone}]", "ud", &shape);
assert_eq!(cmd, "z=[]");
}
#[test]
fn substitute_provider_id_replaces() {
assert_eq!(
substitute_provider_id("delete {provider_id} now", "i-9"),
"delete i-9 now"
);
}
#[test]
fn parse_describe_output_handles_id_and_address() {
let out = parse_describe_output("i-1,10.0.0.1\n i-2 \n\ni-3, 10.0.0.3 \n");
assert_eq!(out.len(), 3);
assert_eq!(out[0].provider_id, "i-1");
assert_eq!(out[0].address.as_deref(), Some("10.0.0.1"));
assert_eq!(out[0].join_state, JoinState::Joined);
assert_eq!(out[1].provider_id, "i-2");
assert!(out[1].address.is_none());
assert_eq!(out[2].provider_id, "i-3");
assert_eq!(out[2].address.as_deref(), Some("10.0.0.3"));
}
#[test]
fn parse_describe_output_empty() {
assert!(parse_describe_output("\n \n").is_empty());
}
#[test]
fn metadata_methods() {
let provisioner = CloudInitProvisioner::new(sample_config());
assert_eq!(provisioner.name(), "cloud-init");
assert_eq!(provisioner.capacity_types(), &[CapacityType::OnDemand]);
let shape = NodeShape::new(1.0, 1024 * 1024 * 1024);
let hint = provisioner.price_hint(&shape).expect("price hint");
assert!((hint.hourly_usd - 0.10).abs() < f64::EPSILON);
assert_eq!(hint.capacity_type, CapacityType::OnDemand);
}
}