use std::{
error::Error as StdError,
path::{Path, PathBuf},
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
time::{Duration, Instant},
};
use tokio::{
process::{Child, Command},
sync::Mutex,
time::sleep,
};
use crate::{
config::{EngineConfig, ModelKind},
error::EngineError,
health::HealthState,
};
const ALLOWED_BIN_PREFIXES: &[&str] = &["/usr/local/bin/", "/opt/gradatum/bin/"];
const SHUTDOWN_GRACE_SECS: u64 = 5;
const BACKOFF_INIT_MS: u64 = 500;
const BACKOFF_MAX_MS: u64 = 30_000;
const ALLOWED_EXTRA_FLAGS: &[&str] = &[
"--flash-attn",
"-fa",
"--no-mmap",
"--mlock",
"--no-kv-offload",
"-nkvo",
"--cont-batching",
"--no-cont-batching",
"--batch-size",
"-b",
"--ubatch-size",
"-ub",
"--threads-http",
"--keep",
"--defrag-thold",
"--cache-type-k",
"-ctk",
"--cache-type-v",
"-ctv",
"--numa",
"--log-disable",
"--log-prefix",
"--log-timestamps",
"--rope-scaling",
"--rope-scale",
"--rope-freq-base",
"--rope-freq-scale",
"--yarn-orig-ctx",
"--yarn-ext-factor",
"--yarn-attn-factor",
"--yarn-beta-slow",
"--yarn-beta-fast",
"--seed",
"-s",
"--poll",
"--swa-full",
"--cache-reuse",
"--reasoning",
"--reasoning-format",
"--reasoning-budget",
"--temp",
"--temperature",
"--top-k",
"--top-p",
"--min-p",
"--presence-penalty",
"--repeat-penalty",
"--n-predict",
"-n",
];
const GPU_ENV_PREFIXES: &[&str] = &[
"GGML_", "VK_", "HIP_", "ROCR_", "ROCM_", "HSA_", "CUDA_", "NVIDIA_", "MESA_", "RADV_",
];
const ENV_PASSTHROUGH: &[&str] = &["PATH", "HOME", "LD_LIBRARY_PATH"];
fn build_child_args(cfg: &EngineConfig) -> Vec<String> {
let mut args: Vec<String> = vec![
"--model".into(),
cfg.model_path.clone(),
"--port".into(),
cfg.child_port.to_string(),
"--host".into(),
"127.0.0.1".into(), "--n-gpu-layers".into(),
cfg.gpu_layers.to_string(),
"--threads".into(),
cfg.n_threads.to_string(),
"--ctx-size".into(),
cfg.context_len.to_string(),
"--parallel".into(),
cfg.parallel.to_string(),
];
if cfg.model_kind == ModelKind::Embed {
args.push("--embedding".into());
}
if let Some(mmproj) = &cfg.mmproj_path {
args.push("--mmproj".into());
args.push(mmproj.to_string_lossy().into_owned());
}
args.extend(cfg.extra_args.iter().cloned());
args
}
#[derive(Debug, PartialEq)]
pub enum ChildState {
Ready,
Starting,
StartupTimeout,
}
pub struct LlamaServerSupervisor {
child: Mutex<Option<Child>>,
shutdown_requested: AtomicBool,
child_port: u16,
config: EngineConfig,
canonical_bin: PathBuf,
pub client: reqwest::Client,
}
impl LlamaServerSupervisor {
pub fn new(config: EngineConfig) -> Result<Arc<Self>, EngineError> {
let canonical_bin = canonicalize_bin_path(&config.llama_server_bin)?;
if config.child_port <= 1024 {
return Err(EngineError::ModelLoad(format!(
"child_port {} invalide — doit être > 1024 (SP-P0-4)",
config.child_port
)));
}
if config.child_port == config.port {
return Err(EngineError::ModelLoad(format!(
"child_port {} doit être différent de port {} — collision de port",
config.child_port, config.port
)));
}
validate_extra_args(&config.extra_args)?;
let child_port = config.child_port;
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(config.timeout_secs))
.build()
.map_err(|e| EngineError::ModelLoad(format!("construction client HTTP : {e}")))?;
Ok(Arc::new(Self {
child: Mutex::new(None),
shutdown_requested: AtomicBool::new(false),
child_port,
config,
canonical_bin,
client,
}))
}
pub async fn spawn_child(&self) -> Result<(), EngineError> {
let bin = &self.canonical_bin;
let cfg = &self.config;
let args = build_child_args(cfg);
let mut cmd = Command::new(bin);
cmd.args(&args)
.process_group(0)
.kill_on_drop(true)
.stdout(std::process::Stdio::inherit())
.stderr(std::process::Stdio::inherit());
cmd.env_clear();
inject_allowed_env(&mut cmd);
let child = cmd.spawn().map_err(|e| {
EngineError::ModelLoad(format!("spawn llama-server échoué (bin={bin:?}) : {e}"))
})?;
tracing::info!(
child_port = self.child_port,
model = %cfg.model_path,
"llama-server spawné (PID={})",
child.id().unwrap_or(0)
);
*self.child.lock().await = Some(child);
Ok(())
}
pub async fn wait_ready(&self, health: &HealthState) -> ChildState {
let deadline = Instant::now() + Duration::from_secs(self.config.startup_timeout_secs);
let health_url = format!("http://127.0.0.1:{}/health", self.child_port);
let poll_client = reqwest::Client::builder()
.timeout(Duration::from_secs(2))
.build()
.unwrap_or_else(|_| self.client.clone());
while Instant::now() < deadline {
match poll_client.get(&health_url).send().await {
Ok(resp) if resp.status().is_success() => {
tracing::info!(
child_port = self.child_port,
"llama-server prêt (HTTP {})",
resp.status()
);
health.set_ready();
return ChildState::Ready;
}
Ok(resp) => {
tracing::debug!(
status = %resp.status(),
"llama-server /health non-prêt, attente..."
);
}
Err(e) if is_connection_refused(&e) => {
tracing::debug!(
"llama-server /health : connection refused, démarrage en cours"
);
}
Err(e) => {
tracing::warn!(error = %e, "llama-server /health : erreur poll");
}
}
sleep(Duration::from_millis(500)).await;
}
tracing::error!(
timeout_secs = self.config.startup_timeout_secs,
"llama-server n'a pas répondu dans le timeout de démarrage"
);
ChildState::StartupTimeout
}
pub async fn supervise_loop(
self: Arc<Self>,
health: Arc<HealthState>,
initial_ready_at: Option<Instant>,
) {
let mut restart_budget = self.config.child_restart_max;
let mut backoff_ms = BACKOFF_INIT_MS;
let mut last_ready_at: Option<Instant> = initial_ready_at;
loop {
if self.shutdown_requested.load(Ordering::Relaxed) {
tracing::info!("supervise_loop : shutdown demandé, sortie");
break;
}
let exit_status = {
let mut guard = self.child.lock().await;
match guard.as_mut() {
None => {
tracing::warn!("supervise_loop : pas d'enfant à surveiller");
break;
}
Some(child) => child.wait().await,
}
};
if self.shutdown_requested.load(Ordering::Relaxed) {
break;
}
match exit_status {
Ok(status) => {
tracing::warn!(status = ?status, "llama-server s'est arrêté");
}
Err(e) => {
tracing::error!(error = %e, "erreur wait() sur llama-server");
}
}
let uptime_stable = last_ready_at
.map(|t| t.elapsed() >= Duration::from_secs(self.config.min_stable_uptime_secs))
.unwrap_or(false);
if uptime_stable {
tracing::info!(
min_stable_secs = self.config.min_stable_uptime_secs,
"uptime stable avant crash → reset budget + backoff"
);
restart_budget = self.config.child_restart_max;
backoff_ms = BACKOFF_INIT_MS;
}
if restart_budget == 0 {
tracing::error!(
max_restarts = self.config.child_restart_max,
"llama-server : budget restart épuisé — moteur unhealthy (fallback gateway)"
);
health.set_unhealthy();
break;
}
restart_budget -= 1;
let current_backoff = backoff_ms;
tracing::warn!(
backoff_ms = current_backoff,
restarts_remaining = restart_budget,
"llama-server : restart dans {}ms",
current_backoff
);
sleep(Duration::from_millis(current_backoff)).await;
if self.shutdown_requested.load(Ordering::Relaxed) {
tracing::info!("supervise_loop : shutdown pendant backoff, pas de respawn");
break;
}
match self.spawn_child().await {
Ok(()) => {
tracing::info!(restarts_remaining = restart_budget, "llama-server restarté");
let child_state = self.wait_ready(&health).await;
if child_state == ChildState::StartupTimeout {
tracing::error!("llama-server : timeout redémarrage — unhealthy");
health.set_unhealthy();
break;
}
last_ready_at = Some(Instant::now());
backoff_ms = (backoff_ms * 2).min(BACKOFF_MAX_MS);
}
Err(e) => {
tracing::error!(error = %e, "llama-server : re-spawn échoué — unhealthy");
health.set_unhealthy();
break;
}
}
}
}
pub async fn shutdown(&self) {
self.shutdown_requested.store(true, Ordering::Relaxed);
let mut guard = self.child.lock().await;
let child = match guard.as_mut() {
None => return,
Some(c) => c,
};
let pid = child.id().unwrap_or(0);
if pid > 0 {
use nix::{
sys::signal::{killpg, Signal},
unistd::Pid,
};
let pgid = Pid::from_raw(pid as i32);
if let Err(e) = killpg(pgid, Signal::SIGTERM) {
tracing::warn!(pid, error = %e, "SIGTERM vers process group de llama-server échoué");
} else {
tracing::info!(pid, "SIGTERM envoyé à llama-server");
}
}
let grace = Duration::from_secs(SHUTDOWN_GRACE_SECS);
let result = tokio::time::timeout(grace, child.wait()).await;
match result {
Ok(Ok(status)) => {
tracing::info!(status = ?status, "llama-server terminé proprement");
}
Ok(Err(e)) => {
tracing::warn!(error = %e, "wait() après SIGTERM échoué — SIGKILL");
let _ = child.kill().await;
}
Err(_timeout) => {
tracing::warn!(
grace_secs = SHUTDOWN_GRACE_SECS,
"llama-server n'a pas répondu au SIGTERM — SIGKILL"
);
let _ = child.kill().await;
let _ = child.wait().await;
}
}
*guard = None;
}
pub fn child_base_url(&self) -> String {
format!("http://127.0.0.1:{}", self.child_port)
}
}
pub fn canonicalize_bin_path(bin: &Path) -> Result<PathBuf, EngineError> {
let canonical = bin.canonicalize().map_err(|e| {
EngineError::ModelLoad(format!(
"llama_server_bin canonicalize échoué ({bin:?}) : {e}"
))
})?;
let canonical_str = canonical
.to_str()
.ok_or_else(|| EngineError::ModelLoad("llama_server_bin chemin non-UTF8".into()))?;
let allowed = ALLOWED_BIN_PREFIXES
.iter()
.any(|prefix| canonical_str.starts_with(prefix));
if !allowed {
return Err(EngineError::ModelLoad(format!(
"llama_server_bin hors préfixe autorisé ({canonical_str:?}) — \
préfixes acceptés : {ALLOWED_BIN_PREFIXES:?} (SP-P0-4)"
)));
}
Ok(canonical)
}
pub fn validate_bin_path(bin: &Path) -> Result<(), EngineError> {
canonicalize_bin_path(bin).map(|_| ())
}
pub fn validate_extra_args(extra_args: &[String]) -> Result<(), EngineError> {
for arg in extra_args {
if !arg.starts_with('-') {
continue;
}
let key = arg.split('=').next().unwrap_or(arg.as_str());
if !ALLOWED_EXTRA_FLAGS.contains(&key) {
return Err(EngineError::BadRequest(format!(
"extra_args : flag '{key}' non autorisé — \
seuls les flags de l'allow-list ALLOWED_EXTRA_FLAGS sont acceptés. \
Toute extension est une décision de sécurité explicite."
)));
}
}
Ok(())
}
pub fn inject_allowed_env(cmd: &mut Command) {
for (key, val) in std::env::vars_os() {
let key_str = match key.to_str() {
Some(s) => s,
None => continue, };
let pass = ENV_PASSTHROUGH.contains(&key_str)
|| GPU_ENV_PREFIXES
.iter()
.any(|prefix| key_str.starts_with(prefix));
if pass {
cmd.env(&key, &val);
}
}
}
fn is_connection_refused(e: &reqwest::Error) -> bool {
if let Some(source) = e.source() {
let msg = source.to_string().to_lowercase();
if msg.contains("connection refused") || msg.contains("connexion refusée") {
return true;
}
}
let msg = e.to_string().to_lowercase();
msg.contains("connection refused") || msg.contains("os error 111")
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::{ModelKind, RuntimeKind};
#[test]
fn validate_bin_path_accepts_usr_local_bin() {
let result = validate_bin_path(Path::new("/usr/local/bin/llama-server"));
assert!(
result.is_ok(),
"llama-server dans /usr/local/bin/ doit être accepté : {result:?}"
);
}
#[test]
fn validate_bin_path_rejects_arbitrary_path() {
let result = validate_bin_path(Path::new("/tmp/malicious-llama-server"));
assert!(
result.is_err(),
"binaire dans /tmp doit être rejeté (SP-P0-4)"
);
}
#[test]
fn validate_bin_path_rejects_nonexistent() {
let result = validate_bin_path(Path::new("/usr/local/bin/does-not-exist-engine-test"));
assert!(result.is_err(), "binaire absent doit être rejeté");
}
#[test]
fn validate_bin_path_rejects_sh_injection() {
let result = validate_bin_path(Path::new("/usr/local/bin/../../bin/sh"));
assert!(result.is_err(), "path traversal doit être rejeté (SP-P0-4)");
}
#[test]
fn canonicalize_bin_path_returns_canonical_pathbuf() {
let result = canonicalize_bin_path(Path::new("/usr/local/bin/llama-server"));
assert!(result.is_ok(), "doit retourner un PathBuf canonicalisé");
let p = result.unwrap();
assert!(
p.is_absolute(),
"le PathBuf canonicalisé doit être absolu : {p:?}"
);
}
#[test]
fn extra_args_accepts_flash_attn() {
assert!(
validate_extra_args(&["--flash-attn".into()]).is_ok(),
"--flash-attn doit être accepté"
);
}
#[test]
fn extra_args_accepts_log_disable() {
assert!(
validate_extra_args(&["--log-disable".into()]).is_ok(),
"--log-disable doit être accepté"
);
}
#[test]
fn extra_args_accepts_no_mmap() {
assert!(
validate_extra_args(&["--no-mmap".into()]).is_ok(),
"--no-mmap doit être accepté"
);
}
#[test]
fn extra_args_accepts_batch_size_with_value() {
assert!(
validate_extra_args(&["--batch-size".into(), "512".into()]).is_ok(),
"--batch-size 512 doit être accepté"
);
}
#[test]
fn extra_args_accepts_batch_size_equals_form() {
assert!(
validate_extra_args(&["--batch-size=512".into()]).is_ok(),
"--batch-size=512 doit être accepté"
);
}
#[test]
fn extra_args_rejects_host_override() {
let result = validate_extra_args(&["--host".into(), "0.0.0.0".into()]);
assert!(result.is_err(), "--host doit être rejeté (allow-list)");
assert!(
result.unwrap_err().to_string().contains("--host"),
"message d'erreur doit citer --host"
);
}
#[test]
fn extra_args_rejects_api_key_file() {
let result = validate_extra_args(&["--api-key-file".into(), "/etc/passwd".into()]);
assert!(result.is_err(), "--api-key-file doit être rejeté");
}
#[test]
fn extra_args_rejects_model_url() {
let result =
validate_extra_args(&["--model-url".into(), "http://evil.example/evil.gguf".into()]);
assert!(result.is_err(), "--model-url doit être rejeté");
}
#[test]
fn extra_args_rejects_n_gpu_layers() {
let result = validate_extra_args(&["--n-gpu-layers".into(), "35".into()]);
assert!(
result.is_err(),
"--n-gpu-layers doit être rejeté (géré par config gpu_layers)"
);
}
#[test]
fn extra_args_rejects_gpu_layers_alias() {
let result = validate_extra_args(&["--gpu-layers".into(), "35".into()]);
assert!(result.is_err(), "--gpu-layers doit être rejeté");
let result2 = validate_extra_args(&["-ngl".into(), "35".into()]);
assert!(result2.is_err(), "-ngl doit être rejeté");
}
#[test]
fn extra_args_rejects_model_short_flag() {
let result = validate_extra_args(&["-m".into(), "/tmp/evil.gguf".into()]);
assert!(result.is_err(), "-m doit être rejeté");
}
#[test]
fn extra_args_rejects_port_override() {
let result = validate_extra_args(&["--port=9999".into()]);
assert!(result.is_err(), "--port= doit être rejeté");
}
#[test]
fn extra_args_rejects_lora_path() {
let result = validate_extra_args(&["--lora".into(), "/tmp/evil.bin".into()]);
assert!(result.is_err(), "--lora doit être rejeté");
}
#[test]
fn extra_args_empty_always_ok() {
assert!(
validate_extra_args(&[]).is_ok(),
"extra_args vide toujours OK"
);
}
#[test]
fn supervisor_rejects_extra_args_hors_allow_list() {
let mut config = make_test_config();
config.extra_args = vec!["--host".into(), "0.0.0.0".into()];
assert!(
LlamaServerSupervisor::new(config).is_err(),
"superviseur doit rejeter extra_args hors allow-list"
);
}
#[test]
fn supervisor_rejects_n_gpu_layers_in_extra_args() {
let mut config = make_test_config();
config.extra_args = vec!["--n-gpu-layers".into(), "35".into()];
assert!(
LlamaServerSupervisor::new(config).is_err(),
"superviseur doit rejeter --n-gpu-layers en extra_args (doublon config)"
);
}
#[test]
fn supervisor_accepts_extra_args_in_allow_list() {
let mut config = make_test_config();
config.extra_args = vec!["--log-disable".into()];
assert!(
LlamaServerSupervisor::new(config).is_ok(),
"superviseur doit accepter extra_args dans l'allow-list"
);
}
#[test]
fn inject_allowed_env_preserves_path() {
assert!(
ENV_PASSTHROUGH.contains(&"PATH"),
"PATH doit être dans ENV_PASSTHROUGH"
);
}
#[test]
fn inject_allowed_env_preserves_gpu_prefixes() {
let required_prefixes = ["VK_", "MESA_", "RADV_", "GGML_", "HIP_", "ROCR_"];
for prefix in &required_prefixes {
assert!(
GPU_ENV_PREFIXES.contains(prefix),
"préfixe GPU {prefix} doit être dans GPU_ENV_PREFIXES"
);
}
}
#[test]
fn inject_allowed_env_excludes_llama_arg_host() {
let key = "LLAMA_ARG_HOST";
let in_passthrough = ENV_PASSTHROUGH.contains(&key);
let in_gpu = GPU_ENV_PREFIXES.iter().any(|p| key.starts_with(p));
assert!(
!in_passthrough && !in_gpu,
"LLAMA_ARG_HOST ne doit pas être dans l'allow-list env"
);
}
#[test]
fn inject_allowed_env_excludes_hf_token() {
let key = "HF_TOKEN";
let pass =
ENV_PASSTHROUGH.contains(&key) || GPU_ENV_PREFIXES.iter().any(|p| key.starts_with(p));
assert!(!pass, "HF_TOKEN ne doit pas être dans l'allow-list env");
}
#[test]
fn inject_allowed_env_excludes_llama_api_key() {
let key = "LLAMA_API_KEY";
let pass =
ENV_PASSTHROUGH.contains(&key) || GPU_ENV_PREFIXES.iter().any(|p| key.starts_with(p));
assert!(
!pass,
"LLAMA_API_KEY ne doit pas être dans l'allow-list env"
);
}
#[test]
fn inject_allowed_env_injects_to_command() {
let test_key = "VK_ICD_FILENAMES";
let would_pass = ENV_PASSTHROUGH.contains(&test_key)
|| GPU_ENV_PREFIXES.iter().any(|p| test_key.starts_with(p));
assert!(
would_pass,
"VK_ICD_FILENAMES (préfixe VK_) doit passer le filtre GPU"
);
let mut cmd = Command::new("/usr/local/bin/llama-server");
cmd.env_clear();
inject_allowed_env(&mut cmd); }
#[test]
fn supervisor_rejects_privileged_port() {
let mut config = make_test_config();
config.child_port = 80;
assert!(
LlamaServerSupervisor::new(config).is_err(),
"child_port=80 doit être rejeté"
);
}
#[test]
fn supervisor_rejects_same_port_as_supervisor() {
let mut config = make_test_config();
config.child_port = config.port;
assert!(
LlamaServerSupervisor::new(config).is_err(),
"child_port == port doit être rejeté"
);
}
#[test]
fn supervisor_accepts_valid_port() {
let result = LlamaServerSupervisor::new(make_test_config());
assert!(
result.is_ok(),
"config valide doit être acceptée : {:?}",
result.err()
);
}
#[test]
fn restart_budget_exhausted_by_flapping() {
let config = make_test_config();
let max = config.child_restart_max;
let min_stable = config.min_stable_uptime_secs;
let mut budget = max;
let mut backoff = BACKOFF_INIT_MS;
for i in 0..max {
let uptime_stable = false; assert!(!uptime_stable);
assert_eq!(budget, max - i, "budget à l'itération {i}");
assert!(budget > 0);
budget -= 1;
backoff = (backoff * 2).min(BACKOFF_MAX_MS);
}
assert_eq!(budget, 0, "budget épuisé après {max} crashs flapping");
assert!(backoff > BACKOFF_INIT_MS, "backoff escaladé : {backoff}ms");
assert_eq!(min_stable, 30, "défaut min_stable_uptime_secs = 30s");
}
#[test]
fn restart_budget_resets_on_stable_uptime() {
let config = make_test_config();
let max = config.child_restart_max;
let min_stable_secs = config.min_stable_uptime_secs;
let mut budget = max - 1;
let mut backoff = BACKOFF_MAX_MS;
let elapsed = Duration::from_secs(min_stable_secs + 5);
let uptime_stable = elapsed >= Duration::from_secs(min_stable_secs);
assert!(uptime_stable);
if uptime_stable {
budget = max;
backoff = BACKOFF_INIT_MS;
}
assert_eq!(budget, max, "budget remis au max");
assert_eq!(backoff, BACKOFF_INIT_MS, "backoff remis à l'init");
}
#[test]
fn child_restart_max_zero_means_no_restart() {
let mut config = make_test_config();
config.child_restart_max = 0;
assert!(
LlamaServerSupervisor::new(config).is_ok(),
"child_restart_max=0 doit être accepté à la construction"
);
}
#[test]
fn initial_ready_at_seed_prevents_false_flapping() {
let config = make_test_config();
let max = config.child_restart_max;
let min_stable_secs = config.min_stable_uptime_secs;
let mut budget = max;
let mut backoff = BACKOFF_MAX_MS;
let simulated_elapsed = Duration::from_secs(min_stable_secs + 5);
let uptime_stable = simulated_elapsed >= Duration::from_secs(min_stable_secs);
assert!(
uptime_stable,
"premier crash après {simulated_elapsed:?} doit être classé stable"
);
if uptime_stable {
budget = max;
backoff = BACKOFF_INIT_MS;
}
assert_eq!(
budget, max,
"budget doit être remis au max après 1er crash stable (pas de faux flapping)"
);
assert_eq!(
backoff, BACKOFF_INIT_MS,
"backoff remis à l'init après uptime stable"
);
}
#[test]
fn without_seed_first_crash_is_flapping() {
let last_ready_at: Option<Instant> = None; let min_stable_secs = 30_u64;
let uptime_stable = last_ready_at
.map(|t| t.elapsed() >= Duration::from_secs(min_stable_secs))
.unwrap_or(false);
assert!(
!uptime_stable,
"sans seed, premier crash est classé flapping (budget décrémenté sans reset)"
);
}
#[test]
fn connection_refused_detection() {
let rt = tokio::runtime::Runtime::new().unwrap();
rt.block_on(async {
let client = reqwest::Client::builder()
.timeout(Duration::from_millis(200))
.build()
.unwrap();
let result = client.get("http://127.0.0.1:1/health").send().await;
if let Err(e) = result {
let _ = is_connection_refused(&e);
}
});
}
#[test]
fn build_child_args_base_chat() {
let cfg = EngineConfig::from_toml(
"[engine]\nmodel_path=\"/opt/gradatum/models/m.gguf\"\nmodel_kind=\"chat\"\nport=8080\nchild_port=8090\n",
)
.unwrap();
let args = build_child_args(&cfg);
assert!(args
.windows(2)
.any(|w| w[0] == "--model" && w[1] == "/opt/gradatum/models/m.gguf"));
assert!(args.windows(2).any(|w| w[0] == "--port" && w[1] == "8090"));
assert!(args
.windows(2)
.any(|w| w[0] == "--host" && w[1] == "127.0.0.1"));
assert!(!args.iter().any(|a| a == "--embedding"));
assert!(!args.iter().any(|a| a == "--mmproj"));
}
#[test]
fn build_child_args_embed_adds_embedding_flag() {
let cfg = EngineConfig::from_toml(
"[engine]\nmodel_path=\"/opt/gradatum/models/e.gguf\"\nmodel_kind=\"embed\"\nport=8080\nchild_port=8090\n",
)
.unwrap();
let args = build_child_args(&cfg);
assert!(
args.iter().any(|a| a == "--embedding"),
"embed → --embedding présent"
);
}
#[test]
fn build_child_args_injects_mmproj_when_set() {
let cfg = EngineConfig::from_toml(
"[engine]\nmodel_path=\"/opt/gradatum/models/v.gguf\"\nmodel_kind=\"chat\"\nport=8080\nchild_port=8090\nmmproj_path=\"/opt/gradatum/models/mmproj-F16.gguf\"\n",
)
.unwrap();
let args = build_child_args(&cfg);
assert!(
args.windows(2)
.any(|w| w[0] == "--mmproj" && w[1] == "/opt/gradatum/models/mmproj-F16.gguf"),
"mmproj_path Some → --mmproj <path> injecté"
);
}
#[test]
fn extra_args_accepts_extended_flags() {
let ok = [
vec!["--swa-full".to_string()],
vec!["--cache-reuse".to_string(), "256".to_string()],
vec!["--reasoning-format".to_string(), "deepseek".to_string()],
vec!["--reasoning-budget".to_string(), "4000".to_string()],
vec!["--temp".to_string(), "0.7".to_string()],
vec!["--top-k".to_string(), "20".to_string()],
vec!["--top-p".to_string(), "0.8".to_string()],
vec!["--min-p".to_string(), "0.0".to_string()],
vec!["--presence-penalty".to_string(), "1.5".to_string()],
vec!["--repeat-penalty".to_string(), "1.1".to_string()],
vec!["--n-predict".to_string(), "512".to_string()],
];
for args in ok {
assert!(
validate_extra_args(&args).is_ok(),
"flag étendu doit être accepté : {args:?}"
);
}
}
#[test]
fn extra_args_still_rejects_mmproj() {
let result = validate_extra_args(&["--mmproj".into(), "/etc/passwd".into()]);
assert!(
result.is_err(),
"--mmproj doit rester rejeté (champ config dédié)"
);
}
#[test]
fn extra_args_security_frontier_unchanged() {
for flag in [
"--host",
"--api-key-file",
"--model-url",
"--rpc",
"--ssl-key-file",
"--path",
] {
assert!(
validate_extra_args(&[flag.to_string(), "x".to_string()]).is_err(),
"{flag} doit rester rejeté (frontière sécu inchangée)"
);
}
}
fn make_test_config() -> EngineConfig {
EngineConfig {
model_path: "/opt/gradatum/models/test.gguf".into(),
model_kind: ModelKind::Chat,
runtime: RuntimeKind::LlamaServer,
warm_up: "eager".into(),
gpu_layers: 0,
n_threads: 4,
context_len: 4096,
port: 11435,
bind_addr: None, metrics_port: None,
timeout_secs: 30,
max_tokens: 512,
gradatum_url: "http://127.0.0.1:19090".into(),
llama_server_bin: PathBuf::from("/usr/local/bin/llama-server"),
child_port: 11436,
parallel: 2,
extra_args: vec![],
mmproj_path: None,
startup_timeout_secs: 60,
child_restart_max: 3,
min_stable_uptime_secs: 30,
body_limit_bytes: 32 * 1024 * 1024,
}
}
}