use std::sync::{Arc, Mutex};
use leindex_embed::protocol::{
BatchId, EmbedRequest, EmbedResponse, ErrorKind, Frame, MsgType, Request, Response, WorkerError,
};
struct MockWorker {
failures_remaining: Arc<Mutex<usize>>,
dimension: usize,
}
impl MockWorker {
fn new(failures: usize, dimension: usize) -> Self {
Self {
failures_remaining: Arc::new(Mutex::new(failures)),
dimension,
}
}
fn process(&self, frame: Frame) -> Frame {
let batch_id = frame.header.batch_id;
let mut failures = self.failures_remaining.lock().unwrap();
if *failures > 0 {
*failures -= 1;
let err = WorkerError {
kind: ErrorKind::Inference,
message: format!("simulated worker failure ({} remaining)", *failures),
};
leindex_embed::protocol::error_frame(batch_id, err)
.expect("error frame construction should not fail")
} else {
let request: Request = frame.decode_payload().expect("decode should work");
let texts = match request {
Request::Embed(req) => req.texts,
_ => vec![],
};
let count = texts.len();
let dim = self.dimension;
let response = EmbedResponse::new(vec![0.0f32; count * dim], count, dim);
leindex_embed::protocol::embed_response_frame(batch_id, response)
.expect("response frame construction should not fail")
}
}
}
#[test]
fn test_embed_response_is_flat_row_major_no_nested_vec() {
let response = EmbedResponse::new(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 2, 3);
assert_eq!(response.vectors.len(), 6);
assert_eq!(response.count, 2);
assert_eq!(response.dimension, 3);
let emb0 = response.get_embedding(0).unwrap();
let emb1 = response.get_embedding(1).unwrap();
assert_eq!(emb0, &[1.0, 2.0, 3.0]);
assert_eq!(emb1, &[4.0, 5.0, 6.0]);
let vecs = response.into_vectors();
assert_eq!(vecs.len(), 2);
assert_eq!(vecs[0], vec![1.0, 2.0, 3.0]);
assert_eq!(vecs[1], vec![4.0, 5.0, 6.0]);
}
#[test]
fn test_flat_write_into_destination_buffer() {
let dim = 4;
let count = 3;
let flat = vec![
0.1f32, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2,
];
let response = EmbedResponse::new(flat, count, dim);
let mut destination: Vec<Vec<f32>> = Vec::with_capacity(count);
for i in 0..response.count {
let embedding = response.get_embedding(i).unwrap();
destination.push(embedding.to_vec());
}
assert_eq!(destination.len(), 3);
assert_eq!(destination[0], vec![0.1, 0.2, 0.3, 0.4]);
assert_eq!(destination[1], vec![0.5, 0.6, 0.7, 0.8]);
assert_eq!(destination[2], vec![0.9, 1.0, 1.1, 1.2]);
}
#[test]
fn test_worker_failure_triggers_retry() {
let mock = MockWorker::new(1, 4);
let request = EmbedRequest {
texts: vec!["test".to_string()],
expected_dim: 4,
};
let frame = leindex_embed::protocol::embed_request_frame(BatchId::new(1), request)
.expect("frame construction");
let response1 = mock.process(frame);
assert_eq!(response1.header.msg_type, MsgType::Error);
let request2 = EmbedRequest {
texts: vec!["test".to_string()],
expected_dim: 4,
};
let frame2 = leindex_embed::protocol::embed_request_frame(BatchId::new(2), request2)
.expect("frame construction");
let response2 = mock.process(frame2);
assert_eq!(response2.header.msg_type, MsgType::EmbedResponse);
}
#[test]
fn test_retry_once_semantics() {
let max_retries = 1;
let mut attempts = 0;
let worker_succeeds_on = 2;
for _ in 0..=max_retries {
attempts += 1;
if attempts >= worker_succeeds_on {
break; }
}
assert_eq!(attempts, 2, "should have retried once");
}
#[test]
fn test_second_failure_triggers_tfidf_fallback() {
let mock = MockWorker::new(2, 4);
let request = EmbedRequest {
texts: vec!["test text".to_string()],
expected_dim: 4,
};
let frame1 = leindex_embed::protocol::embed_request_frame(BatchId::new(1), request.clone())
.expect("frame construction");
let response1 = mock.process(frame1);
assert_eq!(response1.header.msg_type, MsgType::Error);
let frame2 = leindex_embed::protocol::embed_request_frame(BatchId::new(2), request.clone())
.expect("frame construction");
let response2 = mock.process(frame2);
assert_eq!(response2.header.msg_type, MsgType::Error);
let frame3 = leindex_embed::protocol::embed_request_frame(BatchId::new(3), request)
.expect("frame construction");
let response3 = mock.process(frame3);
assert_eq!(response3.header.msg_type, MsgType::EmbedResponse);
}
#[test]
fn test_fallback_only_affects_failed_batch() {
let dim = 4;
let request1 = EmbedRequest {
texts: vec!["batch 1 text".to_string()],
expected_dim: dim,
};
let _frame1 = leindex_embed::protocol::embed_request_frame(BatchId::new(1), request1)
.expect("frame construction");
let response1 = EmbedResponse::new(vec![1.0f32; dim], 1, dim);
let resp_frame1 =
leindex_embed::protocol::embed_response_frame(BatchId::new(1), response1).unwrap();
assert_eq!(resp_frame1.header.msg_type, MsgType::EmbedResponse);
let fallback_embedding = vec![0.0f32; dim]; assert_eq!(fallback_embedding.len(), dim);
let request3 = EmbedRequest {
texts: vec!["batch 3 text".to_string()],
expected_dim: dim,
};
let _frame3 = leindex_embed::protocol::embed_request_frame(BatchId::new(3), request3)
.expect("frame construction");
let response3 = EmbedResponse::new(vec![3.0f32; dim], 1, dim);
let resp_frame3 =
leindex_embed::protocol::embed_response_frame(BatchId::new(3), response3).unwrap();
assert_eq!(resp_frame3.header.msg_type, MsgType::EmbedResponse);
}
#[test]
fn test_fallback_warning_contains_batch_context() {
let batch_id = BatchId::new(42);
let worker_error = WorkerError {
kind: ErrorKind::Inference,
message: "ONNX inference failed: session crashed".to_string(),
};
let warning = format!(
"ONNX worker fallback for batch {}: {} (retry exhausted, degrading to TF-IDF)",
batch_id, worker_error
);
assert!(
warning.contains("batch-42"),
"warning must identify the affected batch"
);
assert!(
warning.contains("Inference"),
"warning must name the error kind"
);
assert!(
warning.contains("session crashed"),
"warning must include the worker error message"
);
assert!(
warning.contains("TF-IDF"),
"warning must mention the fallback path"
);
assert!(
warning.contains("retry exhausted"),
"warning must indicate retry was attempted"
);
}
#[test]
fn test_fallback_warning_includes_error_kind() {
let test_cases = vec![
(ErrorKind::OnnxRuntime, "OnnxRuntime"),
(ErrorKind::ModelNotFound, "ModelNotFound"),
(ErrorKind::Tokenizer, "Tokenizer"),
(ErrorKind::Inference, "Inference"),
(ErrorKind::InvalidRequest, "InvalidRequest"),
(ErrorKind::Internal, "Internal"),
];
for (kind, expected_str) in test_cases {
let err = WorkerError {
kind,
message: "test error".to_string(),
};
let warning = format!("ONNX worker fallback: {:?}", err.kind);
assert!(
warning.contains(expected_str),
"warning for {:?} should contain '{}'",
kind,
expected_str
);
}
}
#[test]
fn test_client_handles_worker_error_gracefully() {
use leindex_embed::protocol::{ErrorKind, WorkerError};
let err = WorkerError {
kind: ErrorKind::Inference,
message: "simulated crash".to_string(),
};
let client_error_msg = format!("worker error: {:?}", err.kind);
assert!(client_error_msg.contains("Inference"));
assert!(err.message.contains("simulated crash"));
}
#[test]
fn test_main_daemon_survives_worker_failure() {
let dim = 4;
let _texts = ["test text".to_string()];
let _worker_error = WorkerError {
kind: ErrorKind::Inference,
message: "worker crashed".to_string(),
};
let fallback_result: Vec<f32> = vec![0.0; dim];
assert_eq!(fallback_result.len(), dim);
let new_texts = ["another request".to_string()];
assert_eq!(new_texts.len(), 1);
}
#[test]
fn test_fresh_worker_after_fallback_episode() {
let mock1 = MockWorker::new(100, 4); let request1 = EmbedRequest {
texts: vec!["first request".to_string()],
expected_dim: 4,
};
let frame1 = leindex_embed::protocol::embed_request_frame(BatchId::new(1), request1)
.expect("frame construction");
let resp1 = mock1.process(frame1);
assert_eq!(resp1.header.msg_type, MsgType::Error);
let request1b = EmbedRequest {
texts: vec!["first request".to_string()],
expected_dim: 4,
};
let frame1b = leindex_embed::protocol::embed_request_frame(BatchId::new(2), request1b)
.expect("frame construction");
let resp1b = mock1.process(frame1b);
assert_eq!(resp1b.header.msg_type, MsgType::Error);
let _fallback_embedding = [0.0f32; 4];
let mock2 = MockWorker::new(0, 4); let request2 = EmbedRequest {
texts: vec!["second request after recovery".to_string()],
expected_dim: 4,
};
let frame2 = leindex_embed::protocol::embed_request_frame(BatchId::new(3), request2)
.expect("frame construction");
let resp2 = mock2.process(frame2);
assert_eq!(
resp2.header.msg_type,
MsgType::EmbedResponse,
"fresh worker should succeed after fallback episode"
);
let response: Response = resp2.decode_payload().expect("decode should work");
match response {
Response::Embed(embed) => {
assert_eq!(embed.count, 1);
assert_eq!(embed.dimension, 4);
}
_ => panic!("expected Embed response from fresh worker"),
}
}
#[test]
fn test_multiple_fallback_recovery_cycles() {
for cycle in 0..3 {
let mock_fail = MockWorker::new(100, 4);
let request = EmbedRequest {
texts: vec![format!("cycle {} request", cycle)],
expected_dim: 4,
};
let frame =
leindex_embed::protocol::embed_request_frame(BatchId::new(cycle as u64 * 10), request)
.expect("frame construction");
let resp = mock_fail.process(frame);
assert_eq!(resp.header.msg_type, MsgType::Error);
let mock_ok = MockWorker::new(0, 4);
let request2 = EmbedRequest {
texts: vec![format!("cycle {} recovery", cycle)],
expected_dim: 4,
};
let frame2 = leindex_embed::protocol::embed_request_frame(
BatchId::new(cycle as u64 * 10 + 1),
request2,
)
.expect("frame construction");
let resp2 = mock_ok.process(frame2);
assert_eq!(resp2.header.msg_type, MsgType::EmbedResponse);
}
}
#[cfg(feature = "onnx")]
mod client_fallback_tests {
use leindex_embed::protocol::EmbedResponse;
#[test]
fn test_fallback_result_success() {
let dim = 4;
let response = EmbedResponse::new(vec![1.0, 2.0, 3.0, 4.0], 1, dim);
assert_eq!(response.count, 1);
assert_eq!(response.dimension, dim);
}
#[test]
fn test_fallback_result_after_retry() {
let dim = 4;
let response = EmbedResponse::new(vec![5.0, 6.0, 7.0, 8.0], 1, dim);
assert_eq!(response.count, 1);
assert_eq!(response.dimension, dim);
}
#[test]
fn test_fallback_result_degraded() {
let dim = 4;
let fallback = vec![0.0f32; dim];
assert_eq!(fallback.len(), dim);
}
}