use std::path::{Path, PathBuf};
use std::process::Stdio;
use std::sync::RwLock;
use std::time::Duration;
use async_trait::async_trait;
use tokio::process::Command;
use super::trainer::{LoraModelId, TrainedModel};
#[derive(Debug)]
pub enum ApplicatorError {
ModelNotFound(LoraModelId),
AdapterNotFound(PathBuf),
ServerStartFailed(String),
ServerStopFailed(String),
Io(std::io::Error),
NoHistory,
Other(String),
}
impl std::fmt::Display for ApplicatorError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::ModelNotFound(id) => write!(f, "Model not found: {}", id),
Self::AdapterNotFound(p) => write!(f, "Adapter not found: {}", p.display()),
Self::ServerStartFailed(msg) => write!(f, "Server start failed: {}", msg),
Self::ServerStopFailed(msg) => write!(f, "Server stop failed: {}", msg),
Self::Io(e) => write!(f, "IO error: {}", e),
Self::NoHistory => write!(f, "No model history for rollback"),
Self::Other(msg) => write!(f, "{}", msg),
}
}
}
impl std::error::Error for ApplicatorError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
Self::Io(e) => Some(e),
_ => None,
}
}
}
impl From<std::io::Error> for ApplicatorError {
fn from(e: std::io::Error) -> Self {
Self::Io(e)
}
}
#[async_trait]
pub trait ModelApplicator: Send + Sync {
async fn apply(&self, model: &TrainedModel) -> Result<(), ApplicatorError>;
async fn rollback(&self, to: &LoraModelId) -> Result<(), ApplicatorError>;
fn current(&self) -> Option<TrainedModel>;
fn previous_model_id(&self) -> Option<LoraModelId>;
}
#[derive(Debug, Clone)]
pub struct LlamaServerConfig {
pub base_model_path: PathBuf,
pub host: String,
pub port: u16,
pub n_gpu_layers: u32,
pub ctx_size: u32,
pub parallel: u32,
pub pid_file: PathBuf,
pub log_file: PathBuf,
pub server_path: String,
}
impl Default for LlamaServerConfig {
fn default() -> Self {
let data_dir = dirs::data_dir()
.unwrap_or_else(|| PathBuf::from("."))
.join("swarm-engine");
Self {
base_model_path: PathBuf::new(),
host: "127.0.0.1".to_string(),
port: 8080,
n_gpu_layers: 99,
ctx_size: 4096,
parallel: 4,
pid_file: data_dir.join("llama-server.pid"),
log_file: data_dir.join("llama-server.log"),
server_path: "llama-server".to_string(),
}
}
}
impl LlamaServerConfig {
pub fn base_model(mut self, path: impl Into<PathBuf>) -> Self {
self.base_model_path = path.into();
self
}
pub fn host(mut self, host: impl Into<String>) -> Self {
self.host = host.into();
self
}
pub fn port(mut self, port: u16) -> Self {
self.port = port;
self
}
pub fn n_gpu_layers(mut self, n: u32) -> Self {
self.n_gpu_layers = n;
self
}
pub fn parallel(mut self, n: u32) -> Self {
self.parallel = n;
self
}
pub fn ctx_size(mut self, size: u32) -> Self {
self.ctx_size = size;
self
}
}
pub struct LlamaServerApplicator {
config: LlamaServerConfig,
current_model: RwLock<Option<TrainedModel>>,
history: RwLock<Vec<TrainedModel>>,
}
impl LlamaServerApplicator {
pub fn new(config: LlamaServerConfig) -> Self {
Self {
config,
current_model: RwLock::new(None),
history: RwLock::new(Vec::new()),
}
}
async fn stop_server(&self) -> Result<(), ApplicatorError> {
if !self.config.pid_file.exists() {
return Ok(());
}
let pid_str = tokio::fs::read_to_string(&self.config.pid_file).await?;
let pid: u32 = pid_str
.trim()
.parse()
.map_err(|_| ApplicatorError::ServerStopFailed("Invalid PID".to_string()))?;
let status = Command::new("kill").arg(pid.to_string()).status().await?;
if !status.success() {
tracing::debug!(pid, "Process already stopped or kill failed");
}
tokio::time::sleep(Duration::from_millis(500)).await;
let _ = tokio::fs::remove_file(&self.config.pid_file).await;
Ok(())
}
async fn start_server(&self, lora_path: Option<&Path>) -> Result<(), ApplicatorError> {
if !self.config.base_model_path.exists() {
return Err(ApplicatorError::Other(format!(
"Base model not found: {}",
self.config.base_model_path.display()
)));
}
if let Some(lora) = lora_path {
if !lora.exists() {
return Err(ApplicatorError::AdapterNotFound(lora.to_path_buf()));
}
}
if let Some(parent) = self.config.pid_file.parent() {
tokio::fs::create_dir_all(parent).await?;
}
let mut cmd = Command::new(&self.config.server_path);
cmd.args([
"-m",
self.config.base_model_path.to_str().unwrap(),
"--host",
&self.config.host,
"--port",
&self.config.port.to_string(),
"-ngl",
&self.config.n_gpu_layers.to_string(),
"-c",
&self.config.ctx_size.to_string(),
"-np",
&self.config.parallel.to_string(),
"--cont-batching",
]);
if let Some(lora) = lora_path {
cmd.args(["--lora", lora.to_str().unwrap()]);
}
let log = std::fs::File::create(&self.config.log_file)?;
let log_err = log.try_clone()?;
cmd.stdout(Stdio::from(log));
cmd.stderr(Stdio::from(log_err));
match cmd.spawn() {
Ok(child) => {
let pid = child.id();
tokio::fs::write(&self.config.pid_file, pid.unwrap_or(0).to_string()).await?;
tracing::info!(
pid = pid.unwrap_or(0),
endpoint = format!("http://{}:{}", self.config.host, self.config.port),
lora = ?lora_path,
"llama-server started"
);
self.wait_for_ready().await?;
Ok(())
}
Err(e) => Err(ApplicatorError::ServerStartFailed(e.to_string())),
}
}
async fn wait_for_ready(&self) -> Result<(), ApplicatorError> {
let max_attempts = 30;
let delay = Duration::from_millis(500);
let addr = format!("{}:{}", self.config.host, self.config.port);
for attempt in 1..=max_attempts {
tokio::time::sleep(delay).await;
match tokio::net::TcpStream::connect(&addr).await {
Ok(_) => {
tracing::debug!(attempt, "llama-server is ready");
return Ok(());
}
Err(_) => {
tracing::trace!(attempt, "Waiting for llama-server...");
}
}
}
Err(ApplicatorError::ServerStartFailed(
"Timeout waiting for server to be ready".to_string(),
))
}
fn find_in_history(&self, id: &LoraModelId) -> Option<TrainedModel> {
let history = self.history.read().unwrap();
history.iter().find(|m| &m.id == id).cloned()
}
}
#[async_trait]
impl ModelApplicator for LlamaServerApplicator {
async fn apply(&self, model: &TrainedModel) -> Result<(), ApplicatorError> {
tracing::info!(
model_id = %model.id,
adapter = %model.adapter_path.display(),
"Applying model"
);
let previous_model = self.current();
self.stop_server().await?;
if let Err(e) = self.start_server(Some(&model.adapter_path)).await {
if let Some(ref prev) = previous_model {
tracing::warn!(
model_id = %prev.id,
"Apply failed, attempting to restore previous model"
);
let _ = self.start_server(Some(&prev.adapter_path)).await;
}
return Err(e);
}
if let Some(prev) = previous_model {
self.history.write().unwrap().push(prev);
}
*self.current_model.write().unwrap() = Some(model.clone());
tracing::info!(model_id = %model.id, "Model applied successfully");
Ok(())
}
async fn rollback(&self, to: &LoraModelId) -> Result<(), ApplicatorError> {
tracing::info!(target_id = %to, "Rolling back model");
let model = self
.find_in_history(to)
.ok_or_else(|| ApplicatorError::ModelNotFound(to.clone()))?;
self.stop_server().await?;
self.start_server(Some(&model.adapter_path)).await?;
*self.current_model.write().unwrap() = Some(model);
tracing::info!(target_id = %to, "Rollback completed");
Ok(())
}
fn current(&self) -> Option<TrainedModel> {
self.current_model.read().unwrap().clone()
}
fn previous_model_id(&self) -> Option<LoraModelId> {
let history = self.history.read().unwrap();
history.last().map(|m| m.id.clone())
}
}
pub struct NoOpApplicator {
current_model: RwLock<Option<TrainedModel>>,
history: RwLock<Vec<TrainedModel>>,
}
impl NoOpApplicator {
pub fn new() -> Self {
Self {
current_model: RwLock::new(None),
history: RwLock::new(Vec::new()),
}
}
}
impl Default for NoOpApplicator {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl ModelApplicator for NoOpApplicator {
async fn apply(&self, model: &TrainedModel) -> Result<(), ApplicatorError> {
if let Some(current) = self.current() {
self.history.write().unwrap().push(current);
}
*self.current_model.write().unwrap() = Some(model.clone());
Ok(())
}
async fn rollback(&self, to: &LoraModelId) -> Result<(), ApplicatorError> {
let history = self.history.read().unwrap();
let model = history
.iter()
.find(|m| &m.id == to)
.cloned()
.ok_or_else(|| ApplicatorError::ModelNotFound(to.clone()))?;
drop(history);
*self.current_model.write().unwrap() = Some(model);
Ok(())
}
fn current(&self) -> Option<TrainedModel> {
self.current_model.read().unwrap().clone()
}
fn previous_model_id(&self) -> Option<LoraModelId> {
let history = self.history.read().unwrap();
history.last().map(|m| m.id.clone())
}
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_model(id: &str, adapter: &str) -> TrainedModel {
TrainedModel {
id: LoraModelId::parse(id),
base_model: "test-base".to_string(),
adapter_path: PathBuf::from(adapter),
learn_model_name: "test-learn".to_string(),
episode_ids: vec![],
sample_count: 10,
created_at: 0,
metrics: None,
}
}
#[tokio::test]
async fn test_noop_applicator_apply() {
let applicator = NoOpApplicator::new();
let model = create_test_model("model-1", "/path/to/adapter1");
assert!(applicator.current().is_none());
applicator.apply(&model).await.unwrap();
assert_eq!(applicator.current().unwrap().id.as_str(), "model-1");
}
#[tokio::test]
async fn test_noop_applicator_history() {
let applicator = NoOpApplicator::new();
let model1 = create_test_model("model-1", "/path/to/adapter1");
let model2 = create_test_model("model-2", "/path/to/adapter2");
applicator.apply(&model1).await.unwrap();
assert!(applicator.previous_model_id().is_none());
applicator.apply(&model2).await.unwrap();
assert_eq!(applicator.previous_model_id().unwrap().as_str(), "model-1");
assert_eq!(applicator.current().unwrap().id.as_str(), "model-2");
}
#[tokio::test]
async fn test_noop_applicator_rollback() {
let applicator = NoOpApplicator::new();
let model1 = create_test_model("model-1", "/path/to/adapter1");
let model2 = create_test_model("model-2", "/path/to/adapter2");
applicator.apply(&model1).await.unwrap();
applicator.apply(&model2).await.unwrap();
applicator
.rollback(&LoraModelId::parse("model-1"))
.await
.unwrap();
assert_eq!(applicator.current().unwrap().id.as_str(), "model-1");
}
#[tokio::test]
async fn test_rollback_not_found() {
let applicator = NoOpApplicator::new();
let model = create_test_model("model-1", "/path/to/adapter");
applicator.apply(&model).await.unwrap();
let result = applicator
.rollback(&LoraModelId::parse("nonexistent"))
.await;
assert!(matches!(result, Err(ApplicatorError::ModelNotFound(_))));
}
#[test]
fn test_llama_server_config_builder() {
let config = LlamaServerConfig::default()
.base_model("/path/to/model.gguf")
.host("0.0.0.0")
.port(8081)
.n_gpu_layers(50)
.parallel(8);
assert_eq!(config.base_model_path, PathBuf::from("/path/to/model.gguf"));
assert_eq!(config.host, "0.0.0.0");
assert_eq!(config.port, 8081);
assert_eq!(config.n_gpu_layers, 50);
assert_eq!(config.parallel, 8);
}
}