Skip to main content

coding_agent_search/daemon/
protocol.rs

1//! Wire-compatible protocol for semantic model daemon.
2//!
3//! This protocol is designed to be wire-compatible with xf's daemon implementation,
4//! allowing both tools to share a daemon if both are installed.
5//!
6//! Protocol uses MessagePack for efficient binary serialization over Unix Domain Sockets.
7
8use serde::{Deserialize, Serialize};
9
10/// Protocol version for compatibility checks.
11/// Both cass and xf must use the same version to share a daemon.
12pub const PROTOCOL_VERSION: u32 = 1;
13
14/// Default socket path (shared between cass and xf).
15pub fn default_socket_path() -> std::path::PathBuf {
16    let user = std::env::var("USER").unwrap_or_else(|_| "unknown".into());
17    // Sanitize: keep only alphanumeric, dash, underscore to prevent path traversal
18    let safe_user: String = user
19        .chars()
20        .filter(|c| c.is_alphanumeric() || *c == '-' || *c == '_')
21        .take(64)
22        .collect();
23    let safe_user = if safe_user.is_empty() {
24        "unknown".to_string()
25    } else {
26        safe_user
27    };
28    std::path::PathBuf::from(format!("/tmp/semantic-daemon-{}.sock", safe_user))
29}
30
31/// Request types for the daemon protocol.
32#[derive(Debug, Clone, Serialize, Deserialize)]
33pub enum Request {
34    /// Health check - returns daemon status.
35    Health,
36
37    /// Generate embeddings for texts.
38    Embed {
39        texts: Vec<String>,
40        model: String,
41        dims: Option<usize>,
42    },
43
44    /// Rerank documents against a query.
45    Rerank {
46        query: String,
47        documents: Vec<String>,
48        model: String,
49    },
50
51    /// Get daemon status and loaded models.
52    Status,
53
54    /// Submit a background embedding job.
55    SubmitEmbeddingJob {
56        db_path: String,
57        index_path: String,
58        two_tier: bool,
59        fast_model: Option<String>,
60        quality_model: Option<String>,
61    },
62
63    /// Query embedding job status.
64    EmbeddingJobStatus { db_path: String },
65
66    /// Cancel embedding jobs.
67    CancelEmbeddingJob {
68        db_path: String,
69        model_id: Option<String>,
70    },
71
72    /// Request graceful shutdown.
73    Shutdown,
74}
75
76/// Response types from the daemon.
77#[derive(Debug, Clone, Serialize, Deserialize)]
78pub enum Response {
79    /// Health check response.
80    Health(HealthStatus),
81
82    /// Embedding response with vectors.
83    Embed(EmbedResponse),
84
85    /// Rerank response with scores.
86    Rerank(RerankResponse),
87
88    /// Status response with daemon info.
89    Status(StatusResponse),
90
91    /// Embedding job submitted.
92    JobSubmitted { job_id: String, message: String },
93
94    /// Embedding job status.
95    JobStatus(EmbeddingJobInfo),
96
97    /// Embedding jobs cancelled.
98    JobCancelled { cancelled: usize, message: String },
99
100    /// Shutdown acknowledgement.
101    Shutdown { message: String },
102
103    /// Error response.
104    Error(ErrorResponse),
105}
106
107/// Health status of the daemon.
108#[derive(Debug, Clone, Serialize, Deserialize)]
109pub struct HealthStatus {
110    /// Daemon uptime in seconds.
111    pub uptime_secs: u64,
112    /// Protocol version.
113    pub version: u32,
114    /// Whether models are loaded and ready.
115    pub ready: bool,
116    /// Current memory usage in bytes (approximate).
117    pub memory_bytes: u64,
118}
119
120/// Response containing embeddings.
121#[derive(Debug, Clone, Serialize, Deserialize)]
122pub struct EmbedResponse {
123    /// Embeddings as Vec<Vec<f32>>.
124    pub embeddings: Vec<Vec<f32>>,
125    /// Model ID used.
126    pub model: String,
127    /// Processing time in milliseconds.
128    pub elapsed_ms: u64,
129}
130
131/// Response containing rerank scores.
132#[derive(Debug, Clone, Serialize, Deserialize)]
133pub struct RerankResponse {
134    /// Scores for each document (same order as input).
135    pub scores: Vec<f32>,
136    /// Model ID used.
137    pub model: String,
138    /// Processing time in milliseconds.
139    pub elapsed_ms: u64,
140}
141
142/// Daemon status response.
143#[derive(Debug, Clone, Serialize, Deserialize)]
144pub struct StatusResponse {
145    /// Daemon uptime in seconds.
146    pub uptime_secs: u64,
147    /// Protocol version.
148    pub version: u32,
149    /// Loaded embedder models.
150    pub embedders: Vec<ModelInfo>,
151    /// Loaded reranker models.
152    pub rerankers: Vec<ModelInfo>,
153    /// Current memory usage in bytes.
154    pub memory_bytes: u64,
155    /// Total requests served.
156    pub total_requests: u64,
157}
158
159/// Information about a loaded model.
160#[derive(Debug, Clone, Serialize, Deserialize)]
161pub struct ModelInfo {
162    /// Model ID.
163    pub id: String,
164    /// Model name/path.
165    pub name: String,
166    /// Output dimension (for embedders).
167    pub dimension: Option<usize>,
168    /// Whether the model is currently loaded.
169    pub loaded: bool,
170    /// Approximate memory usage in bytes.
171    pub memory_bytes: u64,
172}
173
174/// Error response from daemon.
175#[derive(Debug, Clone, Serialize, Deserialize)]
176pub struct ErrorResponse {
177    /// Error code for programmatic handling.
178    pub code: ErrorCode,
179    /// Human-readable error message.
180    pub message: String,
181    /// Whether the request can be retried.
182    pub retryable: bool,
183    /// Suggested retry delay in milliseconds (if retryable).
184    pub retry_after_ms: Option<u64>,
185}
186
187/// Error codes for daemon errors.
188#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
189pub enum ErrorCode {
190    /// Unknown or internal error.
191    Internal,
192    /// Model not found or not loaded.
193    ModelNotFound,
194    /// Invalid request parameters.
195    InvalidInput,
196    /// Daemon is overloaded, try again later.
197    Overloaded,
198    /// Request timed out.
199    Timeout,
200    /// Model loading failed.
201    ModelLoadFailed,
202    /// Protocol version mismatch.
203    VersionMismatch,
204}
205
206/// Status information for embedding jobs.
207#[derive(Debug, Clone, Serialize, Deserialize)]
208pub struct EmbeddingJobInfo {
209    pub jobs: Vec<EmbeddingJobDetail>,
210}
211
212/// Detail for a single embedding job.
213#[derive(Debug, Clone, Serialize, Deserialize)]
214pub struct EmbeddingJobDetail {
215    pub job_id: i64,
216    pub model_id: String,
217    pub status: String,
218    pub total_docs: i64,
219    pub completed_docs: i64,
220    pub error_message: Option<String>,
221}
222
223/// Framed message wrapper for length-prefixed protocol.
224#[derive(Debug, Clone, Serialize, Deserialize)]
225pub struct FramedMessage<T> {
226    /// Protocol version.
227    pub version: u32,
228    /// Request ID for correlation.
229    pub request_id: String,
230    /// Payload.
231    pub payload: T,
232}
233
234impl<T> FramedMessage<T> {
235    pub fn new(request_id: impl Into<String>, payload: T) -> Self {
236        Self {
237            version: PROTOCOL_VERSION,
238            request_id: request_id.into(),
239            payload,
240        }
241    }
242}
243
244/// Encode a message to MessagePack bytes with length prefix.
245pub fn encode_message<T: Serialize>(msg: &FramedMessage<T>) -> Result<Vec<u8>, EncodeError> {
246    let payload = rmp_serde::to_vec(msg).map_err(|e| EncodeError(e.to_string()))?;
247    let len = u32::try_from(payload.len())
248        .map_err(|_| EncodeError("payload exceeds maximum size of 4GB".to_string()))?;
249    let mut buf = Vec::with_capacity(4 + payload.len());
250    buf.extend_from_slice(&len.to_be_bytes());
251    buf.extend_from_slice(&payload);
252    Ok(buf)
253}
254
255/// Decode a message from MessagePack bytes (without length prefix).
256pub fn decode_message<T: for<'de> Deserialize<'de>>(
257    data: &[u8],
258) -> Result<FramedMessage<T>, DecodeError> {
259    rmp_serde::from_slice(data).map_err(|e| DecodeError(e.to_string()))
260}
261
262#[derive(Debug, Clone, thiserror::Error)]
263#[error("encode error: {0}")]
264pub struct EncodeError(pub String);
265
266#[derive(Debug, Clone, thiserror::Error)]
267#[error("decode error: {0}")]
268pub struct DecodeError(pub String);
269
270#[cfg(test)]
271mod tests {
272    use super::*;
273
274    #[test]
275    fn test_encode_decode_health_request() {
276        let msg = FramedMessage::new("req-1", Request::Health);
277        let encoded = encode_message(&msg).unwrap();
278
279        // Skip 4-byte length prefix
280        let decoded: FramedMessage<Request> = decode_message(&encoded[4..]).unwrap();
281        assert_eq!(decoded.version, PROTOCOL_VERSION);
282        assert_eq!(decoded.request_id, "req-1");
283        assert!(matches!(decoded.payload, Request::Health));
284    }
285
286    #[test]
287    fn test_protocol_error_display_strings_are_preserved() {
288        let encode = EncodeError("bad payload".to_string());
289        let decode = DecodeError("bad frame".to_string());
290        let cases: &[(&str, &dyn std::error::Error, &str)] = &[
291            ("encode", &encode, "encode error: bad payload"),
292            ("decode", &decode, "decode error: bad frame"),
293        ];
294
295        for (label, error, expected_display) in cases {
296            assert_eq!(error.to_string(), *expected_display, "{label}");
297            assert!(error.source().is_none(), "{label}");
298        }
299    }
300
301    #[test]
302    fn test_encode_decode_embed_request() {
303        let msg = FramedMessage::new(
304            "req-2",
305            Request::Embed {
306                texts: vec!["hello".to_string(), "world".to_string()],
307                model: "all-MiniLM-L6-v2".to_string(),
308                dims: None,
309            },
310        );
311        let encoded = encode_message(&msg).unwrap();
312        let decoded: FramedMessage<Request> = decode_message(&encoded[4..]).unwrap();
313
314        assert!(matches!(&decoded.payload, Request::Embed { .. }));
315        if let Request::Embed { texts, model, dims } = decoded.payload {
316            assert_eq!(texts, vec!["hello", "world"]);
317            assert_eq!(model, "all-MiniLM-L6-v2");
318            assert!(dims.is_none());
319        }
320    }
321
322    #[test]
323    fn test_encode_decode_rerank_request() {
324        let msg = FramedMessage::new(
325            "req-3",
326            Request::Rerank {
327                query: "test query".to_string(),
328                documents: vec!["doc1".to_string(), "doc2".to_string()],
329                model: "ms-marco-MiniLM-L-6-v2".to_string(),
330            },
331        );
332        let encoded = encode_message(&msg).unwrap();
333        let decoded: FramedMessage<Request> = decode_message(&encoded[4..]).unwrap();
334
335        assert!(matches!(&decoded.payload, Request::Rerank { .. }));
336        if let Request::Rerank {
337            query,
338            documents,
339            model,
340        } = decoded.payload
341        {
342            assert_eq!(query, "test query");
343            assert_eq!(documents, vec!["doc1", "doc2"]);
344            assert_eq!(model, "ms-marco-MiniLM-L-6-v2");
345        }
346    }
347
348    #[test]
349    fn test_encode_decode_health_response() {
350        let msg = FramedMessage::new(
351            "resp-1",
352            Response::Health(HealthStatus {
353                uptime_secs: 120,
354                version: PROTOCOL_VERSION,
355                ready: true,
356                memory_bytes: 100_000_000,
357            }),
358        );
359        let encoded = encode_message(&msg).unwrap();
360        let decoded: FramedMessage<Response> = decode_message(&encoded[4..]).unwrap();
361
362        assert!(matches!(&decoded.payload, Response::Health(_)));
363        if let Response::Health(status) = decoded.payload {
364            assert_eq!(status.uptime_secs, 120);
365            assert!(status.ready);
366        }
367    }
368
369    #[test]
370    fn test_encode_decode_error_response() {
371        let msg = FramedMessage::new(
372            "resp-err",
373            Response::Error(ErrorResponse {
374                code: ErrorCode::Overloaded,
375                message: "too many requests".to_string(),
376                retryable: true,
377                retry_after_ms: Some(1000),
378            }),
379        );
380        let encoded = encode_message(&msg).unwrap();
381        let decoded: FramedMessage<Response> = decode_message(&encoded[4..]).unwrap();
382
383        assert!(matches!(&decoded.payload, Response::Error(_)));
384        if let Response::Error(err) = decoded.payload {
385            assert_eq!(err.code, ErrorCode::Overloaded);
386            assert!(err.retryable);
387            assert_eq!(err.retry_after_ms, Some(1000));
388        }
389    }
390
391    #[test]
392    fn test_default_socket_path() {
393        let path = default_socket_path();
394        let path_str = path.to_string_lossy();
395        assert!(path_str.starts_with("/tmp/semantic-daemon-"));
396        assert!(path_str.ends_with(".sock"));
397    }
398
399    #[test]
400    fn test_wire_compatibility_embed_response() {
401        // Test that embed response can be serialized and deserialized
402        let msg = FramedMessage::new(
403            "resp-embed",
404            Response::Embed(EmbedResponse {
405                embeddings: vec![vec![0.1, 0.2, 0.3], vec![0.4, 0.5, 0.6]],
406                model: "minilm-384".to_string(),
407                elapsed_ms: 15,
408            }),
409        );
410        let encoded = encode_message(&msg).unwrap();
411        let decoded: FramedMessage<Response> = decode_message(&encoded[4..]).unwrap();
412
413        assert!(matches!(&decoded.payload, Response::Embed(_)));
414        if let Response::Embed(resp) = decoded.payload {
415            assert_eq!(resp.embeddings.len(), 2);
416            assert_eq!(resp.embeddings[0], vec![0.1, 0.2, 0.3]);
417            assert_eq!(resp.model, "minilm-384");
418        }
419    }
420
421    #[test]
422    fn test_wire_compatibility_rerank_response() {
423        let msg = FramedMessage::new(
424            "resp-rerank",
425            Response::Rerank(RerankResponse {
426                scores: vec![0.95, 0.72, 0.31],
427                model: "ms-marco".to_string(),
428                elapsed_ms: 8,
429            }),
430        );
431        let encoded = encode_message(&msg).unwrap();
432        let decoded: FramedMessage<Response> = decode_message(&encoded[4..]).unwrap();
433
434        assert!(matches!(&decoded.payload, Response::Rerank(_)));
435        if let Response::Rerank(resp) = decoded.payload {
436            assert_eq!(resp.scores, vec![0.95, 0.72, 0.31]);
437        }
438    }
439}