use crate::config::ModelConfig;
use crate::switcher::SleepLevel;
use anyhow::Result;
use dashmap::DashMap;
use std::collections::HashMap;
use std::process::Stdio;
use std::sync::Arc;
use std::time::Duration;
use tokio::io::{AsyncBufReadExt, BufReader};
use tokio::process::{Child, Command};
use tokio::sync::{Mutex, Notify};
use tracing::{debug, info, warn};
#[cfg(unix)]
fn kill_process_group(pid: u32) {
unsafe {
libc::kill(-(pid as libc::pid_t), libc::SIGKILL);
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ProcessState {
NotStarted,
Starting,
Running {
sleeping: Option<SleepLevel>,
},
Failed { reason: String },
}
struct ManagedProcess {
#[allow(dead_code)] config: ModelConfig,
state: ProcessState,
child: Option<Child>,
ready_notify: Arc<Notify>,
}
fn strip_ansi(s: &str) -> String {
let mut out = String::with_capacity(s.len());
let mut chars = s.chars();
while let Some(c) = chars.next() {
if c == '\x1b' {
for c2 in chars.by_ref() {
if c2.is_ascii_alphabetic() {
break;
}
}
} else {
out.push(c);
}
}
out
}
pub struct Orchestrator {
configs: HashMap<String, ModelConfig>,
processes: DashMap<String, Arc<Mutex<ManagedProcess>>>,
operation_lock: Mutex<()>,
health_timeout: Duration,
startup_timeout: Duration,
vllm_command: String,
}
impl Orchestrator {
pub fn new(configs: HashMap<String, ModelConfig>) -> Self {
Self::with_command(configs, "vllm".to_string())
}
pub fn with_command(configs: HashMap<String, ModelConfig>, vllm_command: String) -> Self {
let processes = DashMap::new();
for (name, config) in &configs {
processes.insert(
name.clone(),
Arc::new(Mutex::new(ManagedProcess {
config: config.clone(),
state: ProcessState::NotStarted,
child: None,
ready_notify: Arc::new(Notify::new()),
})),
);
}
Self {
configs,
processes,
operation_lock: Mutex::new(()),
health_timeout: Duration::from_secs(5),
startup_timeout: Duration::from_secs(300), vllm_command,
}
}
pub async fn process_state(&self, model: &str) -> Option<ProcessState> {
let process = self.processes.get(model)?;
let guard = process.lock().await;
Some(guard.state.clone())
}
pub fn registered_models(&self) -> Vec<String> {
self.configs.keys().cloned().collect()
}
pub fn sleep_level_for(&self, model: &str) -> Option<u8> {
self.configs.get(model).map(|c| c.sleep_level)
}
pub async fn ensure_running(&self, model: &str) -> Result<(), OrchestratorError> {
let process = self
.processes
.get(model)
.ok_or_else(|| OrchestratorError::ModelNotFound(model.to_string()))?;
{
let guard = process.lock().await;
match &guard.state {
ProcessState::Running { sleeping: None } => {
return Ok(());
}
ProcessState::Running { sleeping: Some(_) } => {
return Ok(());
}
ProcessState::Starting => {
let notify = guard.ready_notify.clone();
drop(guard);
notify.notified().await;
return Ok(());
}
ProcessState::Failed { reason } => {
return Err(OrchestratorError::ProcessFailed {
model: model.to_string(),
reason: reason.clone(),
});
}
ProcessState::NotStarted => {
}
}
}
let _op_guard = self.operation_lock.lock().await;
{
let guard = process.lock().await;
if matches!(
guard.state,
ProcessState::Running { .. } | ProcessState::Starting
) {
return Ok(());
}
}
self.start_process_internal(model, &process).await
}
async fn start_process_internal(
&self,
model: &str,
process: &Arc<Mutex<ManagedProcess>>,
) -> Result<(), OrchestratorError> {
let config = self
.configs
.get(model)
.ok_or_else(|| OrchestratorError::ModelNotFound(model.to_string()))?;
info!(model = %model, port = config.port, "Starting vLLM process");
{
let mut guard = process.lock().await;
guard.state = ProcessState::Starting;
}
let args = config.vllm_args();
debug!(model = %model, args = ?args, "vLLM command args");
let mut child = Command::new(&self.vllm_command)
.args(&args)
.env("VLLM_SERVER_DEV_MODE", "1") .env("NO_COLOR", "1") .stdout(Stdio::piped())
.stderr(Stdio::piped())
.process_group(0)
.spawn()
.map_err(|e| OrchestratorError::SpawnFailed {
model: model.to_string(),
reason: e.to_string(),
})?;
{
let model_name = model.to_string();
if let Some(stdout) = child.stdout.take() {
let name = model_name.clone();
tokio::spawn(async move {
let reader = BufReader::new(stdout);
let mut lines = reader.lines();
while let Ok(Some(line)) = lines.next_line().await {
let clean = strip_ansi(&line);
debug!(target: "vllm", model = %name, stream = "stdout", "{}", clean);
}
});
}
if let Some(stderr) = child.stderr.take() {
let name = model_name;
tokio::spawn(async move {
let reader = BufReader::new(stderr);
let mut lines = reader.lines();
while let Ok(Some(line)) = lines.next_line().await {
let clean = strip_ansi(&line);
debug!(target: "vllm", model = %name, stream = "stderr", "{}", clean);
}
});
}
}
{
let mut guard = process.lock().await;
guard.child = Some(child);
}
let health_url = format!("http://localhost:{}/health", config.port);
let start = std::time::Instant::now();
loop {
if start.elapsed() > self.startup_timeout {
let mut guard = process.lock().await;
guard.state = ProcessState::Failed {
reason: "Startup timeout".to_string(),
};
if let Some(ref mut child) = guard.child {
let _ = child.kill().await;
}
return Err(OrchestratorError::StartupTimeout {
model: model.to_string(),
});
}
match self.check_health(&health_url).await {
Ok(true) => {
info!(model = %model, "vLLM process is ready");
let mut guard = process.lock().await;
guard.state = ProcessState::Running { sleeping: None };
guard.ready_notify.notify_waiters();
return Ok(());
}
Ok(false) => {
debug!(model = %model, "Health check returned unhealthy, retrying...");
}
Err(e) => {
debug!(model = %model, error = %e, "Health check failed, retrying...");
}
}
{
let mut guard = process.lock().await;
if let Some(ref mut child) = guard.child {
match child.try_wait() {
Ok(Some(status)) => {
let reason = format!("Process exited with status: {}", status);
guard.state = ProcessState::Failed {
reason: reason.clone(),
};
return Err(OrchestratorError::ProcessFailed {
model: model.to_string(),
reason,
});
}
Ok(None) => {
}
Err(e) => {
warn!(model = %model, error = %e, "Failed to check process status");
}
}
}
}
tokio::time::sleep(Duration::from_millis(500)).await;
}
}
async fn check_health(&self, url: &str) -> Result<bool, String> {
use http_body_util::Empty;
let client: hyper_util::client::legacy::Client<_, Empty<bytes::Bytes>> =
hyper_util::client::legacy::Client::builder(hyper_util::rt::TokioExecutor::new())
.build_http();
let uri: hyper::Uri = url.parse().map_err(|e| format!("Invalid URL: {}", e))?;
let request = hyper::Request::builder()
.method("GET")
.uri(uri)
.body(Empty::new())
.map_err(|e| format!("Failed to build request: {}", e))?;
let result = tokio::time::timeout(self.health_timeout, client.request(request)).await;
match result {
Ok(Ok(response)) => Ok(response.status().is_success()),
Ok(Err(e)) => Err(format!("Request failed: {}", e)),
Err(_) => Err("Health check timeout".to_string()),
}
}
async fn check_process_alive(&self, model: &str) {
let Some(process) = self.processes.get(model) else {
return;
};
let mut guard = process.lock().await;
if let Some(ref mut child) = guard.child {
match child.try_wait() {
Ok(Some(status)) => {
warn!(
model = %model,
status = %status,
"Process found dead, resetting to NotStarted"
);
guard.child = None;
guard.state = ProcessState::NotStarted;
}
Ok(None) => {
}
Err(e) => {
warn!(model = %model, error = %e, "Failed to check process status");
}
}
}
}
pub async fn wake_model(&self, model: &str) -> Result<(), OrchestratorError> {
self.check_process_alive(model).await;
self.ensure_running(model).await?;
let process = self
.processes
.get(model)
.ok_or_else(|| OrchestratorError::ModelNotFound(model.to_string()))?;
let config = self
.configs
.get(model)
.ok_or_else(|| OrchestratorError::ModelNotFound(model.to_string()))?;
let actual_sleep_level = {
let guard = process.lock().await;
match &guard.state {
ProcessState::Running { sleeping: None } => return Ok(()),
ProcessState::Running {
sleeping: Some(level),
} => *level,
_ => {
return Ok(());
}
}
};
info!(model = %model, "Waking model");
let base_url = format!("http://localhost:{}", config.port);
self.post_request(
&format!("{}/wake_up", base_url),
None,
Duration::from_secs(30),
)
.await
.map_err(|e| OrchestratorError::WakeFailed {
model: model.to_string(),
reason: e,
})?;
if actual_sleep_level == SleepLevel::L2 {
debug!(model = %model, "L2 sleep: reloading weights");
self.post_request(
&format!("{}/collective_rpc", base_url),
Some(r#"{"method": "reload_weights"}"#),
Duration::from_secs(60),
)
.await
.map_err(|e| OrchestratorError::WakeFailed {
model: model.to_string(),
reason: e,
})?;
self.post_request(
&format!("{}/reset_prefix_cache", base_url),
None,
Duration::from_secs(30),
)
.await
.map_err(|e| {
warn!(model = %model, error = %e, "Failed to reset prefix cache");
})
.ok();
}
{
let mut guard = process.lock().await;
guard.state = ProcessState::Running { sleeping: None };
}
info!(model = %model, "Model is now awake");
Ok(())
}
pub async fn sleep_model(
&self,
model: &str,
level: SleepLevel,
) -> Result<(), OrchestratorError> {
self.check_process_alive(model).await;
let process = self
.processes
.get(model)
.ok_or_else(|| OrchestratorError::ModelNotFound(model.to_string()))?;
let config = self
.configs
.get(model)
.ok_or_else(|| OrchestratorError::ModelNotFound(model.to_string()))?;
{
let guard = process.lock().await;
match &guard.state {
ProcessState::Running {
sleeping: Some(_), ..
} => return Ok(()),
ProcessState::Running { sleeping: None } => {
}
_ => return Ok(()), }
}
info!(model = %model, level = ?level, "Putting model to sleep");
if level == SleepLevel::Stop {
let mut guard = process.lock().await;
if let Some(ref mut child) = guard.child {
info!(model = %model, "Stopping vLLM process group");
if let Some(pid) = child.id() {
kill_process_group(pid);
} else {
let _ = child.kill().await;
}
let _ = child.wait().await;
}
guard.child = None;
guard.state = ProcessState::NotStarted;
info!(model = %model, "vLLM process stopped");
return Ok(());
}
let level_num = match level {
SleepLevel::L1 => 1,
SleepLevel::L2 => 2,
SleepLevel::Stop => unreachable!(),
};
let url = format!("http://localhost:{}/sleep?level={}", config.port, level_num);
self.post_request(&url, None, Duration::from_secs(120))
.await
.map_err(|e| OrchestratorError::SleepFailed {
model: model.to_string(),
reason: e,
})?;
{
let mut guard = process.lock().await;
guard.state = ProcessState::Running {
sleeping: Some(level),
};
}
info!(model = %model, "Model is now sleeping");
Ok(())
}
pub async fn is_ready(&self, model: &str) -> bool {
let Some(process) = self.processes.get(model) else {
return false;
};
let guard = process.lock().await;
matches!(guard.state, ProcessState::Running { sleeping: None })
}
async fn post_request(
&self,
url: &str,
body: Option<&str>,
timeout: Duration,
) -> Result<(), String> {
use http_body_util::Full;
use hyper::Request;
let client: hyper_util::client::legacy::Client<_, Full<bytes::Bytes>> =
hyper_util::client::legacy::Client::builder(hyper_util::rt::TokioExecutor::new())
.build_http();
let uri: hyper::Uri = url.parse().map_err(|e| format!("Invalid URL: {}", e))?;
let has_body = body.is_some();
let body_bytes = body.map(|b| b.as_bytes().to_vec()).unwrap_or_default();
let request_body = Full::new(bytes::Bytes::from(body_bytes));
let mut req_builder = Request::builder().method("POST").uri(uri);
if has_body {
req_builder = req_builder.header("Content-Type", "application/json");
}
let request = req_builder
.body(request_body)
.map_err(|e| format!("Failed to build request: {}", e))?;
let response = tokio::time::timeout(timeout, client.request(request))
.await
.map_err(|_| "Request timeout".to_string())?
.map_err(|e| format!("Request failed: {}", e))?;
if response.status().is_success() {
Ok(())
} else {
Err(format!("Request failed with status: {}", response.status()))
}
}
}
#[derive(Debug, thiserror::Error)]
pub enum OrchestratorError {
#[error("model not found: {0}")]
ModelNotFound(String),
#[error("failed to spawn process for {model}: {reason}")]
SpawnFailed { model: String, reason: String },
#[error("startup timeout for {model}")]
StartupTimeout { model: String },
#[error("process failed for {model}: {reason}")]
ProcessFailed { model: String, reason: String },
#[error("failed to wake {model}: {reason}")]
WakeFailed { model: String, reason: String },
#[error("failed to sleep {model}: {reason}")]
SleepFailed { model: String, reason: String },
}
impl Drop for Orchestrator {
fn drop(&mut self) {
for entry in self.processes.iter() {
if let Ok(mut guard) = entry.value().try_lock()
&& let Some(ref mut child) = guard.child
{
let _ = child.start_kill();
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_orchestrator_creation() {
let mut configs = HashMap::new();
configs.insert(
"model-a".to_string(),
ModelConfig {
model_path: "test/model".to_string(),
port: 8001,
extra_args: vec![],
sleep_level: 1,
},
);
let orchestrator = Orchestrator::new(configs);
assert_eq!(orchestrator.registered_models(), vec!["model-a"]);
}
#[test]
fn test_strip_ansi() {
assert_eq!(strip_ansi("hello"), "hello");
assert_eq!(strip_ansi("\x1b[31mred\x1b[0m"), "red");
assert_eq!(
strip_ansi("\x1b[1;32mgreen bold\x1b[0m text"),
"green bold text"
);
}
}