use super::*;
#[test]
fn request_serialises_correctly() {
let texts = vec!["hello".to_string(), "world".to_string()];
let req = RpcRequest {
jsonrpc: JSONRPC_VERSION,
method: METHOD_EMBED,
params: EmbedParams { texts: &texts },
id: 1,
};
let s = serde_json::to_string(&req).unwrap();
assert!(s.contains("\"jsonrpc\":\"2.0\""), "must have jsonrpc 2.0");
assert!(s.contains("\"method\":\"embed\""), "must have embed method");
assert!(
s.contains("\"texts\":[\"hello\",\"world\"]"),
"must include texts"
);
assert!(s.contains("\"id\":1"), "must have id");
}
#[test]
fn error_response_maps_to_model_error() {
let json = r#"{"jsonrpc":"2.0","error":{"code":-32603,"message":"ort failed"},"id":1}"#;
let result = decode_response(json, 1);
assert!(
matches!(result, Err(EmbedderError::ModelError(_))),
"got: {result:?}"
);
}
#[test]
fn success_response_decoded() {
let json = r#"{"jsonrpc":"2.0","result":{"embeddings":[[0.1,0.2],[0.3,0.4]]},"id":1}"#;
let result = decode_response(json, 2).unwrap();
assert_eq!(result.len(), 2);
assert_eq!(result[0][0], 0.1_f32);
}
#[test]
fn count_mismatch_returns_dimension_error() {
let json = r#"{"jsonrpc":"2.0","result":{"embeddings":[[0.1],[0.2]]},"id":1}"#;
let result = decode_response(json, 3);
assert!(
matches!(
result,
Err(EmbedderError::DimensionMismatch { sent: 3, got: 2 })
),
"got: {result:?}"
);
}
#[test]
fn extract_response_id_numeric() {
let json = r#"{"jsonrpc":"2.0","result":{"embeddings":[]},"id":42}"#;
assert_eq!(extract_response_id(json), Some(42));
}
#[test]
fn extract_response_id_null_returns_none() {
let json = r#"{"jsonrpc":"2.0","error":{"code":-32700,"message":"parse error"},"id":null}"#;
assert_eq!(extract_response_id(json), None);
}
#[test]
fn extract_response_id_string_returns_none() {
let json = r#"{"jsonrpc":"2.0","result":{"embeddings":[]},"id":"abc"}"#;
assert_eq!(extract_response_id(json), None);
}
#[tokio::test]
async fn embed_call_stalled_reader_times_out() {
use tokio::io::AsyncBufReadExt;
use tokio::io::duplex;
let (_tx, rx) = duplex(1024);
let mut buf = String::new();
let mut reader = tokio::io::BufReader::new(rx);
let result = tokio::time::timeout(Duration::from_secs(1), reader.read_line(&mut buf)).await;
assert!(
result.is_err(),
"a read_line on a never-writing reader must time out under a 1 s deadline; \
got: {result:?}"
);
}
#[tokio::test]
async fn reader_task_survives_timeout_and_serves_next_request() {
use tokio::io::{AsyncWriteExt, duplex};
use tokio::sync::oneshot;
let short_timeout = Duration::from_millis(50);
let (mut writer, reader_end) = duplex(4096);
let reader = tokio::io::BufReader::new(reader_end);
let pending: PendingMap = Arc::new(Mutex::new(HashMap::new()));
let pending_clone = Arc::clone(&pending);
let handle = tokio::spawn(reader_task(reader, pending_clone, short_timeout));
let (tx_a, mut rx_a) = oneshot::channel();
pending.lock().await.insert(
1,
PendingRequest {
sent: 2,
reply: tx_a,
},
);
tokio::time::sleep(short_timeout * 3).await;
let result_a = rx_a.try_recv();
assert!(
matches!(result_a, Ok(Err(EmbedderError::Stdio(_)))),
"request A after timeout must receive Err(Stdio): got {result_a:?}"
);
let stale_a =
b"{\"jsonrpc\":\"2.0\",\"result\":{\"embeddings\":[[0.1,0.2],[0.3,0.4]]},\"id\":1}\n";
writer.write_all(stale_a).await.unwrap();
writer.flush().await.unwrap();
let (tx_b, rx_b) = oneshot::channel();
pending.lock().await.insert(
2,
PendingRequest {
sent: 2,
reply: tx_b,
},
);
let real_b =
b"{\"jsonrpc\":\"2.0\",\"result\":{\"embeddings\":[[0.5,0.6],[0.7,0.8]]},\"id\":2}\n";
writer.write_all(real_b).await.unwrap();
writer.flush().await.unwrap();
let result_b = tokio::time::timeout(Duration::from_secs(2), rx_b)
.await
.expect("rx_b timed out — reader task may have exited instead of continuing")
.expect("rx_b channel closed unexpectedly");
assert!(
result_b.is_ok(),
"request B must succeed after reader task survived timeout (#763): \
got {result_b:?}"
);
let embeddings_b = result_b.unwrap();
assert_eq!(
embeddings_b.len(),
2,
"request B must return 2 embedding vectors"
);
assert!(
(embeddings_b[0][0] - 0.5_f32).abs() < 1e-6,
"request B must receive its OWN embeddings (0.5…), not A's stale \
embeddings (0.1…) — misattribution bug would put 0.1 here. \
Got: {:?}",
embeddings_b[0]
);
assert!(
(embeddings_b[1][0] - 0.7_f32).abs() < 1e-6,
"request B second vector must be B's own data (0.7…), not A's (0.3…). \
Got: {:?}",
embeddings_b[1]
);
drop(writer);
let _ = tokio::time::timeout(Duration::from_secs(1), handle).await;
}
#[test]
fn timeout_stall_hint_is_provider_aware() {
use crate::embedder::ExecutionProvider;
let cuda = super::timeout_stall_hint(ExecutionProvider::Cuda);
assert!(
cuda.contains("CUDA"),
"CUDA provider must mention CUDA; got: {cuda:?}"
);
assert!(
!cuda.contains("CoreML"),
"CUDA hint must not mention CoreML; got: {cuda:?}"
);
let coreml = super::timeout_stall_hint(ExecutionProvider::CoreML);
assert!(
coreml.contains("CoreML"),
"CoreML provider must mention CoreML; got: {coreml:?}"
);
assert!(
!coreml.contains("CUDA"),
"CoreML hint must not mention CUDA; got: {coreml:?}"
);
let coreml_ane = super::timeout_stall_hint(ExecutionProvider::CoreMLAne);
assert!(
coreml_ane.contains("CoreML"),
"CoreMLAne provider must mention CoreML; got: {coreml_ane:?}"
);
assert!(
!coreml_ane.contains("CUDA"),
"CoreMLAne hint must not mention CUDA; got: {coreml_ane:?}"
);
let cpu = super::timeout_stall_hint(ExecutionProvider::Cpu);
assert!(
!cpu.contains("CUDA"),
"CPU hint must not mention CUDA; got: {cpu:?}"
);
assert!(
!cpu.contains("CoreML"),
"CPU hint must not mention CoreML; got: {cpu:?}"
);
assert!(!cpu.is_empty(), "CPU hint must not be empty");
}