use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{Duration, Instant};
use libgrite_ipc::{
messages::{ArchivedIpcRequest, IpcRequest, IpcResponse},
IpcCommand, Notification, IPC_SCHEMA_VERSION,
};
use nng::{options::Options, Message, Protocol, Socket};
use tokio::sync::{mpsc, oneshot, RwLock};
use tracing::{debug, error, info, warn};
use crate::error::DaemonError;
use crate::worker::{Worker, WorkerMessage};
struct WorkerHandle {
tx: mpsc::Sender<WorkerMessage>,
repo_root: PathBuf,
actor_id: String,
data_dir: PathBuf,
}
#[derive(Hash, Eq, PartialEq, Clone)]
struct WorkerKey {
repo_root: String,
actor_id: String,
}
pub struct Supervisor {
daemon_id: String,
host_id: String,
ipc_endpoint: String,
workers: Arc<RwLock<HashMap<WorkerKey, WorkerHandle>>>,
notify_rx: mpsc::Receiver<Notification>,
notify_tx: mpsc::Sender<Notification>,
shutdown_tx: Option<tokio::sync::broadcast::Sender<()>>,
idle_timeout: Option<Duration>,
last_activity_ms: Arc<AtomicU64>,
start_instant: Instant,
}
impl Supervisor {
pub fn new(ipc_endpoint: String, idle_timeout: Option<Duration>) -> Self {
let (notify_tx, notify_rx) = mpsc::channel(1000);
let start_instant = Instant::now();
Self {
daemon_id: uuid::Uuid::new_v4().to_string(),
host_id: get_host_id(),
ipc_endpoint,
workers: Arc::new(RwLock::new(HashMap::new())),
notify_rx,
notify_tx,
shutdown_tx: None,
idle_timeout,
last_activity_ms: Arc::new(AtomicU64::new(0)),
start_instant,
}
}
fn touch_activity(&self) {
let elapsed_ms = self.start_instant.elapsed().as_millis() as u64;
self.last_activity_ms.store(elapsed_ms, Ordering::Relaxed);
}
fn is_idle_timeout(&self) -> bool {
if let Some(timeout) = self.idle_timeout {
let last_activity_ms = self.last_activity_ms.load(Ordering::Relaxed);
let now_ms = self.start_instant.elapsed().as_millis() as u64;
let idle_ms = now_ms.saturating_sub(last_activity_ms);
idle_ms >= timeout.as_millis() as u64
} else {
false
}
}
pub async fn run(mut self) -> Result<(), DaemonError> {
info!(
daemon_id = %self.daemon_id,
endpoint = %self.ipc_endpoint,
idle_timeout_secs = ?self.idle_timeout.map(|d| d.as_secs()),
"Supervisor starting"
);
self.touch_activity();
let (shutdown_tx, _) = tokio::sync::broadcast::channel::<()>(1);
self.shutdown_tx = Some(shutdown_tx.clone());
let rep_socket = Socket::new(Protocol::Rep0)?;
rep_socket
.set_opt::<nng::options::RecvTimeout>(Some(Duration::from_millis(100)))
.map_err(|e| DaemonError::BindFailed(e.to_string()))?;
rep_socket
.listen(&self.ipc_endpoint)
.map_err(|e| DaemonError::BindFailed(format!("Failed to bind to {}: {}", self.ipc_endpoint, e)))?;
info!("Listening on {}", self.ipc_endpoint);
let pub_endpoint = format!("{}-pub", self.ipc_endpoint);
let pub_socket = Socket::new(Protocol::Pub0)?;
let _ = pub_socket.listen(&pub_endpoint);
let workers_clone = self.workers.clone();
let last_activity_ms = self.last_activity_ms.clone();
let idle_timeout = self.idle_timeout;
let start_instant = self.start_instant;
let idle_shutdown_tx = shutdown_tx.clone();
let mut heartbeat_shutdown = shutdown_tx.subscribe();
tokio::spawn(async move {
let mut interval = tokio::time::interval(Duration::from_secs(10));
loop {
tokio::select! {
_ = interval.tick() => {
let workers = workers_clone.read().await;
for handle in workers.values() {
let _ = handle.tx.send(WorkerMessage::Heartbeat).await;
}
if let Some(timeout) = idle_timeout {
let last_ms = last_activity_ms.load(Ordering::Relaxed);
let now_ms = start_instant.elapsed().as_millis() as u64;
let idle_ms = now_ms.saturating_sub(last_ms);
if idle_ms >= timeout.as_millis() as u64 {
info!("Idle timeout reached ({} ms), shutting down", idle_ms);
let _ = idle_shutdown_tx.send(());
break;
}
}
}
_ = heartbeat_shutdown.recv() => {
break;
}
}
}
});
let mut notify_rx = std::mem::replace(
&mut self.notify_rx,
mpsc::channel(1).1, );
let mut pub_shutdown = shutdown_tx.subscribe();
tokio::spawn(async move {
loop {
tokio::select! {
Some(notification) = notify_rx.recv() => {
if let Ok(bytes) = rkyv::to_bytes::<rkyv::rancor::Error>(¬ification) {
let msg = Message::from(bytes.as_slice());
let _ = pub_socket.send(msg);
}
}
_ = pub_shutdown.recv() => {
break;
}
}
}
});
let mut shutdown_rx = shutdown_tx.subscribe();
loop {
tokio::select! {
_ = shutdown_rx.recv() => {
info!("Shutdown signal received");
break;
}
result = tokio::task::spawn_blocking({
let socket = rep_socket.clone();
move || socket.recv()
}) => {
match result {
Ok(Ok(msg)) => {
self.touch_activity();
let response = self.handle_request(&msg).await;
if let Ok(bytes) = rkyv::to_bytes::<rkyv::rancor::Error>(&response) {
let reply = Message::from(bytes.as_slice());
if let Err(e) = rep_socket.send(reply) {
warn!("Failed to send response: {:?}", e);
}
}
}
Ok(Err(nng::Error::TimedOut)) => {
continue;
}
Ok(Err(e)) => {
warn!("Receive error: {}", e);
}
Err(e) => {
error!("Task join error: {}", e);
}
}
}
}
}
self.shutdown_workers().await;
info!("Supervisor stopped");
Ok(())
}
async fn handle_request(&self, msg: &Message) -> IpcResponse {
let archived = match rkyv::access::<ArchivedIpcRequest, rkyv::rancor::Error>(msg) {
Ok(a) => a,
Err(e) => {
return IpcResponse::error(
"unknown".to_string(),
"deserialization".to_string(),
format!("Failed to deserialize request: {}", e),
);
}
};
let version: u32 = archived.ipc_schema_version.into();
if version != IPC_SCHEMA_VERSION {
return IpcResponse::error(
archived.request_id.to_string(),
"version_mismatch".to_string(),
format!("Expected version {}, got {}", IPC_SCHEMA_VERSION, version),
);
}
let request: IpcRequest = match rkyv::deserialize::<IpcRequest, rkyv::rancor::Error>(archived) {
Ok(r) => r,
Err(e) => {
return IpcResponse::error(
archived.request_id.to_string(),
"deserialization".to_string(),
format!("Failed to deserialize request: {}", e),
);
}
};
debug!(
request_id = %request.request_id,
repo = %request.repo_root,
actor = %request.actor_id,
"Handling request"
);
if matches!(request.command, IpcCommand::DaemonStop) {
if let Some(ref tx) = self.shutdown_tx {
let _ = tx.send(());
}
return IpcResponse::success(
request.request_id,
Some(serde_json::json!({"stopping": true}).to_string()),
);
}
self.route_to_worker(request).await
}
async fn route_to_worker(&self, request: IpcRequest) -> IpcResponse {
let key = WorkerKey {
repo_root: request.repo_root.clone(),
actor_id: request.actor_id.clone(),
};
let tx = {
let workers = self.workers.read().await;
workers.get(&key).map(|h| h.tx.clone())
};
let tx = match tx {
Some(tx) => tx,
None => {
match self.create_worker(
PathBuf::from(&request.repo_root),
request.actor_id.clone(),
PathBuf::from(&request.data_dir),
).await {
Ok(tx) => tx,
Err(e) => {
return IpcResponse::error(
request.request_id,
"worker_creation_failed".to_string(),
e.to_string(),
);
}
}
}
};
let (response_tx, response_rx) = oneshot::channel();
let msg = WorkerMessage::Command {
request_id: request.request_id.clone(),
command: request.command,
response_tx,
};
if let Err(_) = tx.send(msg).await {
return IpcResponse::error(
request.request_id,
"worker_unavailable".to_string(),
"Worker channel closed".to_string(),
);
}
match tokio::time::timeout(Duration::from_secs(30), response_rx).await {
Ok(Ok(response)) => response,
Ok(Err(_)) => IpcResponse::error(
request.request_id,
"worker_error".to_string(),
"Worker response channel dropped".to_string(),
),
Err(_) => IpcResponse::error(
request.request_id,
"timeout".to_string(),
"Worker timed out".to_string(),
),
}
}
async fn create_worker(
&self,
repo_root: PathBuf,
actor_id: String,
data_dir: PathBuf,
) -> Result<mpsc::Sender<WorkerMessage>, DaemonError> {
let key = WorkerKey {
repo_root: repo_root.to_string_lossy().to_string(),
actor_id: actor_id.clone(),
};
{
let workers = self.workers.read().await;
if let Some(handle) = workers.get(&key) {
return Ok(handle.tx.clone());
}
}
let (tx, rx) = mpsc::channel(100);
let worker = Worker::new(
repo_root.clone(),
actor_id.clone(),
data_dir.clone(),
rx,
self.notify_tx.clone(),
self.host_id.clone(),
self.ipc_endpoint.clone(),
)?;
tokio::task::spawn_blocking(move || {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
rt.block_on(worker.run());
});
{
let mut workers = self.workers.write().await;
workers.insert(
key,
WorkerHandle {
tx: tx.clone(),
repo_root,
actor_id,
data_dir,
},
);
}
Ok(tx)
}
async fn shutdown_workers(&self) {
let workers = self.workers.read().await;
for handle in workers.values() {
let _ = handle.tx.send(WorkerMessage::Shutdown).await;
}
tokio::time::sleep(Duration::from_millis(500)).await;
}
}
fn get_host_id() -> String {
std::env::var("HOSTNAME")
.or_else(|_| {
std::fs::read_to_string("/etc/hostname")
.map(|s| s.trim().to_string())
})
.unwrap_or_else(|_| uuid::Uuid::new_v4().to_string())
}