1use serde::{Deserialize, Serialize};
9
10pub const PROTOCOL_VERSION: u32 = 1;
13
14pub fn default_socket_path() -> std::path::PathBuf {
16 let user = std::env::var("USER").unwrap_or_else(|_| "unknown".into());
17 let safe_user: String = user
19 .chars()
20 .filter(|c| c.is_alphanumeric() || *c == '-' || *c == '_')
21 .take(64)
22 .collect();
23 let safe_user = if safe_user.is_empty() {
24 "unknown".to_string()
25 } else {
26 safe_user
27 };
28 std::path::PathBuf::from(format!("/tmp/semantic-daemon-{}.sock", safe_user))
29}
30
31#[derive(Debug, Clone, Serialize, Deserialize)]
33pub enum Request {
34 Health,
36
37 Embed {
39 texts: Vec<String>,
40 model: String,
41 dims: Option<usize>,
42 },
43
44 Rerank {
46 query: String,
47 documents: Vec<String>,
48 model: String,
49 },
50
51 Status,
53
54 SubmitEmbeddingJob {
56 db_path: String,
57 index_path: String,
58 two_tier: bool,
59 fast_model: Option<String>,
60 quality_model: Option<String>,
61 },
62
63 EmbeddingJobStatus { db_path: String },
65
66 CancelEmbeddingJob {
68 db_path: String,
69 model_id: Option<String>,
70 },
71
72 Shutdown,
74}
75
76#[derive(Debug, Clone, Serialize, Deserialize)]
78pub enum Response {
79 Health(HealthStatus),
81
82 Embed(EmbedResponse),
84
85 Rerank(RerankResponse),
87
88 Status(StatusResponse),
90
91 JobSubmitted { job_id: String, message: String },
93
94 JobStatus(EmbeddingJobInfo),
96
97 JobCancelled { cancelled: usize, message: String },
99
100 Shutdown { message: String },
102
103 Error(ErrorResponse),
105}
106
107#[derive(Debug, Clone, Serialize, Deserialize)]
109pub struct HealthStatus {
110 pub uptime_secs: u64,
112 pub version: u32,
114 pub ready: bool,
116 pub memory_bytes: u64,
118}
119
120#[derive(Debug, Clone, Serialize, Deserialize)]
122pub struct EmbedResponse {
123 pub embeddings: Vec<Vec<f32>>,
125 pub model: String,
127 pub elapsed_ms: u64,
129}
130
131#[derive(Debug, Clone, Serialize, Deserialize)]
133pub struct RerankResponse {
134 pub scores: Vec<f32>,
136 pub model: String,
138 pub elapsed_ms: u64,
140}
141
142#[derive(Debug, Clone, Serialize, Deserialize)]
144pub struct StatusResponse {
145 pub uptime_secs: u64,
147 pub version: u32,
149 pub embedders: Vec<ModelInfo>,
151 pub rerankers: Vec<ModelInfo>,
153 pub memory_bytes: u64,
155 pub total_requests: u64,
157}
158
159#[derive(Debug, Clone, Serialize, Deserialize)]
161pub struct ModelInfo {
162 pub id: String,
164 pub name: String,
166 pub dimension: Option<usize>,
168 pub loaded: bool,
170 pub memory_bytes: u64,
172}
173
174#[derive(Debug, Clone, Serialize, Deserialize)]
176pub struct ErrorResponse {
177 pub code: ErrorCode,
179 pub message: String,
181 pub retryable: bool,
183 pub retry_after_ms: Option<u64>,
185}
186
187#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
189pub enum ErrorCode {
190 Internal,
192 ModelNotFound,
194 InvalidInput,
196 Overloaded,
198 Timeout,
200 ModelLoadFailed,
202 VersionMismatch,
204}
205
206#[derive(Debug, Clone, Serialize, Deserialize)]
208pub struct EmbeddingJobInfo {
209 pub jobs: Vec<EmbeddingJobDetail>,
210}
211
212#[derive(Debug, Clone, Serialize, Deserialize)]
214pub struct EmbeddingJobDetail {
215 pub job_id: i64,
216 pub model_id: String,
217 pub status: String,
218 pub total_docs: i64,
219 pub completed_docs: i64,
220 pub error_message: Option<String>,
221}
222
223#[derive(Debug, Clone, Serialize, Deserialize)]
225pub struct FramedMessage<T> {
226 pub version: u32,
228 pub request_id: String,
230 pub payload: T,
232}
233
234impl<T> FramedMessage<T> {
235 pub fn new(request_id: impl Into<String>, payload: T) -> Self {
236 Self {
237 version: PROTOCOL_VERSION,
238 request_id: request_id.into(),
239 payload,
240 }
241 }
242}
243
244pub fn encode_message<T: Serialize>(msg: &FramedMessage<T>) -> Result<Vec<u8>, EncodeError> {
246 let payload = rmp_serde::to_vec(msg).map_err(|e| EncodeError(e.to_string()))?;
247 let len = u32::try_from(payload.len())
248 .map_err(|_| EncodeError("payload exceeds maximum size of 4GB".to_string()))?;
249 let mut buf = Vec::with_capacity(4 + payload.len());
250 buf.extend_from_slice(&len.to_be_bytes());
251 buf.extend_from_slice(&payload);
252 Ok(buf)
253}
254
255pub fn decode_message<T: for<'de> Deserialize<'de>>(
257 data: &[u8],
258) -> Result<FramedMessage<T>, DecodeError> {
259 rmp_serde::from_slice(data).map_err(|e| DecodeError(e.to_string()))
260}
261
262#[derive(Debug, Clone, thiserror::Error)]
263#[error("encode error: {0}")]
264pub struct EncodeError(pub String);
265
266#[derive(Debug, Clone, thiserror::Error)]
267#[error("decode error: {0}")]
268pub struct DecodeError(pub String);
269
270#[cfg(test)]
271mod tests {
272 use super::*;
273
274 #[test]
275 fn test_encode_decode_health_request() {
276 let msg = FramedMessage::new("req-1", Request::Health);
277 let encoded = encode_message(&msg).unwrap();
278
279 let decoded: FramedMessage<Request> = decode_message(&encoded[4..]).unwrap();
281 assert_eq!(decoded.version, PROTOCOL_VERSION);
282 assert_eq!(decoded.request_id, "req-1");
283 assert!(matches!(decoded.payload, Request::Health));
284 }
285
286 #[test]
287 fn test_protocol_error_display_strings_are_preserved() {
288 let encode = EncodeError("bad payload".to_string());
289 let decode = DecodeError("bad frame".to_string());
290 let cases: &[(&str, &dyn std::error::Error, &str)] = &[
291 ("encode", &encode, "encode error: bad payload"),
292 ("decode", &decode, "decode error: bad frame"),
293 ];
294
295 for (label, error, expected_display) in cases {
296 assert_eq!(error.to_string(), *expected_display, "{label}");
297 assert!(error.source().is_none(), "{label}");
298 }
299 }
300
301 #[test]
302 fn test_encode_decode_embed_request() {
303 let msg = FramedMessage::new(
304 "req-2",
305 Request::Embed {
306 texts: vec!["hello".to_string(), "world".to_string()],
307 model: "all-MiniLM-L6-v2".to_string(),
308 dims: None,
309 },
310 );
311 let encoded = encode_message(&msg).unwrap();
312 let decoded: FramedMessage<Request> = decode_message(&encoded[4..]).unwrap();
313
314 assert!(matches!(&decoded.payload, Request::Embed { .. }));
315 if let Request::Embed { texts, model, dims } = decoded.payload {
316 assert_eq!(texts, vec!["hello", "world"]);
317 assert_eq!(model, "all-MiniLM-L6-v2");
318 assert!(dims.is_none());
319 }
320 }
321
322 #[test]
323 fn test_encode_decode_rerank_request() {
324 let msg = FramedMessage::new(
325 "req-3",
326 Request::Rerank {
327 query: "test query".to_string(),
328 documents: vec!["doc1".to_string(), "doc2".to_string()],
329 model: "ms-marco-MiniLM-L-6-v2".to_string(),
330 },
331 );
332 let encoded = encode_message(&msg).unwrap();
333 let decoded: FramedMessage<Request> = decode_message(&encoded[4..]).unwrap();
334
335 assert!(matches!(&decoded.payload, Request::Rerank { .. }));
336 if let Request::Rerank {
337 query,
338 documents,
339 model,
340 } = decoded.payload
341 {
342 assert_eq!(query, "test query");
343 assert_eq!(documents, vec!["doc1", "doc2"]);
344 assert_eq!(model, "ms-marco-MiniLM-L-6-v2");
345 }
346 }
347
348 #[test]
349 fn test_encode_decode_health_response() {
350 let msg = FramedMessage::new(
351 "resp-1",
352 Response::Health(HealthStatus {
353 uptime_secs: 120,
354 version: PROTOCOL_VERSION,
355 ready: true,
356 memory_bytes: 100_000_000,
357 }),
358 );
359 let encoded = encode_message(&msg).unwrap();
360 let decoded: FramedMessage<Response> = decode_message(&encoded[4..]).unwrap();
361
362 assert!(matches!(&decoded.payload, Response::Health(_)));
363 if let Response::Health(status) = decoded.payload {
364 assert_eq!(status.uptime_secs, 120);
365 assert!(status.ready);
366 }
367 }
368
369 #[test]
370 fn test_encode_decode_error_response() {
371 let msg = FramedMessage::new(
372 "resp-err",
373 Response::Error(ErrorResponse {
374 code: ErrorCode::Overloaded,
375 message: "too many requests".to_string(),
376 retryable: true,
377 retry_after_ms: Some(1000),
378 }),
379 );
380 let encoded = encode_message(&msg).unwrap();
381 let decoded: FramedMessage<Response> = decode_message(&encoded[4..]).unwrap();
382
383 assert!(matches!(&decoded.payload, Response::Error(_)));
384 if let Response::Error(err) = decoded.payload {
385 assert_eq!(err.code, ErrorCode::Overloaded);
386 assert!(err.retryable);
387 assert_eq!(err.retry_after_ms, Some(1000));
388 }
389 }
390
391 #[test]
392 fn test_default_socket_path() {
393 let path = default_socket_path();
394 let path_str = path.to_string_lossy();
395 assert!(path_str.starts_with("/tmp/semantic-daemon-"));
396 assert!(path_str.ends_with(".sock"));
397 }
398
399 #[test]
400 fn test_wire_compatibility_embed_response() {
401 let msg = FramedMessage::new(
403 "resp-embed",
404 Response::Embed(EmbedResponse {
405 embeddings: vec![vec![0.1, 0.2, 0.3], vec![0.4, 0.5, 0.6]],
406 model: "minilm-384".to_string(),
407 elapsed_ms: 15,
408 }),
409 );
410 let encoded = encode_message(&msg).unwrap();
411 let decoded: FramedMessage<Response> = decode_message(&encoded[4..]).unwrap();
412
413 assert!(matches!(&decoded.payload, Response::Embed(_)));
414 if let Response::Embed(resp) = decoded.payload {
415 assert_eq!(resp.embeddings.len(), 2);
416 assert_eq!(resp.embeddings[0], vec![0.1, 0.2, 0.3]);
417 assert_eq!(resp.model, "minilm-384");
418 }
419 }
420
421 #[test]
422 fn test_wire_compatibility_rerank_response() {
423 let msg = FramedMessage::new(
424 "resp-rerank",
425 Response::Rerank(RerankResponse {
426 scores: vec![0.95, 0.72, 0.31],
427 model: "ms-marco".to_string(),
428 elapsed_ms: 8,
429 }),
430 );
431 let encoded = encode_message(&msg).unwrap();
432 let decoded: FramedMessage<Response> = decode_message(&encoded[4..]).unwrap();
433
434 assert!(matches!(&decoded.payload, Response::Rerank(_)));
435 if let Response::Rerank(resp) = decoded.payload {
436 assert_eq!(resp.scores, vec![0.95, 0.72, 0.31]);
437 }
438 }
439}