use std::collections::{HashMap, VecDeque};
use std::env;
use std::fs;
use std::io::{self, BufRead, BufReader, Read, Write};
use std::net::{TcpListener, TcpStream};
use std::os::fd::{AsRawFd, IntoRawFd};
use std::os::unix::net::UnixStream;
use std::path::{Path, PathBuf};
use std::process::{Child, Command, Stdio};
use std::sync::{Arc, Mutex, OnceLock};
use std::thread;
use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
use serde_json::Value;
#[cfg(target_os = "macos")]
fn pin_to_p_core() {
const QOS_CLASS_USER_INTERACTIVE: u32 = 0x21;
unsafe extern "C" {
fn pthread_set_qos_class_self_np(qos_class: u32, relative_priority: i32) -> i32;
}
unsafe {
let _ = pthread_set_qos_class_self_np(QOS_CLASS_USER_INTERACTIVE, 0);
}
}
#[cfg(not(target_os = "macos"))]
fn pin_to_p_core() {}
fn metadata_runtime_is_known(s: Option<&str>) -> bool {
matches!(
s,
None | Some("supermachine") | Some("supermachine-worker")
)
}
#[derive(Clone, Debug, Default)]
struct SnapshotMeta {
name: String,
init_cpio: Option<PathBuf>,
layers: Vec<PathBuf>,
volumes: Vec<(String, String)>,
restart_policy: String,
health_cmd: String,
health_interval_secs: u32,
delta_squashfs: Option<PathBuf>,
rootfs_squashfs: Option<PathBuf>,
snapshot_base: Option<PathBuf>,
kernel: Option<PathBuf>,
memory_mib: u32,
vcpus: u32,
egress_policy: Option<String>,
egress_bps: Option<u64>,
cpu_nice: Option<i32>,
cpu_affinity: Option<i32>,
cpu_qos: Option<String>,
balloon_target_pages: Option<u32>,
auth: Option<Value>,
ttl_seconds: Option<u64>,
baked_at: Option<String>,
runtime_sha16: Option<String>,
tsi_token: Option<String>,
}
#[derive(Default)]
struct Metrics {
requests_total: std::sync::atomic::AtomicU64,
auth_failures_total: std::sync::atomic::AtomicU64,
upstream_errors_total: std::sync::atomic::AtomicU64,
latency_us_sum: std::sync::atomic::AtomicU64,
latency_us_count: std::sync::atomic::AtomicU64,
profile_requests: std::sync::atomic::AtomicU64,
profile_read_request_us: std::sync::atomic::AtomicU64,
profile_route_auth_us: std::sync::atomic::AtomicU64,
profile_open_us: std::sync::atomic::AtomicU64,
profile_upstream_write_us: std::sync::atomic::AtomicU64,
profile_upstream_read_us: std::sync::atomic::AtomicU64,
profile_client_write_us: std::sync::atomic::AtomicU64,
}
struct Worker {
sock_path: PathBuf,
handoff_sock_path: Option<PathBuf>,
proc: Mutex<Child>,
process_spawn_ms: u128,
socket_ready_ms: u128,
restore_us: Option<u128>,
pool: Mutex<VecDeque<UnixStream>>,
handoff_pool: Mutex<VecDeque<UnixStream>>,
hits: std::sync::atomic::AtomicU64,
raw_failures: std::sync::atomic::AtomicUsize,
restart_count: std::sync::atomic::AtomicU32,
restart_window_start_ms: std::sync::atomic::AtomicU64,
last_exit_status: std::sync::atomic::AtomicI32,
}
const POOL_MAX_PER_WORKER: usize = 32;
impl Worker {
fn pop_conn(&self) -> Option<UnixStream> {
self.pool.lock().unwrap().pop_front()
}
fn open_conn(&self) -> io::Result<UnixStream> {
let s = UnixStream::connect(&self.sock_path)?;
s.set_read_timeout(Some(Duration::from_secs(60)))?;
s.set_write_timeout(Some(Duration::from_secs(60)))?;
Ok(s)
}
fn pop_handoff_conn(&self) -> Option<UnixStream> {
self.handoff_pool.lock().unwrap().pop_front()
}
fn push_handoff_conn(&self, s: UnixStream) {
let mut p = self.handoff_pool.lock().unwrap();
if p.len() < POOL_MAX_PER_WORKER {
p.push_back(s);
}
}
fn open_handoff_conn(&self) -> io::Result<UnixStream> {
let path = self.handoff_sock_path.as_ref().ok_or_else(|| {
io::Error::new(
io::ErrorKind::Unsupported,
"worker has no handoff socket path",
)
})?;
let s = UnixStream::connect(path)?;
s.set_write_timeout(Some(Duration::from_secs(60)))?;
Ok(s)
}
}
struct Snapshot {
meta: SnapshotMeta,
workers: Vec<Arc<Worker>>,
next_idx: std::sync::atomic::AtomicUsize,
health_status: Mutex<String>,
}
impl Snapshot {
fn pick_worker(&self) -> Option<Arc<Worker>> {
if self.workers.is_empty() {
return None;
}
let i = self
.next_idx
.fetch_add(1, std::sync::atomic::Ordering::Relaxed)
% self.workers.len();
Some(self.workers[i].clone())
}
}
type Snapshots = HashMap<String, Arc<Snapshot>>;
fn env_truthy(name: &str) -> bool {
matches!(
std::env::var(name).as_deref(),
Ok("1") | Ok("true") | Ok("yes") | Ok("on")
)
}
fn log_epoch_enabled() -> bool {
static ENABLED: OnceLock<bool> = OnceLock::new();
*ENABLED.get_or_init(|| env_truthy("SUPERMACHINE_ROUTER_LOG_EPOCH"))
}
fn log(msg: &str) {
if log_epoch_enabled() {
let now_ms = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_millis())
.unwrap_or(0);
println!("[router epoch_ms={now_ms}] {msg}");
} else {
println!("[router] {msg}");
}
}
fn parse_restore_us(log_path: &Path) -> Option<u128> {
let text = fs::read_to_string(log_path).ok()?;
for line in text.lines().rev() {
if let Some(rest) = line.trim().strip_prefix("restored in ") {
let us = rest.split_ascii_whitespace().next()?;
if let Ok(n) = us.parse::<u128>() {
return Some(n);
}
}
}
None
}
fn access_log_enabled() -> bool {
static ENABLED: OnceLock<bool> = OnceLock::new();
*ENABLED.get_or_init(|| {
matches!(
env::var("SUPERMACHINE_ROUTER_ACCESS_LOG").as_deref(),
Ok("1") | Ok("true") | Ok("yes") | Ok("on")
)
})
}
fn profile_enabled() -> bool {
static ENABLED: OnceLock<bool> = OnceLock::new();
*ENABLED.get_or_init(|| {
matches!(
env::var("SUPERMACHINE_ROUTER_PROFILE").as_deref(),
Ok("1") | Ok("true") | Ok("yes") | Ok("on")
)
})
}
fn raw_tunnel_enabled() -> bool {
static ENABLED: OnceLock<bool> = OnceLock::new();
*ENABLED.get_or_init(|| {
!matches!(
env::var("SUPERMACHINE_ROUTER_RAW_TUNNEL").as_deref(),
Ok("0") | Ok("false") | Ok("no") | Ok("off")
)
})
}
fn scm_handoff_enabled() -> bool {
static ENABLED: OnceLock<bool> = OnceLock::new();
*ENABLED.get_or_init(|| {
!matches!(
env::var("SUPERMACHINE_ROUTER_HANDOFF").as_deref(),
Ok("0") | Ok("false") | Ok("no") | Ok("off")
)
})
}
fn default_workers() -> usize {
std::thread::available_parallelism()
.map(|n| n.get().saturating_mul(2).clamp(4, 16))
.unwrap_or(4)
}
fn file_sha16(path: &Path) -> io::Result<String> {
let bytes = fs::read(path)?;
let hex = sha256_hex(&bytes);
Ok(hex[..16].to_owned())
}
fn read_metadata(meta_path: &Path) -> Option<SnapshotMeta> {
let text = fs::read_to_string(meta_path).ok()?;
let v: Value = serde_json::from_str(&text).ok()?;
let obj = v.as_object()?;
let pathbuf = |k: &str| obj.get(k).and_then(|x| x.as_str()).map(PathBuf::from);
let string = |k: &str| {
obj.get(k)
.and_then(|x| x.as_str())
.map(str::to_owned)
.filter(|s| !s.is_empty())
};
let u32_or = |k: &str, d: u32| {
obj.get(k)
.and_then(|x| x.as_u64())
.map(|n| n as u32)
.unwrap_or(d)
};
let u64_opt = |k: &str| obj.get(k).and_then(|x| x.as_u64());
let i32_opt = |k: &str| obj.get(k).and_then(|x| x.as_i64()).map(|n| n as i32);
let mut m = SnapshotMeta::default();
m.name = string("name").unwrap_or_default();
m.init_cpio = pathbuf("init_cpio");
m.delta_squashfs = pathbuf("delta_squashfs");
m.rootfs_squashfs = pathbuf("rootfs_squashfs");
m.snapshot_base = pathbuf("snapshot_base");
m.kernel = pathbuf("kernel");
if let Some(arr) = obj.get("layers").and_then(|x| x.as_array()) {
for x in arr {
if let Some(s) = x.as_str() {
m.layers.push(PathBuf::from(s));
}
}
}
if let Some(arr) = obj.get("volumes").and_then(|x| x.as_array()) {
for x in arr {
let host = x.get("host_file").and_then(|h| h.as_str());
let guest = x.get("guest_path").and_then(|g| g.as_str());
if let (Some(h), Some(g)) = (host, guest) {
m.volumes.push((h.to_owned(), g.to_owned()));
}
}
}
m.restart_policy = obj
.get("restart_policy")
.and_then(|v| v.as_str())
.map(|s| s.to_owned())
.unwrap_or_else(|| "no".to_owned());
m.health_cmd = obj
.get("health_cmd")
.and_then(|v| v.as_str())
.map(|s| s.to_owned())
.unwrap_or_default();
m.health_interval_secs = obj
.get("health_interval_secs")
.and_then(|v| v.as_u64())
.map(|n| n as u32)
.unwrap_or(0);
m.memory_mib = u32_or("memory_mib", 512);
m.vcpus = u32_or("vcpus", 1);
m.egress_policy = string("egress_policy");
m.egress_bps = u64_opt("egress_bps");
m.cpu_nice = i32_opt("cpu_nice");
m.cpu_affinity = i32_opt("cpu_affinity");
m.cpu_qos = string("cpu_qos");
m.balloon_target_pages = u64_opt("balloon_target_pages").map(|n| n as u32);
m.tsi_token = string("tsi_token")
.filter(|s| s.len() == 64 && s.bytes().all(|b| b.is_ascii_hexdigit()))
.map(|s| s.to_ascii_lowercase());
m.auth = obj.get("auth").cloned();
m.ttl_seconds = u64_opt("ttl_seconds");
m.baked_at = string("baked_at");
m.runtime_sha16 = string("runtime_sha16");
let raw_runtime = obj.get("runtime").and_then(|x| x.as_str());
if !metadata_runtime_is_known(raw_runtime) {
log(&format!(
"skip metadata {}: unsupported runtime={:?} (only supermachine)",
meta_path.display(),
raw_runtime,
));
return None;
}
Some(m)
}
fn discover(snapshots_dir: &Path) -> HashMap<String, SnapshotMeta> {
let mut out = HashMap::new();
let dir = match fs::read_dir(snapshots_dir) {
Ok(d) => d,
Err(_) => return out,
};
for entry in dir.flatten() {
let mp = entry.path().join("metadata.json");
if !mp.exists() {
continue;
}
let m = match read_metadata(&mp) {
Some(m) => m,
None => continue,
};
let snap_base = match &m.snapshot_base {
Some(p) => p.clone(),
None => continue,
};
if !snap_base.exists() {
log(&format!(
"skip {}: snapshot file missing ({:?})",
m.name, snap_base
));
continue;
}
out.insert(m.name.clone(), m);
}
out
}
fn spawn_worker(
worker_bin: &Path,
socks_dir: &Path,
name: &str,
idx: usize,
meta: &SnapshotMeta,
) -> io::Result<Arc<Worker>> {
let spawn_start = Instant::now();
if std::env::var("SUPERMACHINE_ALLOW_STALE_SNAPSHOTS").as_deref() != Ok("1") {
let current = file_sha16(worker_bin)?;
match meta.runtime_sha16.as_deref() {
Some(expected) if expected == current => {}
Some(expected) => {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!(
"snapshot {name} runtime mismatch: metadata runtime_sha16={expected}, current={current}; rebake with tools/supermachine-push or set SUPERMACHINE_ALLOW_STALE_SNAPSHOTS=1"
),
));
}
None => {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!(
"snapshot {name} missing runtime_sha16; rebake with tools/supermachine-push or set SUPERMACHINE_ALLOW_STALE_SNAPSHOTS=1"
),
));
}
}
}
let sock_path = socks_dir.join(format!("{}-w{}.sock", name, idx));
if sock_path.exists() {
let _ = fs::remove_file(&sock_path);
}
let log_path = PathBuf::from(format!("/tmp/supermachine-router-{}-w{}.log", name, idx));
let mut cmd = Command::new(worker_bin);
let base = meta.snapshot_base.as_ref().ok_or_else(|| {
io::Error::new(
io::ErrorKind::InvalidData,
format!("snapshot {name} missing snapshot_base"),
)
})?;
if !meta.layers.is_empty() && meta.delta_squashfs.is_some() {
for l in &meta.layers {
cmd.arg("--virtio-blk").arg(l);
}
if let Some(d) = &meta.delta_squashfs {
cmd.arg("--virtio-blk").arg(d);
}
} else if let Some(r) = &meta.rootfs_squashfs {
cmd.arg("--virtio-blk").arg(r);
}
for (host, guest) in &meta.volumes {
cmd.arg("--volume").arg(format!("{host}:{guest}"));
}
cmd.arg("--memory").arg(meta.memory_mib.to_string());
cmd.arg("--vcpus").arg(meta.vcpus.to_string());
cmd.arg("--restore-from").arg(base);
cmd.arg("--cow-restore");
cmd.arg("--vsock-mux").arg(&sock_path);
if scm_handoff_enabled() {
let mut p = sock_path.clone();
p.set_extension("handoff");
if p.exists() {
let _ = fs::remove_file(&p);
}
cmd.arg("--vsock-mux-handoff").arg(&p);
}
{
let mut p = sock_path.clone();
let mut name = p.file_name().unwrap().to_owned();
name.push("-exec");
p.set_file_name(name);
if p.exists() {
let _ = fs::remove_file(&p);
}
cmd.arg("--vsock-exec").arg(&p);
}
if let Some(pol) = &meta.egress_policy {
cmd.arg("--egress-policy").arg(pol);
}
if let Some(hex) = &meta.tsi_token {
cmd.arg("--tsi-token").arg(hex);
}
cmd.env("SUPERMACHINE_ENABLE_BALLOON", "1");
if let Some(p) = meta.balloon_target_pages {
cmd.env("SUPERMACHINE_BALLOON_TARGET_PAGES", p.to_string());
}
if let Some(pol) = &meta.egress_policy {
cmd.env("SUPERMACHINE_TSI_POLICY", pol);
}
if let Some(bps) = meta.egress_bps {
cmd.env("SUPERMACHINE_TSI_EGRESS_BPS", bps.to_string());
}
if !matches!(
env::var("SUPERMACHINE_WORKER_CONSOLE_LOG").as_deref(),
Ok("1") | Ok("true") | Ok("yes") | Ok("on")
) {
cmd.env("SUPERMACHINE_CONSOLE_LOG", "0");
}
let logf = fs::File::create(&log_path)?;
let logf2 = logf.try_clone()?;
cmd.stdout(Stdio::from(logf));
cmd.stderr(Stdio::from(logf2));
let proc = cmd.spawn()?;
let process_spawn_ms = spawn_start.elapsed().as_millis();
let deadline = Instant::now() + Duration::from_secs(30);
while Instant::now() < deadline {
if sock_path.exists() {
let socket_ready_ms = spawn_start.elapsed().as_millis();
let restore_us = parse_restore_us(&log_path);
let handoff_sock_path = if scm_handoff_enabled() {
let mut p = sock_path.clone();
p.set_extension("handoff");
let hdeadline = Instant::now() + Duration::from_secs(2);
while !p.exists() && Instant::now() < hdeadline {
std::thread::sleep(Duration::from_millis(20));
}
if p.exists() {
Some(p)
} else {
log(&format!(
"warning: snapshot {name}-w{idx} handoff socket did not appear; falling back to in-router proxy"
));
None
}
} else {
None
};
return Ok(Arc::new(Worker {
sock_path,
handoff_sock_path,
proc: Mutex::new(proc),
process_spawn_ms,
socket_ready_ms,
restore_us,
pool: Mutex::new(VecDeque::with_capacity(POOL_MAX_PER_WORKER)),
handoff_pool: Mutex::new(VecDeque::with_capacity(POOL_MAX_PER_WORKER)),
hits: std::sync::atomic::AtomicU64::new(0),
raw_failures: std::sync::atomic::AtomicUsize::new(0),
restart_count: std::sync::atomic::AtomicU32::new(0),
restart_window_start_ms: std::sync::atomic::AtomicU64::new(0),
last_exit_status: std::sync::atomic::AtomicI32::new(0),
}));
}
std::thread::sleep(Duration::from_millis(50));
}
Err(io::Error::new(
io::ErrorKind::TimedOut,
format!("worker {name}-w{idx}: vsock-mux sock did not appear"),
))
}
fn respawn_worker_in_place(
worker_bin: &Path,
socks_dir: &Path,
name: &str,
idx: usize,
meta: &SnapshotMeta,
worker: &Worker,
) -> io::Result<()> {
let new = spawn_worker(worker_bin, socks_dir, name, idx, meta)?;
let mut new_lock = new.proc.lock().unwrap();
let mut old_lock = worker.proc.lock().unwrap();
std::mem::swap(&mut *old_lock, &mut *new_lock);
drop(old_lock);
drop(new_lock);
let mut q = worker.pool.lock().unwrap();
for s in q.drain(..) {
drop(s);
}
Ok(())
}
fn spawn_snapshot_workers(
worker_bin: &Path,
socks_dir: &Path,
name: &str,
meta: &SnapshotMeta,
workers_per_snapshot: usize,
) -> Vec<Arc<Worker>> {
let handles: Vec<_> = (0..workers_per_snapshot)
.map(|i| {
let worker_bin = worker_bin.to_path_buf();
let socks_dir = socks_dir.to_path_buf();
let name = name.to_owned();
let meta = meta.clone();
std::thread::spawn(move || {
spawn_worker(&worker_bin, &socks_dir, &name, i, &meta)
})
})
.collect();
let mut workers = Vec::new();
for (i, h) in handles.into_iter().enumerate() {
match h.join().unwrap() {
Ok(w) => {
log(&format!(
"snapshot {name} (supermachine) worker {i} ready at {:?} spawn={}ms socket_ready={}ms restore={}us",
w.sock_path,
w.process_spawn_ms,
w.socket_ready_ms,
w.restore_us
.map(|n| n.to_string())
.unwrap_or_else(|| "N/A".to_owned())
));
workers.push(w);
}
Err(e) => log(&format!("snapshot {name} worker {i}: {e}")),
}
}
workers
}
fn read_http_request<S: Read>(
br: &mut BufReader<&mut S>,
) -> io::Result<Option<(Vec<u8>, String, String, bool)>> {
let mut head: Vec<u8> = Vec::with_capacity(2048);
let sep = loop {
let buf = br.fill_buf()?;
if buf.is_empty() {
if head.is_empty() {
return Ok(None);
}
return Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
"EOF mid-request-headers",
));
}
let scan_start = head.len().saturating_sub(3);
head.extend_from_slice(buf);
let consumed = buf.len();
br.consume(consumed);
if let Some(p) = head[scan_start..].windows(4).position(|w| w == b"\r\n\r\n") {
break scan_start + p + 4;
}
if head.len() > 64 * 1024 {
return Err(io::Error::new(io::ErrorKind::InvalidData, "header too big"));
}
};
let body_already = head.split_off(sep);
let head_str = std::str::from_utf8(&head)
.map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "non-utf8 headers"))?;
let mut lines = head_str.split("\r\n");
let req_line = lines.next().unwrap_or("");
let mut parts = req_line.split(' ');
let method = parts.next().unwrap_or("").to_owned();
let path = parts.next().unwrap_or("").to_owned();
if method.is_empty() || path.is_empty() {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"bad request line",
));
}
let mut content_length: usize = 0;
let mut connection_close = false;
for line in lines {
if line.is_empty() {
break;
}
let (k, v) = match line.split_once(':') {
Some((k, v)) => (k, v.trim()),
None => continue,
};
let kl = k.trim().to_ascii_lowercase();
if kl == "content-length" {
content_length = v.parse().unwrap_or(0);
} else if kl == "connection" {
if v.eq_ignore_ascii_case("close") {
connection_close = true;
}
}
}
let mut body = body_already;
if body.len() < content_length {
let need = content_length - body.len();
let mut rest = vec![0u8; need];
br.read_exact(&mut rest)?;
body.extend_from_slice(&rest);
} else if body.len() > content_length {
body.truncate(content_length);
}
let mut raw = head;
raw.extend_from_slice(&body);
Ok(Some((raw, method, path, connection_close)))
}
static DEFAULT_SNAPSHOT: std::sync::OnceLock<Option<String>> = std::sync::OnceLock::new();
fn default_snapshot_name() -> Option<&'static str> {
DEFAULT_SNAPSHOT.get().and_then(|o| o.as_deref())
}
fn route_request(
snapshots: &Snapshots,
method: &str,
path: &str,
raw_headers: &[u8],
) -> Option<(String, String)> {
let default_snapshot = default_snapshot_name();
let head_str = std::str::from_utf8(raw_headers).unwrap_or("");
for line in head_str.split("\r\n").skip(1) {
if let Some((k, v)) = line.split_once(':') {
if k.trim().eq_ignore_ascii_case("x-supermachine-snapshot") {
let n = v.trim();
if snapshots.contains_key(n) {
return Some((n.to_owned(), path.to_owned()));
}
}
}
}
if let Some(rest) = path.strip_prefix('/') {
let (first, tail) = rest
.split_once('/')
.map(|(a, b)| (a, format!("/{}", b)))
.unwrap_or((rest, "/".to_owned()));
if snapshots.contains_key(first) {
return Some((first.to_owned(), tail));
}
}
if let Some(name) = default_snapshot {
if snapshots.contains_key(name) {
return Some((name.to_owned(), path.to_owned()));
}
}
if snapshots.len() == 1 {
let only = snapshots.keys().next().unwrap();
return Some((only.clone(), path.to_owned()));
}
let _ = method;
None
}
fn response_with_connection_header(response: &[u8], close: bool) -> Vec<u8> {
let Some(head_end) = response
.windows(4)
.position(|w| w == b"\r\n\r\n")
.map(|i| i + 4)
else {
return response.to_vec();
};
let head = &response[..head_end - 4];
let body = &response[head_end..];
let mut out = Vec::with_capacity(response.len() + 32);
let mut saw_first = false;
for line in head.split(|&b| b == b'\n') {
let line = line.strip_suffix(b"\r").unwrap_or(line);
if line.is_empty() {
continue;
}
if saw_first
&& (line.len() >= "connection:".len()
&& line[.."connection:".len()].eq_ignore_ascii_case(b"connection:"))
{
continue;
}
saw_first = true;
out.extend_from_slice(line);
out.extend_from_slice(b"\r\n");
}
if close {
out.extend_from_slice(b"Connection: close\r\n\r\n");
} else {
out.extend_from_slice(b"Connection: keep-alive\r\n\r\n");
}
out.extend_from_slice(body);
out
}
fn proxy_raw_http(conn: &mut UnixStream, request: &[u8]) -> io::Result<Vec<u8>> {
conn.write_all(request)?;
let resp = read_raw_http_response(conn)?;
if resp.is_empty() {
return Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
"empty upstream response",
));
}
Ok(resp)
}
fn proxy_raw_http_profiled(
conn: &mut UnixStream,
request: &[u8],
) -> io::Result<(Vec<u8>, u64, u64)> {
let t_write = Instant::now();
conn.write_all(request)?;
let write_us = t_write.elapsed().as_micros() as u64;
let t_read = Instant::now();
let resp = read_raw_http_response(conn)?;
let read_us = t_read.elapsed().as_micros() as u64;
if resp.is_empty() {
return Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
"empty upstream response",
));
}
Ok((resp, write_us, read_us))
}
fn raw_response_allows_keepalive(response: &[u8]) -> bool {
let Some(head_end) = response.windows(4).position(|w| w == b"\r\n\r\n") else {
return false;
};
let head = &response[..head_end];
let mut lines = head.split(|&b| b == b'\n');
let status = lines
.next()
.and_then(|line| std::str::from_utf8(line).ok())
.unwrap_or("");
let http11_or_newer = status.starts_with("HTTP/1.1") || status.starts_with("HTTP/2");
let mut saw_keep_alive = false;
let mut saw_close = false;
for line in lines {
let line = line.strip_suffix(b"\r").unwrap_or(line);
if line.len() >= "connection:".len()
&& line[.."connection:".len()].eq_ignore_ascii_case(b"connection:")
{
let value = &line["connection:".len()..];
if value.eq_ignore_ascii_case(b"close") {
saw_close = true;
}
if value.eq_ignore_ascii_case(b"keep-alive") {
saw_keep_alive = true;
}
}
}
!saw_close && (http11_or_newer || saw_keep_alive)
}
fn raw_error_recycles_worker(e: &io::Error) -> bool {
!matches!(
e.kind(),
io::ErrorKind::WouldBlock | io::ErrorKind::TimedOut | io::ErrorKind::Interrupted
)
}
fn raw_error_retries_fresh(e: &io::Error) -> bool {
matches!(
e.kind(),
io::ErrorKind::UnexpectedEof
| io::ErrorKind::BrokenPipe
| io::ErrorKind::ConnectionReset
| io::ErrorKind::ConnectionAborted
| io::ErrorKind::NotConnected
)
}
fn auth_policy_is_none(meta: &SnapshotMeta) -> bool {
meta.auth
.as_ref()
.and_then(|v| v.get("type"))
.and_then(|v| v.as_str())
.map(|t| t == "none")
.unwrap_or(true)
}
fn read_raw_http_response(conn: &mut UnixStream) -> io::Result<Vec<u8>> {
let mut resp = Vec::with_capacity(16 * 1024);
let mut tmp = [0u8; 8192];
let head_end = loop {
let n = conn.read(&mut tmp)?;
if n == 0 {
return Ok(resp);
}
resp.extend_from_slice(&tmp[..n]);
if let Some(pos) = resp.windows(4).position(|w| w == b"\r\n\r\n") {
break pos + 4;
}
if resp.len() > 128 * 1024 {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"upstream response headers too large",
));
}
};
let headers = &resp[..head_end - 4];
let mut content_length: Option<usize> = None;
for line in headers.split(|&b| b == b'\n').skip(1) {
let line = line.strip_suffix(b"\r").unwrap_or(line);
if line.len() >= "content-length:".len()
&& line[.."content-length:".len()].eq_ignore_ascii_case(b"content-length:")
{
if let Ok(s) = std::str::from_utf8(&line["content-length:".len()..]) {
content_length = s.trim().parse().ok();
}
}
}
let Some(content_length) = content_length else {
conn.read_to_end(&mut resp)?;
return Ok(resp);
};
let total = head_end + content_length;
while resp.len() < total {
let n = conn.read(&mut tmp)?;
if n == 0 {
break;
}
resp.extend_from_slice(&tmp[..n]);
}
if resp.len() > total {
resp.truncate(total);
}
Ok(resp)
}
fn handle_conn(tcp: TcpStream, snapshots: Arc<Snapshots>, metrics: Arc<Metrics>) {
pin_to_p_core();
let _ = tcp.set_nodelay(true);
let _ = tcp.set_read_timeout(Some(Duration::from_secs(120)));
let _ = tcp.set_write_timeout(Some(Duration::from_secs(120)));
if let Some(target) = handoff_target(&snapshots) {
try_scm_handoff(tcp, &snapshots, &metrics, &target);
return;
}
handle_stream(tcp, snapshots, metrics, ws_bidi_pump_tcp);
}
fn handoff_target(snapshots: &Snapshots) -> Option<Arc<Worker>> {
if !scm_handoff_enabled() {
return None;
}
if snapshots.len() != 1 {
return None;
}
let snap = snapshots.values().next()?;
if !auth_policy_is_none(&snap.meta) {
return None;
}
let worker = snap.pick_worker()?;
if worker.handoff_sock_path.is_none() {
return None;
}
Some(worker)
}
fn try_scm_handoff(
mut tcp: TcpStream,
snapshots: &Snapshots,
metrics: &Metrics,
worker: &Arc<Worker>,
) {
let mut br = BufReader::new(&mut tcp);
let parsed = match read_http_request(&mut br) {
Ok(Some(p)) => p,
Ok(None) => return, Err(_) => return,
};
let (raw_req, method, path, _conn_close) = parsed;
if method == "GET" && (path == "/_health" || path == "/_metrics") {
let body = if path == "/_health" {
build_health_body(snapshots)
} else {
build_metrics_body(snapshots, metrics)
};
let ctype = if path == "/_metrics" {
"text/plain; version=0.0.4"
} else {
"application/json"
};
let buffered = br.buffer().to_vec();
drop(br);
let _ = write_keepalive_status(&mut tcp, 200, ctype, &body, true);
let _ = buffered; return;
}
let head_end = raw_req
.windows(4)
.position(|w| w == b"\r\n\r\n")
.map(|i| i + 4)
.unwrap_or(raw_req.len());
let raw_headers = &raw_req[..head_end];
let snap = match snapshots.values().next() {
Some(s) => s.clone(),
None => return,
};
if let Some(why) = check_auth(&snap.meta, raw_headers) {
metrics
.auth_failures_total
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
let atype = snap
.meta
.auth
.as_ref()
.and_then(|v| v.get("type"))
.and_then(|v| v.as_str())
.unwrap_or("");
let _ = write_auth_failure(&mut tcp, atype, &why);
return;
}
let mut prefix = raw_req;
let buffered = br.buffer();
if !buffered.is_empty() {
prefix.extend_from_slice(buffered);
}
drop(br);
let mut conn = match worker.pop_handoff_conn() {
Some(c) => c,
None => match worker.open_handoff_conn() {
Ok(c) => c,
Err(e) => {
metrics
.upstream_errors_total
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
log(&format!("handoff: open handoff conn: {e}"));
let _ = write_simple_status(&mut tcp, 502, "handoff: open conn failed");
return;
}
},
};
let tcp_fd = tcp.as_raw_fd();
if let Err(e) = send_scm_handoff(&mut conn, &prefix, tcp_fd) {
log(&format!("handoff: send err, retrying: {e}"));
drop(conn);
let mut fresh = match worker.open_handoff_conn() {
Ok(c) => c,
Err(e2) => {
metrics
.upstream_errors_total
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
log(&format!("handoff: reopen failed: {e2}"));
let _ = write_simple_status(&mut tcp, 502, "handoff: reopen failed");
return;
}
};
if let Err(e3) = send_scm_handoff(&mut fresh, &prefix, tcp_fd) {
metrics
.upstream_errors_total
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
log(&format!("handoff: send retry failed: {e3}"));
let _ = write_simple_status(&mut tcp, 502, "handoff: send failed");
return;
}
worker.push_handoff_conn(fresh);
} else {
worker.push_handoff_conn(conn);
}
metrics
.requests_total
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
worker
.hits
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
let our_fd = tcp.into_raw_fd();
unsafe { libc::close(our_fd) };
}
fn send_scm_handoff(conn: &mut UnixStream, prefix: &[u8], tcp_fd: i32) -> io::Result<()> {
if prefix.is_empty() {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"handoff prefix must be non-empty",
));
}
let prefix_len = u32::try_from(prefix.len()).map_err(|_| {
io::Error::new(io::ErrorKind::InvalidInput, "handoff prefix too large")
})?;
conn.write_all(&prefix_len.to_be_bytes())?;
let mut sent = 0usize;
let mut cmsg_attached = false;
while sent < prefix.len() {
let chunk = &prefix[sent..];
let mut iov = libc::iovec {
iov_base: chunk.as_ptr() as *mut libc::c_void,
iov_len: chunk.len(),
};
let mut msg: libc::msghdr = unsafe { std::mem::zeroed() };
msg.msg_iov = &mut iov as *mut libc::iovec;
msg.msg_iovlen = 1;
let cmsg_space = unsafe { libc::CMSG_SPACE(std::mem::size_of::<libc::c_int>() as u32) };
let mut cmsg_buf = vec![0u8; cmsg_space as usize];
if !cmsg_attached {
msg.msg_control = cmsg_buf.as_mut_ptr() as *mut libc::c_void;
msg.msg_controllen = cmsg_space as _;
unsafe {
let cmsg = libc::CMSG_FIRSTHDR(&msg);
if cmsg.is_null() {
return Err(io::Error::new(
io::ErrorKind::Other,
"CMSG_FIRSTHDR returned NULL",
));
}
(*cmsg).cmsg_level = libc::SOL_SOCKET;
(*cmsg).cmsg_type = libc::SCM_RIGHTS;
(*cmsg).cmsg_len = libc::CMSG_LEN(std::mem::size_of::<libc::c_int>() as u32) as _;
std::ptr::write_unaligned(libc::CMSG_DATA(cmsg) as *mut libc::c_int, tcp_fd);
}
}
let n = unsafe { libc::sendmsg(conn.as_raw_fd(), &msg, 0) };
if n < 0 {
let err = io::Error::last_os_error();
if err.kind() == io::ErrorKind::Interrupted {
continue;
}
return Err(err);
}
if n == 0 {
return Err(io::Error::new(io::ErrorKind::WriteZero, "sendmsg 0"));
}
sent += n as usize;
cmsg_attached = true;
}
Ok(())
}
fn handle_stream<S: Read + Write, F: FnOnce(S, UnixStream)>(
mut stream: S,
snapshots: Arc<Snapshots>,
metrics: Arc<Metrics>,
on_upgrade: F,
) {
let profiling = profile_enabled();
let mut br = BufReader::new(&mut stream);
let mut sticky_raw_conn: Option<(String, Arc<Worker>, UnixStream)> = None;
loop {
let read_t0 = if profiling {
Some(Instant::now())
} else {
None
};
let parsed = match read_http_request(&mut br) {
Ok(Some(p)) => p,
Ok(None) => return, Err(_) => return,
};
if let Some(t0) = read_t0 {
metrics.profile_read_request_us.fetch_add(
t0.elapsed().as_micros() as u64,
std::sync::atomic::Ordering::Relaxed,
);
}
let (raw_req, method, path, conn_close) = parsed;
drop(br);
let head_end = raw_req
.windows(4)
.position(|w| w == b"\r\n\r\n")
.map(|i| i + 4)
.unwrap_or(raw_req.len());
let raw_headers = &raw_req[..head_end];
if method == "GET" && (path == "/_health" || path == "/_metrics") {
let body = if path == "/_health" {
build_health_body(&snapshots)
} else {
build_metrics_body(&snapshots, &metrics)
};
let _ = write_keepalive_status(
&mut stream,
200,
if path == "/_metrics" {
"text/plain; version=0.0.4"
} else {
"application/json"
},
&body,
conn_close,
);
if conn_close {
return;
}
br = BufReader::new(&mut stream);
continue;
}
let route_auth_t0 = if profiling {
Some(Instant::now())
} else {
None
};
let (name, rewritten_path) = match route_request(&snapshots, &method, &path, raw_headers) {
Some(x) => x,
None => {
let body = format!(
"no snapshot route for {method} {path} (known: {:?})",
snapshots.keys().collect::<Vec<_>>()
);
let _ = write_simple_status(&mut stream, 404, &body);
return;
}
};
let snap = snapshots.get(&name).unwrap().clone();
if let Some(why) = check_auth(&snap.meta, raw_headers) {
metrics
.auth_failures_total
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
let atype = snap
.meta
.auth
.as_ref()
.and_then(|v| v.get("type"))
.and_then(|v| v.as_str())
.unwrap_or("");
let _ = write_auth_failure(&mut stream, atype, &why);
return;
}
metrics
.requests_total
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
let req_to_send = if rewritten_path != path {
rewrite_request_path(&raw_req, &path, &rewritten_path)
} else {
raw_req
};
if raw_tunnel_enabled()
&& snapshots.len() == 1
&& rewritten_path == path
&& auth_policy_is_none(&snap.meta)
{
if let Some(t0) = route_auth_t0 {
metrics.profile_route_auth_us.fetch_add(
t0.elapsed().as_micros() as u64,
std::sync::atomic::Ordering::Relaxed,
);
}
if profiling {
metrics
.profile_requests
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}
let worker = match snap.pick_worker() {
Some(w) => w,
None => {
metrics
.upstream_errors_total
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
let _ = write_simple_status(&mut stream, 502, "no live workers");
return;
}
};
let open_t0 = if profiling {
Some(Instant::now())
} else {
None
};
let mut conn = match worker.open_conn() {
Ok(c) => c,
Err(e) => {
metrics
.upstream_errors_total
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
let _ = write_simple_status(
&mut stream,
502,
&format!("upstream {name}: open: {e}"),
);
return;
}
};
if let Some(t0) = open_t0 {
metrics.profile_open_us.fetch_add(
t0.elapsed().as_micros() as u64,
std::sync::atomic::Ordering::Relaxed,
);
}
let write_t0 = if profiling {
Some(Instant::now())
} else {
None
};
if conn.write_all(&req_to_send).is_err() {
metrics
.upstream_errors_total
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
let _ = write_simple_status(&mut stream, 502, "upstream write failed");
return;
}
if let Some(t0) = write_t0 {
metrics.profile_upstream_write_us.fetch_add(
t0.elapsed().as_micros() as u64,
std::sync::atomic::Ordering::Relaxed,
);
}
worker
.hits
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
on_upgrade(stream, conn);
return;
}
let worker = if let Some((sticky_name, sticky_worker, _)) = sticky_raw_conn.as_ref() {
if sticky_name == &name {
sticky_worker.clone()
} else {
sticky_raw_conn.take();
match snap.pick_worker() {
Some(w) => w,
None => {
metrics
.upstream_errors_total
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
let _ = write_simple_status(&mut stream, 502, "no live workers");
return;
}
}
}
} else {
match snap.pick_worker() {
Some(w) => w,
None => {
metrics
.upstream_errors_total
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
let _ = write_simple_status(&mut stream, 502, "no live workers");
return;
}
}
};
if let Some(t0) = route_auth_t0 {
metrics.profile_route_auth_us.fetch_add(
t0.elapsed().as_micros() as u64,
std::sync::atomic::Ordering::Relaxed,
);
}
if profiling {
metrics
.profile_requests
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}
let raw_dispatch_wait_us = 0;
let open_start = Instant::now();
let sticky_conn = match sticky_raw_conn.take() {
Some((sticky_name, sticky_worker, c))
if sticky_name == name && Arc::ptr_eq(&sticky_worker, &worker) =>
{
Some(c)
}
other => {
sticky_raw_conn = other;
None
}
};
let (mut conn, mut pooled_conn, mut open_conn_us) = match sticky_conn {
Some(c) => (c, true, 0),
None => match worker.pop_conn() {
Some(c) => (c, true, 0),
None => match worker.open_conn() {
Ok(c) => (c, false, open_start.elapsed().as_micros()),
Err(e) => {
let _ = write_simple_status(
&mut stream,
502,
&format!("upstream {name}: open: {e}"),
);
return;
}
},
},
};
if profiling {
metrics
.profile_open_us
.fetch_add(open_conn_us as u64, std::sync::atomic::Ordering::Relaxed);
}
let t0 = Instant::now();
let mut upstream_write_us = 0u64;
let mut upstream_read_us = 0u64;
let resp = if profiling {
match proxy_raw_http_profiled(&mut conn, &req_to_send) {
Ok((r, w_us, r_us)) => {
upstream_write_us = w_us;
upstream_read_us = r_us;
Ok(r)
}
Err(e) => Err(e),
}
} else {
proxy_raw_http(&mut conn, &req_to_send)
};
let resp = match resp {
Ok(r) => r,
Err(e) => {
if raw_error_retries_fresh(&e) {
drop(conn);
let retry_open_start = Instant::now();
match worker.open_conn() {
Ok(mut fresh) => match if profiling {
proxy_raw_http_profiled(&mut fresh, &req_to_send).map(
|(r, w_us, r_us)| {
upstream_write_us = w_us;
upstream_read_us = r_us;
r
},
)
} else {
proxy_raw_http(&mut fresh, &req_to_send)
} {
Ok(r) => {
conn = fresh;
pooled_conn = false;
open_conn_us = retry_open_start.elapsed().as_micros();
r
}
Err(e2) => {
if raw_error_recycles_worker(&e2) {
let failures = worker
.raw_failures
.fetch_add(1, std::sync::atomic::Ordering::Relaxed)
+ 1;
if failures >= 3 {
worker
.raw_failures
.store(0, std::sync::atomic::Ordering::Relaxed);
if let Ok(mut p) = worker.proc.lock() {
let _ = p.kill();
}
}
}
metrics
.upstream_errors_total
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
log(&format!(
"snap={name} {method} {path} -> 502 (upstream retry: {e2})"
));
let _ = write_simple_status(
&mut stream,
502,
&format!("upstream {name}: {e2}"),
);
return;
}
},
Err(e2) => {
metrics
.upstream_errors_total
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
let _ = write_simple_status(
&mut stream,
502,
&format!("upstream {name}: reopen after {e}: {e2}"),
);
return;
}
}
} else {
if raw_error_recycles_worker(&e) {
let failures = worker
.raw_failures
.fetch_add(1, std::sync::atomic::Ordering::Relaxed)
+ 1;
if failures >= 3 {
worker
.raw_failures
.store(0, std::sync::atomic::Ordering::Relaxed);
if let Ok(mut p) = worker.proc.lock() {
let _ = p.kill();
}
}
}
metrics
.upstream_errors_total
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
log(&format!(
"snap={name} {method} {path} -> 502 (upstream: {e})"
));
let _ = write_simple_status(&mut stream, 502, &format!("upstream {name}: {e}"));
return;
}
}
};
let elapsed = t0.elapsed();
if profiling {
metrics
.profile_upstream_write_us
.fetch_add(upstream_write_us, std::sync::atomic::Ordering::Relaxed);
metrics
.profile_upstream_read_us
.fetch_add(upstream_read_us, std::sync::atomic::Ordering::Relaxed);
}
let elapsed_us = elapsed.as_micros() as u64;
metrics
.latency_us_sum
.fetch_add(elapsed_us, std::sync::atomic::Ordering::Relaxed);
metrics
.latency_us_count
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
worker
.hits
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
worker
.raw_failures
.store(0, std::sync::atomic::Ordering::Relaxed);
let resp_is_upgrade = resp.starts_with(b"HTTP/1.1 101");
let raw_upstream_keepalive = raw_response_allows_keepalive(&resp);
let reusable_raw_conn = !conn_close && !resp_is_upgrade && raw_upstream_keepalive;
let client_write_t0 = if profiling {
Some(Instant::now())
} else {
None
};
if !resp_is_upgrade {
if conn_close || !raw_upstream_keepalive {
let client_resp = response_with_connection_header(&resp, conn_close);
if stream.write_all(&client_resp).is_err() {
return;
}
} else if stream.write_all(&resp).is_err() {
return;
}
} else if stream.write_all(&resp).is_err() {
return;
}
if let Some(t0) = client_write_t0 {
metrics.profile_client_write_us.fetch_add(
t0.elapsed().as_micros() as u64,
std::sync::atomic::Ordering::Relaxed,
);
}
if access_log_enabled() {
log(&format!(
"snap={name} {method} {path} -> {}B upstream={:.1}ms raw_wait={}us open={}us pooled={} worker_ready={}ms restore={}us",
resp.len(),
elapsed.as_secs_f64() * 1000.0,
raw_dispatch_wait_us,
open_conn_us,
pooled_conn,
worker.socket_ready_ms,
worker
.restore_us
.map(|n| n.to_string())
.unwrap_or_else(|| "N/A".to_owned())
));
}
if resp_is_upgrade {
on_upgrade(stream, conn);
return;
}
if reusable_raw_conn {
sticky_raw_conn = Some((name.clone(), worker.clone(), conn));
}
if conn_close {
return;
}
br = BufReader::new(&mut stream);
}
}
fn ws_bidi_pump_tls(
mut client: rustls::StreamOwned<rustls::ServerConnection, TcpStream>,
worker: UnixStream,
) {
use std::io::{ErrorKind, Read, Write};
use std::os::fd::AsRawFd;
let _ = client.sock.set_nonblocking(true);
let _ = worker.set_nonblocking(true);
let cfd = client.sock.as_raw_fd();
let wfd = worker.as_raw_fd();
let mut worker_w = worker;
let mut buf_c2w = [0u8; 8192];
let mut buf_w2c = [0u8; 8192];
loop {
let mut pfds = [
libc::pollfd {
fd: cfd,
events: libc::POLLIN,
revents: 0,
},
libc::pollfd {
fd: wfd,
events: libc::POLLIN,
revents: 0,
},
];
let rc = unsafe { libc::poll(pfds.as_mut_ptr(), 2, 1000) };
if rc < 0 {
let err = std::io::Error::last_os_error();
if err.raw_os_error() == Some(libc::EINTR) {
continue;
}
return;
}
if rc == 0 {
continue; }
let drain_mask = libc::POLLIN | libc::POLLERR | libc::POLLHUP;
if pfds[0].revents & drain_mask != 0 {
loop {
match client.read(&mut buf_c2w) {
Ok(0) => return,
Ok(n) => {
if worker_w.write_all(&buf_c2w[..n]).is_err() {
return;
}
}
Err(e) if e.kind() == ErrorKind::WouldBlock => break,
Err(_) => return,
}
}
}
if pfds[1].revents & drain_mask != 0 {
loop {
match worker_w.read(&mut buf_w2c) {
Ok(0) => return,
Ok(n) => {
if client.write_all(&buf_w2c[..n]).is_err() {
return;
}
}
Err(e) if e.kind() == ErrorKind::WouldBlock => break,
Err(_) => return,
}
}
}
}
}
fn ws_bidi_pump_tcp(client: TcpStream, worker: UnixStream) {
let client_r = match client.try_clone() {
Ok(c) => c,
Err(_) => return,
};
let worker_r = match worker.try_clone() {
Ok(c) => c,
Err(_) => return,
};
let mut worker_w = worker;
let mut client_r_for_thread = client_r;
let _t = std::thread::Builder::new()
.name("ws-c2w".into())
.spawn(move || {
let mut buf = [0u8; 8192];
loop {
let n = match client_r_for_thread.read(&mut buf) {
Ok(0) | Err(_) => 0,
Ok(n) => n,
};
if n == 0 {
break;
}
if worker_w.write_all(&buf[..n]).is_err() {
break;
}
}
let _ = worker_w.shutdown(std::net::Shutdown::Both);
});
let mut client_w = client;
let mut worker_r2 = worker_r;
let mut buf = [0u8; 8192];
loop {
let n = match worker_r2.read(&mut buf) {
Ok(0) | Err(_) => 0,
Ok(n) => n,
};
if n == 0 {
break;
}
if client_w.write_all(&buf[..n]).is_err() {
break;
}
}
let _ = client_w.shutdown(std::net::Shutdown::Both);
}
fn build_health_body(snapshots: &Snapshots) -> String {
let mut live_workers = 0usize;
for s in snapshots.values() {
for w in &s.workers {
if w.sock_path.exists() {
live_workers += 1;
}
}
}
let status = if live_workers > 0 { "ok" } else { "degraded" };
format!(
"{{\"status\":\"{status}\",\"snapshots\":{},\"live_workers\":{}}}",
snapshots.len(),
live_workers
)
}
fn build_metrics_body(snapshots: &Snapshots, m: &Metrics) -> String {
use std::sync::atomic::Ordering::Relaxed;
let mut s = String::new();
let push = |s: &mut String, line: &str| {
s.push_str(line);
s.push('\n');
};
push(
&mut s,
"# HELP supermachine_router_requests_total Total HTTP requests served",
);
push(&mut s, "# TYPE supermachine_router_requests_total counter");
push(
&mut s,
&format!(
"supermachine_router_requests_total {}",
m.requests_total.load(Relaxed)
),
);
push(
&mut s,
"# HELP supermachine_router_auth_failures_total Total 401 responses",
);
push(
&mut s,
"# TYPE supermachine_router_auth_failures_total counter",
);
push(
&mut s,
&format!(
"supermachine_router_auth_failures_total {}",
m.auth_failures_total.load(Relaxed)
),
);
push(
&mut s,
"# HELP supermachine_router_upstream_errors_total Total 502 responses",
);
push(
&mut s,
"# TYPE supermachine_router_upstream_errors_total counter",
);
push(
&mut s,
&format!(
"supermachine_router_upstream_errors_total {}",
m.upstream_errors_total.load(Relaxed)
),
);
push(
&mut s,
"# HELP supermachine_router_request_duration_us_sum Sum of request latencies in microseconds",
);
push(
&mut s,
"# TYPE supermachine_router_request_duration_us_sum counter",
);
push(
&mut s,
&format!(
"supermachine_router_request_duration_us_sum {}",
m.latency_us_sum.load(Relaxed)
),
);
push(
&mut s,
&format!(
"supermachine_router_request_duration_us_count {}",
m.latency_us_count.load(Relaxed)
),
);
if profile_enabled() {
let n = m.profile_requests.load(Relaxed).max(1);
push(
&mut s,
"# HELP supermachine_router_profile_avg_us Average profiled hot-path stage duration in microseconds",
);
push(&mut s, "# TYPE supermachine_router_profile_avg_us gauge");
let stages = [
("read_request", m.profile_read_request_us.load(Relaxed)),
("route_auth", m.profile_route_auth_us.load(Relaxed)),
("open", m.profile_open_us.load(Relaxed)),
("upstream_write", m.profile_upstream_write_us.load(Relaxed)),
("upstream_read", m.profile_upstream_read_us.load(Relaxed)),
("client_write", m.profile_client_write_us.load(Relaxed)),
];
for (stage, total) in stages {
push(
&mut s,
&format!(
"supermachine_router_profile_avg_us{{stage=\"{stage}\"}} {:.3}",
total as f64 / n as f64
),
);
}
push(
&mut s,
&format!(
"supermachine_router_profile_requests_total {}",
m.profile_requests.load(Relaxed)
),
);
}
push(
&mut s,
"# HELP supermachine_router_workers_live Live worker count per snapshot",
);
push(&mut s, "# TYPE supermachine_router_workers_live gauge");
for (name, snap) in snapshots {
let live = snap.workers.iter().filter(|w| w.sock_path.exists()).count();
push(
&mut s,
&format!("supermachine_router_workers_live{{snapshot=\"{name}\"}} {live}"),
);
for (i, w) in snap.workers.iter().enumerate() {
push(
&mut s,
&format!(
"supermachine_router_worker_hits_total{{snapshot=\"{name}\",worker=\"{i}\"}} {}",
w.hits.load(Relaxed)
),
);
}
}
s
}
fn check_auth(meta: &SnapshotMeta, raw_headers: &[u8]) -> Option<String> {
let cfg = match meta.auth.as_ref() {
Some(c) => c,
None => return None,
};
let atype = cfg.get("type").and_then(|v| v.as_str()).unwrap_or("none");
if atype == "none" {
return None;
}
let head_str = std::str::from_utf8(raw_headers).unwrap_or("");
let mut auth_val: Option<&str> = None;
for line in head_str.split("\r\n").skip(1) {
if let Some((k, v)) = line.split_once(':') {
if k.trim().eq_ignore_ascii_case("authorization") {
auth_val = Some(v.trim());
break;
}
}
}
let val = match auth_val {
Some(v) => v,
None => return Some(format!("missing {atype} auth header")),
};
if atype == "bearer" {
let tok = val.strip_prefix("Bearer ").unwrap_or("");
let digest = sha256_hex(tok.as_bytes());
let needle = format!("sha256:{digest}");
let toks = cfg.get("tokens").and_then(|v| v.as_array());
if let Some(arr) = toks {
for t in arr {
if t.as_str() == Some(&needle) {
return None;
}
}
}
return Some("invalid token".into());
}
if atype == "basic" {
let raw = val.strip_prefix("Basic ").unwrap_or("");
let decoded = match base64_decode(raw) {
Some(d) => d,
None => return Some("malformed Basic".into()),
};
let s = String::from_utf8_lossy(&decoded);
let (user, pw) = match s.split_once(':') {
Some(x) => x,
None => return Some("malformed Basic".into()),
};
let pw_digest = format!("sha256:{}", sha256_hex(pw.as_bytes()));
let users = cfg.get("users").and_then(|v| v.as_object());
if let Some(u) = users {
if u.get(user).and_then(|v| v.as_str()) == Some(&pw_digest) {
return None;
}
}
return Some("invalid credentials".into());
}
Some(format!("unknown auth type {atype:?}"))
}
fn sha256_hex(data: &[u8]) -> String {
use std::sync::atomic::AtomicBool;
static WARNED: AtomicBool = AtomicBool::new(false);
if !WARNED.load(std::sync::atomic::Ordering::Relaxed) {
WARNED.store(true, std::sync::atomic::Ordering::Relaxed);
}
let h = sha256(data);
let mut s = String::with_capacity(64);
for b in h.iter() {
s.push_str(&format!("{b:02x}"));
}
s
}
fn sha256(msg: &[u8]) -> [u8; 32] {
const K: [u32; 64] = [
0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5, 0x3956c25b, 0x59f111f1, 0x923f82a4,
0xab1c5ed5, 0xd807aa98, 0x12835b01, 0x243185be, 0x550c7dc3, 0x72be5d74, 0x80deb1fe,
0x9bdc06a7, 0xc19bf174, 0xe49b69c1, 0xefbe4786, 0x0fc19dc6, 0x240ca1cc, 0x2de92c6f,
0x4a7484aa, 0x5cb0a9dc, 0x76f988da, 0x983e5152, 0xa831c66d, 0xb00327c8, 0xbf597fc7,
0xc6e00bf3, 0xd5a79147, 0x06ca6351, 0x14292967, 0x27b70a85, 0x2e1b2138, 0x4d2c6dfc,
0x53380d13, 0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85, 0xa2bfe8a1, 0xa81a664b,
0xc24b8b70, 0xc76c51a3, 0xd192e819, 0xd6990624, 0xf40e3585, 0x106aa070, 0x19a4c116,
0x1e376c08, 0x2748774c, 0x34b0bcb5, 0x391c0cb3, 0x4ed8aa4a, 0x5b9cca4f, 0x682e6ff3,
0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208, 0x90befffa, 0xa4506ceb, 0xbef9a3f7,
0xc67178f2,
];
let mut h: [u32; 8] = [
0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a, 0x510e527f, 0x9b05688c, 0x1f83d9ab,
0x5be0cd19,
];
let mut padded = msg.to_vec();
let bitlen = (msg.len() as u64) * 8;
padded.push(0x80);
while padded.len() % 64 != 56 {
padded.push(0);
}
padded.extend_from_slice(&bitlen.to_be_bytes());
for chunk in padded.chunks_exact(64) {
let mut w = [0u32; 64];
for i in 0..16 {
w[i] = u32::from_be_bytes([
chunk[i * 4],
chunk[i * 4 + 1],
chunk[i * 4 + 2],
chunk[i * 4 + 3],
]);
}
for i in 16..64 {
let s0 = w[i - 15].rotate_right(7) ^ w[i - 15].rotate_right(18) ^ (w[i - 15] >> 3);
let s1 = w[i - 2].rotate_right(17) ^ w[i - 2].rotate_right(19) ^ (w[i - 2] >> 10);
w[i] = w[i - 16]
.wrapping_add(s0)
.wrapping_add(w[i - 7])
.wrapping_add(s1);
}
let mut a = h[0];
let mut b = h[1];
let mut c = h[2];
let mut d = h[3];
let mut e = h[4];
let mut f = h[5];
let mut g = h[6];
let mut hh = h[7];
for i in 0..64 {
let s1 = e.rotate_right(6) ^ e.rotate_right(11) ^ e.rotate_right(25);
let ch = (e & f) ^ ((!e) & g);
let t1 = hh
.wrapping_add(s1)
.wrapping_add(ch)
.wrapping_add(K[i])
.wrapping_add(w[i]);
let s0 = a.rotate_right(2) ^ a.rotate_right(13) ^ a.rotate_right(22);
let mj = (a & b) ^ (a & c) ^ (b & c);
let t2 = s0.wrapping_add(mj);
hh = g;
g = f;
f = e;
e = d.wrapping_add(t1);
d = c;
c = b;
b = a;
a = t1.wrapping_add(t2);
}
h[0] = h[0].wrapping_add(a);
h[1] = h[1].wrapping_add(b);
h[2] = h[2].wrapping_add(c);
h[3] = h[3].wrapping_add(d);
h[4] = h[4].wrapping_add(e);
h[5] = h[5].wrapping_add(f);
h[6] = h[6].wrapping_add(g);
h[7] = h[7].wrapping_add(hh);
}
let mut out = [0u8; 32];
for (i, v) in h.iter().enumerate() {
out[i * 4..i * 4 + 4].copy_from_slice(&v.to_be_bytes());
}
out
}
fn base64_decode(s: &str) -> Option<Vec<u8>> {
fn val(c: u8) -> Option<u8> {
match c {
b'A'..=b'Z' => Some(c - b'A'),
b'a'..=b'z' => Some(c - b'a' + 26),
b'0'..=b'9' => Some(c - b'0' + 52),
b'+' => Some(62),
b'/' => Some(63),
_ => None,
}
}
let bytes: Vec<u8> = s.bytes().filter(|&b| b != b'=' && b != b'\n').collect();
let mut out = Vec::with_capacity(bytes.len() * 3 / 4);
let mut i = 0;
while i + 1 < bytes.len() {
let b0 = val(bytes[i])?;
let b1 = val(bytes[i + 1])?;
out.push((b0 << 2) | (b1 >> 4));
if i + 2 < bytes.len() {
let b2 = val(bytes[i + 2])?;
out.push(((b1 & 0x0f) << 4) | (b2 >> 2));
if i + 3 < bytes.len() {
let b3 = val(bytes[i + 3])?;
out.push(((b2 & 0x03) << 6) | b3);
}
}
i += 4;
}
Some(out)
}
fn write_keepalive_status<S: Write>(
stream: &mut S,
code: u16,
content_type: &str,
body: &str,
close: bool,
) -> io::Result<()> {
let conn_hdr = if close { "close" } else { "keep-alive" };
let resp = format!(
"HTTP/1.1 {code} OK\r\nContent-Type: {content_type}\r\nContent-Length: {}\r\nConnection: {conn_hdr}\r\n\r\n{body}",
body.len()
);
stream.write_all(resp.as_bytes())
}
fn write_auth_failure<S: Write>(stream: &mut S, atype: &str, why: &str) -> io::Result<()> {
let challenge = if atype == "basic" {
"WWW-Authenticate: Basic realm=\"supermachine\"\r\n"
} else {
""
};
let body = format!("{why}\n");
let resp = format!(
"HTTP/1.1 401 Unauthorized\r\n{challenge}Content-Type: text/plain\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{body}",
body.len()
);
stream.write_all(resp.as_bytes())
}
fn rewrite_request_path(req: &[u8], old_path: &str, new_path: &str) -> Vec<u8> {
let nl = req
.windows(2)
.position(|w| w == b"\r\n")
.unwrap_or(req.len());
let line = &req[..nl];
let line_str = match std::str::from_utf8(line) {
Ok(s) => s,
Err(_) => return req.to_vec(),
};
let mut parts: Vec<&str> = line_str.splitn(3, ' ').collect();
if parts.len() < 3 {
return req.to_vec();
}
if parts[1] != old_path {
return req.to_vec();
}
parts[1] = new_path;
let new_line = parts.join(" ");
let mut out = Vec::with_capacity(req.len() + new_line.len());
out.extend_from_slice(new_line.as_bytes());
out.extend_from_slice(&req[nl..]);
out
}
fn write_simple_status<S: Write>(stream: &mut S, code: u16, body: &str) -> io::Result<()> {
let reason = match code {
404 => "Not Found",
502 => "Bad Gateway",
_ => "Error",
};
let resp = format!(
"HTTP/1.1 {code} {reason}\r\nContent-Type: text/plain\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{body}",
body.len()
);
stream.write_all(resp.as_bytes())
}
struct Args {
snapshots_dir: PathBuf,
socks_dir: PathBuf,
http_port: u16,
workers_per_snapshot: usize,
tls_port: u16,
tls_bind: String,
tls_cert: Option<PathBuf>,
tls_key: Option<PathBuf>,
cert_dir: Option<PathBuf>,
worker_bin: PathBuf,
default_snapshot: Option<String>,
}
fn parse_args() -> Args {
let mut snapshots_dir = env::var_os("SUPERMACHINE_SNAPSHOTS")
.map(PathBuf::from)
.unwrap_or_else(|| {
let h = env::var_os("HOME").unwrap_or_default();
PathBuf::from(h).join(".local/supermachine-snapshots")
});
let mut http_port: u16 = 8080;
let mut socks_dir = PathBuf::from("/tmp/supermachine-router-socks");
let mut workers_per_snapshot: usize = env::var("SUPERMACHINE_WORKERS_PER_SNAPSHOT")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(0);
let mut tls_port: u16 = 0;
let mut tls_bind: String = "127.0.0.1".to_owned();
let mut tls_cert: Option<PathBuf> = None;
let mut tls_key: Option<PathBuf> = None;
let mut cert_dir: Option<PathBuf> = env::var_os("SUPERMACHINE_CERT_DIR").map(PathBuf::from);
let mut default_snapshot: Option<String> = env::var("SUPERMACHINE_DEFAULT_SNAPSHOT").ok();
let mut worker_bin: PathBuf = env::var_os("SUPERMACHINE_WORKER_BIN")
.map(PathBuf::from)
.unwrap_or_else(|| {
let exe = std::env::current_exe().unwrap_or_default();
for candidate in ["supermachine-worker", "supermachine-worker"] {
if let Some(d) = exe.parent() {
let p = d.join(candidate);
if p.is_file() {
return p;
}
}
}
PathBuf::from("supermachine-worker")
});
let mut args = env::args().skip(1);
while let Some(a) = args.next() {
match a.as_str() {
"--snapshots-dir" => {
snapshots_dir = PathBuf::from(args.next().expect("--snapshots-dir <path>"))
}
"--http-port" => {
http_port = args.next().expect("--http-port <n>").parse().expect("u16")
}
"--socks-dir" => socks_dir = PathBuf::from(args.next().expect("--socks-dir <path>")),
"--workers-per-snapshot" => {
workers_per_snapshot = args
.next()
.expect("--workers-per-snapshot <n>")
.parse()
.expect("usize")
}
"--worker-bin" => {
worker_bin = PathBuf::from(args.next().expect("--worker-bin <path>"))
}
"--tls-port" => tls_port = args.next().expect("--tls-port <n>").parse().expect("u16"),
"--tls-bind" => tls_bind = args.next().expect("--tls-bind <addr>"),
"--tls-cert" => tls_cert = Some(PathBuf::from(args.next().expect("--tls-cert <p>"))),
"--tls-key" => tls_key = Some(PathBuf::from(args.next().expect("--tls-key <p>"))),
"--cert-dir" => cert_dir = Some(PathBuf::from(args.next().expect("--cert-dir <p>"))),
"--default-snapshot" => {
default_snapshot = Some(args.next().expect("--default-snapshot <name>"));
}
"--help" | "-h" => {
eprintln!(
"usage: supermachine-router [--snapshots-dir DIR] \
[--http-port N] [--socks-dir DIR] [--workers-per-snapshot N] [--worker-bin PATH] \
[--tls-port N] [--tls-bind ADDR] [--tls-cert PEM] [--tls-key PEM] [--cert-dir DIR]"
);
std::process::exit(0);
}
other => {
eprintln!("unknown arg: {other}");
std::process::exit(2);
}
}
}
Args {
snapshots_dir,
socks_dir,
http_port,
workers_per_snapshot,
tls_port,
tls_bind,
tls_cert,
tls_key,
cert_dir,
worker_bin,
default_snapshot,
}
}
fn main() {
let router_start = Instant::now();
if env_truthy("SUPERMACHINE_ROUTER_TRACE") {
log("main start");
}
let a = parse_args();
let _ = DEFAULT_SNAPSHOT.set(a.default_snapshot.clone());
if !a.snapshots_dir.exists() {
eprintln!("snapshots dir {:?} missing", a.snapshots_dir);
std::process::exit(2);
}
fs::create_dir_all(&a.socks_dir).expect("mkdir socks_dir");
let metas = discover(&a.snapshots_dir);
if metas.is_empty() {
log(&format!("no snapshots in {:?}", a.snapshots_dir));
std::process::exit(1);
}
if !a.worker_bin.exists() {
eprintln!(
"missing supermachine-worker at {:?} — set --worker-bin or SUPERMACHINE_WORKER_BIN",
a.worker_bin
);
std::process::exit(2);
}
log(&format!(
"discovered {} snapshots: {:?}",
metas.len(),
metas.keys().collect::<Vec<_>>()
));
let mut snapshots: Snapshots = HashMap::new();
for (name, meta) in &metas {
let workers_per_snapshot = if a.workers_per_snapshot == 0 {
default_workers()
} else {
a.workers_per_snapshot
};
let workers = spawn_snapshot_workers(
&a.worker_bin,
&a.socks_dir,
name,
meta,
workers_per_snapshot,
);
if workers.is_empty() {
log(&format!("FATAL: zero live workers for {name}; skipping"));
continue;
}
log(&format!("snapshot {name}: {} workers ready", workers.len()));
let initial_health = if meta.health_cmd.is_empty() {
"—".to_owned()
} else {
"starting".to_owned()
};
snapshots.insert(
name.clone(),
Arc::new(Snapshot {
meta: meta.clone(),
workers,
next_idx: std::sync::atomic::AtomicUsize::new(0),
health_status: Mutex::new(initial_health),
}),
);
}
if snapshots.is_empty() {
eprintln!("FATAL: no workers started");
std::process::exit(1);
}
let snapshots = Arc::new(snapshots);
let metrics = Arc::new(Metrics::default());
{
let snaps = snapshots.clone();
ctrlc(move || {
log("shutting down workers");
for s in snaps.values() {
for w in &s.workers {
let mut p = w.proc.lock().unwrap();
let _ = p.kill();
}
}
std::process::exit(0);
});
}
{
let snaps = snapshots.clone();
let worker_bin = a.worker_bin.clone();
let socks_dir = a.socks_dir.clone();
thread::Builder::new()
.name("watchdog".into())
.spawn(move || watchdog_loop(snaps, worker_bin, socks_dir))
.ok();
}
for (name, snap) in snapshots.iter() {
if snap.meta.health_cmd.is_empty() || snap.meta.health_interval_secs == 0 {
continue;
}
let snap_clone = snap.clone();
let socks_dir = a.socks_dir.clone();
let snap_name = name.clone();
thread::Builder::new()
.name(format!("health:{snap_name}"))
.spawn(move || health_check_loop(snap_name, snap_clone, socks_dir))
.ok();
}
{
let snaps = snapshots.clone();
let snapshots_dir = a.snapshots_dir.clone();
thread::Builder::new()
.name("ttl-gc".into())
.spawn(move || gc_loop(snaps, snapshots_dir))
.ok();
}
if a.tls_port != 0 {
match build_tls_config(&a) {
Ok(cfg) => {
let snaps = snapshots.clone();
let mtx = metrics.clone();
let bind = format!("{}:{}", a.tls_bind, a.tls_port);
let listener =
TcpListener::bind(&bind).unwrap_or_else(|e| panic!("bind {bind}: {e}"));
log(&format!("listening on https://{}/ (TLS via rustls)", bind));
let cfg = Arc::new(cfg);
thread::Builder::new()
.name("tls-accept".into())
.spawn(move || {
for tcp in listener.incoming() {
let tcp = match tcp {
Ok(t) => t,
Err(_) => continue,
};
let _ = tcp.set_nodelay(true);
let _ = tcp.set_read_timeout(Some(Duration::from_secs(120)));
let _ = tcp.set_write_timeout(Some(Duration::from_secs(120)));
let snaps = snaps.clone();
let mtx = mtx.clone();
let cfg = cfg.clone();
thread::Builder::new()
.name("tls-conn".into())
.stack_size(256 * 1024)
.spawn(move || handle_tls_conn(tcp, cfg, snaps, mtx))
.ok();
}
})
.ok();
}
Err(e) => log(&format!("TLS disabled: {e}")),
}
}
let listener = TcpListener::bind(("127.0.0.1", a.http_port))
.unwrap_or_else(|e| panic!("bind 127.0.0.1:{}: {e}", a.http_port));
log(&format!(
"listening on http://127.0.0.1:{}/ (snapshots: {:?}) startup={}ms",
a.http_port,
snapshots.keys().collect::<Vec<_>>(),
router_start.elapsed().as_millis()
));
for tcp in listener.incoming() {
let tcp = match tcp {
Ok(t) => t,
Err(e) => {
log(&format!("accept: {e}"));
continue;
}
};
let snaps = snapshots.clone();
let mtx = metrics.clone();
thread::Builder::new()
.name("router-conn".into())
.stack_size(256 * 1024)
.spawn(move || handle_conn(tcp, snaps, mtx))
.ok();
}
}
fn build_tls_config(a: &Args) -> Result<rustls::ServerConfig, String> {
use std::sync::Arc;
let cert_path = a
.tls_cert
.as_ref()
.ok_or("--tls-cert required for --tls-port")?;
let key_path = a
.tls_key
.as_ref()
.ok_or("--tls-key required for --tls-port")?;
let certs = load_certs(cert_path)?;
let key = load_key(key_path)?;
let default_certified = rustls::sign::CertifiedKey::new(
certs,
rustls::crypto::aws_lc_rs::sign::any_supported_type(&key)
.map_err(|e| format!("sign: {e:?}"))?,
);
let resolver: Arc<dyn rustls::server::ResolvesServerCert> = if let Some(dir) = &a.cert_dir {
Arc::new(SniResolver {
default: Arc::new(default_certified),
cert_dir: dir.clone(),
})
} else {
Arc::new(SingleCertResolver {
cert: Arc::new(default_certified),
})
};
Ok(rustls::ServerConfig::builder()
.with_no_client_auth()
.with_cert_resolver(resolver))
}
fn load_certs(path: &Path) -> Result<Vec<rustls::pki_types::CertificateDer<'static>>, String> {
let mut rd =
io::BufReader::new(fs::File::open(path).map_err(|e| format!("cert {path:?}: {e}"))?);
rustls_pemfile::certs(&mut rd)
.collect::<Result<Vec<_>, _>>()
.map_err(|e| format!("parse cert {path:?}: {e}"))
}
fn load_key(path: &Path) -> Result<rustls::pki_types::PrivateKeyDer<'static>, String> {
let mut rd =
io::BufReader::new(fs::File::open(path).map_err(|e| format!("key {path:?}: {e}"))?);
rustls_pemfile::private_key(&mut rd)
.map_err(|e| format!("parse key {path:?}: {e}"))?
.ok_or_else(|| format!("no private key in {path:?}"))
}
struct SniResolver {
default: Arc<rustls::sign::CertifiedKey>,
cert_dir: PathBuf,
}
impl std::fmt::Debug for SniResolver {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "SniResolver({:?})", self.cert_dir)
}
}
impl rustls::server::ResolvesServerCert for SniResolver {
fn resolve(
&self,
ch: rustls::server::ClientHello<'_>,
) -> Option<Arc<rustls::sign::CertifiedKey>> {
let host = match ch.server_name() {
Some(h) => h,
None => return Some(self.default.clone()),
};
let host_dir = self.cert_dir.join(host);
let cert = host_dir.join("fullchain.pem");
let key = host_dir.join("privkey.pem");
if !cert.exists() || !key.exists() {
return Some(self.default.clone());
}
let certs = match load_certs(&cert) {
Ok(c) => c,
Err(_) => return Some(self.default.clone()),
};
let key_der = match load_key(&key) {
Ok(k) => k,
Err(_) => return Some(self.default.clone()),
};
let signing = rustls::crypto::aws_lc_rs::sign::any_supported_type(&key_der).ok()?;
Some(Arc::new(rustls::sign::CertifiedKey::new(certs, signing)))
}
}
struct SingleCertResolver {
cert: Arc<rustls::sign::CertifiedKey>,
}
impl std::fmt::Debug for SingleCertResolver {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "SingleCertResolver")
}
}
impl rustls::server::ResolvesServerCert for SingleCertResolver {
fn resolve(
&self,
_ch: rustls::server::ClientHello<'_>,
) -> Option<Arc<rustls::sign::CertifiedKey>> {
Some(self.cert.clone())
}
}
fn handle_tls_conn(
tcp: TcpStream,
cfg: Arc<rustls::ServerConfig>,
snapshots: Arc<Snapshots>,
metrics: Arc<Metrics>,
) {
pin_to_p_core();
let conn = match rustls::ServerConnection::new(cfg) {
Ok(c) => c,
Err(e) => {
log(&format!("tls: ServerConnection: {e}"));
return;
}
};
let stream = rustls::StreamOwned::new(conn, tcp);
handle_stream(stream, snapshots, metrics, ws_bidi_pump_tls);
}
fn health_check_loop(name: String, snap: Arc<Snapshot>, socks_dir: PathBuf) {
let interval = Duration::from_secs(snap.meta.health_interval_secs as u64);
let cmd = snap.meta.health_cmd.clone();
let exec_sock = socks_dir.join(format!("{name}-w0.sock-exec"));
log(&format!(
"health: snapshot {name} -> '{cmd}' every {}s",
snap.meta.health_interval_secs
));
loop {
std::thread::sleep(interval);
if !exec_sock.exists() {
continue;
}
let result = run_health_probe(&exec_sock, &cmd);
let status = match result {
Ok(0) => "healthy",
Ok(code) => {
log(&format!("health: {name} unhealthy (exit {code})"));
"unhealthy"
}
Err(e) => {
log(&format!("health: {name} check failed: {e}"));
"unhealthy"
}
};
if let Ok(mut g) = snap.health_status.lock() {
*g = status.to_owned();
}
}
}
fn run_health_probe(exec_sock: &Path, cmd: &str) -> io::Result<i32> {
use std::io::{Read, Write};
use std::os::unix::net::UnixStream;
const FRAME_REQUEST: u8 = 0xff;
const FRAME_STDOUT: u8 = 1;
const FRAME_STDERR: u8 = 2;
const FRAME_EXIT: u8 = 5;
const FRAME_ERROR: u8 = 6;
let mut sock = UnixStream::connect(exec_sock)?;
sock.set_read_timeout(Some(Duration::from_secs(10)))?;
sock.set_write_timeout(Some(Duration::from_secs(2)))?;
let req = serde_json::json!({
"argv": ["/bin/sh", "-c", cmd],
"tty": false,
});
let body = serde_json::to_vec(&req)
.map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, format!("encode: {e}")))?;
let mut hdr = [0u8; 5];
hdr[0] = FRAME_REQUEST;
hdr[1..5].copy_from_slice(&(body.len() as u32).to_be_bytes());
sock.write_all(&hdr)?;
sock.write_all(&body)?;
loop {
let mut h = [0u8; 5];
if sock.read_exact(&mut h).is_err() {
return Err(io::Error::new(io::ErrorKind::UnexpectedEof, "EOF before EXIT"));
}
let kind = h[0];
let len = u32::from_be_bytes([h[1], h[2], h[3], h[4]]) as usize;
if len > 1 << 20 {
return Err(io::Error::new(io::ErrorKind::InvalidData, "frame too large"));
}
let mut payload = vec![0u8; len];
if !payload.is_empty() {
sock.read_exact(&mut payload)?;
}
match kind {
FRAME_STDOUT | FRAME_STDERR => continue,
FRAME_EXIT => {
if payload.len() == 4 {
let code = u32::from_be_bytes([payload[0], payload[1], payload[2], payload[3]]);
return Ok(code as i32);
}
return Ok(0);
}
FRAME_ERROR => {
let msg = String::from_utf8_lossy(&payload).into_owned();
return Err(io::Error::new(io::ErrorKind::Other, msg));
}
_ => continue,
}
}
}
fn watchdog_loop(
snapshots: Arc<Snapshots>,
worker_bin: PathBuf,
socks_dir: PathBuf,
) {
loop {
std::thread::sleep(Duration::from_secs(2));
for (name, snap) in snapshots.iter() {
let policy = snap.meta.restart_policy.as_str();
for (i, w) in snap.workers.iter().enumerate() {
let exit_status: Option<i32> = {
let mut p = w.proc.lock().unwrap();
match p.try_wait() {
Ok(Some(st)) => Some(st.code().unwrap_or(-1)),
_ => None,
}
};
if let Some(code) = exit_status {
use std::sync::atomic::Ordering;
w.last_exit_status.store(code, Ordering::SeqCst);
} else {
continue;
}
let last_code = w.last_exit_status.load(std::sync::atomic::Ordering::SeqCst);
let should_restart = match policy {
"no" => false,
"on-failure" => last_code != 0,
"always" => true,
other => {
log(&format!(
"watchdog: snapshot {name} unknown restart_policy {other:?}, treating as 'no'"
));
false
}
};
if !should_restart {
continue;
}
const MAX_RESTARTS: u32 = 5;
const WINDOW_MS: u64 = 60_000;
let now_ms = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_millis() as u64)
.unwrap_or(0);
let win_start = w
.restart_window_start_ms
.load(std::sync::atomic::Ordering::SeqCst);
let count = if win_start == 0 || now_ms.saturating_sub(win_start) > WINDOW_MS {
w.restart_window_start_ms
.store(now_ms, std::sync::atomic::Ordering::SeqCst);
w.restart_count
.store(0, std::sync::atomic::Ordering::SeqCst);
0
} else {
w.restart_count
.load(std::sync::atomic::Ordering::SeqCst)
};
if count >= MAX_RESTARTS {
log(&format!(
"watchdog: snapshot {name}-w{i} exceeded {MAX_RESTARTS} restarts/{}s; giving up",
WINDOW_MS / 1000
));
continue;
}
log(&format!(
"watchdog: respawning {name}-w{i} (policy={policy}, exit={last_code}, attempt {}/{MAX_RESTARTS})",
count + 1
));
match respawn_worker_in_place(
&worker_bin, &socks_dir, name, i, &snap.meta, w,
) {
Ok(()) => {
w.restart_count
.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
}
Err(e) => log(&format!("watchdog: respawn {name}-w{i} failed: {e}")),
}
}
}
}
}
fn gc_loop(_snapshots: Arc<Snapshots>, snapshots_dir: PathBuf) {
loop {
std::thread::sleep(Duration::from_secs(60));
let dir = match fs::read_dir(&snapshots_dir) {
Ok(d) => d,
Err(_) => continue,
};
for entry in dir.flatten() {
let mp = entry.path().join("metadata.json");
if !mp.exists() {
continue;
}
let m = match read_metadata(&mp) {
Some(m) => m,
None => continue,
};
let ttl = match m.ttl_seconds {
Some(t) if t > 0 => t,
_ => continue,
};
let baked = match m.baked_at {
Some(s) => s,
None => continue,
};
let age = match parse_iso8601_age_seconds(&baked) {
Some(a) => a,
None => continue,
};
if age < ttl {
continue;
}
let _ = fs::remove_dir_all(entry.path());
log(&format!(
"gc: expired {} (age {}s > ttl {}s)",
m.name, age, ttl
));
}
}
}
fn parse_iso8601_age_seconds(s: &str) -> Option<u64> {
let s = s.trim_end_matches('Z');
let (date, time) = s.split_once('T')?;
let mut d = date.split('-');
let y: i64 = d.next()?.parse().ok()?;
let mo: i64 = d.next()?.parse().ok()?;
let da: i64 = d.next()?.parse().ok()?;
let mut t = time.split(':');
let h: i64 = t.next()?.parse().ok()?;
let mi: i64 = t.next()?.parse().ok()?;
let se: i64 = t.next().and_then(|x| x.split('.').next()?.parse().ok())?;
let yp = if mo <= 2 { y - 1 } else { y };
let era = if yp >= 0 { yp } else { yp - 399 } / 400;
let yoe = (yp - era * 400) as u64;
let doy = (153 * (if mo > 2 { mo - 3 } else { mo + 9 } as i64) + 2) as u64 / 5 + da as u64 - 1;
let doe = yoe * 365 + yoe / 4 - yoe / 100 + doy;
let days = era * 146097 + doe as i64 - 719468;
let total = days * 86400 + h * 3600 + mi * 60 + se;
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.ok()?
.as_secs() as i64;
if now < total {
Some(0)
} else {
Some((now - total) as u64)
}
}
fn ctrlc<F: Fn() + Send + Sync + 'static>(f: F) {
use std::sync::atomic::{AtomicBool, Ordering};
let armed = Arc::new(AtomicBool::new(false));
let armed_c = armed.clone();
let f = Arc::new(f);
let f_int = f.clone();
let f_term = f.clone();
let armed_int = armed.clone();
let armed_term = armed_c;
extern "C" fn handler_int(_: libc::c_int) {
SIGNALED.store(true, std::sync::atomic::Ordering::SeqCst);
}
static SIGNALED: std::sync::atomic::AtomicBool = std::sync::atomic::AtomicBool::new(false);
unsafe {
let mut sa: libc::sigaction = std::mem::zeroed();
sa.sa_sigaction = handler_int as usize;
let _ = libc::sigaction(libc::SIGINT, &sa, std::ptr::null_mut());
let _ = libc::sigaction(libc::SIGTERM, &sa, std::ptr::null_mut());
}
std::thread::spawn(move || loop {
if SIGNALED.load(std::sync::atomic::Ordering::SeqCst) {
if !armed_int.swap(true, Ordering::SeqCst) {
f_int();
}
return;
}
std::thread::sleep(Duration::from_millis(100));
});
let _ = (armed, armed_term, f_term);
let _ = f;
}