use std::sync::atomic::{AtomicBool, AtomicU32, AtomicU64, Ordering};
use std::sync::Arc;
use std::time::Instant;
use dashmap::DashMap;
mod builder;
mod config;
mod dispatch;
mod generation;
mod handlers;
mod prompt;
pub use builder::DaemonBuilder;
pub use config::{DaemonConfig, EvictionPolicy, HttpConfig, ModelDefaultsConfig, ResourceConfig};
pub(crate) use prompt::resolve_chat_stop_sequences;
use super::models::ModelManager;
use super::protocol::*;
use super::store::{DaemonStore, StorageBackend};
use crate::memory_monitor::{MemoryMonitor, MemoryPressure, RecoveryManager};
pub struct Daemon {
pub config: DaemonConfig,
pub models: Arc<ModelManager>,
pub start_time: Instant,
pub shutdown: Arc<AtomicBool>,
pub active_requests: Arc<AtomicU32>,
pub total_requests: AtomicU64,
pub cancellations: Arc<DashMap<String, Arc<AtomicBool>>>,
pub memory_monitor: Option<Arc<MemoryMonitor>>,
pub recovery_manager: RecoveryManager,
pub store: Arc<dyn StorageBackend>,
pub providers: Vec<Box<dyn super::provider::ModelProvider>>,
}
impl Daemon {
pub fn new(config: DaemonConfig) -> Self {
let memory_monitor = if config.resources.enable_memory_monitoring {
let monitor = MemoryMonitor::new(config.resources.memory_config.clone());
monitor.start();
Some(monitor)
} else {
None
};
let recovery_manager = if let Some(ref monitor) = memory_monitor {
RecoveryManager::new().with_monitor(Arc::clone(monitor))
} else {
RecoveryManager::new()
};
let store: Arc<dyn StorageBackend> = match DaemonStore::open_default() {
Ok(s) => Arc::new(s),
Err(e) => {
eprintln!(
"Warning: Failed to open persistent store: {}. Using in-memory fallback.",
e
);
let tmp = tempfile::tempdir().expect("Failed to create temp dir");
Arc::new(
DaemonStore::open(&tmp.path().join("mullama.db"))
.expect("Failed to create temp store"),
)
}
};
let mut providers: Vec<Box<dyn super::provider::ModelProvider>> = Vec::new();
if let Ok(ollama) = super::ollama::OllamaClient::new() {
providers.push(Box::new(ollama));
}
if let Ok(hf) = crate::hf::HfDownloader::new() {
providers.push(Box::new(hf));
}
Self {
config,
models: Arc::new(ModelManager::new()),
start_time: Instant::now(),
shutdown: Arc::new(AtomicBool::new(false)),
active_requests: Arc::new(AtomicU32::new(0)),
total_requests: AtomicU64::new(0),
cancellations: Arc::new(DashMap::new()),
memory_monitor,
recovery_manager,
store,
providers,
}
}
#[allow(clippy::result_large_err)]
fn validate_max_tokens(&self, max_tokens: u32) -> Result<(), Response> {
if max_tokens == 0 {
return Err(Response::error(
ErrorCode::InvalidRequest,
"max_tokens must be greater than 0",
));
}
if max_tokens > self.config.resources.max_tokens_per_request {
return Err(Response::error(
ErrorCode::InvalidRequest,
format!(
"max_tokens {} exceeds server limit {}",
max_tokens, self.config.resources.max_tokens_per_request
),
));
}
Ok(())
}
fn register_cancellation(&self, request_id: &str) -> Arc<AtomicBool> {
let flag = Arc::new(AtomicBool::new(false));
self.cancellations
.insert(request_id.to_string(), Arc::clone(&flag));
flag
}
pub fn cancel_request(&self, request_id: &str) -> bool {
if let Some(flag) = self.cancellations.get(request_id) {
flag.store(true, Ordering::SeqCst);
true
} else {
false
}
}
pub async fn resolve_model_spec(
&self,
spec: &str,
) -> Result<super::provider::ResolvedModelPath, crate::MullamaError> {
for provider in &self.providers {
if provider.supports(spec) {
match provider.resolve(spec).await {
Ok(resolved) => return Ok(resolved),
Err(e) => {
tracing::warn!(
"Provider '{}' failed for '{}': {}",
provider.name(),
spec,
e
);
continue;
}
}
}
}
Err(crate::MullamaError::OperationFailed(format!(
"No provider can resolve: {}",
spec
)))
}
pub fn memory_pressure(&self) -> MemoryPressure {
self.memory_monitor
.as_ref()
.map(|m| m.pressure())
.unwrap_or(MemoryPressure::Normal)
}
pub fn memory_stats(&self) -> Option<crate::memory_monitor::MemoryStats> {
self.memory_monitor.as_ref().map(|m| m.stats())
}
pub fn needs_memory_recovery(&self) -> bool {
self.recovery_manager.needs_recovery()
}
#[allow(dead_code)]
fn log_memory_pressure(&self) {
if let Some(monitor) = &self.memory_monitor {
let pressure = monitor.pressure();
let stats = monitor.stats();
match pressure {
MemoryPressure::Warning => {
tracing::warn!(
gpu_usage = stats.gpu_usage() * 100.0,
system_usage = stats.system_usage() * 100.0,
"Memory pressure elevated"
);
}
MemoryPressure::Critical => {
tracing::error!(
gpu_usage = stats.gpu_usage() * 100.0,
system_usage = stats.system_usage() * 100.0,
"Memory pressure CRITICAL"
);
}
MemoryPressure::Emergency => {
tracing::error!(
gpu_usage = stats.gpu_usage() * 100.0,
system_usage = stats.system_usage() * 100.0,
"Memory EMERGENCY - recovery needed"
);
}
MemoryPressure::Normal => {}
}
}
}
pub fn is_shutdown(&self) -> bool {
self.shutdown.load(Ordering::SeqCst)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn test_daemon() -> Daemon {
let config = DaemonConfig {
resources: ResourceConfig {
enable_memory_monitoring: false,
..ResourceConfig::default()
},
..DaemonConfig::default()
};
Daemon::new(config)
}
#[test]
fn merge_stop_sequences_deduplicates_and_filters_empty() {
let merged = super::prompt::merge_stop_sequences(
vec!["</s>".to_string(), "".to_string()],
vec!["<|eot_id|>".to_string(), "</s>".to_string()],
);
assert_eq!(merged, vec!["</s>", "<|eot_id|>"]);
}
#[test]
fn find_stop_in_recent_window_detects_cross_token_boundary() {
let generated = "hello<|eot_id|>";
let previous_len = "hello<|eo".len();
let stop_sequences = vec!["<|eot_id|>".to_string()];
let pos =
super::prompt::find_stop_in_recent_window(generated, previous_len, &stop_sequences, 10);
assert_eq!(pos, Some("hello".len()));
}
#[test]
fn apply_default_system_prompt_only_when_missing() {
let daemon = test_daemon();
let messages = vec![ChatMessage {
role: "user".to_string(),
content: "hello".to_string().into(),
name: None,
tool_calls: None,
tool_call_id: None,
}];
let with_system =
daemon.apply_default_system_prompt(messages.clone(), Some("You are helpful."));
assert_eq!(with_system.len(), 2);
assert_eq!(with_system[0].role, "system");
let with_existing = vec![
ChatMessage {
role: "system".to_string(),
content: "existing".to_string().into(),
name: None,
tool_calls: None,
tool_call_id: None,
},
messages[0].clone(),
];
let unchanged = daemon.apply_default_system_prompt(with_existing.clone(), Some("ignored"));
assert_eq!(unchanged.len(), with_existing.len());
assert_eq!(unchanged[0].content.text(), "existing");
}
}