use std::io::{Read, Write};
use std::os::unix::net::UnixStream;
use std::path::PathBuf;
use std::process::{Child, Command, Stdio};
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::time::{Duration, Instant};
use fs2::FileExt;
use parking_lot::Mutex;
use tracing::{debug, info, warn};
use super::daemon_spawn_guard_lock_path;
use super::protocol::{
EmbeddingJobInfo, ErrorCode, FramedMessage, HealthStatus, PROTOCOL_VERSION, Request, Response,
decode_message, default_socket_path, encode_message,
};
use super::worker::EmbeddingJobConfig;
use crate::search::daemon_client::{DaemonClient, DaemonError};
fn connection_not_established() -> DaemonError {
DaemonError::Unavailable("connection not established".to_string())
}
fn unexpected_response(response: Response) -> DaemonError {
DaemonError::Failed(format!("unexpected response: {response:?}"))
}
#[derive(Debug, Clone)]
pub struct DaemonClientConfig {
pub socket_path: PathBuf,
pub connect_timeout: Duration,
pub request_timeout: Duration,
pub auto_spawn: bool,
pub daemon_binary: Option<PathBuf>,
}
impl Default for DaemonClientConfig {
fn default() -> Self {
Self {
socket_path: default_socket_path(),
connect_timeout: Duration::from_secs(2),
request_timeout: Duration::from_secs(30),
auto_spawn: true,
daemon_binary: None, }
}
}
impl DaemonClientConfig {
pub fn from_env() -> Self {
let mut cfg = Self::default();
if let Ok(path) = dotenvy::var("CASS_DAEMON_SOCKET") {
cfg.socket_path = PathBuf::from(path);
}
if let Ok(val) = dotenvy::var("CASS_DAEMON_CONNECT_TIMEOUT_MS")
&& let Ok(ms) = val.parse::<u64>()
{
cfg.connect_timeout = Duration::from_millis(ms);
}
if let Ok(val) = dotenvy::var("CASS_DAEMON_REQUEST_TIMEOUT_MS")
&& let Ok(ms) = val.parse::<u64>()
{
cfg.request_timeout = Duration::from_millis(ms);
}
if let Ok(val) = dotenvy::var("CASS_DAEMON_AUTO_SPAWN") {
cfg.auto_spawn = val.eq_ignore_ascii_case("true") || val == "1";
}
if let Ok(path) = dotenvy::var("CASS_DAEMON_BINARY") {
cfg.daemon_binary = Some(PathBuf::from(path));
}
cfg
}
}
pub struct UdsDaemonClient {
config: DaemonClientConfig,
connection: Mutex<Option<UnixStream>>,
available: AtomicBool,
request_counter: AtomicU64,
last_health_check: Mutex<Option<Instant>>,
}
impl UdsDaemonClient {
pub fn new(config: DaemonClientConfig) -> Self {
Self {
config,
connection: Mutex::new(None),
available: AtomicBool::new(false),
request_counter: AtomicU64::new(0),
last_health_check: Mutex::new(None),
}
}
pub fn with_defaults() -> Self {
Self::new(DaemonClientConfig::from_env())
}
pub fn connect(&self) -> Result<(), DaemonError> {
if let Ok(stream) = self.try_connect() {
*self.connection.lock() = Some(stream);
self.available.store(true, Ordering::SeqCst);
debug!(socket = %self.config.socket_path.display(), "Connected to existing daemon");
return Ok(());
}
if self.config.auto_spawn {
info!("Daemon not running, attempting to spawn");
self.spawn_daemon()?;
for attempt in 0..10 {
std::thread::sleep(Duration::from_millis(100 * (attempt + 1)));
if let Ok(stream) = self.try_connect() {
*self.connection.lock() = Some(stream);
self.available.store(true, Ordering::SeqCst);
info!(
socket = %self.config.socket_path.display(),
attempts = attempt + 1,
"Connected to newly spawned daemon"
);
return Ok(());
}
}
return Err(DaemonError::Unavailable(
"daemon failed to start within timeout".to_string(),
));
}
Err(DaemonError::Unavailable(format!(
"daemon not running at {}",
self.config.socket_path.display()
)))
}
fn try_connect(&self) -> std::io::Result<UnixStream> {
let stream = UnixStream::connect(&self.config.socket_path)?;
stream.set_read_timeout(Some(self.config.request_timeout))?;
stream.set_write_timeout(Some(self.config.request_timeout))?;
Ok(stream)
}
fn spawn_daemon(&self) -> Result<(), DaemonError> {
let binary = self
.config
.daemon_binary
.clone()
.or_else(|| std::env::current_exe().ok())
.ok_or_else(|| {
DaemonError::Unavailable("cannot determine daemon binary path".to_string())
})?;
let lock_path = daemon_spawn_guard_lock_path(&self.config.socket_path);
let lock_file = match std::fs::OpenOptions::new()
.read(true)
.write(true)
.create_new(true)
.open(&lock_path)
{
Ok(file) => file,
Err(e) if e.kind() == std::io::ErrorKind::AlreadyExists => {
if std::fs::symlink_metadata(&lock_path)
.map(|m| m.file_type().is_symlink())
.unwrap_or(false)
{
return Err(DaemonError::Unavailable(
"refusing to open a symlink spawn lock".to_string(),
));
}
std::fs::OpenOptions::new()
.read(true)
.write(true)
.open(&lock_path)
.map_err(|e| {
DaemonError::Unavailable(format!("failed to open spawn lock: {}", e))
})?
}
Err(e) => {
return Err(DaemonError::Unavailable(format!(
"failed to create spawn lock: {}",
e
)));
}
};
lock_file.lock_exclusive().map_err(|e| {
DaemonError::Unavailable(format!("failed to acquire spawn lock: {}", e))
})?;
if UnixStream::connect(&self.config.socket_path).is_ok() {
debug!("Daemon already running, skipping spawn");
return Ok(());
}
remove_stale_daemon_socket(&self.config.socket_path)?;
let result = Command::new(&binary)
.arg("daemon")
.arg("--socket")
.arg(&self.config.socket_path)
.stdin(Stdio::null())
.stdout(Stdio::null())
.stderr(Stdio::null())
.spawn();
match result {
Ok(mut child) => {
info!(
pid = child.id(),
binary = %binary.display(),
socket = %self.config.socket_path.display(),
"Spawned daemon process"
);
self.wait_for_spawned_daemon_ready(&mut child)?;
std::thread::spawn(move || {
let _ = child.wait();
});
Ok(())
}
Err(e) => Err(DaemonError::Unavailable(format!(
"failed to spawn daemon: {}",
e
))),
}
}
fn wait_for_spawned_daemon_ready(&self, child: &mut Child) -> Result<(), DaemonError> {
let ready_timeout = self.config.connect_timeout.max(Duration::from_secs(5));
let started = Instant::now();
while started.elapsed() < ready_timeout {
if UnixStream::connect(&self.config.socket_path).is_ok() {
return Ok(());
}
match child.try_wait() {
Ok(Some(status)) => {
return Err(DaemonError::Unavailable(format!(
"spawned daemon exited before becoming ready: {}",
status
)));
}
Ok(None) => {}
Err(error) => {
warn!(
error = %error,
socket = %self.config.socket_path.display(),
"failed to poll spawned daemon status while waiting for readiness"
);
break;
}
}
std::thread::sleep(Duration::from_millis(50));
}
Ok(())
}
fn get_connection_locked(
&self,
) -> Result<parking_lot::MutexGuard<'_, Option<UnixStream>>, DaemonError> {
let conn = self.connection.lock();
let is_valid = conn.as_ref().is_some_and(|s| s.peer_addr().is_ok());
if is_valid {
return Ok(conn);
}
drop(conn);
self.available.store(false, Ordering::SeqCst);
self.connect()?;
let conn = self.connection.lock();
if conn.is_some() {
Ok(conn)
} else {
Err(connection_not_established())
}
}
fn send_request(&self, request: Request) -> Result<Response, DaemonError> {
let request_id = format!(
"cass-{}",
self.request_counter.fetch_add(1, Ordering::Relaxed)
);
let msg = FramedMessage::new(&request_id, request);
let encoded = encode_message(&msg)
.map_err(|e| DaemonError::Failed(format!("failed to encode request: {}", e)))?;
let mut stream_guard = self.get_connection_locked()?;
let stream = stream_guard
.as_mut()
.ok_or_else(connection_not_established)?;
if let Err(e) = stream.write_all(&encoded) {
*stream_guard = None;
self.available.store(false, Ordering::SeqCst);
return Err(DaemonError::Unavailable(format!(
"failed to send request: {}",
e
)));
}
let mut len_buf = [0u8; 4];
if let Err(e) = stream.read_exact(&mut len_buf) {
*stream_guard = None;
self.available.store(false, Ordering::SeqCst);
if e.kind() == std::io::ErrorKind::TimedOut {
return Err(DaemonError::Timeout("response timeout".to_string()));
} else {
return Err(DaemonError::Unavailable(format!(
"failed to read response length: {}",
e
)));
}
}
let len = u32::from_be_bytes(len_buf) as usize;
const MAX_RESPONSE_SIZE: usize = 10 * 1024 * 1024;
if len > MAX_RESPONSE_SIZE {
*stream_guard = None;
warn!(
response_size = len,
max_size = MAX_RESPONSE_SIZE,
"Rejecting oversized daemon response"
);
return Err(DaemonError::Failed(format!(
"response too large: {} bytes (max {})",
len, MAX_RESPONSE_SIZE
)));
}
let mut payload = vec![0u8; len];
if let Err(e) = stream.read_exact(&mut payload) {
*stream_guard = None;
self.available.store(false, Ordering::SeqCst);
if e.kind() == std::io::ErrorKind::TimedOut {
return Err(DaemonError::Timeout("response timeout".to_string()));
} else {
return Err(DaemonError::Unavailable(format!(
"failed to read response: {}",
e
)));
}
}
drop(stream_guard);
let response: FramedMessage<Response> = decode_message(&payload)
.map_err(|e| DaemonError::Failed(format!("failed to decode response: {}", e)))?;
if response.version != PROTOCOL_VERSION {
return Err(DaemonError::Failed(format!(
"protocol version mismatch: expected {}, got {}",
PROTOCOL_VERSION, response.version
)));
}
match response.payload {
Response::Error(err) => {
let daemon_err = match err.code {
ErrorCode::Overloaded => DaemonError::Overloaded {
retry_after: err.retry_after_ms.map(Duration::from_millis),
message: err.message,
},
ErrorCode::Timeout => DaemonError::Timeout(err.message),
ErrorCode::InvalidInput => DaemonError::InvalidInput(err.message),
_ => DaemonError::Failed(err.message),
};
Err(daemon_err)
}
other => Ok(other),
}
}
pub fn health(&self) -> Result<HealthStatus, DaemonError> {
match self.send_request(Request::Health)? {
Response::Health(status) => {
*self.last_health_check.lock() = Some(Instant::now());
Ok(status)
}
other => Err(unexpected_response(other)),
}
}
pub fn shutdown(&self) -> Result<(), DaemonError> {
match self.send_request(Request::Shutdown)? {
Response::Shutdown { .. } => {
self.available.store(false, Ordering::SeqCst);
*self.connection.lock() = None;
Ok(())
}
other => Err(unexpected_response(other)),
}
}
pub fn submit_embedding_job(&self, config: EmbeddingJobConfig) -> Result<String, DaemonError> {
let response = self.send_request(Request::SubmitEmbeddingJob {
db_path: config.db_path,
index_path: config.index_path,
two_tier: config.two_tier,
fast_model: config.fast_model,
quality_model: config.quality_model,
})?;
match response {
Response::JobSubmitted { job_id, .. } => Ok(job_id),
other => Err(unexpected_response(other)),
}
}
pub fn embedding_job_status(&self, db_path: &str) -> Result<EmbeddingJobInfo, DaemonError> {
let response = self.send_request(Request::EmbeddingJobStatus {
db_path: db_path.to_string(),
})?;
match response {
Response::JobStatus(info) => Ok(info),
other => Err(unexpected_response(other)),
}
}
pub fn cancel_embedding_job(
&self,
db_path: &str,
model_id: Option<&str>,
) -> Result<usize, DaemonError> {
let response = self.send_request(Request::CancelEmbeddingJob {
db_path: db_path.to_string(),
model_id: model_id.map(|s| s.to_string()),
})?;
match response {
Response::JobCancelled { cancelled, .. } => Ok(cancelled),
other => Err(unexpected_response(other)),
}
}
}
impl DaemonClient for UdsDaemonClient {
fn id(&self) -> &str {
"uds-daemon"
}
fn is_available(&self) -> bool {
if !self.available.load(Ordering::SeqCst) {
return false;
}
if let Some(last) = *self.last_health_check.lock()
&& last.elapsed() < Duration::from_secs(5)
{
return true;
}
match self.health() {
Ok(status) => status.ready,
Err(_) => {
self.available.store(false, Ordering::SeqCst);
false
}
}
}
fn embed(&self, text: &str, request_id: &str) -> Result<Vec<f32>, DaemonError> {
debug!(
request_id = request_id,
text_len = text.len(),
"Daemon embed request"
);
let response = self.send_request(Request::Embed {
texts: vec![text.to_string()],
model: "default".to_string(),
dims: None,
})?;
match response {
Response::Embed(embed) => {
if embed.embeddings.is_empty() {
return Err(DaemonError::Failed("no embeddings returned".to_string()));
}
debug!(
request_id = request_id,
elapsed_ms = embed.elapsed_ms,
dimension = embed.embeddings[0].len(),
"Daemon embed completed"
);
embed
.embeddings
.into_iter()
.next()
.ok_or_else(|| DaemonError::Failed("embedding unexpectedly empty".to_string()))
}
other => Err(unexpected_response(other)),
}
}
fn embed_batch(&self, texts: &[&str], request_id: &str) -> Result<Vec<Vec<f32>>, DaemonError> {
debug!(
request_id = request_id,
batch_size = texts.len(),
"Daemon embed batch request"
);
let response = self.send_request(Request::Embed {
texts: texts.iter().map(|s| s.to_string()).collect(),
model: "default".to_string(),
dims: None,
})?;
match response {
Response::Embed(embed) => {
if embed.embeddings.len() != texts.len() {
return Err(DaemonError::Failed(format!(
"embedding count mismatch: expected {}, got {}",
texts.len(),
embed.embeddings.len()
)));
}
debug!(
request_id = request_id,
elapsed_ms = embed.elapsed_ms,
batch_size = texts.len(),
"Daemon embed batch completed"
);
Ok(embed.embeddings)
}
other => Err(unexpected_response(other)),
}
}
fn rerank(
&self,
query: &str,
documents: &[&str],
request_id: &str,
) -> Result<Vec<f32>, DaemonError> {
debug!(
request_id = request_id,
query_len = query.len(),
doc_count = documents.len(),
"Daemon rerank request"
);
let response = self.send_request(Request::Rerank {
query: query.to_string(),
documents: documents.iter().map(|s| s.to_string()).collect(),
model: "default".to_string(),
})?;
match response {
Response::Rerank(rerank) => {
if rerank.scores.len() != documents.len() {
return Err(DaemonError::Failed(format!(
"score count mismatch: expected {}, got {}",
documents.len(),
rerank.scores.len()
)));
}
debug!(
request_id = request_id,
elapsed_ms = rerank.elapsed_ms,
doc_count = documents.len(),
"Daemon rerank completed"
);
Ok(rerank.scores)
}
other => Err(unexpected_response(other)),
}
}
}
fn remove_stale_daemon_socket(socket_path: &std::path::Path) -> Result<(), DaemonError> {
use std::os::unix::fs::FileTypeExt;
match std::fs::symlink_metadata(socket_path) {
Ok(metadata) if metadata.file_type().is_socket() || metadata.file_type().is_symlink() => {
std::fs::remove_file(socket_path).map_err(|error| {
DaemonError::Unavailable(format!(
"failed to remove stale daemon socket {}: {}",
socket_path.display(),
error
))
})
}
Ok(metadata) => Err(DaemonError::Unavailable(format!(
"refusing to remove non-socket daemon path {} (file type: {:?})",
socket_path.display(),
metadata.file_type()
))),
Err(error) if error.kind() == std::io::ErrorKind::NotFound => Ok(()),
Err(error) => Err(DaemonError::Unavailable(format!(
"failed to inspect daemon socket path {}: {}",
socket_path.display(),
error
))),
}
}
pub fn connect_or_spawn() -> Result<Arc<UdsDaemonClient>, DaemonError> {
let client = UdsDaemonClient::with_defaults();
client.connect()?;
Ok(Arc::new(client))
}
pub fn try_connect() -> Option<Arc<UdsDaemonClient>> {
let mut config = DaemonClientConfig::from_env();
config.auto_spawn = false;
let client = UdsDaemonClient::new(config);
match client.connect() {
Ok(()) => Some(Arc::new(client)),
Err(_) => None,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_config_defaults() {
let config = DaemonClientConfig::default();
assert!(config.auto_spawn);
assert_eq!(config.connect_timeout, Duration::from_secs(2));
assert_eq!(config.request_timeout, Duration::from_secs(30));
}
#[test]
fn test_default_socket_path() {
let config = DaemonClientConfig::default();
let path_str = config.socket_path.to_string_lossy();
assert!(path_str.starts_with("/tmp/semantic-daemon-"));
assert!(path_str.ends_with(".sock"));
}
#[test]
fn test_client_not_available_initially() {
let config = DaemonClientConfig {
auto_spawn: false,
socket_path: PathBuf::from("/tmp/nonexistent-test-socket.sock"),
..Default::default()
};
let client = UdsDaemonClient::new(config);
assert!(!client.is_available());
}
#[test]
fn test_request_counter_increments() {
let client = UdsDaemonClient::with_defaults();
let first = client.request_counter.fetch_add(1, Ordering::Relaxed);
let second = client.request_counter.fetch_add(1, Ordering::Relaxed);
assert_eq!(second, first + 1);
}
#[test]
fn connection_not_established_error_text_is_stable() {
assert_eq!(
connection_not_established().to_string(),
"daemon unavailable: connection not established"
);
}
#[test]
fn unexpected_response_error_text_is_stable() {
assert_eq!(
unexpected_response(Response::Shutdown {
message: "bye".to_string()
})
.to_string(),
"daemon failed: unexpected response: Shutdown { message: \"bye\" }"
);
}
#[test]
fn test_spawn_guard_lock_path_is_distinct_from_run_lock() {
let socket = PathBuf::from("/tmp/cass-semantic.sock");
assert_ne!(
crate::daemon::daemon_spawn_guard_lock_path(&socket),
crate::daemon::daemon_run_lock_path(&socket)
);
assert_eq!(
crate::daemon::daemon_spawn_guard_lock_path(&socket),
PathBuf::from("/tmp/cass-semantic.spawn-guard.lock")
);
}
#[test]
fn stale_socket_cleanup_refuses_to_remove_regular_file() {
let dir = tempfile::tempdir().expect("tempdir");
let socket_path = dir.path().join("cass-daemon.sock");
std::fs::write(&socket_path, b"not a socket").expect("write regular file");
let err = remove_stale_daemon_socket(&socket_path)
.expect_err("regular files must not be removed as stale sockets");
assert!(
socket_path.exists(),
"regular file at daemon socket path must be preserved"
);
let message = err.to_string();
assert!(
message.contains("refusing to remove non-socket daemon path"),
"error should explain the protected path type; got {message:?}"
);
}
#[test]
fn stale_socket_cleanup_removes_public_socket_symlink() {
let dir = tempfile::tempdir().expect("tempdir");
let socket_path = dir.path().join("cass-daemon.sock");
let stale_private_socket = dir.path().join(".cass-daemon.sock.runtime/daemon.sock");
std::os::unix::fs::symlink(&stale_private_socket, &socket_path)
.expect("create stale daemon public symlink");
remove_stale_daemon_socket(&socket_path).expect("stale public symlink is removable");
assert!(
!socket_path.exists(),
"stale daemon public symlink should be removed before auto-spawn"
);
}
}