#![allow(clippy::unwrap_used, clippy::expect_used)]
use std::time::Duration;
use quiver_server::{Config, serve};
use serde_json::{Value, json};
use tokio::net::TcpListener;
use tokio::task::JoinHandle;
struct Node {
rest: String,
grpc: String,
handle: JoinHandle<()>,
}
async fn wait_ready(http: &reqwest::Client, base: &str) {
for _ in 0..300 {
if let Ok(r) = http.get(format!("{base}/healthz")).send().await
&& r.status().is_success()
{
return;
}
tokio::time::sleep(Duration::from_millis(20)).await;
}
panic!("server {base} did not become ready");
}
async fn boot(
data_dir: std::path::PathBuf,
leader_grpc: Option<String>,
shards: Vec<String>,
replicas: Vec<String>,
) -> Node {
let rest = TcpListener::bind("127.0.0.1:0").await.unwrap();
let grpc = TcpListener::bind("127.0.0.1:0").await.unwrap();
let rest_addr = rest.local_addr().unwrap();
let grpc_addr = grpc.local_addr().unwrap();
let config = Config {
data_dir,
rest_addr,
grpc_addr,
insecure: true,
leader_url: leader_grpc,
cluster_shards: shards,
cluster_replicas: replicas,
..Default::default()
};
let handle = tokio::spawn(async move {
let _ = serve(config, rest, grpc).await;
});
Node {
rest: format!("http://{rest_addr}"),
grpc: format!("http://{grpc_addr}"),
handle,
}
}
fn vec_for(i: u32) -> Vec<f32> {
(0..8)
.map(|j| (((i * 7 + j * 13) % 91) as f32) / 9.0)
.collect()
}
async fn create(http: &reqwest::Client, base: &str) {
http.post(format!("{base}/v1/collections"))
.json(&json!({"name": "c", "dim": 8, "metric": "l2"}))
.send()
.await
.unwrap()
.error_for_status()
.unwrap();
}
async fn upsert_all(http: &reqwest::Client, base: &str, n: u32) {
let points: Vec<Value> = (0..n)
.map(|i| json!({"id": format!("p{i}"), "vector": vec_for(i), "payload": {"i": i}}))
.collect();
http.post(format!("{base}/v1/collections/c/points"))
.json(&json!({ "points": points }))
.send()
.await
.unwrap()
.error_for_status()
.unwrap();
}
async fn top_scores(http: &reqwest::Client, base: &str, q: &[f32], k: usize) -> Vec<f32> {
let resp: Value = http
.post(format!("{base}/v1/collections/c/query"))
.json(&json!({"vector": q, "k": k, "ef_search": 256, "with_payload": false, "with_vector": false}))
.send()
.await
.unwrap()
.json()
.await
.unwrap();
resp["matches"]
.as_array()
.map(|a| {
a.iter()
.map(|m| m["score"].as_f64().unwrap() as f32)
.collect()
})
.unwrap_or_default()
}
async fn count(http: &reqwest::Client, base: &str) -> u64 {
let resp: Value = http
.get(format!("{base}/v1/collections/c"))
.send()
.await
.unwrap()
.json()
.await
.unwrap();
resp["count"].as_u64().unwrap_or(0)
}
fn close(got: &[f32], want: &[f32]) -> bool {
got.len() == want.len() && got.iter().zip(want).all(|(g, w)| (g - w).abs() < 1e-4)
}
const QUERIES: [u32; 4] = [0, 17, 63, 119];
async fn wait_caught_up(http: &reqwest::Client, replica: &str, primary: &str) {
for _ in 0..300 {
let mut ok = true;
for qi in QUERIES {
let q = vec_for(qi);
let got = top_scores(http, replica, &q, 10).await;
let want = top_scores(http, primary, &q, 10).await;
if got.is_empty() || !close(&got, &want) {
ok = false;
break;
}
}
if ok {
return;
}
tokio::time::sleep(Duration::from_millis(20)).await;
}
panic!("replica {replica} did not catch up to its primary {primary}");
}
#[tokio::test]
async fn cluster_replicas_serve_reads_matching_single_node() {
let dirs: Vec<_> = (0..6).map(|_| tempfile::tempdir().unwrap()).collect();
let http = reqwest::Client::new();
let p0 = boot(dirs[0].path().into(), None, vec![], vec![]).await;
let p1 = boot(dirs[1].path().into(), None, vec![], vec![]).await;
wait_ready(&http, &p0.rest).await;
wait_ready(&http, &p1.rest).await;
let r0 = boot(dirs[2].path().into(), Some(p0.grpc.clone()), vec![], vec![]).await;
let r1 = boot(dirs[3].path().into(), Some(p1.grpc.clone()), vec![], vec![]).await;
wait_ready(&http, &r0.rest).await;
wait_ready(&http, &r1.rest).await;
let router = boot(
dirs[4].path().into(),
None,
vec![p0.rest.clone(), p1.rest.clone()],
vec![format!("0={}", r0.rest), format!("1={}", r1.rest)],
)
.await;
let baseline = boot(dirs[5].path().into(), None, vec![], vec![]).await;
wait_ready(&http, &router.rest).await;
wait_ready(&http, &baseline.rest).await;
create(&http, &router.rest).await;
create(&http, &baseline.rest).await;
upsert_all(&http, &router.rest, 120).await;
upsert_all(&http, &baseline.rest, 120).await;
let (c0, c1) = (count(&http, &p0.rest).await, count(&http, &p1.rest).await);
assert!(c0 > 0 && c1 > 0, "write did not shard: p0={c0} p1={c1}");
assert_eq!(c0 + c1, 120, "points lost or duplicated across shards");
wait_caught_up(&http, &r0.rest, &p0.rest).await;
wait_caught_up(&http, &r1.rest, &p1.rest).await;
for qi in QUERIES {
let q = vec_for(qi);
let got = top_scores(&http, &router.rest, &q, 10).await;
let want = top_scores(&http, &baseline.rest, &q, 10).await;
assert_eq!(got.len(), 10, "router returned {} hits", got.len());
assert!(close(&got, &want), "router != baseline for q{qi}");
}
let denied = http
.post(format!("{}/v1/collections/c/points", r0.rest))
.json(&json!({"points": [{"id": "x", "vector": vec_for(1)}]}))
.send()
.await
.unwrap();
assert_eq!(denied.status(), reqwest::StatusCode::FORBIDDEN);
p0.handle.abort();
p1.handle.abort();
for qi in QUERIES {
let q = vec_for(qi);
let got = top_scores(&http, &router.rest, &q, 10).await;
assert_eq!(
got.len(),
10,
"replica-only router returned {} hits",
got.len()
);
let want = top_scores(&http, &baseline.rest, &q, 10).await;
assert!(
close(&got, &want),
"replica-served top-k != baseline for q{qi}"
);
}
}
#[tokio::test]
async fn router_tolerates_a_down_replica() {
let dirs: Vec<_> = (0..4).map(|_| tempfile::tempdir().unwrap()).collect();
let http = reqwest::Client::new();
let primary = boot(dirs[0].path().into(), None, vec![], vec![]).await;
wait_ready(&http, &primary.rest).await;
let replica = boot(
dirs[1].path().into(),
Some(primary.grpc.clone()),
vec![],
vec![],
)
.await;
wait_ready(&http, &replica.rest).await;
let router = boot(
dirs[2].path().into(),
None,
vec![primary.rest.clone()],
vec![format!("0={}", replica.rest)],
)
.await;
let baseline = boot(dirs[3].path().into(), None, vec![], vec![]).await;
wait_ready(&http, &router.rest).await;
wait_ready(&http, &baseline.rest).await;
create(&http, &router.rest).await;
create(&http, &baseline.rest).await;
upsert_all(&http, &router.rest, 80).await;
upsert_all(&http, &baseline.rest, 80).await;
wait_caught_up(&http, &replica.rest, &primary.rest).await;
replica.handle.abort();
for qi in QUERIES {
let q = vec_for(qi);
for _ in 0..4 {
let got = top_scores(&http, &router.rest, &q, 10).await;
let want = top_scores(&http, &baseline.rest, &q, 10).await;
assert_eq!(
got.len(),
10,
"router returned {} hits after replica down",
got.len()
);
assert!(
close(&got, &want),
"router != baseline after replica down for q{qi}"
);
}
}
}