1use std::io::{Read, Write};
8use std::os::unix::net::UnixStream;
9use std::path::PathBuf;
10use std::process::{Child, Command, Stdio};
11use std::sync::Arc;
12use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
13use std::time::{Duration, Instant};
14
15use fs2::FileExt;
16use parking_lot::Mutex;
17use tracing::{debug, info, warn};
18
19use super::daemon_spawn_guard_lock_path;
20use super::protocol::{
21 EmbeddingJobInfo, ErrorCode, FramedMessage, HealthStatus, PROTOCOL_VERSION, Request, Response,
22 decode_message, default_socket_path, encode_message,
23};
24use super::worker::EmbeddingJobConfig;
25use crate::search::daemon_client::{DaemonClient, DaemonError};
26
27fn connection_not_established() -> DaemonError {
28 DaemonError::Unavailable("connection not established".to_string())
29}
30
31fn unexpected_response(response: Response) -> DaemonError {
32 DaemonError::Failed(format!("unexpected response: {response:?}"))
33}
34
35#[derive(Debug, Clone)]
37pub struct DaemonClientConfig {
38 pub socket_path: PathBuf,
40 pub connect_timeout: Duration,
42 pub request_timeout: Duration,
44 pub auto_spawn: bool,
46 pub daemon_binary: Option<PathBuf>,
48}
49
50impl Default for DaemonClientConfig {
51 fn default() -> Self {
52 Self {
53 socket_path: default_socket_path(),
54 connect_timeout: Duration::from_secs(2),
55 request_timeout: Duration::from_secs(30),
56 auto_spawn: true,
57 daemon_binary: None, }
59 }
60}
61
62impl DaemonClientConfig {
63 pub fn from_env() -> Self {
65 let mut cfg = Self::default();
66
67 if let Ok(path) = dotenvy::var("CASS_DAEMON_SOCKET") {
68 cfg.socket_path = PathBuf::from(path);
69 }
70
71 if let Ok(val) = dotenvy::var("CASS_DAEMON_CONNECT_TIMEOUT_MS")
72 && let Ok(ms) = val.parse::<u64>()
73 {
74 cfg.connect_timeout = Duration::from_millis(ms);
75 }
76
77 if let Ok(val) = dotenvy::var("CASS_DAEMON_REQUEST_TIMEOUT_MS")
78 && let Ok(ms) = val.parse::<u64>()
79 {
80 cfg.request_timeout = Duration::from_millis(ms);
81 }
82
83 if let Ok(val) = dotenvy::var("CASS_DAEMON_AUTO_SPAWN") {
84 cfg.auto_spawn = val.eq_ignore_ascii_case("true") || val == "1";
85 }
86
87 if let Ok(path) = dotenvy::var("CASS_DAEMON_BINARY") {
88 cfg.daemon_binary = Some(PathBuf::from(path));
89 }
90
91 cfg
92 }
93}
94
95pub struct UdsDaemonClient {
97 config: DaemonClientConfig,
98 connection: Mutex<Option<UnixStream>>,
99 available: AtomicBool,
100 request_counter: AtomicU64,
101 last_health_check: Mutex<Option<Instant>>,
102}
103
104impl UdsDaemonClient {
105 pub fn new(config: DaemonClientConfig) -> Self {
107 Self {
108 config,
109 connection: Mutex::new(None),
110 available: AtomicBool::new(false),
111 request_counter: AtomicU64::new(0),
112 last_health_check: Mutex::new(None),
113 }
114 }
115
116 pub fn with_defaults() -> Self {
118 Self::new(DaemonClientConfig::from_env())
119 }
120
121 pub fn connect(&self) -> Result<(), DaemonError> {
123 if let Ok(stream) = self.try_connect() {
125 *self.connection.lock() = Some(stream);
126 self.available.store(true, Ordering::SeqCst);
127 debug!(socket = %self.config.socket_path.display(), "Connected to existing daemon");
128 return Ok(());
129 }
130
131 if self.config.auto_spawn {
133 info!("Daemon not running, attempting to spawn");
134 self.spawn_daemon()?;
135
136 for attempt in 0..10 {
138 std::thread::sleep(Duration::from_millis(100 * (attempt + 1)));
139 if let Ok(stream) = self.try_connect() {
140 *self.connection.lock() = Some(stream);
141 self.available.store(true, Ordering::SeqCst);
142 info!(
143 socket = %self.config.socket_path.display(),
144 attempts = attempt + 1,
145 "Connected to newly spawned daemon"
146 );
147 return Ok(());
148 }
149 }
150
151 return Err(DaemonError::Unavailable(
152 "daemon failed to start within timeout".to_string(),
153 ));
154 }
155
156 Err(DaemonError::Unavailable(format!(
157 "daemon not running at {}",
158 self.config.socket_path.display()
159 )))
160 }
161
162 fn try_connect(&self) -> std::io::Result<UnixStream> {
164 let stream = UnixStream::connect(&self.config.socket_path)?;
165 stream.set_read_timeout(Some(self.config.request_timeout))?;
166 stream.set_write_timeout(Some(self.config.request_timeout))?;
167 Ok(stream)
168 }
169
170 fn spawn_daemon(&self) -> Result<(), DaemonError> {
172 let binary = self
173 .config
174 .daemon_binary
175 .clone()
176 .or_else(|| std::env::current_exe().ok())
177 .ok_or_else(|| {
178 DaemonError::Unavailable("cannot determine daemon binary path".to_string())
179 })?;
180
181 let lock_path = daemon_spawn_guard_lock_path(&self.config.socket_path);
183
184 let lock_file = match std::fs::OpenOptions::new()
185 .read(true)
186 .write(true)
187 .create_new(true)
188 .open(&lock_path)
189 {
190 Ok(file) => file,
191 Err(e) if e.kind() == std::io::ErrorKind::AlreadyExists => {
192 if std::fs::symlink_metadata(&lock_path)
194 .map(|m| m.file_type().is_symlink())
195 .unwrap_or(false)
196 {
197 return Err(DaemonError::Unavailable(
198 "refusing to open a symlink spawn lock".to_string(),
199 ));
200 }
201 std::fs::OpenOptions::new()
202 .read(true)
203 .write(true)
204 .open(&lock_path)
205 .map_err(|e| {
206 DaemonError::Unavailable(format!("failed to open spawn lock: {}", e))
207 })?
208 }
209 Err(e) => {
210 return Err(DaemonError::Unavailable(format!(
211 "failed to create spawn lock: {}",
212 e
213 )));
214 }
215 };
216
217 lock_file.lock_exclusive().map_err(|e| {
220 DaemonError::Unavailable(format!("failed to acquire spawn lock: {}", e))
221 })?;
222
223 if UnixStream::connect(&self.config.socket_path).is_ok() {
225 debug!("Daemon already running, skipping spawn");
226 return Ok(());
227 }
228
229 remove_stale_daemon_socket(&self.config.socket_path)?;
230
231 let result = Command::new(&binary)
233 .arg("daemon")
234 .arg("--socket")
235 .arg(&self.config.socket_path)
236 .stdin(Stdio::null())
237 .stdout(Stdio::null())
238 .stderr(Stdio::null())
239 .spawn();
240
241 match result {
242 Ok(mut child) => {
243 info!(
244 pid = child.id(),
245 binary = %binary.display(),
246 socket = %self.config.socket_path.display(),
247 "Spawned daemon process"
248 );
249 self.wait_for_spawned_daemon_ready(&mut child)?;
250 std::thread::spawn(move || {
256 let _ = child.wait();
257 });
258 Ok(())
259 }
260 Err(e) => Err(DaemonError::Unavailable(format!(
261 "failed to spawn daemon: {}",
262 e
263 ))),
264 }
265 }
266
267 fn wait_for_spawned_daemon_ready(&self, child: &mut Child) -> Result<(), DaemonError> {
268 let ready_timeout = self.config.connect_timeout.max(Duration::from_secs(5));
269 let started = Instant::now();
270 while started.elapsed() < ready_timeout {
271 if UnixStream::connect(&self.config.socket_path).is_ok() {
272 return Ok(());
273 }
274 match child.try_wait() {
275 Ok(Some(status)) => {
276 return Err(DaemonError::Unavailable(format!(
277 "spawned daemon exited before becoming ready: {}",
278 status
279 )));
280 }
281 Ok(None) => {}
282 Err(error) => {
283 warn!(
284 error = %error,
285 socket = %self.config.socket_path.display(),
286 "failed to poll spawned daemon status while waiting for readiness"
287 );
288 break;
289 }
290 }
291 std::thread::sleep(Duration::from_millis(50));
292 }
293 Ok(())
294 }
295
296 fn get_connection_locked(
298 &self,
299 ) -> Result<parking_lot::MutexGuard<'_, Option<UnixStream>>, DaemonError> {
300 let conn = self.connection.lock();
302 let is_valid = conn.as_ref().is_some_and(|s| s.peer_addr().is_ok());
303
304 if is_valid {
305 return Ok(conn);
306 }
307
308 drop(conn);
310
311 self.available.store(false, Ordering::SeqCst);
313 self.connect()?;
314
315 let conn = self.connection.lock();
316 if conn.is_some() {
317 Ok(conn)
318 } else {
319 Err(connection_not_established())
320 }
321 }
322
323 fn send_request(&self, request: Request) -> Result<Response, DaemonError> {
325 let request_id = format!(
326 "cass-{}",
327 self.request_counter.fetch_add(1, Ordering::Relaxed)
328 );
329 let msg = FramedMessage::new(&request_id, request);
330
331 let encoded = encode_message(&msg)
332 .map_err(|e| DaemonError::Failed(format!("failed to encode request: {}", e)))?;
333
334 let mut stream_guard = self.get_connection_locked()?;
335 let stream = stream_guard
336 .as_mut()
337 .ok_or_else(connection_not_established)?;
338
339 if let Err(e) = stream.write_all(&encoded) {
341 *stream_guard = None;
342 self.available.store(false, Ordering::SeqCst);
343 return Err(DaemonError::Unavailable(format!(
344 "failed to send request: {}",
345 e
346 )));
347 }
348
349 let mut len_buf = [0u8; 4];
351 if let Err(e) = stream.read_exact(&mut len_buf) {
352 *stream_guard = None;
353 self.available.store(false, Ordering::SeqCst);
354 if e.kind() == std::io::ErrorKind::TimedOut {
355 return Err(DaemonError::Timeout("response timeout".to_string()));
356 } else {
357 return Err(DaemonError::Unavailable(format!(
358 "failed to read response length: {}",
359 e
360 )));
361 }
362 }
363
364 let len = u32::from_be_bytes(len_buf) as usize;
365 const MAX_RESPONSE_SIZE: usize = 10 * 1024 * 1024;
367 if len > MAX_RESPONSE_SIZE {
368 *stream_guard = None;
369 warn!(
370 response_size = len,
371 max_size = MAX_RESPONSE_SIZE,
372 "Rejecting oversized daemon response"
373 );
374 return Err(DaemonError::Failed(format!(
375 "response too large: {} bytes (max {})",
376 len, MAX_RESPONSE_SIZE
377 )));
378 }
379
380 let mut payload = vec![0u8; len];
382 if let Err(e) = stream.read_exact(&mut payload) {
383 *stream_guard = None;
384 self.available.store(false, Ordering::SeqCst);
385 if e.kind() == std::io::ErrorKind::TimedOut {
386 return Err(DaemonError::Timeout("response timeout".to_string()));
387 } else {
388 return Err(DaemonError::Unavailable(format!(
389 "failed to read response: {}",
390 e
391 )));
392 }
393 }
394
395 drop(stream_guard);
397
398 let response: FramedMessage<Response> = decode_message(&payload)
400 .map_err(|e| DaemonError::Failed(format!("failed to decode response: {}", e)))?;
401
402 if response.version != PROTOCOL_VERSION {
404 return Err(DaemonError::Failed(format!(
405 "protocol version mismatch: expected {}, got {}",
406 PROTOCOL_VERSION, response.version
407 )));
408 }
409
410 match response.payload {
412 Response::Error(err) => {
413 let daemon_err = match err.code {
414 ErrorCode::Overloaded => DaemonError::Overloaded {
415 retry_after: err.retry_after_ms.map(Duration::from_millis),
416 message: err.message,
417 },
418 ErrorCode::Timeout => DaemonError::Timeout(err.message),
419 ErrorCode::InvalidInput => DaemonError::InvalidInput(err.message),
420 _ => DaemonError::Failed(err.message),
421 };
422 Err(daemon_err)
423 }
424 other => Ok(other),
425 }
426 }
427
428 pub fn health(&self) -> Result<HealthStatus, DaemonError> {
430 match self.send_request(Request::Health)? {
431 Response::Health(status) => {
432 *self.last_health_check.lock() = Some(Instant::now());
433 Ok(status)
434 }
435 other => Err(unexpected_response(other)),
436 }
437 }
438
439 pub fn shutdown(&self) -> Result<(), DaemonError> {
441 match self.send_request(Request::Shutdown)? {
442 Response::Shutdown { .. } => {
443 self.available.store(false, Ordering::SeqCst);
444 *self.connection.lock() = None;
445 Ok(())
446 }
447 other => Err(unexpected_response(other)),
448 }
449 }
450
451 pub fn submit_embedding_job(&self, config: EmbeddingJobConfig) -> Result<String, DaemonError> {
453 let response = self.send_request(Request::SubmitEmbeddingJob {
454 db_path: config.db_path,
455 index_path: config.index_path,
456 two_tier: config.two_tier,
457 fast_model: config.fast_model,
458 quality_model: config.quality_model,
459 })?;
460 match response {
461 Response::JobSubmitted { job_id, .. } => Ok(job_id),
462 other => Err(unexpected_response(other)),
463 }
464 }
465
466 pub fn embedding_job_status(&self, db_path: &str) -> Result<EmbeddingJobInfo, DaemonError> {
468 let response = self.send_request(Request::EmbeddingJobStatus {
469 db_path: db_path.to_string(),
470 })?;
471 match response {
472 Response::JobStatus(info) => Ok(info),
473 other => Err(unexpected_response(other)),
474 }
475 }
476
477 pub fn cancel_embedding_job(
479 &self,
480 db_path: &str,
481 model_id: Option<&str>,
482 ) -> Result<usize, DaemonError> {
483 let response = self.send_request(Request::CancelEmbeddingJob {
484 db_path: db_path.to_string(),
485 model_id: model_id.map(|s| s.to_string()),
486 })?;
487 match response {
488 Response::JobCancelled { cancelled, .. } => Ok(cancelled),
489 other => Err(unexpected_response(other)),
490 }
491 }
492}
493
494impl DaemonClient for UdsDaemonClient {
495 fn id(&self) -> &str {
496 "uds-daemon"
497 }
498
499 fn is_available(&self) -> bool {
500 if !self.available.load(Ordering::SeqCst) {
502 return false;
503 }
504
505 if let Some(last) = *self.last_health_check.lock()
507 && last.elapsed() < Duration::from_secs(5)
508 {
509 return true;
510 }
511
512 match self.health() {
514 Ok(status) => status.ready,
515 Err(_) => {
516 self.available.store(false, Ordering::SeqCst);
517 false
518 }
519 }
520 }
521
522 fn embed(&self, text: &str, request_id: &str) -> Result<Vec<f32>, DaemonError> {
523 debug!(
524 request_id = request_id,
525 text_len = text.len(),
526 "Daemon embed request"
527 );
528
529 let response = self.send_request(Request::Embed {
530 texts: vec![text.to_string()],
531 model: "default".to_string(),
532 dims: None,
533 })?;
534
535 match response {
536 Response::Embed(embed) => {
537 if embed.embeddings.is_empty() {
538 return Err(DaemonError::Failed("no embeddings returned".to_string()));
539 }
540 debug!(
541 request_id = request_id,
542 elapsed_ms = embed.elapsed_ms,
543 dimension = embed.embeddings[0].len(),
544 "Daemon embed completed"
545 );
546 embed
548 .embeddings
549 .into_iter()
550 .next()
551 .ok_or_else(|| DaemonError::Failed("embedding unexpectedly empty".to_string()))
552 }
553 other => Err(unexpected_response(other)),
554 }
555 }
556
557 fn embed_batch(&self, texts: &[&str], request_id: &str) -> Result<Vec<Vec<f32>>, DaemonError> {
558 debug!(
559 request_id = request_id,
560 batch_size = texts.len(),
561 "Daemon embed batch request"
562 );
563
564 let response = self.send_request(Request::Embed {
565 texts: texts.iter().map(|s| s.to_string()).collect(),
566 model: "default".to_string(),
567 dims: None,
568 })?;
569
570 match response {
571 Response::Embed(embed) => {
572 if embed.embeddings.len() != texts.len() {
573 return Err(DaemonError::Failed(format!(
574 "embedding count mismatch: expected {}, got {}",
575 texts.len(),
576 embed.embeddings.len()
577 )));
578 }
579 debug!(
580 request_id = request_id,
581 elapsed_ms = embed.elapsed_ms,
582 batch_size = texts.len(),
583 "Daemon embed batch completed"
584 );
585 Ok(embed.embeddings)
586 }
587 other => Err(unexpected_response(other)),
588 }
589 }
590
591 fn rerank(
592 &self,
593 query: &str,
594 documents: &[&str],
595 request_id: &str,
596 ) -> Result<Vec<f32>, DaemonError> {
597 debug!(
598 request_id = request_id,
599 query_len = query.len(),
600 doc_count = documents.len(),
601 "Daemon rerank request"
602 );
603
604 let response = self.send_request(Request::Rerank {
605 query: query.to_string(),
606 documents: documents.iter().map(|s| s.to_string()).collect(),
607 model: "default".to_string(),
608 })?;
609
610 match response {
611 Response::Rerank(rerank) => {
612 if rerank.scores.len() != documents.len() {
613 return Err(DaemonError::Failed(format!(
614 "score count mismatch: expected {}, got {}",
615 documents.len(),
616 rerank.scores.len()
617 )));
618 }
619 debug!(
620 request_id = request_id,
621 elapsed_ms = rerank.elapsed_ms,
622 doc_count = documents.len(),
623 "Daemon rerank completed"
624 );
625 Ok(rerank.scores)
626 }
627 other => Err(unexpected_response(other)),
628 }
629 }
630}
631
632fn remove_stale_daemon_socket(socket_path: &std::path::Path) -> Result<(), DaemonError> {
633 use std::os::unix::fs::FileTypeExt;
634
635 match std::fs::symlink_metadata(socket_path) {
636 Ok(metadata) if metadata.file_type().is_socket() || metadata.file_type().is_symlink() => {
637 std::fs::remove_file(socket_path).map_err(|error| {
638 DaemonError::Unavailable(format!(
639 "failed to remove stale daemon socket {}: {}",
640 socket_path.display(),
641 error
642 ))
643 })
644 }
645 Ok(metadata) => Err(DaemonError::Unavailable(format!(
646 "refusing to remove non-socket daemon path {} (file type: {:?})",
647 socket_path.display(),
648 metadata.file_type()
649 ))),
650 Err(error) if error.kind() == std::io::ErrorKind::NotFound => Ok(()),
651 Err(error) => Err(DaemonError::Unavailable(format!(
652 "failed to inspect daemon socket path {}: {}",
653 socket_path.display(),
654 error
655 ))),
656 }
657}
658
659pub fn connect_or_spawn() -> Result<Arc<UdsDaemonClient>, DaemonError> {
661 let client = UdsDaemonClient::with_defaults();
662 client.connect()?;
663 Ok(Arc::new(client))
664}
665
666pub fn try_connect() -> Option<Arc<UdsDaemonClient>> {
668 let mut config = DaemonClientConfig::from_env();
669 config.auto_spawn = false;
670 let client = UdsDaemonClient::new(config);
671 match client.connect() {
672 Ok(()) => Some(Arc::new(client)),
673 Err(_) => None,
674 }
675}
676
677#[cfg(test)]
678mod tests {
679 use super::*;
680
681 #[test]
682 fn test_config_defaults() {
683 let config = DaemonClientConfig::default();
684 assert!(config.auto_spawn);
685 assert_eq!(config.connect_timeout, Duration::from_secs(2));
686 assert_eq!(config.request_timeout, Duration::from_secs(30));
687 }
688
689 #[test]
690 fn test_default_socket_path() {
691 let config = DaemonClientConfig::default();
692 let path_str = config.socket_path.to_string_lossy();
693 assert!(path_str.starts_with("/tmp/semantic-daemon-"));
694 assert!(path_str.ends_with(".sock"));
695 }
696
697 #[test]
698 fn test_client_not_available_initially() {
699 let config = DaemonClientConfig {
700 auto_spawn: false,
701 socket_path: PathBuf::from("/tmp/nonexistent-test-socket.sock"),
702 ..Default::default()
703 };
704
705 let client = UdsDaemonClient::new(config);
706 assert!(!client.is_available());
707 }
708
709 #[test]
710 fn test_request_counter_increments() {
711 let client = UdsDaemonClient::with_defaults();
712 let first = client.request_counter.fetch_add(1, Ordering::Relaxed);
713 let second = client.request_counter.fetch_add(1, Ordering::Relaxed);
714 assert_eq!(second, first + 1);
715 }
716
717 #[test]
718 fn connection_not_established_error_text_is_stable() {
719 assert_eq!(
720 connection_not_established().to_string(),
721 "daemon unavailable: connection not established"
722 );
723 }
724
725 #[test]
726 fn unexpected_response_error_text_is_stable() {
727 assert_eq!(
728 unexpected_response(Response::Shutdown {
729 message: "bye".to_string()
730 })
731 .to_string(),
732 "daemon failed: unexpected response: Shutdown { message: \"bye\" }"
733 );
734 }
735
736 #[test]
737 fn test_spawn_guard_lock_path_is_distinct_from_run_lock() {
738 let socket = PathBuf::from("/tmp/cass-semantic.sock");
739 assert_ne!(
740 crate::daemon::daemon_spawn_guard_lock_path(&socket),
741 crate::daemon::daemon_run_lock_path(&socket)
742 );
743 assert_eq!(
744 crate::daemon::daemon_spawn_guard_lock_path(&socket),
745 PathBuf::from("/tmp/cass-semantic.spawn-guard.lock")
746 );
747 }
748
749 #[test]
750 fn stale_socket_cleanup_refuses_to_remove_regular_file() {
751 let dir = tempfile::tempdir().expect("tempdir");
752 let socket_path = dir.path().join("cass-daemon.sock");
753 std::fs::write(&socket_path, b"not a socket").expect("write regular file");
754
755 let err = remove_stale_daemon_socket(&socket_path)
756 .expect_err("regular files must not be removed as stale sockets");
757
758 assert!(
759 socket_path.exists(),
760 "regular file at daemon socket path must be preserved"
761 );
762 let message = err.to_string();
763 assert!(
764 message.contains("refusing to remove non-socket daemon path"),
765 "error should explain the protected path type; got {message:?}"
766 );
767 }
768
769 #[test]
770 fn stale_socket_cleanup_removes_public_socket_symlink() {
771 let dir = tempfile::tempdir().expect("tempdir");
772 let socket_path = dir.path().join("cass-daemon.sock");
773 let stale_private_socket = dir.path().join(".cass-daemon.sock.runtime/daemon.sock");
774 std::os::unix::fs::symlink(&stale_private_socket, &socket_path)
775 .expect("create stale daemon public symlink");
776
777 remove_stale_daemon_socket(&socket_path).expect("stale public symlink is removable");
778
779 assert!(
780 !socket_path.exists(),
781 "stale daemon public symlink should be removed before auto-spawn"
782 );
783 }
784}