use clap::Parser;
use somatize_worker::detect::ResourceLimits;
use somatize_worker::protocol::Capabilities;
use somatize_worker::worker::Worker;
#[derive(Parser, Debug)]
#[command(name = "soma-worker", about = "Soma distributed execution worker")]
struct Args {
#[arg(short, long, default_value = "8080")]
port: u16,
#[arg(long)]
cpus: Option<usize>,
#[arg(long)]
memory: Option<String>,
#[arg(long)]
gpus: Option<usize>,
#[arg(long, default_value = "4")]
max_concurrent: usize,
#[arg(long, value_delimiter = ',')]
tags: Vec<String>,
#[arg(long, env = "SOMA_TOKEN")]
token: Option<String>,
#[arg(long, env = "SOMA_COORDINATOR")]
coordinator: Option<String>,
#[arg(long)]
id: Option<String>,
#[arg(long, default_value = "/tmp/soma-envs")]
env_dir: String,
#[arg(long, default_value = "/tmp/soma-work")]
work_dir: String,
#[arg(long)]
temp_dir: Option<String>,
#[arg(long, env = "SOMA_DATA_STORE")]
data_store: Option<String>,
}
fn parse_memory(s: &str) -> u64 {
let s = s.trim().to_uppercase();
if let Some(n) = s.strip_suffix('G') {
n.parse::<u64>().unwrap_or(0) * 1024 * 1024 * 1024
} else if let Some(n) = s.strip_suffix('M') {
n.parse::<u64>().unwrap_or(0) * 1024 * 1024
} else if let Some(n) = s.strip_suffix('T') {
n.parse::<u64>().unwrap_or(0) * 1024 * 1024 * 1024 * 1024
} else {
s.parse::<u64>().unwrap_or(0) }
}
#[tokio::main]
async fn main() {
tracing_subscriber::fmt::init();
let args = Args::parse();
let mut caps = Capabilities::detect();
let limits = ResourceLimits {
max_cpus: args.cpus,
max_memory_bytes: args.memory.as_deref().map(parse_memory),
max_gpus: args.gpus,
max_concurrent: args.max_concurrent,
};
caps = caps.with_limits(&limits);
for tag in &args.tags {
if !caps.tags.contains(tag) {
caps.tags.push(tag.clone());
}
}
let worker_id = args.id.unwrap_or_else(|| {
hostname::get()
.map(|h| h.to_string_lossy().to_string())
.unwrap_or_else(|_| format!("worker_{}", std::process::id()))
});
tracing::info!("Starting worker: {worker_id}");
tracing::info!("Capabilities: {}", caps.summary());
let mut worker = Worker::new(&worker_id, caps.clone());
if let Some(temp_dir) = args.temp_dir {
worker = worker.with_temp_dir(temp_dir.into());
}
if let Some(store_path) = &args.data_store {
let store = somatize_core::store::LocalDataStore::new(store_path);
worker = worker.with_data_store(std::sync::Arc::new(store));
tracing::info!("DataStore configured: {store_path}");
}
let addr = format!("0.0.0.0:{}", args.port);
if let Some(coordinator_url) = &args.coordinator {
let url = format!("{coordinator_url}/register");
let body = serde_json::json!({
"worker_id": worker_id,
"address": format!("ws://{}:{}", local_ip(), args.port),
"capabilities": caps,
});
let mut request = reqwest::Client::new().post(&url).json(&body);
if let Some(token) = &args.token {
request = request.query(&[("token", token.as_str())]);
}
match request.send().await {
Ok(resp) if resp.status().is_success() => {
tracing::info!("Registered with coordinator at {coordinator_url}");
}
Ok(resp) => {
tracing::warn!(
"Coordinator registration failed: {} {}",
resp.status(),
resp.text().await.unwrap_or_default()
);
}
Err(e) => {
tracing::warn!("Could not reach coordinator at {coordinator_url}: {e}");
}
}
}
let shutdown = async {
tokio::signal::ctrl_c()
.await
.expect("failed to listen for Ctrl+C");
tracing::info!("Ctrl+C received, shutting down...");
};
if let Some(token) = args.token {
tracing::info!("Authentication enabled");
let listener = tokio::net::TcpListener::bind(&addr).await.unwrap();
tracing::info!("Worker listening on {addr}");
let router = somatize_worker::worker_router_authenticated(
worker,
&args.env_dir,
&args.work_dir,
&token,
);
axum::serve(listener, router)
.with_graceful_shutdown(shutdown)
.await
.unwrap();
} else {
let listener = tokio::net::TcpListener::bind(&addr).await.unwrap();
tracing::info!("Worker listening on {addr}");
let router = somatize_worker::worker_router(worker);
axum::serve(listener, router)
.with_graceful_shutdown(shutdown)
.await
.unwrap();
}
tracing::info!("Worker stopped.");
}
fn local_ip() -> String {
std::net::UdpSocket::bind("0.0.0.0:0")
.and_then(|s| {
s.connect("8.8.8.8:80")?;
s.local_addr()
})
.map(|addr| addr.ip().to_string())
.unwrap_or_else(|_| "127.0.0.1".to_string())
}