1use serde::{Deserialize, Serialize};
9use std::path::PathBuf;
10
11pub const PROTOCOL_VERSION: u32 = 1;
14
15pub fn default_socket_path() -> PathBuf {
17 let user = dotenvy::var("USER").unwrap_or_else(|_| "unknown".into());
18 let safe_user = sanitize_socket_user(&user);
19 PathBuf::from(format!("/tmp/semantic-daemon-{safe_user}.sock"))
20}
21
22fn sanitize_socket_user(user: &str) -> String {
23 let safe_user: String = user
24 .chars()
25 .filter(|c| c.is_alphanumeric() || *c == '-' || *c == '_')
26 .take(64)
27 .collect();
28
29 if safe_user.is_empty() {
30 "unknown".to_string()
31 } else {
32 safe_user
33 }
34}
35
36#[derive(Debug, Clone, Serialize, Deserialize)]
38pub enum Request {
39 Health,
41
42 Embed {
44 texts: Vec<String>,
45 model: String,
46 dims: Option<usize>,
47 },
48
49 Rerank {
51 query: String,
52 documents: Vec<String>,
53 model: String,
54 },
55
56 Status,
58
59 SubmitEmbeddingJob {
61 db_path: String,
62 index_path: String,
63 two_tier: bool,
64 fast_model: Option<String>,
65 quality_model: Option<String>,
66 },
67
68 EmbeddingJobStatus { db_path: String },
70
71 CancelEmbeddingJob {
73 db_path: String,
74 model_id: Option<String>,
75 },
76
77 Shutdown,
79}
80
81#[derive(Debug, Clone, Serialize, Deserialize)]
83pub enum Response {
84 Health(HealthStatus),
86
87 Embed(EmbedResponse),
89
90 Rerank(RerankResponse),
92
93 Status(StatusResponse),
95
96 JobSubmitted { job_id: String, message: String },
98
99 JobStatus(EmbeddingJobInfo),
101
102 JobCancelled { cancelled: usize, message: String },
104
105 Shutdown { message: String },
107
108 Error(ErrorResponse),
110}
111
112#[derive(Debug, Clone, Serialize, Deserialize)]
114pub struct HealthStatus {
115 pub uptime_secs: u64,
117 pub version: u32,
119 pub ready: bool,
121 pub memory_bytes: u64,
123}
124
125#[derive(Debug, Clone, Serialize, Deserialize)]
127pub struct EmbedResponse {
128 pub embeddings: Vec<Vec<f32>>,
130 pub model: String,
132 pub elapsed_ms: u64,
134}
135
136#[derive(Debug, Clone, Serialize, Deserialize)]
138pub struct RerankResponse {
139 pub scores: Vec<f32>,
141 pub model: String,
143 pub elapsed_ms: u64,
145}
146
147#[derive(Debug, Clone, Serialize, Deserialize)]
149pub struct StatusResponse {
150 pub uptime_secs: u64,
152 pub version: u32,
154 pub embedders: Vec<ModelInfo>,
156 pub rerankers: Vec<ModelInfo>,
158 pub memory_bytes: u64,
160 pub total_requests: u64,
162}
163
164#[derive(Debug, Clone, Serialize, Deserialize)]
166pub struct ModelInfo {
167 pub id: String,
169 pub name: String,
171 pub dimension: Option<usize>,
173 pub loaded: bool,
175 pub memory_bytes: u64,
177}
178
179#[derive(Debug, Clone, Serialize, Deserialize)]
181pub struct ErrorResponse {
182 pub code: ErrorCode,
184 pub message: String,
186 pub retryable: bool,
188 pub retry_after_ms: Option<u64>,
190}
191
192#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
194pub enum ErrorCode {
195 Internal,
197 ModelNotFound,
199 InvalidInput,
201 Overloaded,
203 Timeout,
205 ModelLoadFailed,
207 VersionMismatch,
209}
210
211#[derive(Debug, Clone, Serialize, Deserialize)]
213pub struct EmbeddingJobInfo {
214 pub jobs: Vec<EmbeddingJobDetail>,
215}
216
217#[derive(Debug, Clone, Serialize, Deserialize)]
219pub struct EmbeddingJobDetail {
220 pub job_id: i64,
221 pub model_id: String,
222 pub status: String,
223 pub total_docs: i64,
224 pub completed_docs: i64,
225 pub error_message: Option<String>,
226}
227
228#[derive(Debug, Clone, Serialize, Deserialize)]
230pub struct FramedMessage<T> {
231 pub version: u32,
233 pub request_id: String,
235 pub payload: T,
237}
238
239impl<T> FramedMessage<T> {
240 pub fn new(request_id: impl Into<String>, payload: T) -> Self {
241 Self {
242 version: PROTOCOL_VERSION,
243 request_id: request_id.into(),
244 payload,
245 }
246 }
247}
248
249pub fn encode_message<T: Serialize>(msg: &FramedMessage<T>) -> Result<Vec<u8>, EncodeError> {
251 let payload = rmp_serde::to_vec(msg)?;
252 let len = u32::try_from(payload.len())
253 .map_err(|_| EncodeError::Message("payload exceeds maximum size of 4GB".to_string()))?;
254 let mut buf = Vec::with_capacity(4 + payload.len());
255 buf.extend_from_slice(&len.to_be_bytes());
256 buf.extend_from_slice(&payload);
257 Ok(buf)
258}
259
260pub fn decode_message<T: for<'de> Deserialize<'de>>(
262 data: &[u8],
263) -> Result<FramedMessage<T>, DecodeError> {
264 rmp_serde::from_slice(data).map_err(DecodeError::from)
265}
266
267#[derive(Debug, thiserror::Error)]
268pub enum EncodeError {
269 #[error("encode error: {0}")]
270 Message(String),
271 #[error("encode error: {0}")]
272 MessagePack(#[from] rmp_serde::encode::Error),
273}
274
275#[derive(Debug, thiserror::Error)]
276pub enum DecodeError {
277 #[error("decode error: {0}")]
278 Message(String),
279 #[error("decode error: {0}")]
280 MessagePack(#[from] rmp_serde::decode::Error),
281}
282
283#[cfg(test)]
284mod tests {
285 use super::{
286 DecodeError, EmbedResponse, EncodeError, ErrorCode, ErrorResponse, FramedMessage,
287 HealthStatus, PROTOCOL_VERSION, Request, RerankResponse, Response, decode_message,
288 default_socket_path, encode_message, sanitize_socket_user,
289 };
290 use serde::de::DeserializeOwned;
291 use std::error::Error;
292 use std::fmt::Debug;
293
294 type TestResult = Result<(), Box<dyn Error>>;
295
296 fn test_error(message: impl Into<String>) -> Box<dyn Error> {
297 std::io::Error::other(message.into()).into()
298 }
299
300 fn ensure(condition: bool, message: impl Into<String>) -> TestResult {
301 if condition {
302 Ok(())
303 } else {
304 Err(test_error(message))
305 }
306 }
307
308 fn ensure_eq<T>(actual: T, expected: T, message: impl Into<String>) -> TestResult
309 where
310 T: Debug + PartialEq,
311 {
312 if actual == expected {
313 Ok(())
314 } else {
315 Err(test_error(format!(
316 "{}: expected {expected:?}, got {actual:?}",
317 message.into()
318 )))
319 }
320 }
321
322 fn decode_framed<T>(encoded: &[u8]) -> Result<FramedMessage<T>, Box<dyn Error>>
323 where
324 T: DeserializeOwned,
325 {
326 let payload = encoded
327 .get(4..)
328 .ok_or_else(|| test_error("encoded frame should include a 4-byte length prefix"))?;
329 decode_message(payload).map_err(|err| test_error(err.to_string()))
330 }
331
332 #[test]
333 fn test_encode_decode_health_request() -> TestResult {
334 let msg = FramedMessage::new("req-1", Request::Health);
335 let encoded = encode_message(&msg)?;
336
337 let decoded: FramedMessage<Request> = decode_framed(&encoded)?;
338 ensure_eq(decoded.version, PROTOCOL_VERSION, "protocol version")?;
339 ensure_eq(decoded.request_id, "req-1".to_string(), "request id")?;
340 ensure(matches!(decoded.payload, Request::Health), "health payload")
341 }
342
343 #[test]
344 fn test_protocol_error_display_strings_are_preserved() -> TestResult {
345 let encode = EncodeError::Message("bad payload".to_string());
346 ensure_eq(
347 encode.to_string(),
348 "encode error: bad payload".to_string(),
349 "encode",
350 )?;
351 ensure(encode.source().is_none(), "encode")?;
352
353 let decode = DecodeError::Message("bad frame".to_string());
354 ensure_eq(
355 decode.to_string(),
356 "decode error: bad frame".to_string(),
357 "decode",
358 )?;
359 ensure(decode.source().is_none(), "decode")?;
360 Ok(())
361 }
362
363 #[test]
364 fn test_encode_decode_embed_request() -> TestResult {
365 let msg = FramedMessage::new(
366 "req-2",
367 Request::Embed {
368 texts: vec!["hello".to_string(), "world".to_string()],
369 model: "all-MiniLM-L6-v2".to_string(),
370 dims: None,
371 },
372 );
373 let encoded = encode_message(&msg)?;
374 let decoded: FramedMessage<Request> = decode_framed(&encoded)?;
375
376 let Request::Embed { texts, model, dims } = decoded.payload else {
377 return Err(test_error("expected embed request payload"));
378 };
379 ensure_eq(
380 texts,
381 vec!["hello".to_string(), "world".to_string()],
382 "embed texts",
383 )?;
384 ensure_eq(model, "all-MiniLM-L6-v2".to_string(), "embed model")?;
385 ensure(dims.is_none(), "embed dims should be absent")
386 }
387
388 #[test]
389 fn test_encode_decode_rerank_request() -> TestResult {
390 let msg = FramedMessage::new(
391 "req-3",
392 Request::Rerank {
393 query: "test query".to_string(),
394 documents: vec!["doc1".to_string(), "doc2".to_string()],
395 model: "ms-marco-MiniLM-L-6-v2".to_string(),
396 },
397 );
398 let encoded = encode_message(&msg)?;
399 let decoded: FramedMessage<Request> = decode_framed(&encoded)?;
400
401 let Request::Rerank {
402 query,
403 documents,
404 model,
405 } = decoded.payload
406 else {
407 return Err(test_error("expected rerank request payload"));
408 };
409 ensure_eq(query, "test query".to_string(), "rerank query")?;
410 ensure_eq(
411 documents,
412 vec!["doc1".to_string(), "doc2".to_string()],
413 "rerank documents",
414 )?;
415 ensure_eq(model, "ms-marco-MiniLM-L-6-v2".to_string(), "rerank model")
416 }
417
418 #[test]
419 fn test_encode_decode_health_response() -> TestResult {
420 let msg = FramedMessage::new(
421 "resp-1",
422 Response::Health(HealthStatus {
423 uptime_secs: 120,
424 version: PROTOCOL_VERSION,
425 ready: true,
426 memory_bytes: 100_000_000,
427 }),
428 );
429 let encoded = encode_message(&msg)?;
430 let decoded: FramedMessage<Response> = decode_framed(&encoded)?;
431
432 let Response::Health(status) = decoded.payload else {
433 return Err(test_error("expected health response payload"));
434 };
435 ensure_eq(status.uptime_secs, 120, "health uptime")?;
436 ensure(status.ready, "health response should be ready")
437 }
438
439 #[test]
440 fn test_encode_decode_error_response() -> TestResult {
441 let msg = FramedMessage::new(
442 "resp-err",
443 Response::Error(ErrorResponse {
444 code: ErrorCode::Overloaded,
445 message: "too many requests".to_string(),
446 retryable: true,
447 retry_after_ms: Some(1000),
448 }),
449 );
450 let encoded = encode_message(&msg)?;
451 let decoded: FramedMessage<Response> = decode_framed(&encoded)?;
452
453 let Response::Error(err) = decoded.payload else {
454 return Err(test_error("expected error response payload"));
455 };
456 ensure_eq(err.code, ErrorCode::Overloaded, "error code")?;
457 ensure(err.retryable, "error should be retryable")?;
458 ensure_eq(err.retry_after_ms, Some(1000), "retry delay")
459 }
460
461 #[test]
462 fn test_default_socket_path() -> TestResult {
463 let path = default_socket_path();
464 let path_str = path.to_string_lossy();
465 ensure(
466 path_str.starts_with("/tmp/semantic-daemon-"),
467 "socket path prefix",
468 )?;
469 ensure(path_str.ends_with(".sock"), "socket path suffix")
470 }
471
472 #[test]
473 fn test_socket_user_sanitization() -> TestResult {
474 ensure_eq(
475 sanitize_socket_user("../bad user!"),
476 "baduser".to_string(),
477 "path traversal and punctuation should be removed",
478 )?;
479 ensure_eq(
480 sanitize_socket_user(""),
481 "unknown".to_string(),
482 "empty user fallback",
483 )?;
484 ensure_eq(
485 sanitize_socket_user("a".repeat(80).as_str()).len(),
486 64,
487 "socket user length cap",
488 )
489 }
490
491 #[test]
492 fn test_wire_compatibility_embed_response() -> TestResult {
493 let msg = FramedMessage::new(
494 "resp-embed",
495 Response::Embed(EmbedResponse {
496 embeddings: vec![vec![0.1, 0.2, 0.3], vec![0.4, 0.5, 0.6]],
497 model: "minilm-384".to_string(),
498 elapsed_ms: 15,
499 }),
500 );
501 let encoded = encode_message(&msg)?;
502 let decoded: FramedMessage<Response> = decode_framed(&encoded)?;
503
504 let Response::Embed(resp) = decoded.payload else {
505 return Err(test_error("expected embed response payload"));
506 };
507 ensure_eq(resp.embeddings.len(), 2, "embedding count")?;
508 let first = resp
509 .embeddings
510 .first()
511 .ok_or_else(|| test_error("first embedding should exist"))?;
512 ensure_eq(first.clone(), vec![0.1, 0.2, 0.3], "first embedding")?;
513 ensure_eq(resp.model, "minilm-384".to_string(), "embedding model")
514 }
515
516 #[test]
517 fn test_wire_compatibility_rerank_response() -> TestResult {
518 let msg = FramedMessage::new(
519 "resp-rerank",
520 Response::Rerank(RerankResponse {
521 scores: vec![0.95, 0.72, 0.31],
522 model: "ms-marco".to_string(),
523 elapsed_ms: 8,
524 }),
525 );
526 let encoded = encode_message(&msg)?;
527 let decoded: FramedMessage<Response> = decode_framed(&encoded)?;
528
529 let Response::Rerank(resp) = decoded.payload else {
530 return Err(test_error("expected rerank response payload"));
531 };
532 ensure_eq(resp.scores, vec![0.95, 0.72, 0.31], "rerank scores")
533 }
534}