use anyhow::{Context, Result};
use std::path::PathBuf;
use std::time::Duration;
use tokio::process::Command;
use tokio::time::interval;
use tracing::{error, info, warn};
#[cfg(unix)]
use tokio::signal::unix::{signal, SignalKind};
#[derive(Debug)]
struct Config {
bucket: String,
prefix: String,
queue_name: String,
pool_id: String,
supervisor: PathBuf,
poll_interval: Duration,
heartbeat_interval: Duration,
node_id: String,
}
impl Config {
fn from_env() -> Result<Self> {
let bucket =
std::env::var("CELLOS_FLEET_BUCKET").context("CELLOS_FLEET_BUCKET is required")?;
let prefix = std::env::var("CELLOS_FLEET_PREFIX").unwrap_or_else(|_| "fleet".to_string());
let queue_name = std::env::var("CELLOS_FLEET_QUEUE_NAME")
.map(|value| value.trim().to_string())
.unwrap_or_default();
let pool_id = std::env::var("CELLOS_FLEET_POOL_ID")
.map(|value| value.trim().to_string())
.unwrap_or_default();
let supervisor = PathBuf::from(
std::env::var("CELLOS_FLEET_SUPERVISOR")
.unwrap_or_else(|_| "cellos-supervisor".to_string()),
);
let poll_ms: u64 = std::env::var("CELLOS_FLEET_POLL_INTERVAL_MS")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(5000);
let heartbeat_ms: u64 = std::env::var("CELLOS_FLEET_HEARTBEAT_INTERVAL_MS")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(30_000);
let node_id = std::env::var("CELLOS_FLEET_NODE_ID").unwrap_or_else(|_| {
hostname::get()
.ok()
.and_then(|h| h.into_string().ok())
.unwrap_or_else(|| "unknown-node".to_string())
});
Ok(Config {
bucket,
prefix,
queue_name,
pool_id,
supervisor,
poll_interval: Duration::from_millis(poll_ms),
heartbeat_interval: Duration::from_millis(heartbeat_ms),
node_id,
})
}
fn queue_prefix(&self) -> String {
let base_prefix = self.prefix.trim_end_matches('/');
if self.queue_name.is_empty() {
base_prefix.to_string()
} else {
format!("{}/{}/", base_prefix, self.queue_name)
.trim_end_matches('/')
.to_string()
}
}
fn pending_prefix(&self) -> String {
format!("{}/pending/", self.queue_prefix())
}
fn claimed_key(&self, spec_id: &str) -> String {
format!("{}/claimed/{}.json", self.queue_prefix(), spec_id)
}
fn completed_key(&self, spec_id: &str) -> String {
format!("{}/completed/{}.json", self.queue_prefix(), spec_id)
}
fn failed_key(&self, spec_id: &str) -> String {
format!("{}/failed/{}.json", self.queue_prefix(), spec_id)
}
fn should_dispatch(&self, spec_pool_id: Option<&str>) -> bool {
if self.pool_id.is_empty() {
return true;
}
match spec_pool_id {
None => true,
Some(spec_pool) => spec_pool == self.pool_id,
}
}
}
#[derive(Debug, serde::Deserialize)]
#[serde(rename_all = "camelCase")]
struct FleetSpecView {
#[serde(default)]
spec: FleetSpecBody,
}
#[derive(Debug, Default, serde::Deserialize)]
#[serde(rename_all = "camelCase")]
struct FleetSpecBody {
#[serde(default)]
placement: Option<FleetPlacementView>,
}
#[derive(Debug, Default, serde::Deserialize)]
#[serde(rename_all = "camelCase")]
struct FleetPlacementView {
#[serde(default)]
pool_id: Option<String>,
}
fn read_spec_pool_id(spec_path: &std::path::Path) -> Result<Option<String>> {
let bytes = std::fs::read(spec_path)
.with_context(|| format!("read spec file {}", spec_path.display()))?;
let view: FleetSpecView = serde_json::from_slice(&bytes)
.with_context(|| format!("parse spec file {}", spec_path.display()))?;
Ok(view.spec.placement.and_then(|p| p.pool_id))
}
async fn list_pending(cfg: &Config) -> Result<Vec<String>> {
let output = Command::new("aws")
.args([
"s3api",
"list-objects-v2",
"--bucket",
&cfg.bucket,
"--prefix",
&cfg.pending_prefix(),
"--query",
"Contents[?ends_with(Key, '.json')].Key",
"--output",
"json",
])
.output()
.await
.context("aws s3api list-objects-v2 failed")?;
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
warn!("list_pending stderr: {stderr}");
return Ok(vec![]);
}
let stdout = String::from_utf8_lossy(&output.stdout).trim().to_string();
if stdout == "null" || stdout.is_empty() {
return Ok(vec![]);
}
let keys: Vec<String> =
serde_json::from_str(&stdout).context("failed to parse list-objects-v2 output")?;
Ok(keys)
}
async fn try_claim(cfg: &Config, pending_key: &str) -> Result<bool> {
let spec_id = pending_key
.trim_start_matches(&cfg.pending_prefix())
.trim_end_matches(".json");
let src = format!("s3://{}/{}", cfg.bucket, pending_key);
let dst = format!("s3://{}/{}", cfg.bucket, cfg.claimed_key(spec_id));
let status = Command::new("aws")
.args(["s3", "mv", &src, &dst])
.status()
.await
.context("aws s3 mv (claim) failed")?;
Ok(status.success())
}
async fn download_spec(cfg: &Config, spec_id: &str) -> Result<tempfile::NamedTempFile> {
let tmp = tempfile::Builder::new()
.prefix("cellos-fleet-spec-")
.suffix(".json")
.tempfile()
.context("failed to create temp file for spec")?;
let s3_key = format!("s3://{}/{}", cfg.bucket, cfg.claimed_key(spec_id));
let status = Command::new("aws")
.args(["s3", "cp", &s3_key, tmp.path().to_str().unwrap()])
.status()
.await
.context("aws s3 cp (download spec) failed")?;
anyhow::ensure!(status.success(), "aws s3 cp exited non-zero");
Ok(tmp)
}
async fn peek_pending_spec(cfg: &Config, pending_key: &str) -> Result<tempfile::NamedTempFile> {
let tmp = tempfile::Builder::new()
.prefix("cellos-fleet-peek-")
.suffix(".json")
.tempfile()
.context("failed to create temp file for peek")?;
let s3_key = format!("s3://{}/{}", cfg.bucket, pending_key);
let status = Command::new("aws")
.args(["s3", "cp", &s3_key, tmp.path().to_str().unwrap()])
.status()
.await
.context("aws s3 cp (peek pending spec) failed")?;
anyhow::ensure!(status.success(), "aws s3 cp (peek) exited non-zero");
Ok(tmp)
}
async fn run_cell(cfg: &Config, spec_path: &std::path::Path) -> Result<i32> {
let status = Command::new(&cfg.supervisor)
.arg(spec_path)
.status()
.await
.context("cellos-supervisor failed to launch")?;
Ok(status.code().unwrap_or(1))
}
async fn finalize(cfg: &Config, spec_id: &str, exit_code: i32) -> Result<()> {
let src = format!("s3://{}/{}", cfg.bucket, cfg.claimed_key(spec_id));
let dst = if exit_code == 0 {
format!("s3://{}/{}", cfg.bucket, cfg.completed_key(spec_id))
} else {
format!("s3://{}/{}", cfg.bucket, cfg.failed_key(spec_id))
};
Command::new("aws")
.args(["s3", "mv", &src, &dst])
.status()
.await
.context("aws s3 mv (finalize) failed")?;
Ok(())
}
async fn process_spec(cfg: &Config, pending_key: &str) -> Result<()> {
let spec_id = pending_key
.trim_start_matches(&cfg.pending_prefix())
.trim_end_matches(".json");
if !cfg.pool_id.is_empty() {
let peek_tmp = peek_pending_spec(cfg, pending_key).await?;
let spec_pool = read_spec_pool_id(peek_tmp.path()).unwrap_or_else(|e| {
warn!(spec_id, error = %e, "failed to read placement.poolId — treating as no constraint");
None
});
if !cfg.should_dispatch(spec_pool.as_deref()) {
info!(
node = %cfg.node_id,
spec_id,
"skipping spec {}: placement.poolId={} != runner poolId={}",
spec_id,
spec_pool.as_deref().unwrap_or("<none>"),
cfg.pool_id,
);
return Ok(());
}
}
info!(node = %cfg.node_id, spec_id, "claiming spec");
if !try_claim(cfg, pending_key).await? {
info!(spec_id, "spec already claimed by another node, skipping");
return Ok(());
}
info!(node = %cfg.node_id, spec_id, "claimed — downloading");
let tmp = download_spec(cfg, spec_id).await?;
info!(node = %cfg.node_id, spec_id, path = %tmp.path().display(), "running cell");
let exit_code = run_cell(cfg, tmp.path()).await?;
info!(node = %cfg.node_id, spec_id, exit_code, "cell completed — finalizing");
finalize(cfg, spec_id, exit_code).await?;
if exit_code == 0 {
info!(node = %cfg.node_id, spec_id, "spec completed successfully");
} else {
warn!(node = %cfg.node_id, spec_id, exit_code, "spec completed with failure");
}
Ok(())
}
#[cfg(unix)]
async fn wait_for_shutdown_signal() -> Result<()> {
let mut sigterm =
signal(SignalKind::terminate()).context("failed to install SIGTERM handler")?;
sigterm.recv().await;
Ok(())
}
#[cfg(not(unix))]
async fn wait_for_shutdown_signal() -> Result<()> {
tokio::signal::ctrl_c()
.await
.context("failed to install Ctrl+C handler")?;
Ok(())
}
async fn run(cfg: Config) -> Result<()> {
info!(
node = %cfg.node_id,
bucket = %cfg.bucket,
prefix = %cfg.prefix,
queue_name = %cfg.queue_name,
pool_id = %cfg.pool_id,
supervisor = %cfg.supervisor.display(),
poll_interval_ms = cfg.poll_interval.as_millis(),
heartbeat_interval_ms = cfg.heartbeat_interval.as_millis(),
"cellos-fleet agent starting"
);
let mut poll_tick = interval(cfg.poll_interval);
let mut heartbeat_tick = interval(cfg.heartbeat_interval);
let shutdown = wait_for_shutdown_signal();
tokio::pin!(shutdown);
poll_tick.tick().await;
heartbeat_tick.tick().await;
loop {
tokio::select! {
_ = poll_tick.tick() => {
match list_pending(&cfg).await {
Err(e) => error!("list_pending error: {e:#}"),
Ok(keys) if keys.is_empty() => {}
Ok(keys) => {
for key in &keys {
if let Err(e) = process_spec(&cfg, key).await {
error!(key, "process_spec error: {e:#}");
}
}
}
}
}
_ = heartbeat_tick.tick() => {
info!(
event_type = "dev.cellos.events.fleet.v1.heartbeat",
node = %cfg.node_id,
bucket = %cfg.bucket,
prefix = %cfg.prefix,
queue_name = %cfg.queue_name,
pool_id = %cfg.pool_id,
"heartbeat"
);
}
_ = &mut shutdown => {
info!(
node = %cfg.node_id,
"SIGTERM received — draining (no new work accepted)"
);
break;
}
}
}
info!(node = %cfg.node_id, "cellos-fleet agent stopped (drain complete)");
Ok(())
}
const BUILD_SHA: &str = match option_env!("VERGEN_GIT_SHA") {
Some(s) => s,
None => match option_env!("GIT_COMMIT") {
Some(s) => s,
None => "unknown",
},
};
#[tokio::main]
async fn main() -> Result<()> {
let args: Vec<String> = std::env::args().skip(1).collect();
if args
.first()
.map(|a| a == "--version" || a == "-V")
.unwrap_or(false)
{
let sha_short = if BUILD_SHA.len() > 7 {
&BUILD_SHA[..7]
} else {
BUILD_SHA
};
println!("cellos-fleet {} ({})", env!("CARGO_PKG_VERSION"), sha_short);
return Ok(());
}
{
use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::util::SubscriberInitExt;
use tracing_subscriber::Layer;
let fmt_layer = tracing_subscriber::fmt::layer()
.json()
.with_filter(cellos_core::observability::redacted_filter());
tracing_subscriber::registry()
.with(
tracing_subscriber::EnvFilter::try_from_default_env()
.unwrap_or_else(|_| tracing_subscriber::EnvFilter::new("info")),
)
.with(fmt_layer)
.init();
}
let cfg = Config::from_env().context("failed to read configuration")?;
run(cfg).await
}
#[cfg(test)]
mod tests {
use super::Config;
use std::path::PathBuf;
use std::time::Duration;
fn config(prefix: &str, queue_name: &str) -> Config {
config_with_pool(prefix, queue_name, "")
}
fn config_with_pool(prefix: &str, queue_name: &str, pool_id: &str) -> Config {
Config {
bucket: "bucket".into(),
prefix: prefix.into(),
queue_name: queue_name.into(),
pool_id: pool_id.into(),
supervisor: PathBuf::from("cellos-supervisor"),
poll_interval: Duration::from_secs(5),
heartbeat_interval: Duration::from_secs(30),
node_id: "node-a".into(),
}
}
#[test]
fn version_compiles() {
let _ = format!(
"cellos-fleet {} ({})",
env!("CARGO_PKG_VERSION"),
super::BUILD_SHA
);
}
#[test]
fn uses_legacy_layout_when_queue_name_is_empty() {
let cfg = config("fleet", "");
assert_eq!(cfg.pending_prefix(), "fleet/pending/");
assert_eq!(cfg.claimed_key("spec-1"), "fleet/claimed/spec-1.json");
assert_eq!(cfg.completed_key("spec-1"), "fleet/completed/spec-1.json");
assert_eq!(cfg.failed_key("spec-1"), "fleet/failed/spec-1.json");
}
#[test]
fn uses_queue_qualified_layout_when_queue_name_is_set() {
let cfg = config("fleet", "gpu-runners");
assert_eq!(cfg.pending_prefix(), "fleet/gpu-runners/pending/");
assert_eq!(
cfg.claimed_key("spec-1"),
"fleet/gpu-runners/claimed/spec-1.json"
);
assert_eq!(
cfg.completed_key("spec-1"),
"fleet/gpu-runners/completed/spec-1.json"
);
assert_eq!(
cfg.failed_key("spec-1"),
"fleet/gpu-runners/failed/spec-1.json"
);
}
#[test]
fn trims_trailing_slash_from_prefix() {
let cfg = config("fleet/", "gpu-runners");
assert_eq!(cfg.pending_prefix(), "fleet/gpu-runners/pending/");
assert_eq!(
cfg.claimed_key("spec-1"),
"fleet/gpu-runners/claimed/spec-1.json"
);
}
#[test]
fn dispatch_matrix_for_pool_id_placement_gate() {
let unbounded = config_with_pool("fleet", "", "");
assert!(
unbounded.should_dispatch(None),
"no-pool runner must accept specs without a poolId constraint"
);
assert!(
unbounded.should_dispatch(Some("runner-pool-amd64")),
"no-pool runner must accept specs with any poolId constraint"
);
let amd64 = config_with_pool("fleet", "", "runner-pool-amd64");
assert!(
amd64.should_dispatch(None),
"pool-bound runner must accept specs with no poolId constraint"
);
assert!(
amd64.should_dispatch(Some("runner-pool-amd64")),
"pool-bound runner must accept matching poolId"
);
assert!(
!amd64.should_dispatch(Some("runner-pool-arm64")),
"pool-bound runner must skip mismatching poolId"
);
}
#[test]
fn read_spec_pool_id_parses_placement_and_handles_absence() {
use std::io::Write;
let mut with_pool = tempfile::NamedTempFile::new().unwrap();
write!(
with_pool,
r#"{{
"apiVersion": "cellos.io/v1",
"kind": "ExecutionCell",
"spec": {{
"id": "test",
"placement": {{ "poolId": "runner-pool-amd64" }}
}}
}}"#
)
.unwrap();
let pool = super::read_spec_pool_id(with_pool.path()).unwrap();
assert_eq!(pool.as_deref(), Some("runner-pool-amd64"));
let mut without = tempfile::NamedTempFile::new().unwrap();
write!(
without,
r#"{{
"apiVersion": "cellos.io/v1",
"kind": "ExecutionCell",
"spec": {{ "id": "test" }}
}}"#
)
.unwrap();
assert_eq!(super::read_spec_pool_id(without.path()).unwrap(), None);
}
}