use std::collections::HashMap;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::Duration;
use arc_swap::ArcSwap;
use quiver_cluster::{ShardMap, merge_top_k};
pub(crate) const MAP_REFRESH_INTERVAL: Duration = Duration::from_secs(2);
use quiver_embed::{
DistanceMetric, Filter, FilterableField, IndexKind, IndexSpec, VectorEncryption,
};
use serde_json::{Value, json};
use tokio::sync::RwLock;
use crate::error::Error;
use crate::{CollectionInfo, MatchOut, PointIn, PointOut};
pub(crate) struct Cluster {
map: ArcSwap<ShardMap>,
http: reqwest::Client,
shard_key: Option<String>,
ordering: RwLock<HashMap<String, bool>>,
read_rr: AtomicUsize,
leaders: RwLock<HashMap<u64, String>>,
}
const WRITE_LEADER_ATTEMPTS: usize = 60;
const WRITE_LEADER_BACKOFF: Duration = Duration::from_millis(50);
enum WriteOutcome {
NotLeader,
Unreachable(Error),
Fatal(Error),
}
impl Cluster {
pub(crate) fn new(
shard_urls: Vec<String>,
replica_specs: Vec<String>,
shard_key: Option<String>,
) -> Result<Self, Error> {
let mut map = ShardMap::from_urls(shard_urls).map_err(|e| Error::Config(e.to_string()))?;
for spec in &replica_specs {
let (index, url) = spec.split_once('=').ok_or_else(|| {
Error::Config(format!(
"QUIVER_CLUSTER_REPLICAS entry {spec:?} must be \"<shard_index>=<url>\""
))
})?;
let index: u64 = index.trim().parse().map_err(|_| {
Error::Config(format!("replica entry {spec:?} has a non-numeric shard id"))
})?;
map.add_replica(index, url)
.map_err(|e| Error::Config(e.to_string()))?;
}
Ok(Self {
map: ArcSwap::from_pointee(map),
http: reqwest::Client::new(),
shard_key,
ordering: RwLock::new(HashMap::new()),
read_rr: AtomicUsize::new(0),
leaders: RwLock::new(HashMap::new()),
})
}
pub(crate) fn shard_count(&self) -> usize {
self.map.load().len()
}
pub(crate) fn current_map(&self) -> ShardMap {
ShardMap::clone(&self.map.load())
}
pub(crate) async fn refresh_from(&self, coordinator_url: &str) -> Result<bool, Error> {
let url = format!("{}/cluster/map", coordinator_url.trim_end_matches('/'));
let body = self.send(reqwest::Method::GET, url, None).await?;
let new_map: ShardMap = serde_json::from_value(body)
.map_err(|e| Error::Internal(format!("coordinator map: {e}")))?;
if new_map.version() > self.map.load().version() {
self.map.store(std::sync::Arc::new(new_map));
Ok(true)
} else {
Ok(false)
}
}
fn auth(&self, rb: reqwest::RequestBuilder) -> reqwest::RequestBuilder {
match &self.shard_key {
Some(k) => rb.bearer_auth(k),
None => rb,
}
}
async fn send(
&self,
method: reqwest::Method,
url: String,
body: Option<Value>,
) -> Result<Value, Error> {
let mut rb = self.http.request(method, &url);
if let Some(b) = body {
rb = rb.json(&b);
}
let resp = self
.auth(rb)
.send()
.await
.map_err(|e| Error::Internal(format!("shard {url} unreachable: {e}")))?;
let status = resp.status();
let text = resp.text().await.unwrap_or_default();
if !status.is_success() {
return Err(Error::Internal(format!(
"shard {url} returned {status}: {text}"
)));
}
if text.is_empty() {
return Ok(Value::Null);
}
serde_json::from_str(&text)
.map_err(|e| Error::Internal(format!("shard {url} bad response: {e}")))
}
async fn shard_query(
&self,
shard: &quiver_cluster::Shard,
nth: usize,
collection: &str,
body: &Value,
) -> Result<Value, Error> {
let targets = shard.read_order(nth);
let mut last_err = None;
for (i, target) in targets.iter().enumerate() {
let url = format!("{target}/v1/collections/{collection}/query");
match self
.send(reqwest::Method::POST, url, Some(body.clone()))
.await
{
Ok(v) => return Ok(v),
Err(e) => {
if i + 1 < targets.len() {
tracing::warn!(target, error = %e, "shard read target failed; trying next");
}
last_err = Some(e);
}
}
}
Err(last_err.unwrap_or_else(|| Error::Internal("shard has no read targets".into())))
}
async fn broadcast(
&self,
method: reqwest::Method,
path: &str,
body: Option<Value>,
) -> Result<Value, Error> {
let map = self.map.load();
let mut last = Value::Null;
for shard in map.shards() {
last = self
.write_to_shard(shard, method.clone(), path, body.clone())
.await?;
}
Ok(last)
}
async fn try_write(
&self,
method: reqwest::Method,
url: String,
body: Option<Value>,
) -> Result<Value, WriteOutcome> {
let mut rb = self.http.request(method, &url);
if let Some(b) = body {
rb = rb.json(&b);
}
let resp = self.auth(rb).send().await.map_err(|e| {
WriteOutcome::Unreachable(Error::Internal(format!("shard {url} unreachable: {e}")))
})?;
let status = resp.status();
let text = resp.text().await.unwrap_or_default();
if status == reqwest::StatusCode::MISDIRECTED_REQUEST {
return Err(WriteOutcome::NotLeader);
}
if !status.is_success() {
return Err(WriteOutcome::Fatal(Error::Internal(format!(
"shard {url} returned {status}: {text}"
))));
}
if text.is_empty() {
return Ok(Value::Null);
}
serde_json::from_str(&text).map_err(|e| {
WriteOutcome::Fatal(Error::Internal(format!("shard {url} bad response: {e}")))
})
}
async fn write_to_shard(
&self,
shard: &quiver_cluster::Shard,
method: reqwest::Method,
path: &str,
body: Option<Value>,
) -> Result<Value, Error> {
let target = |url: &str| format!("{url}{path}");
for _ in 0..WRITE_LEADER_ATTEMPTS {
let cached = self.leaders.read().await.get(&shard.id).cloned();
let mut candidates: Vec<String> = cached.into_iter().collect();
for v in shard.read_order(0) {
if !candidates.iter().any(|u| u == v) {
candidates.push(v.to_owned());
}
}
let mut saw_not_leader = false;
let mut unreachable: Option<Error> = None;
for url in &candidates {
match self
.try_write(method.clone(), target(url), body.clone())
.await
{
Ok(v) => return self.cache_leader(shard.id, url, v).await,
Err(WriteOutcome::NotLeader) => saw_not_leader = true,
Err(WriteOutcome::Unreachable(e)) => unreachable = Some(e),
Err(WriteOutcome::Fatal(e)) => return Err(e),
}
}
if saw_not_leader {
tokio::time::sleep(WRITE_LEADER_BACKOFF).await;
} else {
return Err(unreachable.unwrap_or_else(|| {
Error::Internal(format!("shard {} has no write target", shard.id))
}));
}
}
Err(Error::Internal(format!(
"shard {} has no Raft leader (writes unavailable after retries)",
shard.id
)))
}
async fn cache_leader(&self, id: u64, url: &str, body: Value) -> Result<Value, Error> {
self.leaders.write().await.insert(id, url.to_owned());
Ok(body)
}
#[allow(clippy::too_many_arguments)]
pub(crate) async fn create_collection(
&self,
name: String,
dim: u32,
metric: DistanceMetric,
index: IndexSpec,
filterable: Vec<FilterableField>,
multivector: bool,
vector_encryption: VectorEncryption,
) -> Result<CollectionInfo, Error> {
let mut body = json!({
"name": name,
"dim": dim,
"metric": metric_wire(metric),
"index": index_wire(index.kind),
"multivector": multivector,
"vector_encryption": encryption_wire(vector_encryption),
});
if let Some(pq) = index.pq_subspaces {
body["pq_subspaces"] = json!(pq);
}
if !filterable.is_empty() {
body["filterable"] = json!(
filterable
.iter()
.map(|f| json!({ "path": f.path, "type": field_type_wire(f.field_type) }))
.collect::<Vec<_>>()
);
}
self.broadcast(reqwest::Method::POST, "/v1/collections", Some(body))
.await?;
self.ordering
.write()
.await
.insert(name.clone(), higher_is_better(metric));
Ok(CollectionInfo {
name,
dim,
metric,
count: 0,
index,
filterable,
multivector,
vector_encryption,
})
}
pub(crate) async fn drop_collection(&self, name: &str) -> Result<bool, Error> {
self.broadcast(
reqwest::Method::DELETE,
&format!("/v1/collections/{name}"),
None,
)
.await?;
self.ordering.write().await.remove(name);
Ok(true)
}
pub(crate) async fn upsert(
&self,
collection: &str,
points: Vec<PointIn>,
) -> Result<u64, Error> {
self.upsert_to(collection, points, "points").await
}
pub(crate) async fn upsert_bulk(
&self,
collection: &str,
points: Vec<PointIn>,
) -> Result<u64, Error> {
self.upsert_to(collection, points, "points:bulk").await
}
async fn upsert_to(
&self,
collection: &str,
points: Vec<PointIn>,
endpoint: &str,
) -> Result<u64, Error> {
let map = self.map.load();
let mut total = 0u64;
for (shard, group) in map.partition(&points, |p| p.id.as_str()) {
total += self
.post_points(collection, endpoint, shard, &group)
.await?;
}
for (donor, group) in map.partition_to_donors(&points, |p| p.id.as_str()) {
self.post_points(collection, endpoint, donor, &group)
.await?;
}
Ok(total)
}
async fn post_points(
&self,
collection: &str,
endpoint: &str,
shard: &quiver_cluster::Shard,
group: &[&PointIn],
) -> Result<u64, Error> {
let dtos: Vec<Value> = group
.iter()
.map(|p| json!({ "id": p.id, "vector": p.vector, "payload": p.payload }))
.collect();
let path = format!("/v1/collections/{collection}/{endpoint}");
let resp = self
.write_to_shard(
shard,
reqwest::Method::POST,
&path,
Some(json!({ "points": dtos })),
)
.await?;
Ok(resp.get("upserted").and_then(Value::as_u64).unwrap_or(0))
}
pub(crate) async fn delete_points(
&self,
collection: &str,
ids: Vec<String>,
) -> Result<u64, Error> {
let map = self.map.load();
let mut total = 0u64;
for (shard, group) in map.partition(&ids, |id| id.as_str()) {
total += self.delete_group(collection, shard, &group).await?;
}
for (donor, group) in map.partition_to_donors(&ids, |id| id.as_str()) {
self.delete_group(collection, donor, &group).await?;
}
Ok(total)
}
async fn delete_group(
&self,
collection: &str,
shard: &quiver_cluster::Shard,
group: &[&String],
) -> Result<u64, Error> {
let path = format!("/v1/collections/{collection}/points");
let resp = self
.write_to_shard(
shard,
reqwest::Method::DELETE,
&path,
Some(json!({ "ids": group })),
)
.await?;
Ok(resp.get("deleted").and_then(Value::as_u64).unwrap_or(0))
}
pub(crate) async fn get_points(
&self,
collection: &str,
ids: Vec<String>,
with_vector: bool,
) -> Result<Vec<PointOut>, Error> {
let map = self.map.load();
let mut out = Vec::new();
for id in &ids {
let shard = map.donor_for(id).unwrap_or_else(|| map.shard_for(id));
let url = format!(
"{}/v1/collections/{collection}/points/{id}",
shard.primary_url
);
let resp = match self.send(reqwest::Method::GET, url, None).await {
Ok(v) => v,
Err(Error::Internal(msg)) if msg.contains("404") => continue,
Err(e) => return Err(e),
};
if let Some(p) = point_from_json(&resp, with_vector) {
out.push(p);
}
}
Ok(out)
}
#[allow(clippy::too_many_arguments)]
pub(crate) async fn search(
&self,
collection: &str,
vector: Vec<f32>,
k: usize,
filter: Option<Filter>,
ef_search: usize,
with_payload: bool,
with_vector: bool,
) -> Result<Vec<MatchOut>, Error> {
let higher = self.higher_is_better(collection).await?;
let mut body = json!({
"vector": vector,
"k": k,
"ef_search": ef_search,
"with_payload": with_payload,
"with_vector": with_vector,
});
if let Some(f) = &filter {
body["filter"] =
serde_json::to_value(f).map_err(|e| Error::BadRequest(e.to_string()))?;
}
let map = self.map.load();
let base = self.read_rr.fetch_add(1, Ordering::Relaxed);
let active = map.active_shards();
let mut per_shard: Vec<Vec<MatchOut>> = Vec::with_capacity(active.len());
for shard in &active {
let resp = self
.shard_query(
shard,
base.wrapping_add(shard.id as usize),
collection,
&body,
)
.await?;
per_shard.push(matches_from_json(&resp, with_vector));
}
let mut seen = std::collections::HashSet::new();
let deduped: Vec<MatchOut> = per_shard
.into_iter()
.flatten()
.filter(|m| seen.insert(m.id.clone()))
.collect();
Ok(merge_top_k(vec![deduped], k, |m| m.score, higher))
}
async fn higher_is_better(&self, collection: &str) -> Result<bool, Error> {
if let Some(h) = self.ordering.read().await.get(collection).copied() {
return Ok(h);
}
let map = self.map.load();
let shard = map
.shards()
.first()
.ok_or_else(|| Error::Internal("no shards".into()))?;
let url = format!("{}/v1/collections/{collection}", shard.primary_url);
let info = self.send(reqwest::Method::GET, url, None).await?;
let metric = info.get("metric").and_then(Value::as_str).unwrap_or("l2");
let higher = matches!(metric, "cosine" | "dot");
self.ordering
.write()
.await
.insert(collection.to_owned(), higher);
Ok(higher)
}
}
fn metric_wire(m: DistanceMetric) -> &'static str {
match m {
DistanceMetric::L2 => "l2",
DistanceMetric::Cosine => "cosine",
DistanceMetric::Dot => "dot",
}
}
fn higher_is_better(m: DistanceMetric) -> bool {
!matches!(m, DistanceMetric::L2)
}
fn index_wire(k: IndexKind) -> &'static str {
match k {
IndexKind::Hnsw => "hnsw",
IndexKind::Vamana => "vamana",
IndexKind::DiskVamana => "disk_vamana",
IndexKind::Ivf => "ivf",
IndexKind::Colbert => "colbert",
_ => "hnsw",
}
}
fn encryption_wire(e: VectorEncryption) -> &'static str {
match e {
VectorEncryption::None => "none",
VectorEncryption::Dcpe => "dcpe",
VectorEncryption::ClientSide => "client_side",
}
}
fn field_type_wire(t: quiver_embed::FieldType) -> &'static str {
match t {
quiver_embed::FieldType::Keyword => "keyword",
quiver_embed::FieldType::Numeric => "numeric",
_ => "keyword",
}
}
fn matches_from_json(resp: &Value, with_vector: bool) -> Vec<MatchOut> {
resp.get("matches")
.and_then(Value::as_array)
.map(|arr| {
arr.iter()
.filter_map(|m| {
Some(MatchOut {
id: m.get("id")?.as_str()?.to_owned(),
score: m.get("score")?.as_f64()? as f32,
payload: m.get("payload").cloned(),
vector: if with_vector {
m.get("vector").and_then(Value::as_array).map(|v| {
v.iter()
.filter_map(|x| x.as_f64().map(|f| f as f32))
.collect()
})
} else {
None
},
})
})
.collect()
})
.unwrap_or_default()
}
fn point_from_json(resp: &Value, with_vector: bool) -> Option<PointOut> {
Some(PointOut {
id: resp.get("id")?.as_str()?.to_owned(),
payload: resp.get("payload").cloned().unwrap_or(Value::Null),
vector: if with_vector {
resp.get("vector").and_then(Value::as_array).map(|v| {
v.iter()
.filter_map(|x| x.as_f64().map(|f| f as f32))
.collect()
})
} else {
None
},
})
}
#[cfg(test)]
mod tests {
use super::*;
fn shards() -> Vec<String> {
vec!["http://s0:6333".into(), "http://s1:6333".into()]
}
#[test]
fn new_accepts_well_formed_replicas() {
let c = Cluster::new(
shards(),
vec!["0=http://s0b:6333".into(), "1=http://s1b:6333".into()],
None,
)
.unwrap();
assert_eq!(c.shard_count(), 2);
let map = c.map.load();
assert_eq!(map.shards()[0].replica_urls, ["http://s0b:6333"]);
assert_eq!(map.shards()[1].replica_urls, ["http://s1b:6333"]);
assert!(Cluster::new(shards(), vec![], None).is_ok());
}
#[test]
fn new_rejects_malformed_replica_specs() {
let config_err = |spec: &str| match Cluster::new(shards(), vec![spec.into()], None) {
Err(Error::Config(_)) => {}
Err(e) => panic!("expected a Config error, got {e:?}"),
Ok(_) => panic!("expected a Config error for {spec:?}, built a router"),
};
config_err("http://nope"); config_err("x=http://nope"); config_err("9=http://nope"); }
}