use std::{
path::PathBuf,
thread::sleep,
time::{Duration, Instant},
};
use crate::{
client::props::PropsResponse,
error::{LmcppError, LmcppResult},
server::{
ipc::{ServerClient, ServerClientExt, error::ClientError},
process::guard::ServerProcessGuard,
types::start_args::ServerArgs,
},
};
#[derive(PartialEq, Debug)]
pub enum ServerStatus {
RunningModel(String),
Loading,
ErrorOrOffline(String),
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[repr(transparent)]
pub struct LoadBudget(pub std::time::Duration);
impl Default for LoadBudget {
fn default() -> Self {
LoadBudget(Duration::from_secs(45))
}
}
impl From<std::time::Duration> for LoadBudget {
fn from(value: std::time::Duration) -> Self {
LoadBudget(value)
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[repr(transparent)]
pub struct DownloadBudget(pub std::time::Duration);
impl Default for DownloadBudget {
fn default() -> Self {
DownloadBudget(Duration::from_secs(600))
}
}
impl From<std::time::Duration> for DownloadBudget {
fn from(value: std::time::Duration) -> Self {
DownloadBudget(value)
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[repr(transparent)]
pub struct RetryDelay(pub std::time::Duration);
impl Default for RetryDelay {
fn default() -> Self {
RetryDelay(Duration::from_millis(100))
}
}
impl From<std::time::Duration> for RetryDelay {
fn from(value: std::time::Duration) -> Self {
RetryDelay(value)
}
}
#[derive(Debug)]
pub struct LmcppServer {
pub guard: ServerProcessGuard,
pub client: Box<dyn ServerClient>,
pub model_name: String,
pub pidfile_path: PathBuf,
}
impl LmcppServer {
pub fn new(
bin_path: PathBuf,
bin_dir: PathBuf,
mut server_args: ServerArgs,
load_budget: LoadBudget,
download_budget: DownloadBudget,
retry_delay: RetryDelay,
client: Box<dyn ServerClient>,
) -> LmcppResult<Self> {
debug_assert!(
server_args.host.is_some(),
"Host should be set in LmcppServerBuilder."
);
let pid_id = client.pid_id();
debug_assert!(!pid_id.is_empty(), "PID ID must not be empty");
let pidfile = format!("{pid_id}.pid");
debug_assert!(
pidfile.len() <= 240,
"PID ID must not exceed 240 characters"
);
debug_assert!(sanitize_filename::is_sanitized(&pidfile));
let pidfile_path = bin_dir.join(pidfile);
let model_name = if let Some(path) = &server_args.model {
path.0
.file_name()
.and_then(|s| s.to_str())
.map(|s| s.to_owned())
.ok_or_else(|| LmcppError::InvalidConfig {
field: "model",
reason: format!(
"Model path `{}` has no filename component",
path.0.display()
),
})?
} else if let Some(repo) = &server_args.hf_repo {
repo.0
.split_once('/')
.map(|(_, model)| model.to_owned())
.ok_or_else(|| LmcppError::InvalidConfig {
field: "hf_repo",
reason: format!(
"Hugging Face repo `{}` is missing the `user/model` slash",
repo.0
),
})?
} else if let Some(url) = &server_args.model_url {
url.0
.path_segments()
.and_then(|segments| segments.last())
.filter(|s| !s.is_empty())
.map(|s| s.to_owned())
.ok_or_else(|| LmcppError::InvalidConfig {
field: "model_url",
reason: format!("URL `{}` has no filename component", url.0),
})?
} else {
return Err(LmcppError::InvalidConfig {
field: "model",
reason: "No model source specified (model, hf_repo, or model_url)".to_string(),
});
};
let expect_download = server_args.hf_repo.is_some() || server_args.model_url.is_some();
let alias = server_args
.model_id
.clone()
.unwrap_or_else(|| model_name.clone());
server_args.alias = Some(alias);
match Self::server_status(&client, Duration::from_millis(500), retry_delay.0) {
ServerStatus::Loading => {
crate::error!(
"The client at that address is already loading a model. This shouldn't happen. Attempting to kill it before starting LmcppServer with correct model."
);
crate::server::process::kill::kill_by_client(&pidfile_path, &client.host())?;
}
ServerStatus::RunningModel(running_model_name) => {
crate::error!(
"The client at that address is already running a model. This shouldn't happen. Expected: {}, got: {} Attempting to kill it before starting LmcppServer with correct model.",
model_name,
running_model_name
);
crate::server::process::kill::kill_by_client(&pidfile_path, &client.host())?;
}
ServerStatus::ErrorOrOffline(_) => (), };
let mut guard = ServerProcessGuard::new(&bin_path, &bin_dir, &pidfile_path, &server_args)?;
match Self::start_up_loop(
&client,
&mut guard,
download_budget,
load_budget,
retry_delay,
&model_name,
expect_download,
) {
Ok(()) => (),
Err(e) => {
crate::error!("Failed to start LmcppServer: {e}");
return Err(e);
}
}
let server = Self {
guard,
client,
model_name,
pidfile_path,
};
crate::trace!("Started LmcppServer: {server}");
Ok(server)
}
fn start_up_loop(
client: &Box<dyn ServerClient>,
guard: &mut ServerProcessGuard,
download_budget: DownloadBudget,
load_budget: LoadBudget,
retry_delay: RetryDelay,
model_name: &str,
expect_download: bool,
) -> LmcppResult<()> {
let overall_budget = if expect_download {
download_budget.0
} else {
load_budget.0
};
let retry_delay = retry_delay.0;
let deadline = Instant::now() + overall_budget;
loop {
match Self::server_status(&client, Duration::from_secs(3), retry_delay) {
ServerStatus::RunningModel(running_model_name)
if model_ids_match(&running_model_name, &model_name) =>
{
return Ok(());
}
ServerStatus::RunningModel(other) => {
guard.stop()?;
return Err(LmcppError::ServerLaunch(format!(
"Server started with wrong model. Expected {}, got {}",
model_name, other
)));
}
ServerStatus::Loading => (),
ServerStatus::ErrorOrOffline(msg) => {
if !expect_download {
guard.stop()?;
return Err(LmcppError::ServerLaunch(format!(
"Server failed to start: {msg}"
)));
}
}
}
if Instant::now() >= deadline {
guard.stop()?;
return Err(LmcppError::ServerLaunch(format!(
"Timed out after {overall_budget:?} waiting for model to load"
)));
}
sleep(retry_delay);
}
}
pub fn stop(&self) -> LmcppResult<()> {
self.guard.stop()?;
Ok(())
}
pub fn status(&self) -> ServerStatus {
Self::server_status(
&self.client,
Duration::from_millis(1000),
Duration::from_millis(100),
)
}
pub fn server_status(
client: &Box<dyn ServerClient>,
total_budget: Duration,
retry_delay: Duration,
) -> ServerStatus {
debug_assert!(
retry_delay > Duration::ZERO,
"Retry delay must be greater than zero"
);
debug_assert!(
total_budget >= retry_delay,
"Total budget must be greater than or equal to retry delay"
);
let deadline = Instant::now() + total_budget;
loop {
match client.get::<serde_json::Value>("/health") {
Ok(_) => break,
Err(e) => {
if let ClientError::Remote {
code: 503,
message: _,
} = &e
{
return ServerStatus::Loading;
}
if Instant::now() >= deadline {
return ServerStatus::ErrorOrOffline(format!(
"Health check failed after {:?}: {e:?}",
total_budget
));
}
crate::trace!("health check error ({e:?}); retrying in {retry_delay:?}");
sleep(retry_delay);
}
}
}
loop {
match client.get::<PropsResponse>("/props") {
Ok(props) => {
let path = match &props.model_path {
Some(p) => p,
None => {
return ServerStatus::ErrorOrOffline(
"No model path in /props response".to_string(),
);
}
};
let file_osstr = match path.file_name() {
Some(f) => f,
None => {
return ServerStatus::ErrorOrOffline(format!(
"Model path `{}` has no filename component",
path.display()
));
}
};
let file_str = match file_osstr.to_str() {
Some(s) => s,
None => {
return ServerStatus::ErrorOrOffline(format!(
"Model path `{}` is not valid UTF-8",
path.display()
));
}
};
let model_name = file_str.to_ascii_lowercase();
return ServerStatus::RunningModel(model_name);
}
Err(e) => {
if Instant::now() >= deadline {
return ServerStatus::ErrorOrOffline(format!(
"Health check failed after {:?}: {e:?}",
total_budget
));
}
crate::trace!("health check not ready ({e:?}); retrying in {retry_delay:?}");
sleep(retry_delay);
}
}
}
}
pub fn pid(&self) -> u32 {
self.guard.pid()
}
#[cfg(test)]
pub fn dummy() -> Self {
Self {
guard: ServerProcessGuard::dummy(),
client: Box::new(super::ipc::uds::UdsClient::dummy()),
model_name: "dummy_model".into(),
pidfile_path: PathBuf::from("/tmp/dummy.pid"),
}
}
}
impl Drop for LmcppServer {
fn drop(&mut self) {
if let Err(e) = self.stop() {
crate::error!("Failed to stop LmcppServer during drop: {e}");
}
}
}
impl PartialEq for LmcppServer {
#[inline]
fn eq(&self, other: &Self) -> bool {
self.client.pid_id() == other.client.pid_id() && self.model_name == other.model_name
}
}
impl Eq for LmcppServer {}
impl std::hash::Hash for LmcppServer {
#[inline]
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.client.pid_id().hash(state);
self.model_name.hash(state);
}
}
impl std::fmt::Display for LmcppServer {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"LmcppServer {{ {}, model: {} }}",
self.client, self.model_name
)
}
}
pub fn model_ids_match(a: &str, b: &str) -> bool {
debug_assert!(
!a.is_empty() && !b.is_empty(),
"Model identifiers must not be empty"
);
let canonicalise = |s: &str| {
let stem = s.rsplit_once('.').map_or(s, |(prefix, _)| prefix);
let mut out = String::with_capacity(stem.len());
let mut last_us = false; for ch in stem.chars().map(|c| c.to_ascii_lowercase()) {
if out.ends_with("ggu") && ch == 'f' {
out.truncate(out.len() - 3); last_us = out.ends_with('_');
continue;
}
let mapped = if ch.is_ascii_alphanumeric() { ch } else { '_' };
if mapped == '_' && last_us {
continue; }
last_us = mapped == '_';
out.push(mapped);
}
out.trim_matches('_').to_string()
};
let ca = canonicalise(a);
let cb = canonicalise(b);
if ca == cb {
return true;
}
let (short, long) = if ca.len() <= cb.len() {
(&ca, &cb)
} else {
(&cb, &ca)
};
let min_match = (short.len() * 75 + 99) / 100;
for len in (min_match..=short.len()).rev() {
for i in 0..=short.len() - len {
if long.contains(&short[i..i + len]) {
return true;
}
}
}
false
}
#[cfg(test)]
mod tests {
use std::hash::Hasher;
use super::*;
#[test]
fn model_ids_match_cases() {
let cases = [
("Gemma-3B-It-Q4_K_M.gguf", "gemma_3b_it_q4-k-m", true),
(
"google_gemma-3-1b-it-qat-GGUF:q4_k_m",
"google_gemma-3-1b-it-qat-Q4_K_M.gguf",
true,
),
("alpaca.gguf", "gemma_3b.gguf", false),
(
"Llama-3-8B-Instruct:q4_k_m",
"llama-3_8b-instruct.Q4_K_M.gguf",
true,
),
("Llama-3-8B-Instruct", "llama-3-8b-instruct.gguf", true),
(
"Mixtral-8x22B-Instruct-v0.1:q4_k_m",
"mixtral-8x22b-instruct-v0_1.Q4_K_M.gguf",
true,
),
(
"Mixtral-8x22B-Instruct-v0.1",
"mixtral-8x22b-instruct-v0_1.Q4_K_M.gguf",
true,
),
(
"Qwen2-72B-Instruct:q4_k_m",
"qwen2-72b-instruct.q4_k_m.GGUF",
true,
),
(
"Phi-3-mini-4k-instruct",
"phi-3-mini-4k-instruct.Q8_0.gguf",
true,
),
(
"Llama-3-8B-Instruct",
"mixtral-8x22b-instruct-v0_1.Q4_K_M.gguf",
false,
),
];
for (a, b, expect) in cases {
assert_eq!(model_ids_match(a, b), expect, "({a}, {b})");
}
}
#[test]
fn lmcpp_server_inequality_and_hash() {
use std::collections::hash_map::DefaultHasher;
let a = LmcppServer::dummy();
let mut b = LmcppServer::dummy();
b.model_name = "different".into();
assert_ne!(a, b);
let mut h1 = DefaultHasher::new();
let mut h2 = DefaultHasher::new();
std::hash::Hash::hash(&a, &mut h1);
std::hash::Hash::hash(&b, &mut h2);
assert_ne!(h1.finish(), h2.finish());
}
}