use serde::{Deserialize, Serialize};
use std::path::PathBuf;
pub const PROTOCOL_VERSION: u32 = 1;
pub fn default_socket_path() -> PathBuf {
let user = dotenvy::var("USER").unwrap_or_else(|_| "unknown".into());
let safe_user = sanitize_socket_user(&user);
PathBuf::from(format!("/tmp/semantic-daemon-{safe_user}.sock"))
}
fn sanitize_socket_user(user: &str) -> String {
let safe_user: String = user
.chars()
.filter(|c| c.is_alphanumeric() || *c == '-' || *c == '_')
.take(64)
.collect();
if safe_user.is_empty() {
"unknown".to_string()
} else {
safe_user
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum Request {
Health,
Embed {
texts: Vec<String>,
model: String,
dims: Option<usize>,
},
Rerank {
query: String,
documents: Vec<String>,
model: String,
},
Status,
SubmitEmbeddingJob {
db_path: String,
index_path: String,
two_tier: bool,
fast_model: Option<String>,
quality_model: Option<String>,
},
EmbeddingJobStatus { db_path: String },
CancelEmbeddingJob {
db_path: String,
model_id: Option<String>,
},
Shutdown,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum Response {
Health(HealthStatus),
Embed(EmbedResponse),
Rerank(RerankResponse),
Status(StatusResponse),
JobSubmitted { job_id: String, message: String },
JobStatus(EmbeddingJobInfo),
JobCancelled { cancelled: usize, message: String },
Shutdown { message: String },
Error(ErrorResponse),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HealthStatus {
pub uptime_secs: u64,
pub version: u32,
pub ready: bool,
pub memory_bytes: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbedResponse {
pub embeddings: Vec<Vec<f32>>,
pub model: String,
pub elapsed_ms: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RerankResponse {
pub scores: Vec<f32>,
pub model: String,
pub elapsed_ms: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StatusResponse {
pub uptime_secs: u64,
pub version: u32,
pub embedders: Vec<ModelInfo>,
pub rerankers: Vec<ModelInfo>,
pub memory_bytes: u64,
pub total_requests: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelInfo {
pub id: String,
pub name: String,
pub dimension: Option<usize>,
pub loaded: bool,
pub memory_bytes: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ErrorResponse {
pub code: ErrorCode,
pub message: String,
pub retryable: bool,
pub retry_after_ms: Option<u64>,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
pub enum ErrorCode {
Internal,
ModelNotFound,
InvalidInput,
Overloaded,
Timeout,
ModelLoadFailed,
VersionMismatch,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddingJobInfo {
pub jobs: Vec<EmbeddingJobDetail>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddingJobDetail {
pub job_id: i64,
pub model_id: String,
pub status: String,
pub total_docs: i64,
pub completed_docs: i64,
pub error_message: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FramedMessage<T> {
pub version: u32,
pub request_id: String,
pub payload: T,
}
impl<T> FramedMessage<T> {
pub fn new(request_id: impl Into<String>, payload: T) -> Self {
Self {
version: PROTOCOL_VERSION,
request_id: request_id.into(),
payload,
}
}
}
pub fn encode_message<T: Serialize>(msg: &FramedMessage<T>) -> Result<Vec<u8>, EncodeError> {
let payload = rmp_serde::to_vec(msg)?;
let len = u32::try_from(payload.len())
.map_err(|_| EncodeError::Message("payload exceeds maximum size of 4GB".to_string()))?;
let mut buf = Vec::with_capacity(4 + payload.len());
buf.extend_from_slice(&len.to_be_bytes());
buf.extend_from_slice(&payload);
Ok(buf)
}
pub fn decode_message<T: for<'de> Deserialize<'de>>(
data: &[u8],
) -> Result<FramedMessage<T>, DecodeError> {
rmp_serde::from_slice(data).map_err(DecodeError::from)
}
#[derive(Debug, thiserror::Error)]
pub enum EncodeError {
#[error("encode error: {0}")]
Message(String),
#[error("encode error: {0}")]
MessagePack(#[from] rmp_serde::encode::Error),
}
#[derive(Debug, thiserror::Error)]
pub enum DecodeError {
#[error("decode error: {0}")]
Message(String),
#[error("decode error: {0}")]
MessagePack(#[from] rmp_serde::decode::Error),
}
#[cfg(test)]
mod tests {
use super::{
DecodeError, EmbedResponse, EncodeError, ErrorCode, ErrorResponse, FramedMessage,
HealthStatus, PROTOCOL_VERSION, Request, RerankResponse, Response, decode_message,
default_socket_path, encode_message, sanitize_socket_user,
};
use serde::de::DeserializeOwned;
use std::error::Error;
use std::fmt::Debug;
type TestResult = Result<(), Box<dyn Error>>;
fn test_error(message: impl Into<String>) -> Box<dyn Error> {
std::io::Error::other(message.into()).into()
}
fn ensure(condition: bool, message: impl Into<String>) -> TestResult {
if condition {
Ok(())
} else {
Err(test_error(message))
}
}
fn ensure_eq<T>(actual: T, expected: T, message: impl Into<String>) -> TestResult
where
T: Debug + PartialEq,
{
if actual == expected {
Ok(())
} else {
Err(test_error(format!(
"{}: expected {expected:?}, got {actual:?}",
message.into()
)))
}
}
fn decode_framed<T>(encoded: &[u8]) -> Result<FramedMessage<T>, Box<dyn Error>>
where
T: DeserializeOwned,
{
let payload = encoded
.get(4..)
.ok_or_else(|| test_error("encoded frame should include a 4-byte length prefix"))?;
decode_message(payload).map_err(|err| test_error(err.to_string()))
}
#[test]
fn test_encode_decode_health_request() -> TestResult {
let msg = FramedMessage::new("req-1", Request::Health);
let encoded = encode_message(&msg)?;
let decoded: FramedMessage<Request> = decode_framed(&encoded)?;
ensure_eq(decoded.version, PROTOCOL_VERSION, "protocol version")?;
ensure_eq(decoded.request_id, "req-1".to_string(), "request id")?;
ensure(matches!(decoded.payload, Request::Health), "health payload")
}
#[test]
fn test_protocol_error_display_strings_are_preserved() -> TestResult {
let encode = EncodeError::Message("bad payload".to_string());
ensure_eq(
encode.to_string(),
"encode error: bad payload".to_string(),
"encode",
)?;
ensure(encode.source().is_none(), "encode")?;
let decode = DecodeError::Message("bad frame".to_string());
ensure_eq(
decode.to_string(),
"decode error: bad frame".to_string(),
"decode",
)?;
ensure(decode.source().is_none(), "decode")?;
Ok(())
}
#[test]
fn test_encode_decode_embed_request() -> TestResult {
let msg = FramedMessage::new(
"req-2",
Request::Embed {
texts: vec!["hello".to_string(), "world".to_string()],
model: "all-MiniLM-L6-v2".to_string(),
dims: None,
},
);
let encoded = encode_message(&msg)?;
let decoded: FramedMessage<Request> = decode_framed(&encoded)?;
let Request::Embed { texts, model, dims } = decoded.payload else {
return Err(test_error("expected embed request payload"));
};
ensure_eq(
texts,
vec!["hello".to_string(), "world".to_string()],
"embed texts",
)?;
ensure_eq(model, "all-MiniLM-L6-v2".to_string(), "embed model")?;
ensure(dims.is_none(), "embed dims should be absent")
}
#[test]
fn test_encode_decode_rerank_request() -> TestResult {
let msg = FramedMessage::new(
"req-3",
Request::Rerank {
query: "test query".to_string(),
documents: vec!["doc1".to_string(), "doc2".to_string()],
model: "ms-marco-MiniLM-L-6-v2".to_string(),
},
);
let encoded = encode_message(&msg)?;
let decoded: FramedMessage<Request> = decode_framed(&encoded)?;
let Request::Rerank {
query,
documents,
model,
} = decoded.payload
else {
return Err(test_error("expected rerank request payload"));
};
ensure_eq(query, "test query".to_string(), "rerank query")?;
ensure_eq(
documents,
vec!["doc1".to_string(), "doc2".to_string()],
"rerank documents",
)?;
ensure_eq(model, "ms-marco-MiniLM-L-6-v2".to_string(), "rerank model")
}
#[test]
fn test_encode_decode_health_response() -> TestResult {
let msg = FramedMessage::new(
"resp-1",
Response::Health(HealthStatus {
uptime_secs: 120,
version: PROTOCOL_VERSION,
ready: true,
memory_bytes: 100_000_000,
}),
);
let encoded = encode_message(&msg)?;
let decoded: FramedMessage<Response> = decode_framed(&encoded)?;
let Response::Health(status) = decoded.payload else {
return Err(test_error("expected health response payload"));
};
ensure_eq(status.uptime_secs, 120, "health uptime")?;
ensure(status.ready, "health response should be ready")
}
#[test]
fn test_encode_decode_error_response() -> TestResult {
let msg = FramedMessage::new(
"resp-err",
Response::Error(ErrorResponse {
code: ErrorCode::Overloaded,
message: "too many requests".to_string(),
retryable: true,
retry_after_ms: Some(1000),
}),
);
let encoded = encode_message(&msg)?;
let decoded: FramedMessage<Response> = decode_framed(&encoded)?;
let Response::Error(err) = decoded.payload else {
return Err(test_error("expected error response payload"));
};
ensure_eq(err.code, ErrorCode::Overloaded, "error code")?;
ensure(err.retryable, "error should be retryable")?;
ensure_eq(err.retry_after_ms, Some(1000), "retry delay")
}
#[test]
fn test_default_socket_path() -> TestResult {
let path = default_socket_path();
let path_str = path.to_string_lossy();
ensure(
path_str.starts_with("/tmp/semantic-daemon-"),
"socket path prefix",
)?;
ensure(path_str.ends_with(".sock"), "socket path suffix")
}
#[test]
fn test_socket_user_sanitization() -> TestResult {
ensure_eq(
sanitize_socket_user("../bad user!"),
"baduser".to_string(),
"path traversal and punctuation should be removed",
)?;
ensure_eq(
sanitize_socket_user(""),
"unknown".to_string(),
"empty user fallback",
)?;
ensure_eq(
sanitize_socket_user("a".repeat(80).as_str()).len(),
64,
"socket user length cap",
)
}
#[test]
fn test_wire_compatibility_embed_response() -> TestResult {
let msg = FramedMessage::new(
"resp-embed",
Response::Embed(EmbedResponse {
embeddings: vec![vec![0.1, 0.2, 0.3], vec![0.4, 0.5, 0.6]],
model: "minilm-384".to_string(),
elapsed_ms: 15,
}),
);
let encoded = encode_message(&msg)?;
let decoded: FramedMessage<Response> = decode_framed(&encoded)?;
let Response::Embed(resp) = decoded.payload else {
return Err(test_error("expected embed response payload"));
};
ensure_eq(resp.embeddings.len(), 2, "embedding count")?;
let first = resp
.embeddings
.first()
.ok_or_else(|| test_error("first embedding should exist"))?;
ensure_eq(first.clone(), vec![0.1, 0.2, 0.3], "first embedding")?;
ensure_eq(resp.model, "minilm-384".to_string(), "embedding model")
}
#[test]
fn test_wire_compatibility_rerank_response() -> TestResult {
let msg = FramedMessage::new(
"resp-rerank",
Response::Rerank(RerankResponse {
scores: vec![0.95, 0.72, 0.31],
model: "ms-marco".to_string(),
elapsed_ms: 8,
}),
);
let encoded = encode_message(&msg)?;
let decoded: FramedMessage<Response> = decode_framed(&encoded)?;
let Response::Rerank(resp) = decoded.payload else {
return Err(test_error("expected rerank response payload"));
};
ensure_eq(resp.scores, vec![0.95, 0.72, 0.31], "rerank scores")
}
}