Skip to main content

coding_agent_search/daemon/
protocol.rs

1//! Wire-compatible protocol for semantic model daemon.
2//!
3//! This protocol is designed to be wire-compatible with xf's daemon implementation,
4//! allowing both tools to share a daemon if both are installed.
5//!
6//! Protocol uses MessagePack for efficient binary serialization over Unix Domain Sockets.
7
8use serde::{Deserialize, Serialize};
9use std::path::PathBuf;
10
11/// Protocol version for compatibility checks.
12/// Both cass and xf must use the same version to share a daemon.
13pub const PROTOCOL_VERSION: u32 = 1;
14
15/// Default socket path (shared between cass and xf).
16pub fn default_socket_path() -> PathBuf {
17    let user = dotenvy::var("USER").unwrap_or_else(|_| "unknown".into());
18    let safe_user = sanitize_socket_user(&user);
19    PathBuf::from(format!("/tmp/semantic-daemon-{safe_user}.sock"))
20}
21
22fn sanitize_socket_user(user: &str) -> String {
23    let safe_user: String = user
24        .chars()
25        .filter(|c| c.is_alphanumeric() || *c == '-' || *c == '_')
26        .take(64)
27        .collect();
28
29    if safe_user.is_empty() {
30        "unknown".to_string()
31    } else {
32        safe_user
33    }
34}
35
36/// Request types for the daemon protocol.
37#[derive(Debug, Clone, Serialize, Deserialize)]
38pub enum Request {
39    /// Health check - returns daemon status.
40    Health,
41
42    /// Generate embeddings for texts.
43    Embed {
44        texts: Vec<String>,
45        model: String,
46        dims: Option<usize>,
47    },
48
49    /// Rerank documents against a query.
50    Rerank {
51        query: String,
52        documents: Vec<String>,
53        model: String,
54    },
55
56    /// Get daemon status and loaded models.
57    Status,
58
59    /// Submit a background embedding job.
60    SubmitEmbeddingJob {
61        db_path: String,
62        index_path: String,
63        two_tier: bool,
64        fast_model: Option<String>,
65        quality_model: Option<String>,
66    },
67
68    /// Query embedding job status.
69    EmbeddingJobStatus { db_path: String },
70
71    /// Cancel embedding jobs.
72    CancelEmbeddingJob {
73        db_path: String,
74        model_id: Option<String>,
75    },
76
77    /// Request graceful shutdown.
78    Shutdown,
79}
80
81/// Response types from the daemon.
82#[derive(Debug, Clone, Serialize, Deserialize)]
83pub enum Response {
84    /// Health check response.
85    Health(HealthStatus),
86
87    /// Embedding response with vectors.
88    Embed(EmbedResponse),
89
90    /// Rerank response with scores.
91    Rerank(RerankResponse),
92
93    /// Status response with daemon info.
94    Status(StatusResponse),
95
96    /// Embedding job submitted.
97    JobSubmitted { job_id: String, message: String },
98
99    /// Embedding job status.
100    JobStatus(EmbeddingJobInfo),
101
102    /// Embedding jobs cancelled.
103    JobCancelled { cancelled: usize, message: String },
104
105    /// Shutdown acknowledgement.
106    Shutdown { message: String },
107
108    /// Error response.
109    Error(ErrorResponse),
110}
111
112/// Health status of the daemon.
113#[derive(Debug, Clone, Serialize, Deserialize)]
114pub struct HealthStatus {
115    /// Daemon uptime in seconds.
116    pub uptime_secs: u64,
117    /// Protocol version.
118    pub version: u32,
119    /// Whether models are loaded and ready.
120    pub ready: bool,
121    /// Current memory usage in bytes (approximate).
122    pub memory_bytes: u64,
123}
124
125/// Response containing embeddings.
126#[derive(Debug, Clone, Serialize, Deserialize)]
127pub struct EmbedResponse {
128    /// Embeddings as Vec<Vec<f32>>.
129    pub embeddings: Vec<Vec<f32>>,
130    /// Model ID used.
131    pub model: String,
132    /// Processing time in milliseconds.
133    pub elapsed_ms: u64,
134}
135
136/// Response containing rerank scores.
137#[derive(Debug, Clone, Serialize, Deserialize)]
138pub struct RerankResponse {
139    /// Scores for each document (same order as input).
140    pub scores: Vec<f32>,
141    /// Model ID used.
142    pub model: String,
143    /// Processing time in milliseconds.
144    pub elapsed_ms: u64,
145}
146
147/// Daemon status response.
148#[derive(Debug, Clone, Serialize, Deserialize)]
149pub struct StatusResponse {
150    /// Daemon uptime in seconds.
151    pub uptime_secs: u64,
152    /// Protocol version.
153    pub version: u32,
154    /// Loaded embedder models.
155    pub embedders: Vec<ModelInfo>,
156    /// Loaded reranker models.
157    pub rerankers: Vec<ModelInfo>,
158    /// Current memory usage in bytes.
159    pub memory_bytes: u64,
160    /// Total requests served.
161    pub total_requests: u64,
162}
163
164/// Information about a loaded model.
165#[derive(Debug, Clone, Serialize, Deserialize)]
166pub struct ModelInfo {
167    /// Model ID.
168    pub id: String,
169    /// Model name/path.
170    pub name: String,
171    /// Output dimension (for embedders).
172    pub dimension: Option<usize>,
173    /// Whether the model is currently loaded.
174    pub loaded: bool,
175    /// Approximate memory usage in bytes.
176    pub memory_bytes: u64,
177}
178
179/// Error response from daemon.
180#[derive(Debug, Clone, Serialize, Deserialize)]
181pub struct ErrorResponse {
182    /// Error code for programmatic handling.
183    pub code: ErrorCode,
184    /// Human-readable error message.
185    pub message: String,
186    /// Whether the request can be retried.
187    pub retryable: bool,
188    /// Suggested retry delay in milliseconds (if retryable).
189    pub retry_after_ms: Option<u64>,
190}
191
192/// Error codes for daemon errors.
193#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
194pub enum ErrorCode {
195    /// Unknown or internal error.
196    Internal,
197    /// Model not found or not loaded.
198    ModelNotFound,
199    /// Invalid request parameters.
200    InvalidInput,
201    /// Daemon is overloaded, try again later.
202    Overloaded,
203    /// Request timed out.
204    Timeout,
205    /// Model loading failed.
206    ModelLoadFailed,
207    /// Protocol version mismatch.
208    VersionMismatch,
209}
210
211/// Status information for embedding jobs.
212#[derive(Debug, Clone, Serialize, Deserialize)]
213pub struct EmbeddingJobInfo {
214    pub jobs: Vec<EmbeddingJobDetail>,
215}
216
217/// Detail for a single embedding job.
218#[derive(Debug, Clone, Serialize, Deserialize)]
219pub struct EmbeddingJobDetail {
220    pub job_id: i64,
221    pub model_id: String,
222    pub status: String,
223    pub total_docs: i64,
224    pub completed_docs: i64,
225    pub error_message: Option<String>,
226}
227
228/// Framed message wrapper for length-prefixed protocol.
229#[derive(Debug, Clone, Serialize, Deserialize)]
230pub struct FramedMessage<T> {
231    /// Protocol version.
232    pub version: u32,
233    /// Request ID for correlation.
234    pub request_id: String,
235    /// Payload.
236    pub payload: T,
237}
238
239impl<T> FramedMessage<T> {
240    pub fn new(request_id: impl Into<String>, payload: T) -> Self {
241        Self {
242            version: PROTOCOL_VERSION,
243            request_id: request_id.into(),
244            payload,
245        }
246    }
247}
248
249/// Encode a message to MessagePack bytes with length prefix.
250pub fn encode_message<T: Serialize>(msg: &FramedMessage<T>) -> Result<Vec<u8>, EncodeError> {
251    let payload = rmp_serde::to_vec(msg)?;
252    let len = u32::try_from(payload.len())
253        .map_err(|_| EncodeError::Message("payload exceeds maximum size of 4GB".to_string()))?;
254    let mut buf = Vec::with_capacity(4 + payload.len());
255    buf.extend_from_slice(&len.to_be_bytes());
256    buf.extend_from_slice(&payload);
257    Ok(buf)
258}
259
260/// Decode a message from MessagePack bytes (without length prefix).
261pub fn decode_message<T: for<'de> Deserialize<'de>>(
262    data: &[u8],
263) -> Result<FramedMessage<T>, DecodeError> {
264    rmp_serde::from_slice(data).map_err(DecodeError::from)
265}
266
267#[derive(Debug, thiserror::Error)]
268pub enum EncodeError {
269    #[error("encode error: {0}")]
270    Message(String),
271    #[error("encode error: {0}")]
272    MessagePack(#[from] rmp_serde::encode::Error),
273}
274
275#[derive(Debug, thiserror::Error)]
276pub enum DecodeError {
277    #[error("decode error: {0}")]
278    Message(String),
279    #[error("decode error: {0}")]
280    MessagePack(#[from] rmp_serde::decode::Error),
281}
282
283#[cfg(test)]
284mod tests {
285    use super::{
286        DecodeError, EmbedResponse, EncodeError, ErrorCode, ErrorResponse, FramedMessage,
287        HealthStatus, PROTOCOL_VERSION, Request, RerankResponse, Response, decode_message,
288        default_socket_path, encode_message, sanitize_socket_user,
289    };
290    use serde::de::DeserializeOwned;
291    use std::error::Error;
292    use std::fmt::Debug;
293
294    type TestResult = Result<(), Box<dyn Error>>;
295
296    fn test_error(message: impl Into<String>) -> Box<dyn Error> {
297        std::io::Error::other(message.into()).into()
298    }
299
300    fn ensure(condition: bool, message: impl Into<String>) -> TestResult {
301        if condition {
302            Ok(())
303        } else {
304            Err(test_error(message))
305        }
306    }
307
308    fn ensure_eq<T>(actual: T, expected: T, message: impl Into<String>) -> TestResult
309    where
310        T: Debug + PartialEq,
311    {
312        if actual == expected {
313            Ok(())
314        } else {
315            Err(test_error(format!(
316                "{}: expected {expected:?}, got {actual:?}",
317                message.into()
318            )))
319        }
320    }
321
322    fn decode_framed<T>(encoded: &[u8]) -> Result<FramedMessage<T>, Box<dyn Error>>
323    where
324        T: DeserializeOwned,
325    {
326        let payload = encoded
327            .get(4..)
328            .ok_or_else(|| test_error("encoded frame should include a 4-byte length prefix"))?;
329        decode_message(payload).map_err(|err| test_error(err.to_string()))
330    }
331
332    #[test]
333    fn test_encode_decode_health_request() -> TestResult {
334        let msg = FramedMessage::new("req-1", Request::Health);
335        let encoded = encode_message(&msg)?;
336
337        let decoded: FramedMessage<Request> = decode_framed(&encoded)?;
338        ensure_eq(decoded.version, PROTOCOL_VERSION, "protocol version")?;
339        ensure_eq(decoded.request_id, "req-1".to_string(), "request id")?;
340        ensure(matches!(decoded.payload, Request::Health), "health payload")
341    }
342
343    #[test]
344    fn test_protocol_error_display_strings_are_preserved() -> TestResult {
345        let encode = EncodeError::Message("bad payload".to_string());
346        ensure_eq(
347            encode.to_string(),
348            "encode error: bad payload".to_string(),
349            "encode",
350        )?;
351        ensure(encode.source().is_none(), "encode")?;
352
353        let decode = DecodeError::Message("bad frame".to_string());
354        ensure_eq(
355            decode.to_string(),
356            "decode error: bad frame".to_string(),
357            "decode",
358        )?;
359        ensure(decode.source().is_none(), "decode")?;
360        Ok(())
361    }
362
363    #[test]
364    fn test_encode_decode_embed_request() -> TestResult {
365        let msg = FramedMessage::new(
366            "req-2",
367            Request::Embed {
368                texts: vec!["hello".to_string(), "world".to_string()],
369                model: "all-MiniLM-L6-v2".to_string(),
370                dims: None,
371            },
372        );
373        let encoded = encode_message(&msg)?;
374        let decoded: FramedMessage<Request> = decode_framed(&encoded)?;
375
376        let Request::Embed { texts, model, dims } = decoded.payload else {
377            return Err(test_error("expected embed request payload"));
378        };
379        ensure_eq(
380            texts,
381            vec!["hello".to_string(), "world".to_string()],
382            "embed texts",
383        )?;
384        ensure_eq(model, "all-MiniLM-L6-v2".to_string(), "embed model")?;
385        ensure(dims.is_none(), "embed dims should be absent")
386    }
387
388    #[test]
389    fn test_encode_decode_rerank_request() -> TestResult {
390        let msg = FramedMessage::new(
391            "req-3",
392            Request::Rerank {
393                query: "test query".to_string(),
394                documents: vec!["doc1".to_string(), "doc2".to_string()],
395                model: "ms-marco-MiniLM-L-6-v2".to_string(),
396            },
397        );
398        let encoded = encode_message(&msg)?;
399        let decoded: FramedMessage<Request> = decode_framed(&encoded)?;
400
401        let Request::Rerank {
402            query,
403            documents,
404            model,
405        } = decoded.payload
406        else {
407            return Err(test_error("expected rerank request payload"));
408        };
409        ensure_eq(query, "test query".to_string(), "rerank query")?;
410        ensure_eq(
411            documents,
412            vec!["doc1".to_string(), "doc2".to_string()],
413            "rerank documents",
414        )?;
415        ensure_eq(model, "ms-marco-MiniLM-L-6-v2".to_string(), "rerank model")
416    }
417
418    #[test]
419    fn test_encode_decode_health_response() -> TestResult {
420        let msg = FramedMessage::new(
421            "resp-1",
422            Response::Health(HealthStatus {
423                uptime_secs: 120,
424                version: PROTOCOL_VERSION,
425                ready: true,
426                memory_bytes: 100_000_000,
427            }),
428        );
429        let encoded = encode_message(&msg)?;
430        let decoded: FramedMessage<Response> = decode_framed(&encoded)?;
431
432        let Response::Health(status) = decoded.payload else {
433            return Err(test_error("expected health response payload"));
434        };
435        ensure_eq(status.uptime_secs, 120, "health uptime")?;
436        ensure(status.ready, "health response should be ready")
437    }
438
439    #[test]
440    fn test_encode_decode_error_response() -> TestResult {
441        let msg = FramedMessage::new(
442            "resp-err",
443            Response::Error(ErrorResponse {
444                code: ErrorCode::Overloaded,
445                message: "too many requests".to_string(),
446                retryable: true,
447                retry_after_ms: Some(1000),
448            }),
449        );
450        let encoded = encode_message(&msg)?;
451        let decoded: FramedMessage<Response> = decode_framed(&encoded)?;
452
453        let Response::Error(err) = decoded.payload else {
454            return Err(test_error("expected error response payload"));
455        };
456        ensure_eq(err.code, ErrorCode::Overloaded, "error code")?;
457        ensure(err.retryable, "error should be retryable")?;
458        ensure_eq(err.retry_after_ms, Some(1000), "retry delay")
459    }
460
461    #[test]
462    fn test_default_socket_path() -> TestResult {
463        let path = default_socket_path();
464        let path_str = path.to_string_lossy();
465        ensure(
466            path_str.starts_with("/tmp/semantic-daemon-"),
467            "socket path prefix",
468        )?;
469        ensure(path_str.ends_with(".sock"), "socket path suffix")
470    }
471
472    #[test]
473    fn test_socket_user_sanitization() -> TestResult {
474        ensure_eq(
475            sanitize_socket_user("../bad user!"),
476            "baduser".to_string(),
477            "path traversal and punctuation should be removed",
478        )?;
479        ensure_eq(
480            sanitize_socket_user(""),
481            "unknown".to_string(),
482            "empty user fallback",
483        )?;
484        ensure_eq(
485            sanitize_socket_user("a".repeat(80).as_str()).len(),
486            64,
487            "socket user length cap",
488        )
489    }
490
491    #[test]
492    fn test_wire_compatibility_embed_response() -> TestResult {
493        let msg = FramedMessage::new(
494            "resp-embed",
495            Response::Embed(EmbedResponse {
496                embeddings: vec![vec![0.1, 0.2, 0.3], vec![0.4, 0.5, 0.6]],
497                model: "minilm-384".to_string(),
498                elapsed_ms: 15,
499            }),
500        );
501        let encoded = encode_message(&msg)?;
502        let decoded: FramedMessage<Response> = decode_framed(&encoded)?;
503
504        let Response::Embed(resp) = decoded.payload else {
505            return Err(test_error("expected embed response payload"));
506        };
507        ensure_eq(resp.embeddings.len(), 2, "embedding count")?;
508        let first = resp
509            .embeddings
510            .first()
511            .ok_or_else(|| test_error("first embedding should exist"))?;
512        ensure_eq(first.clone(), vec![0.1, 0.2, 0.3], "first embedding")?;
513        ensure_eq(resp.model, "minilm-384".to_string(), "embedding model")
514    }
515
516    #[test]
517    fn test_wire_compatibility_rerank_response() -> TestResult {
518        let msg = FramedMessage::new(
519            "resp-rerank",
520            Response::Rerank(RerankResponse {
521                scores: vec![0.95, 0.72, 0.31],
522                model: "ms-marco".to_string(),
523                elapsed_ms: 8,
524            }),
525        );
526        let encoded = encode_message(&msg)?;
527        let decoded: FramedMessage<Response> = decode_framed(&encoded)?;
528
529        let Response::Rerank(resp) = decoded.payload else {
530            return Err(test_error("expected rerank response payload"));
531        };
532        ensure_eq(resp.scores, vec![0.95, 0.72, 0.31], "rerank scores")
533    }
534}