use anyhow::{Context, Result};
use fs2::FileExt;
use std::fs::{File, OpenOptions};
use std::io::Write;
#[cfg(unix)]
use std::os::unix::fs::{OpenOptionsExt, PermissionsExt};
use std::path::{Path, PathBuf};
use std::sync::Arc;
use std::time::{Duration, Instant};
use thiserror::Error;
use tokio::net::{UnixListener, UnixStream};
use tokio::sync::{broadcast, mpsc};
use tokio::time::interval;
use tokio_util::codec::Framed;
use tracing::{debug, error, info, warn};
use uuid::Uuid;
use crate::config::SqliteConfig;
use crate::daemon::protocol::{JsonRpcCodec, JsonRpcMessage};
use crate::daemon::session::SessionRegistry;
use crate::daemon::types::{JsonRpcRequest, JsonRpcResponse};
use crate::db::{get_database_url, SqliteStore};
use crate::embedding::EmbeddingService;
#[derive(Debug, Error)]
pub enum DaemonError {
#[error("Daemon already running (PID file locked): {0}")]
AlreadyRunning(PathBuf),
#[error("PID file error: {0}")]
PidFileError(#[from] std::io::Error),
#[error("Database error: {0}")]
DatabaseError(#[from] anyhow::Error),
#[error("Socket error: {0}")]
SocketError(String),
}
#[derive(Debug, Clone)]
pub struct ServerConfig {
pub socket_path: PathBuf,
pub pid_path: PathBuf,
pub database_path: String,
pub sqlite_config: SqliteConfig,
pub idle_timeout: Duration,
}
impl ServerConfig {
pub fn default_for_user() -> Result<Self> {
let uid = users::get_current_uid();
let database_path = get_database_url()?;
Ok(Self {
socket_path: PathBuf::from(format!("/tmp/maproom-{}.sock", uid)),
pid_path: PathBuf::from(format!("/tmp/maproom-{}.pid", uid)),
database_path,
sqlite_config: SqliteConfig::from_env().unwrap_or_default(),
idle_timeout: Duration::from_secs(300), })
}
}
pub struct DaemonState {
pub store: SqliteStore,
pub embedding_service: EmbeddingService,
pub sessions: Arc<SessionRegistry>,
}
impl DaemonState {
pub async fn new(config: &ServerConfig) -> Result<Self, DaemonError> {
let store = SqliteStore::connect(&config.database_path)
.await
.context("Failed to connect to SQLite database")?;
let embedding_service = EmbeddingService::from_env()
.await
.context("Failed to initialize embedding service")?;
Ok(Self {
store,
embedding_service,
sessions: Arc::new(SessionRegistry::new()),
})
}
}
pub struct PidFileGuard {
path: PathBuf,
_file: File,
}
impl PidFileGuard {
pub fn create(path: &Path) -> Result<Self, DaemonError> {
#[cfg(unix)]
let mut file = OpenOptions::new()
.write(true)
.create_new(true)
.mode(0o600) .open(path)
.map_err(|e| {
if e.kind() == std::io::ErrorKind::AlreadyExists {
DaemonError::AlreadyRunning(path.to_path_buf())
} else {
DaemonError::PidFileError(e)
}
})?;
#[cfg(not(unix))]
let mut file = OpenOptions::new()
.write(true)
.create_new(true)
.open(path)
.map_err(|e| {
if e.kind() == std::io::ErrorKind::AlreadyExists {
DaemonError::AlreadyRunning(path.to_path_buf())
} else {
DaemonError::PidFileError(e)
}
})?;
file.try_lock_exclusive()
.map_err(|_| DaemonError::AlreadyRunning(path.to_path_buf()))?;
let pid = std::process::id();
writeln!(file, "{}", pid)?;
file.flush()?;
info!(pid, path = %path.display(), "PID file created");
Ok(Self {
path: path.to_path_buf(),
_file: file, })
}
}
impl Drop for PidFileGuard {
fn drop(&mut self) {
if let Err(e) = std::fs::remove_file(&self.path) {
warn!(
path = %self.path.display(),
error = %e,
"Failed to remove PID file"
);
} else {
info!(path = %self.path.display(), "PID file removed");
}
}
}
pub struct SocketServer {
config: ServerConfig,
state: Arc<DaemonState>,
shutdown_tx: broadcast::Sender<()>,
}
impl SocketServer {
pub async fn new(config: ServerConfig) -> Result<Self, DaemonError> {
let (shutdown_tx, _) = broadcast::channel(1);
let state = Arc::new(DaemonState::new(&config).await?);
Ok(Self {
config,
state,
shutdown_tx,
})
}
pub async fn run(&self) -> Result<(), DaemonError> {
let _pid_guard = PidFileGuard::create(&self.config.pid_path)?;
if self.config.socket_path.exists() {
std::fs::remove_file(&self.config.socket_path)?;
}
let listener = UnixListener::bind(&self.config.socket_path)
.map_err(|e| DaemonError::SocketError(format!("Failed to bind socket: {}", e)))?;
#[cfg(unix)]
{
let metadata = std::fs::metadata(&self.config.socket_path)?;
let mut permissions = metadata.permissions();
permissions.set_mode(0o600);
std::fs::set_permissions(&self.config.socket_path, permissions)?;
}
info!(
socket_path = %self.config.socket_path.display(),
idle_timeout_secs = %self.config.idle_timeout.as_secs(),
"Socket server listening"
);
let mut shutdown_rx = self.shutdown_tx.subscribe();
let mut idle_check = interval(Duration::from_secs(60)); let mut idle_since: Option<Instant> = Some(Instant::now());
loop {
tokio::select! {
Ok((stream, _addr)) = listener.accept() => {
idle_since = None; let state = self.state.clone();
tokio::spawn(async move {
if let Err(e) = handle_client(stream, state).await {
error!(error = %e, "Client handler error");
}
});
}
_ = idle_check.tick() => {
let active_count = self.state.sessions.active_count();
if active_count == 0 {
if idle_since.is_none() {
idle_since = Some(Instant::now());
debug!("No active clients, idle timer started");
} else if let Some(since) = idle_since {
let idle_duration = since.elapsed();
if idle_duration >= self.config.idle_timeout {
info!(
idle_secs = idle_duration.as_secs(),
"Idle timeout reached, shutting down"
);
break;
}
}
} else {
if idle_since.is_some() {
debug!(active_count, "Clients connected, idle timer reset");
}
idle_since = None;
}
}
_ = shutdown_rx.recv() => {
info!("Shutdown signal received");
break;
}
}
}
self.graceful_shutdown().await?;
if let Err(e) = std::fs::remove_file(&self.config.socket_path) {
warn!(error = %e, "Failed to remove socket file");
}
Ok(())
}
async fn graceful_shutdown(&self) -> Result<(), DaemonError> {
info!("Starting graceful shutdown");
let shutdown_timeout = Duration::from_secs(30);
let start = Instant::now();
loop {
let active_count = self.state.sessions.active_count();
if active_count == 0 {
info!("All sessions completed");
break;
}
if start.elapsed() >= shutdown_timeout {
warn!(
active_count,
"Shutdown timeout reached, {} sessions still active", active_count
);
break;
}
debug!(active_count, "Waiting for sessions to complete");
tokio::time::sleep(Duration::from_millis(500)).await;
}
Ok(())
}
pub fn shutdown(&self) {
let _ = self.shutdown_tx.send(());
}
}
#[cfg(unix)]
pub async fn run_with_signal_handling(server: SocketServer) -> Result<(), DaemonError> {
use tokio::signal::unix::{signal, SignalKind};
let mut sigterm = signal(SignalKind::terminate())
.map_err(|e| DaemonError::SocketError(format!("Failed to setup SIGTERM handler: {}", e)))?;
let mut sigint = signal(SignalKind::interrupt())
.map_err(|e| DaemonError::SocketError(format!("Failed to setup SIGINT handler: {}", e)))?;
let server = Arc::new(server);
let result = tokio::select! {
_ = sigterm.recv() => {
info!("SIGTERM received, initiating graceful shutdown");
server.shutdown();
Ok(())
}
_ = sigint.recv() => {
info!("SIGINT received, initiating graceful shutdown");
server.shutdown();
Ok(())
}
result = server.run() => {
result
}
};
result
}
#[cfg(not(unix))]
pub async fn run_with_signal_handling(server: SocketServer) -> Result<(), DaemonError> {
use tokio::signal;
let server = Arc::new(server);
let result = tokio::select! {
_ = signal::ctrl_c() => {
info!("Ctrl+C received, initiating graceful shutdown");
server.shutdown();
Ok(())
}
result = server.run() => {
result
}
};
result
}
async fn handle_client(stream: UnixStream, state: Arc<DaemonState>) -> Result<()> {
let mut framed = Framed::new(stream, JsonRpcCodec::new());
let (response_tx, mut response_rx) = mpsc::unbounded_channel();
let session_id = state.sessions.register(response_tx);
let _session_guard = SessionGuard {
registry: state.sessions.clone(),
session_id,
};
use futures::stream::StreamExt;
use futures::SinkExt;
loop {
tokio::select! {
message = framed.next() => {
match message {
Some(Ok(JsonRpcMessage::Request(req))) => {
let state_clone = state.clone();
let sid = session_id;
tokio::spawn(async move {
let response = handle_request(req, &state_clone).await;
if let Err(e) = state_clone.sessions.send_to_session(&sid, response) {
warn!(error = %e, "Failed to send response");
}
});
}
Some(Ok(JsonRpcMessage::Response(_))) => {
warn!("Unexpected response from client (should be request)");
}
Some(Err(e)) => {
error!(error = %e, "Failed to decode message");
break;
}
None => {
break;
}
}
}
response = response_rx.recv() => {
match response {
Some(msg) => {
if let Err(e) = framed.send(msg).await {
error!(error = %e, "Failed to send response to client");
break;
}
}
None => {
break;
}
}
}
}
}
Ok(())
}
struct SessionGuard {
registry: Arc<SessionRegistry>,
session_id: Uuid,
}
impl Drop for SessionGuard {
fn drop(&mut self) {
self.registry.unregister(&self.session_id);
}
}
async fn handle_request(req: JsonRpcRequest, _state: &DaemonState) -> JsonRpcMessage {
JsonRpcMessage::Response(JsonRpcResponse {
jsonrpc: "2.0".into(),
result: Some(serde_json::json!({
"method": req.method,
"received": true
})),
error: None,
id: req.id.unwrap_or(serde_json::Value::Null),
})
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use tempfile::TempDir;
fn temp_pid_path() -> (TempDir, PathBuf) {
let dir = TempDir::new().unwrap();
let path = dir.path().join("test.pid");
(dir, path)
}
#[test]
fn test_pid_file_creation() {
let (_dir, pid_path) = temp_pid_path();
let guard = PidFileGuard::create(&pid_path).unwrap();
assert!(pid_path.exists());
let content = std::fs::read_to_string(&pid_path).unwrap();
let pid: u32 = content.trim().parse().unwrap();
assert_eq!(pid, std::process::id());
drop(guard);
assert!(!pid_path.exists());
}
#[test]
fn test_pid_file_prevents_second_daemon() {
let (_dir, pid_path) = temp_pid_path();
let _guard1 = PidFileGuard::create(&pid_path).unwrap();
let result = PidFileGuard::create(&pid_path);
assert!(matches!(result, Err(DaemonError::AlreadyRunning(_))));
}
#[test]
fn test_pid_file_permissions() {
let (_dir, pid_path) = temp_pid_path();
let _guard = PidFileGuard::create(&pid_path).unwrap();
let metadata = std::fs::metadata(&pid_path).unwrap();
let mode = metadata.permissions().mode();
assert_eq!(mode & 0o777, 0o600);
}
#[tokio::test]
#[ignore] async fn test_multiple_clients_concurrent() {
use tokio::net::UnixStream;
let temp_dir = TempDir::new().unwrap();
let socket_path = temp_dir.path().join("test.sock");
let pid_path = temp_dir.path().join("test.pid");
let db_path = temp_dir.path().join("test.db");
let config = ServerConfig {
socket_path: socket_path.clone(),
pid_path,
database_path: format!("sqlite://{}", db_path.display()),
sqlite_config: SqliteConfig::default(),
idle_timeout: Duration::from_secs(300),
};
let server = SocketServer::new(config).await.unwrap();
let server_handle = {
let server = Arc::new(server);
let server_clone = server.clone();
tokio::spawn(async move { server_clone.run().await })
};
tokio::time::sleep(Duration::from_millis(100)).await;
let mut client_handles = vec![];
for i in 0..5 {
let socket_path = socket_path.clone();
let handle = tokio::spawn(async move {
let stream = UnixStream::connect(&socket_path).await.unwrap();
let mut framed = Framed::new(stream, JsonRpcCodec::new());
let request = JsonRpcMessage::Request(JsonRpcRequest {
jsonrpc: "2.0".into(),
method: format!("test_{}", i),
params: None,
id: Some(serde_json::json!(i)),
});
use futures::SinkExt;
framed.send(request).await.unwrap();
use futures::StreamExt;
let response = framed.next().await.unwrap().unwrap();
match response {
JsonRpcMessage::Response(resp) => {
assert_eq!(resp.id, serde_json::json!(i));
assert!(resp.result.is_some());
}
_ => panic!("Expected response"),
}
});
client_handles.push(handle);
}
for handle in client_handles {
handle.await.unwrap();
}
server_handle.abort();
}
#[tokio::test]
async fn test_server_config_default_for_user() {
let config = ServerConfig::default_for_user().unwrap();
let uid = users::get_current_uid();
assert_eq!(
config.socket_path,
PathBuf::from(format!("/tmp/maproom-{}.sock", uid))
);
assert_eq!(
config.pid_path,
PathBuf::from(format!("/tmp/maproom-{}.pid", uid))
);
assert_eq!(config.idle_timeout, Duration::from_secs(300));
}
#[tokio::test]
#[ignore] async fn test_session_cleanup_on_disconnect() {
use tokio::net::UnixStream;
let temp_dir = TempDir::new().unwrap();
let socket_path = temp_dir.path().join("test.sock");
let pid_path = temp_dir.path().join("test.pid");
let db_path = temp_dir.path().join("test.db");
let config = ServerConfig {
socket_path: socket_path.clone(),
pid_path,
database_path: format!("sqlite://{}", db_path.display()),
sqlite_config: SqliteConfig::default(),
idle_timeout: Duration::from_secs(300),
};
let server = Arc::new(SocketServer::new(config).await.unwrap());
let server_clone = server.clone();
let _server_handle = tokio::spawn(async move { server_clone.run().await });
tokio::time::sleep(Duration::from_millis(100)).await;
let stream = UnixStream::connect(&socket_path).await.unwrap();
tokio::time::sleep(Duration::from_millis(50)).await;
assert_eq!(server.state.sessions.active_count(), 1);
drop(stream);
tokio::time::sleep(Duration::from_millis(200)).await;
assert_eq!(server.state.sessions.active_count(), 0);
}
#[tokio::test]
#[ignore] async fn test_idle_timeout_triggers() {
let temp_dir = TempDir::new().unwrap();
let socket_path = temp_dir.path().join("test.sock");
let pid_path = temp_dir.path().join("test.pid");
let db_path = temp_dir.path().join("test.db");
let config = ServerConfig {
socket_path,
pid_path,
database_path: format!("sqlite://{}", db_path.display()),
sqlite_config: SqliteConfig::default(),
idle_timeout: Duration::from_millis(100), };
let server = SocketServer::new(config).await.unwrap();
let start_time = std::time::Instant::now();
let handle = tokio::spawn(async move { server.run().await });
tokio::time::sleep(Duration::from_millis(200)).await;
assert!(
!handle.is_finished(),
"Server should still be running (idle checks are every 60s)"
);
handle.abort();
let elapsed = start_time.elapsed();
assert!(
elapsed < Duration::from_secs(1),
"Test should complete quickly, got {:?}",
elapsed
);
}
#[tokio::test]
#[ignore] async fn test_active_client_prevents_idle_timeout() {
use tokio::net::UnixStream;
let temp_dir = TempDir::new().unwrap();
let socket_path = temp_dir.path().join("test.sock");
let pid_path = temp_dir.path().join("test.pid");
let db_path = temp_dir.path().join("test.db");
let config = ServerConfig {
socket_path: socket_path.clone(),
pid_path,
database_path: format!("sqlite://{}", db_path.display()),
sqlite_config: SqliteConfig::default(),
idle_timeout: Duration::from_secs(2), };
let server = Arc::new(SocketServer::new(config).await.unwrap());
let server_clone = server.clone();
let handle = tokio::spawn(async move { server_clone.run().await });
tokio::time::sleep(Duration::from_millis(100)).await;
let _stream = UnixStream::connect(&socket_path).await.unwrap();
tokio::time::sleep(Duration::from_millis(50)).await;
assert_eq!(server.state.sessions.active_count(), 1);
tokio::time::sleep(Duration::from_secs(3)).await;
assert!(
!handle.is_finished(),
"Server should still be running with active client"
);
server.shutdown();
let _ = tokio::time::timeout(Duration::from_secs(1), handle).await;
}
#[tokio::test]
#[ignore] async fn test_graceful_shutdown_waits_for_sessions() {
use tokio::net::UnixStream;
let temp_dir = TempDir::new().unwrap();
let socket_path = temp_dir.path().join("test.sock");
let pid_path = temp_dir.path().join("test.pid");
let db_path = temp_dir.path().join("test.db");
let config = ServerConfig {
socket_path: socket_path.clone(),
pid_path,
database_path: format!("sqlite://{}", db_path.display()),
sqlite_config: SqliteConfig::default(),
idle_timeout: Duration::from_secs(300),
};
let server = Arc::new(SocketServer::new(config).await.unwrap());
let server_clone = server.clone();
let handle = tokio::spawn(async move { server_clone.run().await });
tokio::time::sleep(Duration::from_millis(100)).await;
let stream = UnixStream::connect(&socket_path).await.unwrap();
tokio::time::sleep(Duration::from_millis(50)).await;
assert_eq!(server.state.sessions.active_count(), 1);
let shutdown_start = std::time::Instant::now();
server.shutdown();
tokio::spawn(async move {
tokio::time::sleep(Duration::from_millis(200)).await;
drop(stream);
});
let result = tokio::time::timeout(Duration::from_secs(5), handle).await;
assert!(result.is_ok(), "Server should complete graceful shutdown");
let shutdown_duration = shutdown_start.elapsed();
assert!(
shutdown_duration >= Duration::from_millis(200)
&& shutdown_duration < Duration::from_secs(2),
"Graceful shutdown should wait for client disconnect, got {:?}",
shutdown_duration
);
}
}