use std::collections::HashMap;
use arc_swap::ArcSwap;
use quiver_cluster::{ShardMap, merge_top_k};
use quiver_embed::{
DistanceMetric, Filter, FilterableField, IndexKind, IndexSpec, VectorEncryption,
};
use serde_json::{Value, json};
use tokio::sync::RwLock;
use crate::error::Error;
use crate::{CollectionInfo, MatchOut, PointIn, PointOut};
pub(crate) struct Cluster {
map: ArcSwap<ShardMap>,
http: reqwest::Client,
shard_key: Option<String>,
ordering: RwLock<HashMap<String, bool>>,
}
impl Cluster {
pub(crate) fn new(shard_urls: Vec<String>, shard_key: Option<String>) -> Result<Self, Error> {
let map = ShardMap::from_urls(shard_urls).map_err(|e| Error::Config(e.to_string()))?;
Ok(Self {
map: ArcSwap::from_pointee(map),
http: reqwest::Client::new(),
shard_key,
ordering: RwLock::new(HashMap::new()),
})
}
pub(crate) fn shard_count(&self) -> usize {
self.map.load().len()
}
fn auth(&self, rb: reqwest::RequestBuilder) -> reqwest::RequestBuilder {
match &self.shard_key {
Some(k) => rb.bearer_auth(k),
None => rb,
}
}
async fn send(
&self,
method: reqwest::Method,
url: String,
body: Option<Value>,
) -> Result<Value, Error> {
let mut rb = self.http.request(method, &url);
if let Some(b) = body {
rb = rb.json(&b);
}
let resp = self
.auth(rb)
.send()
.await
.map_err(|e| Error::Internal(format!("shard {url} unreachable: {e}")))?;
let status = resp.status();
let text = resp.text().await.unwrap_or_default();
if !status.is_success() {
return Err(Error::Internal(format!(
"shard {url} returned {status}: {text}"
)));
}
if text.is_empty() {
return Ok(Value::Null);
}
serde_json::from_str(&text)
.map_err(|e| Error::Internal(format!("shard {url} bad response: {e}")))
}
async fn broadcast(
&self,
method: reqwest::Method,
path: &str,
body: Option<Value>,
) -> Result<Value, Error> {
let map = self.map.load();
let mut last = Value::Null;
for shard in map.shards() {
last = self
.send(method.clone(), format!("{}{path}", shard.url), body.clone())
.await?;
}
Ok(last)
}
#[allow(clippy::too_many_arguments)]
pub(crate) async fn create_collection(
&self,
name: String,
dim: u32,
metric: DistanceMetric,
index: IndexSpec,
filterable: Vec<FilterableField>,
multivector: bool,
vector_encryption: VectorEncryption,
) -> Result<CollectionInfo, Error> {
let mut body = json!({
"name": name,
"dim": dim,
"metric": metric_wire(metric),
"index": index_wire(index.kind),
"multivector": multivector,
"vector_encryption": encryption_wire(vector_encryption),
});
if let Some(pq) = index.pq_subspaces {
body["pq_subspaces"] = json!(pq);
}
if !filterable.is_empty() {
body["filterable"] = json!(
filterable
.iter()
.map(|f| json!({ "path": f.path, "type": field_type_wire(f.field_type) }))
.collect::<Vec<_>>()
);
}
self.broadcast(reqwest::Method::POST, "/v1/collections", Some(body))
.await?;
self.ordering
.write()
.await
.insert(name.clone(), higher_is_better(metric));
Ok(CollectionInfo {
name,
dim,
metric,
count: 0,
index,
filterable,
multivector,
vector_encryption,
})
}
pub(crate) async fn drop_collection(&self, name: &str) -> Result<bool, Error> {
self.broadcast(
reqwest::Method::DELETE,
&format!("/v1/collections/{name}"),
None,
)
.await?;
self.ordering.write().await.remove(name);
Ok(true)
}
pub(crate) async fn upsert(
&self,
collection: &str,
points: Vec<PointIn>,
) -> Result<u64, Error> {
self.upsert_to(collection, points, "points").await
}
pub(crate) async fn upsert_bulk(
&self,
collection: &str,
points: Vec<PointIn>,
) -> Result<u64, Error> {
self.upsert_to(collection, points, "points:bulk").await
}
async fn upsert_to(
&self,
collection: &str,
points: Vec<PointIn>,
endpoint: &str,
) -> Result<u64, Error> {
let map = self.map.load();
let groups = map.partition(&points, |p| p.id.as_str());
let mut total = 0u64;
for (shard_idx, group) in groups {
let dtos: Vec<Value> = group
.iter()
.map(|p| json!({ "id": p.id, "vector": p.vector, "payload": p.payload }))
.collect();
let url = format!(
"{}/v1/collections/{collection}/{endpoint}",
map.shards()[shard_idx].url
);
let resp = self
.send(reqwest::Method::POST, url, Some(json!({ "points": dtos })))
.await?;
total += resp.get("upserted").and_then(Value::as_u64).unwrap_or(0);
}
Ok(total)
}
pub(crate) async fn delete_points(
&self,
collection: &str,
ids: Vec<String>,
) -> Result<u64, Error> {
let map = self.map.load();
let groups = map.partition(&ids, |id| id.as_str());
let mut total = 0u64;
for (shard_idx, group) in groups {
let owned: Vec<&String> = group;
let url = format!(
"{}/v1/collections/{collection}/points",
map.shards()[shard_idx].url
);
let resp = self
.send(reqwest::Method::DELETE, url, Some(json!({ "ids": owned })))
.await?;
total += resp.get("deleted").and_then(Value::as_u64).unwrap_or(0);
}
Ok(total)
}
pub(crate) async fn get_points(
&self,
collection: &str,
ids: Vec<String>,
with_vector: bool,
) -> Result<Vec<PointOut>, Error> {
let map = self.map.load();
let mut out = Vec::new();
for id in &ids {
let shard = map.shard_for(id);
let url = format!("{}/v1/collections/{collection}/points/{id}", shard.url);
let resp = match self.send(reqwest::Method::GET, url, None).await {
Ok(v) => v,
Err(Error::Internal(msg)) if msg.contains("404") => continue,
Err(e) => return Err(e),
};
if let Some(p) = point_from_json(&resp, with_vector) {
out.push(p);
}
}
Ok(out)
}
#[allow(clippy::too_many_arguments)]
pub(crate) async fn search(
&self,
collection: &str,
vector: Vec<f32>,
k: usize,
filter: Option<Filter>,
ef_search: usize,
with_payload: bool,
with_vector: bool,
) -> Result<Vec<MatchOut>, Error> {
let higher = self.higher_is_better(collection).await?;
let mut body = json!({
"vector": vector,
"k": k,
"ef_search": ef_search,
"with_payload": with_payload,
"with_vector": with_vector,
});
if let Some(f) = &filter {
body["filter"] =
serde_json::to_value(f).map_err(|e| Error::BadRequest(e.to_string()))?;
}
let map = self.map.load();
let mut per_shard: Vec<Vec<MatchOut>> = Vec::with_capacity(map.len());
for shard in map.shards() {
let url = format!("{}/v1/collections/{collection}/query", shard.url);
let resp = self
.send(reqwest::Method::POST, url, Some(body.clone()))
.await?;
per_shard.push(matches_from_json(&resp, with_vector));
}
Ok(merge_top_k(per_shard, k, |m| m.score, higher))
}
async fn higher_is_better(&self, collection: &str) -> Result<bool, Error> {
if let Some(h) = self.ordering.read().await.get(collection).copied() {
return Ok(h);
}
let map = self.map.load();
let shard = map
.shards()
.first()
.ok_or_else(|| Error::Internal("no shards".into()))?;
let url = format!("{}/v1/collections/{collection}", shard.url);
let info = self.send(reqwest::Method::GET, url, None).await?;
let metric = info.get("metric").and_then(Value::as_str).unwrap_or("l2");
let higher = matches!(metric, "cosine" | "dot");
self.ordering
.write()
.await
.insert(collection.to_owned(), higher);
Ok(higher)
}
}
fn metric_wire(m: DistanceMetric) -> &'static str {
match m {
DistanceMetric::L2 => "l2",
DistanceMetric::Cosine => "cosine",
DistanceMetric::Dot => "dot",
}
}
fn higher_is_better(m: DistanceMetric) -> bool {
!matches!(m, DistanceMetric::L2)
}
fn index_wire(k: IndexKind) -> &'static str {
match k {
IndexKind::Hnsw => "hnsw",
IndexKind::Vamana => "vamana",
IndexKind::DiskVamana => "disk_vamana",
IndexKind::Ivf => "ivf",
IndexKind::Colbert => "colbert",
_ => "hnsw",
}
}
fn encryption_wire(e: VectorEncryption) -> &'static str {
match e {
VectorEncryption::None => "none",
VectorEncryption::Dcpe => "dcpe",
VectorEncryption::ClientSide => "client_side",
}
}
fn field_type_wire(t: quiver_embed::FieldType) -> &'static str {
match t {
quiver_embed::FieldType::Keyword => "keyword",
quiver_embed::FieldType::Numeric => "numeric",
_ => "keyword",
}
}
fn matches_from_json(resp: &Value, with_vector: bool) -> Vec<MatchOut> {
resp.get("matches")
.and_then(Value::as_array)
.map(|arr| {
arr.iter()
.filter_map(|m| {
Some(MatchOut {
id: m.get("id")?.as_str()?.to_owned(),
score: m.get("score")?.as_f64()? as f32,
payload: m.get("payload").cloned(),
vector: if with_vector {
m.get("vector").and_then(Value::as_array).map(|v| {
v.iter()
.filter_map(|x| x.as_f64().map(|f| f as f32))
.collect()
})
} else {
None
},
})
})
.collect()
})
.unwrap_or_default()
}
fn point_from_json(resp: &Value, with_vector: bool) -> Option<PointOut> {
Some(PointOut {
id: resp.get("id")?.as_str()?.to_owned(),
payload: resp.get("payload").cloned().unwrap_or(Value::Null),
vector: if with_vector {
resp.get("vector").and_then(Value::as_array).map(|v| {
v.iter()
.filter_map(|x| x.as_f64().map(|f| f as f32))
.collect()
})
} else {
None
},
})
}