use serde::{Deserialize, Serialize};
pub const PROTOCOL_VERSION: u32 = 1;
pub fn default_socket_path() -> std::path::PathBuf {
let user = std::env::var("USER").unwrap_or_else(|_| "unknown".into());
let safe_user: String = user
.chars()
.filter(|c| c.is_alphanumeric() || *c == '-' || *c == '_')
.take(64)
.collect();
let safe_user = if safe_user.is_empty() {
"unknown".to_string()
} else {
safe_user
};
std::path::PathBuf::from(format!("/tmp/semantic-daemon-{}.sock", safe_user))
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum Request {
Health,
Embed {
texts: Vec<String>,
model: String,
dims: Option<usize>,
},
Rerank {
query: String,
documents: Vec<String>,
model: String,
},
Status,
SubmitEmbeddingJob {
db_path: String,
index_path: String,
two_tier: bool,
fast_model: Option<String>,
quality_model: Option<String>,
},
EmbeddingJobStatus { db_path: String },
CancelEmbeddingJob {
db_path: String,
model_id: Option<String>,
},
Shutdown,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum Response {
Health(HealthStatus),
Embed(EmbedResponse),
Rerank(RerankResponse),
Status(StatusResponse),
JobSubmitted { job_id: String, message: String },
JobStatus(EmbeddingJobInfo),
JobCancelled { cancelled: usize, message: String },
Shutdown { message: String },
Error(ErrorResponse),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HealthStatus {
pub uptime_secs: u64,
pub version: u32,
pub ready: bool,
pub memory_bytes: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbedResponse {
pub embeddings: Vec<Vec<f32>>,
pub model: String,
pub elapsed_ms: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RerankResponse {
pub scores: Vec<f32>,
pub model: String,
pub elapsed_ms: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StatusResponse {
pub uptime_secs: u64,
pub version: u32,
pub embedders: Vec<ModelInfo>,
pub rerankers: Vec<ModelInfo>,
pub memory_bytes: u64,
pub total_requests: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelInfo {
pub id: String,
pub name: String,
pub dimension: Option<usize>,
pub loaded: bool,
pub memory_bytes: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ErrorResponse {
pub code: ErrorCode,
pub message: String,
pub retryable: bool,
pub retry_after_ms: Option<u64>,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
pub enum ErrorCode {
Internal,
ModelNotFound,
InvalidInput,
Overloaded,
Timeout,
ModelLoadFailed,
VersionMismatch,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddingJobInfo {
pub jobs: Vec<EmbeddingJobDetail>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddingJobDetail {
pub job_id: i64,
pub model_id: String,
pub status: String,
pub total_docs: i64,
pub completed_docs: i64,
pub error_message: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FramedMessage<T> {
pub version: u32,
pub request_id: String,
pub payload: T,
}
impl<T> FramedMessage<T> {
pub fn new(request_id: impl Into<String>, payload: T) -> Self {
Self {
version: PROTOCOL_VERSION,
request_id: request_id.into(),
payload,
}
}
}
pub fn encode_message<T: Serialize>(msg: &FramedMessage<T>) -> Result<Vec<u8>, EncodeError> {
let payload = rmp_serde::to_vec(msg).map_err(|e| EncodeError(e.to_string()))?;
let len = u32::try_from(payload.len())
.map_err(|_| EncodeError("payload exceeds maximum size of 4GB".to_string()))?;
let mut buf = Vec::with_capacity(4 + payload.len());
buf.extend_from_slice(&len.to_be_bytes());
buf.extend_from_slice(&payload);
Ok(buf)
}
pub fn decode_message<T: for<'de> Deserialize<'de>>(
data: &[u8],
) -> Result<FramedMessage<T>, DecodeError> {
rmp_serde::from_slice(data).map_err(|e| DecodeError(e.to_string()))
}
#[derive(Debug, Clone, thiserror::Error)]
#[error("encode error: {0}")]
pub struct EncodeError(pub String);
#[derive(Debug, Clone, thiserror::Error)]
#[error("decode error: {0}")]
pub struct DecodeError(pub String);
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_encode_decode_health_request() {
let msg = FramedMessage::new("req-1", Request::Health);
let encoded = encode_message(&msg).unwrap();
let decoded: FramedMessage<Request> = decode_message(&encoded[4..]).unwrap();
assert_eq!(decoded.version, PROTOCOL_VERSION);
assert_eq!(decoded.request_id, "req-1");
assert!(matches!(decoded.payload, Request::Health));
}
#[test]
fn test_protocol_error_display_strings_are_preserved() {
let encode = EncodeError("bad payload".to_string());
let decode = DecodeError("bad frame".to_string());
let cases: &[(&str, &dyn std::error::Error, &str)] = &[
("encode", &encode, "encode error: bad payload"),
("decode", &decode, "decode error: bad frame"),
];
for (label, error, expected_display) in cases {
assert_eq!(error.to_string(), *expected_display, "{label}");
assert!(error.source().is_none(), "{label}");
}
}
#[test]
fn test_encode_decode_embed_request() {
let msg = FramedMessage::new(
"req-2",
Request::Embed {
texts: vec!["hello".to_string(), "world".to_string()],
model: "all-MiniLM-L6-v2".to_string(),
dims: None,
},
);
let encoded = encode_message(&msg).unwrap();
let decoded: FramedMessage<Request> = decode_message(&encoded[4..]).unwrap();
assert!(matches!(&decoded.payload, Request::Embed { .. }));
if let Request::Embed { texts, model, dims } = decoded.payload {
assert_eq!(texts, vec!["hello", "world"]);
assert_eq!(model, "all-MiniLM-L6-v2");
assert!(dims.is_none());
}
}
#[test]
fn test_encode_decode_rerank_request() {
let msg = FramedMessage::new(
"req-3",
Request::Rerank {
query: "test query".to_string(),
documents: vec!["doc1".to_string(), "doc2".to_string()],
model: "ms-marco-MiniLM-L-6-v2".to_string(),
},
);
let encoded = encode_message(&msg).unwrap();
let decoded: FramedMessage<Request> = decode_message(&encoded[4..]).unwrap();
assert!(matches!(&decoded.payload, Request::Rerank { .. }));
if let Request::Rerank {
query,
documents,
model,
} = decoded.payload
{
assert_eq!(query, "test query");
assert_eq!(documents, vec!["doc1", "doc2"]);
assert_eq!(model, "ms-marco-MiniLM-L-6-v2");
}
}
#[test]
fn test_encode_decode_health_response() {
let msg = FramedMessage::new(
"resp-1",
Response::Health(HealthStatus {
uptime_secs: 120,
version: PROTOCOL_VERSION,
ready: true,
memory_bytes: 100_000_000,
}),
);
let encoded = encode_message(&msg).unwrap();
let decoded: FramedMessage<Response> = decode_message(&encoded[4..]).unwrap();
assert!(matches!(&decoded.payload, Response::Health(_)));
if let Response::Health(status) = decoded.payload {
assert_eq!(status.uptime_secs, 120);
assert!(status.ready);
}
}
#[test]
fn test_encode_decode_error_response() {
let msg = FramedMessage::new(
"resp-err",
Response::Error(ErrorResponse {
code: ErrorCode::Overloaded,
message: "too many requests".to_string(),
retryable: true,
retry_after_ms: Some(1000),
}),
);
let encoded = encode_message(&msg).unwrap();
let decoded: FramedMessage<Response> = decode_message(&encoded[4..]).unwrap();
assert!(matches!(&decoded.payload, Response::Error(_)));
if let Response::Error(err) = decoded.payload {
assert_eq!(err.code, ErrorCode::Overloaded);
assert!(err.retryable);
assert_eq!(err.retry_after_ms, Some(1000));
}
}
#[test]
fn test_default_socket_path() {
let path = default_socket_path();
let path_str = path.to_string_lossy();
assert!(path_str.starts_with("/tmp/semantic-daemon-"));
assert!(path_str.ends_with(".sock"));
}
#[test]
fn test_wire_compatibility_embed_response() {
let msg = FramedMessage::new(
"resp-embed",
Response::Embed(EmbedResponse {
embeddings: vec![vec![0.1, 0.2, 0.3], vec![0.4, 0.5, 0.6]],
model: "minilm-384".to_string(),
elapsed_ms: 15,
}),
);
let encoded = encode_message(&msg).unwrap();
let decoded: FramedMessage<Response> = decode_message(&encoded[4..]).unwrap();
assert!(matches!(&decoded.payload, Response::Embed(_)));
if let Response::Embed(resp) = decoded.payload {
assert_eq!(resp.embeddings.len(), 2);
assert_eq!(resp.embeddings[0], vec![0.1, 0.2, 0.3]);
assert_eq!(resp.model, "minilm-384");
}
}
#[test]
fn test_wire_compatibility_rerank_response() {
let msg = FramedMessage::new(
"resp-rerank",
Response::Rerank(RerankResponse {
scores: vec![0.95, 0.72, 0.31],
model: "ms-marco".to_string(),
elapsed_ms: 8,
}),
);
let encoded = encode_message(&msg).unwrap();
let decoded: FramedMessage<Response> = decode_message(&encoded[4..]).unwrap();
assert!(matches!(&decoded.payload, Response::Rerank(_)));
if let Response::Rerank(resp) = decoded.payload {
assert_eq!(resp.scores, vec![0.95, 0.72, 0.31]);
}
}
}