use parking_lot::RwLock;
use std::collections::HashMap;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use std::thread::JoinHandle;
use std::time::Duration;
const COMPACTION_INTERVAL: Duration = Duration::from_secs(60);
const SNAPSHOT_INTERVAL: Duration = Duration::from_secs(300);
use bytes::Bytes;
use skeg_core::{Durability, VLog};
use skeg_vector::{DiskVamanaIndex, FlatIndex, QuantKind};
use tokio::sync::mpsc::{Receiver, Sender};
use tokio::sync::oneshot;
use tracing::error;
use xxhash_rust::xxh3::xxh3_64;
type VindexSet = HashMap<String, VectorBackend>;
const VAMANA_L_SEARCH: usize = 100;
const DISK_CONSOLIDATE_MIN: usize = 4096;
enum VectorBackend {
Flat(FlatIndex),
Disk(DiskVamanaIndex),
}
impl VectorBackend {
fn dim(&self) -> usize {
match self {
VectorBackend::Flat(i) => i.dim(),
VectorBackend::Disk(i) => i.dim(),
}
}
fn len(&self) -> usize {
match self {
VectorBackend::Flat(i) => i.len(),
VectorBackend::Disk(i) => i.len(),
}
}
fn kind_byte(&self) -> u8 {
match self {
VectorBackend::Flat(_) => 0,
VectorBackend::Disk(_) => 1, }
}
fn backend_byte(&self) -> u8 {
match self {
VectorBackend::Flat(_) => 0,
VectorBackend::Disk(_) => 1,
}
}
fn insert(&mut self, id: u64, vector: &[f32]) -> std::io::Result<()> {
match self {
VectorBackend::Flat(i) => {
i.insert(id, vector);
Ok(())
}
VectorBackend::Disk(i) => {
i.insert(id, vector)?;
let threshold = (i.main_len() / 20).max(DISK_CONSOLIDATE_MIN);
if i.delta_len() >= threshold {
i.consolidate()?;
}
Ok(())
}
}
}
fn delete(&mut self, id: u64) -> std::io::Result<bool> {
match self {
VectorBackend::Flat(i) => Ok(i.delete(id)),
VectorBackend::Disk(i) => i.delete(id),
}
}
fn get(&self, id: u64) -> std::io::Result<Option<Vec<f32>>> {
match self {
VectorBackend::Flat(i) => Ok(i.get(id)),
VectorBackend::Disk(i) => i.get(id),
}
}
fn search(
&mut self,
query: &[f32],
k: usize,
l_search: u32,
) -> std::io::Result<Vec<(u64, f32)>> {
match self {
VectorBackend::Flat(i) => Ok(i.search(query, k)),
VectorBackend::Disk(i) => i.search_with_l(query, k, l_search as usize),
}
}
}
const SHARD_INBOX_CAPACITY: usize = 4096;
const MAX_INFLIGHT_PER_SHARD: usize = 1024;
#[must_use]
pub fn shard_for(key: &[u8], n_shards: usize) -> usize {
debug_assert!(n_shards >= 1);
#[allow(clippy::cast_possible_truncation)]
let idx = (xxh3_64(key) % n_shards as u64) as usize;
idx
}
#[derive(Debug, thiserror::Error)]
pub enum ShardError {
#[error("shard unavailable")]
Unavailable,
#[error("storage error: {0}")]
Storage(String),
}
enum ShardReq {
Get(Bytes),
Set(Bytes, Bytes, Durability),
Del(Bytes, Durability),
MgetBatch(Vec<(usize, Bytes)>),
Stats,
VindexCreate {
name: String,
dim: usize,
kind: QuantKind,
disk: bool,
},
VindexDrop {
name: String,
},
VindexList,
Vset {
name: String,
id: u64,
vector: Vec<f32>,
},
Vget {
name: String,
id: u64,
},
Vdel {
name: String,
id: u64,
},
Vsearch {
name: String,
query: Vec<f32>,
k: usize,
l_search: u32,
},
}
enum ShardResp {
Value(Option<Bytes>),
Done,
Existed(bool),
MgetBatch(Vec<(usize, Option<Bytes>)>),
Stats(u64, u64, u64, u64),
VindexList(Vec<(String, u32, u8, u8, u64)>),
Vector(Option<Vec<f32>>),
Vsearch(Vec<(u64, f32)>),
Err(String),
}
struct ShardMsg {
req: ShardReq,
reply: oneshot::Sender<ShardResp>,
}
const VINDEX_REGISTRY: &str = "vindexes.registry";
#[allow(clippy::cast_possible_truncation)] fn write_registry(dir: &Path, entries: &[(&str, usize)]) -> std::io::Result<()> {
let mut buf = Vec::new();
buf.extend_from_slice(&(entries.len() as u32).to_le_bytes());
for (name, dim) in entries {
buf.extend_from_slice(&(name.len() as u16).to_le_bytes());
buf.extend_from_slice(name.as_bytes());
buf.extend_from_slice(&(*dim as u32).to_le_bytes());
}
let tmp = dir.join(format!("{VINDEX_REGISTRY}.tmp"));
std::fs::write(&tmp, &buf)?;
std::fs::rename(&tmp, dir.join(VINDEX_REGISTRY))
}
fn read_registry(dir: &Path) -> Vec<(String, usize)> {
let Ok(bytes) = std::fs::read(dir.join(VINDEX_REGISTRY)) else {
return Vec::new();
};
if bytes.len() < 4 {
return Vec::new();
}
let count = u32::from_le_bytes(bytes[0..4].try_into().unwrap()) as usize;
let mut out = Vec::with_capacity(count);
let mut pos = 4;
for _ in 0..count {
if pos + 2 > bytes.len() {
break;
}
let nlen = u16::from_le_bytes([bytes[pos], bytes[pos + 1]]) as usize;
pos += 2;
if pos + nlen + 4 > bytes.len() {
break;
}
let name = String::from_utf8_lossy(&bytes[pos..pos + nlen]).into_owned();
pos += nlen;
let dim = u32::from_le_bytes(bytes[pos..pos + 4].try_into().unwrap()) as usize;
pos += 4;
out.push((name, dim));
}
out
}
fn persist_registry(dir: &Path, vindexes: &RwLock<VindexSet>) {
let vs = vindexes.read();
let entries: Vec<(&str, usize)> = vs
.iter()
.filter_map(|(name, b)| match b {
VectorBackend::Disk(i) => Some((name.as_str(), i.dim())),
VectorBackend::Flat(_) => None,
})
.collect();
if let Err(e) = write_registry(dir, &entries) {
error!("vindex registry write failed: {e}");
}
}
fn recover_vindexes(
shard_id: usize,
dir: &Path,
tier: QuantKind,
mmap_tier: bool,
mmap_graph: bool,
) -> VindexSet {
let mut set = VindexSet::new();
for (name, _dim) in read_registry(dir) {
let vdir = dir.join(format!("vindex-{name}"));
match DiskVamanaIndex::open_with_tier_full(&vdir, tier, mmap_tier, mmap_graph) {
Ok(idx) => {
set.insert(name, VectorBackend::Disk(idx));
}
Err(e) => error!("shard {shard_id}: recovering vindex '{name}' failed: {e}"),
}
}
set
}
#[allow(clippy::needless_pass_by_value)]
fn run_shard(
shard_id: usize,
dir: PathBuf,
mut rx: Receiver<ShardMsg>,
read_only: bool,
tier: QuantKind,
workers: usize,
mmap_tier: bool,
mmap_graph: bool,
) {
skeg_platform::pin_current_thread_to_performance_core();
let rt = match tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
{
Ok(rt) => rt,
Err(e) => {
error!("shard {shard_id}: runtime build failed: {e}");
return;
}
};
rt.block_on(async move {
let vlog = match VLog::open(&dir).await {
Ok(v) => v,
Err(e) => {
error!("shard {shard_id}: VLog::open failed: {e}");
while let Some(msg) = rx.recv().await {
let _ = msg
.reply
.send(ShardResp::Err("shard storage unavailable".to_owned()));
}
return;
}
};
let inflight = std::sync::Arc::new(tokio::sync::Semaphore::new(MAX_INFLIGHT_PER_SHARD));
let vindexes: Arc<RwLock<VindexSet>> = Arc::new(RwLock::new(recover_vindexes(
shard_id, &dir, tier, mmap_tier, mmap_graph,
)));
let local = tokio::task::LocalSet::new();
local
.run_until(async {
if !read_only {
let cvlog = vlog.clone();
tokio::task::spawn_local(async move {
loop {
tokio::time::sleep(COMPACTION_INTERVAL).await;
match cvlog.maybe_compact().await {
Ok(Some(_seg_id)) => {
skeg_telemetry::tick_counter(
skeg_telemetry::Counter::CompactionRunsTotal,
);
}
Ok(None) => { }
Err(e) => {
error!("shard {shard_id}: compaction failed: {e}");
}
}
}
});
let svlog = vlog.clone();
tokio::task::spawn_local(async move {
loop {
tokio::time::sleep(SNAPSHOT_INTERVAL).await;
if let Err(e) = svlog.write_snapshot().await {
error!("shard {shard_id}: snapshot failed: {e}");
}
}
});
}
while let Some(msg) = rx.recv().await {
let permit = inflight
.clone()
.acquire_owned()
.await
.expect("inflight semaphore is never closed");
let vlog = vlog.clone();
let vindexes = vindexes.clone();
let dir = dir.clone();
let shard_id_u16 = shard_id as u16;
tokio::task::spawn_local(async move {
let op_kind = telemetry_op(&msg.req);
let t0 = std::time::Instant::now();
let resp =
process(&vlog, &vindexes, &dir, msg.req, read_only, workers).await;
if let Some(op) = op_kind {
skeg_telemetry::record_op(op, shard_id_u16, t0.elapsed());
}
let _ = msg.reply.send(resp);
drop(permit); });
}
let _ = vlog.flush().await;
})
.await;
});
}
fn is_mutation(req: &ShardReq) -> bool {
matches!(
req,
ShardReq::Set(..)
| ShardReq::Del(..)
| ShardReq::VindexCreate { .. }
| ShardReq::VindexDrop { .. }
| ShardReq::Vset { .. }
| ShardReq::Vdel { .. }
)
}
#[allow(clippy::too_many_lines)]
#[inline]
fn telemetry_op(req: &ShardReq) -> Option<skeg_telemetry::Op> {
use skeg_telemetry::Op;
match req {
ShardReq::Get(_) | ShardReq::MgetBatch(_) => Some(Op::Get),
ShardReq::Set(..) => Some(Op::Set),
ShardReq::Del(..) => Some(Op::Del),
ShardReq::Vset { .. } => Some(Op::VSet),
ShardReq::Vsearch { .. } => Some(Op::VSearch),
ShardReq::Vdel { .. } => Some(Op::VDel),
ShardReq::Vget { .. }
| ShardReq::VindexCreate { .. }
| ShardReq::VindexList
| ShardReq::VindexDrop { .. }
| ShardReq::Stats => None,
}
}
async fn process(
vlog: &VLog,
vindexes: &Arc<RwLock<VindexSet>>,
dir: &Path,
req: ShardReq,
read_only: bool,
workers: usize,
) -> ShardResp {
if read_only && is_mutation(&req) {
return ShardResp::Err("server is in serve mode (read-only)".to_owned());
}
if workers > 0
&& let ShardReq::Vsearch {
name,
query,
k,
l_search,
} = req
{
let vindexes_arc = Arc::clone(vindexes);
let join = tokio::task::spawn_blocking(move || -> ShardResp {
let mut vs = vindexes_arc.write();
match vs.get_mut(&name) {
None => ShardResp::Err(format!("vindex '{name}' not found")),
Some(idx) if idx.dim() != query.len() => ShardResp::Err(format!(
"vindex '{name}' dim {} but query has {}",
idx.dim(),
query.len()
)),
Some(idx) => match idx.search(&query, k, l_search) {
Ok(hits) => ShardResp::Vsearch(hits),
Err(e) => ShardResp::Err(format!("vsearch failed: {e}")),
},
}
});
return match join.await {
Ok(resp) => resp,
Err(_) => ShardResp::Err("vsearch worker task failed".to_owned()),
};
}
let vindexes: &RwLock<VindexSet> = vindexes;
match req {
ShardReq::VindexCreate {
name,
dim,
kind,
disk,
} => {
use std::collections::hash_map::Entry;
let result = match vindexes.write().entry(name) {
Entry::Occupied(e) => Err(format!("vindex '{}' already exists", e.key())),
Entry::Vacant(e) => {
if disk {
let vdir = dir.join(format!("vindex-{}", e.key()));
match DiskVamanaIndex::create_empty(&vdir, dim, VAMANA_L_SEARCH) {
Ok(idx) => {
e.insert(VectorBackend::Disk(idx));
Ok(true)
}
Err(err) => Err(format!("vindex disk create failed: {err}")),
}
} else {
e.insert(VectorBackend::Flat(FlatIndex::new(dim, kind)));
Ok(false)
}
}
};
match result {
Ok(created_disk) => {
if created_disk {
persist_registry(dir, vindexes);
}
ShardResp::Done
}
Err(e) => ShardResp::Err(e),
}
}
ShardReq::VindexList => {
let vs = vindexes.read();
let mut rows: Vec<(String, u32, u8, u8, u64)> = vs
.iter()
.map(|(name, backend)| {
(
name.clone(),
backend.dim() as u32,
backend.kind_byte(),
backend.backend_byte(),
backend.len() as u64,
)
})
.collect();
rows.sort_by(|a, b| a.0.cmp(&b.0));
ShardResp::VindexList(rows)
}
ShardReq::VindexDrop { name } => {
let removed = vindexes.write().remove(&name);
match removed {
Some(backend) => {
let was_disk = matches!(backend, VectorBackend::Disk(_));
drop(backend); if was_disk {
let _ = std::fs::remove_dir_all(dir.join(format!("vindex-{name}")));
persist_registry(dir, vindexes);
}
ShardResp::Done
}
None => ShardResp::Err(format!("vindex '{name}' not found")),
}
}
ShardReq::Vset { name, id, vector } => {
let mut vs = vindexes.write();
match vs.get_mut(&name) {
None => ShardResp::Err(format!("vindex '{name}' not found")),
Some(idx) if idx.dim() != vector.len() => ShardResp::Err(format!(
"vindex '{name}' dim {} but vector has {}",
idx.dim(),
vector.len()
)),
Some(idx) => match idx.insert(id, &vector) {
Ok(()) => ShardResp::Done,
Err(e) => ShardResp::Err(format!("vset failed: {e}")),
},
}
}
ShardReq::Vget { name, id } => {
let vs = vindexes.read();
match vs.get(&name) {
None => ShardResp::Err(format!("vindex '{name}' not found")),
Some(idx) => match idx.get(id) {
Ok(v) => ShardResp::Vector(v),
Err(e) => ShardResp::Err(format!("vget failed: {e}")),
},
}
}
ShardReq::Vdel { name, id } => {
let mut vs = vindexes.write();
match vs.get_mut(&name) {
None => ShardResp::Err(format!("vindex '{name}' not found")),
Some(idx) => match idx.delete(id) {
Ok(existed) => ShardResp::Existed(existed),
Err(e) => ShardResp::Err(format!("vdel failed: {e}")),
},
}
}
ShardReq::Vsearch {
name,
query,
k,
l_search,
} => {
let mut vs = vindexes.write();
match vs.get_mut(&name) {
None => ShardResp::Err(format!("vindex '{name}' not found")),
Some(idx) if idx.dim() != query.len() => ShardResp::Err(format!(
"vindex '{name}' dim {} but query has {}",
idx.dim(),
query.len()
)),
Some(idx) => match idx.search(&query, k, l_search) {
Ok(hits) => ShardResp::Vsearch(hits),
Err(e) => ShardResp::Err(format!("vsearch failed: {e}")),
},
}
}
ShardReq::Get(key) => match vlog.get(&key).await {
Ok(v) => ShardResp::Value(v),
Err(e) => ShardResp::Err(e.to_string()),
},
ShardReq::Set(key, val, dur) => match vlog.set(&key, &val, dur).await {
Ok(()) => ShardResp::Done,
Err(e) => ShardResp::Err(e.to_string()),
},
ShardReq::Del(key, dur) => match vlog.del(&key, dur).await {
Ok(b) => ShardResp::Existed(b),
Err(e) => ShardResp::Err(e.to_string()),
},
ShardReq::MgetBatch(items) => {
let mut out = Vec::with_capacity(items.len());
for (idx, key) in items {
match vlog.get(&key).await {
Ok(v) => out.push((idx, v)),
Err(e) => return ShardResp::Err(e.to_string()),
}
}
ShardResp::MgetBatch(out)
}
ShardReq::Stats => {
let (bytes, evictions, n_keys, budget) = vlog.cache_stats();
skeg_telemetry::set_gauge(
skeg_telemetry::Gauge::VlogLiveBytes,
bytes as u64,
);
ShardResp::Stats(bytes, evictions, n_keys, budget)
}
}
}
struct ShardSetInner {
senders: Vec<Sender<ShardMsg>>,
handles: Vec<JoinHandle<()>>,
n: usize,
}
impl Drop for ShardSetInner {
fn drop(&mut self) {
self.senders.clear();
for h in self.handles.drain(..) {
let _ = h.join();
}
}
}
#[derive(Clone)]
pub struct ShardSet {
inner: Arc<ShardSetInner>,
}
impl ShardSet {
pub fn open(base_dir: &Path, n_shards: usize) -> std::io::Result<Self> {
Self::open_mode(base_dir, n_shards, false, QuantKind::Int8)
}
pub fn open_mode(
base_dir: &Path,
n_shards: usize,
read_only: bool,
tier: QuantKind,
) -> std::io::Result<Self> {
Self::open_mode_with_workers(base_dir, n_shards, read_only, tier, 0)
}
pub fn open_mode_with_workers(
base_dir: &Path,
n_shards: usize,
read_only: bool,
tier: QuantKind,
workers: usize,
) -> std::io::Result<Self> {
Self::open_mode_full(base_dir, n_shards, read_only, tier, workers, false)
}
pub fn open_mode_full(
base_dir: &Path,
n_shards: usize,
read_only: bool,
tier: QuantKind,
workers: usize,
mmap_tier: bool,
) -> std::io::Result<Self> {
Self::open_mode_full_mmap(
base_dir, n_shards, read_only, tier, workers, mmap_tier, false,
)
}
pub fn open_mode_full_mmap(
base_dir: &Path,
n_shards: usize,
read_only: bool,
tier: QuantKind,
workers: usize,
mmap_tier: bool,
mmap_graph: bool,
) -> std::io::Result<Self> {
assert!(n_shards >= 1, "n_shards must be >= 1");
let mut senders = Vec::with_capacity(n_shards);
let mut handles = Vec::with_capacity(n_shards);
for id in 0..n_shards {
let dir = base_dir.join(format!("shard-{id}"));
let (tx, rx) = tokio::sync::mpsc::channel::<ShardMsg>(SHARD_INBOX_CAPACITY);
let handle = std::thread::Builder::new()
.name(format!("skeg-shard-{id}"))
.spawn(move || {
run_shard(id, dir, rx, read_only, tier, workers, mmap_tier, mmap_graph)
})?;
senders.push(tx);
handles.push(handle);
}
Ok(Self {
inner: Arc::new(ShardSetInner {
senders,
handles,
n: n_shards,
}),
})
}
#[must_use]
pub fn n_shards(&self) -> usize {
self.inner.n
}
async fn call(&self, shard: usize, req: ShardReq) -> Result<ShardResp, ShardError> {
let (tx, rx) = oneshot::channel();
self.inner.senders[shard]
.send(ShardMsg { req, reply: tx })
.await
.map_err(|_| ShardError::Unavailable)?;
rx.await.map_err(|_| ShardError::Unavailable)
}
pub async fn get(&self, key: &[u8]) -> Result<Option<Bytes>, ShardError> {
let shard = shard_for(key, self.inner.n);
match self
.call(shard, ShardReq::Get(Bytes::copy_from_slice(key)))
.await?
{
ShardResp::Value(v) => Ok(v),
ShardResp::Err(e) => Err(ShardError::Storage(e)),
_ => Err(ShardError::Unavailable),
}
}
pub async fn set(
&self,
key: &[u8],
value: &[u8],
durability: Durability,
) -> Result<(), ShardError> {
let shard = shard_for(key, self.inner.n);
let req = ShardReq::Set(
Bytes::copy_from_slice(key),
Bytes::copy_from_slice(value),
durability,
);
match self.call(shard, req).await? {
ShardResp::Done => Ok(()),
ShardResp::Err(e) => Err(ShardError::Storage(e)),
_ => Err(ShardError::Unavailable),
}
}
pub async fn del(&self, key: &[u8], durability: Durability) -> Result<bool, ShardError> {
let shard = shard_for(key, self.inner.n);
let req = ShardReq::Del(Bytes::copy_from_slice(key), durability);
match self.call(shard, req).await? {
ShardResp::Existed(b) => Ok(b),
ShardResp::Err(e) => Err(ShardError::Storage(e)),
_ => Err(ShardError::Unavailable),
}
}
pub async fn mget(&self, keys: &[Bytes]) -> Result<Vec<Option<Bytes>>, ShardError> {
let n = self.inner.n;
let mut buckets: Vec<Vec<(usize, Bytes)>> = vec![Vec::new(); n];
for (i, key) in keys.iter().enumerate() {
buckets[shard_for(key, n)].push((i, key.clone()));
}
let mut pending = Vec::new();
for (shard, bucket) in buckets.into_iter().enumerate() {
if bucket.is_empty() {
continue;
}
let (tx, rx) = oneshot::channel();
self.inner.senders[shard]
.send(ShardMsg {
req: ShardReq::MgetBatch(bucket),
reply: tx,
})
.await
.map_err(|_| ShardError::Unavailable)?;
pending.push(rx);
}
let mut result: Vec<Option<Bytes>> = vec![None; keys.len()];
for rx in pending {
match rx.await.map_err(|_| ShardError::Unavailable)? {
ShardResp::MgetBatch(items) => {
for (idx, val) in items {
result[idx] = val;
}
}
ShardResp::Err(e) => return Err(ShardError::Storage(e)),
_ => return Err(ShardError::Unavailable),
}
}
Ok(result)
}
pub async fn stats(&self) -> Result<skeg_proto::ServerStats, ShardError> {
let mut acc = skeg_proto::ServerStats::default();
for shard in 0..self.inner.n {
match self.call(shard, ShardReq::Stats).await? {
ShardResp::Stats(bytes, evictions, n_keys, budget) => {
acc.cache_bytes += bytes;
acc.cache_evictions += evictions;
acc.n_keys += n_keys;
acc.cache_budget += budget;
}
ShardResp::Err(e) => return Err(ShardError::Storage(e)),
_ => return Err(ShardError::Unavailable),
}
}
Ok(acc)
}
pub async fn stats_per_shard(&self) -> Result<Vec<skeg_proto::ShardStats>, ShardError> {
let mut rows = Vec::with_capacity(self.inner.n);
for shard in 0..self.inner.n {
match self.call(shard, ShardReq::Stats).await? {
ShardResp::Stats(bytes, evictions, n_keys, budget) => {
rows.push(skeg_proto::ShardStats {
shard_id: shard as u32,
cache_bytes: bytes,
cache_evictions: evictions,
n_keys,
cache_budget: budget,
});
}
ShardResp::Err(e) => return Err(ShardError::Storage(e)),
_ => return Err(ShardError::Unavailable),
}
}
Ok(rows)
}
async fn broadcast(&self, mut make_req: impl FnMut() -> ShardReq) -> Result<(), ShardError> {
let mut pending = Vec::with_capacity(self.inner.n);
for sender in &self.inner.senders {
let (tx, rx) = oneshot::channel();
sender
.send(ShardMsg {
req: make_req(),
reply: tx,
})
.await
.map_err(|_| ShardError::Unavailable)?;
pending.push(rx);
}
let mut first_err = None;
for rx in pending {
match rx.await.map_err(|_| ShardError::Unavailable)? {
ShardResp::Done => {}
ShardResp::Err(e) => {
first_err.get_or_insert(e);
}
_ => return Err(ShardError::Unavailable),
}
}
match first_err {
Some(e) => Err(ShardError::Storage(e)),
None => Ok(()),
}
}
pub async fn vindex_create(
&self,
name: &str,
dim: u32,
kind: u8,
backend: u8,
) -> Result<(), ShardError> {
if dim == 0 {
return Err(ShardError::Storage(
"vindex dim must be positive".to_owned(),
));
}
let kind = match kind {
0 => QuantKind::F32,
1 => QuantKind::Int8,
2 => QuantKind::Binary,
other => return Err(ShardError::Storage(format!("unknown vindex kind {other}"))),
};
let disk = match backend {
0 => false,
1 => true,
other => {
return Err(ShardError::Storage(format!(
"unknown vindex backend {other}"
)));
}
};
let dim = dim as usize;
let name = name.to_owned();
self.broadcast(|| ShardReq::VindexCreate {
name: name.clone(),
dim,
kind,
disk,
})
.await
}
pub async fn vindex_drop(&self, name: &str) -> Result<(), ShardError> {
let name = name.to_owned();
self.broadcast(|| ShardReq::VindexDrop { name: name.clone() })
.await
}
pub async fn vindex_list(&self) -> Result<Vec<(String, u32, u8, u8, u64)>, ShardError> {
use std::collections::BTreeMap;
let mut agg: BTreeMap<String, (u32, u8, u8, u64)> = BTreeMap::new();
for shard in 0..self.inner.n {
match self.call(shard, ShardReq::VindexList).await? {
ShardResp::VindexList(rows) => {
for (name, dim, kind, backend, n_vectors) in rows {
let entry = agg.entry(name).or_insert((dim, kind, backend, 0));
entry.3 = entry.3.saturating_add(n_vectors);
}
}
ShardResp::Err(e) => return Err(ShardError::Storage(e)),
_ => return Err(ShardError::Unavailable),
}
}
Ok(agg
.into_iter()
.map(|(name, (dim, kind, backend, n))| (name, dim, kind, backend, n))
.collect())
}
pub async fn vset(&self, name: &str, id: u64, vector: Vec<f32>) -> Result<(), ShardError> {
let shard = shard_for(&id.to_le_bytes(), self.inner.n);
let req = ShardReq::Vset {
name: name.to_owned(),
id,
vector,
};
match self.call(shard, req).await? {
ShardResp::Done => Ok(()),
ShardResp::Err(e) => Err(ShardError::Storage(e)),
_ => Err(ShardError::Unavailable),
}
}
pub async fn vget(&self, name: &str, id: u64) -> Result<Option<Vec<f32>>, ShardError> {
let shard = shard_for(&id.to_le_bytes(), self.inner.n);
let req = ShardReq::Vget {
name: name.to_owned(),
id,
};
match self.call(shard, req).await? {
ShardResp::Vector(v) => Ok(v),
ShardResp::Err(e) => Err(ShardError::Storage(e)),
_ => Err(ShardError::Unavailable),
}
}
pub async fn vdel(&self, name: &str, id: u64) -> Result<bool, ShardError> {
let shard = shard_for(&id.to_le_bytes(), self.inner.n);
let req = ShardReq::Vdel {
name: name.to_owned(),
id,
};
match self.call(shard, req).await? {
ShardResp::Existed(b) => Ok(b),
ShardResp::Err(e) => Err(ShardError::Storage(e)),
_ => Err(ShardError::Unavailable),
}
}
pub async fn vsearch(
&self,
name: &str,
query: Vec<f32>,
k: usize,
l_search: u32,
) -> Result<Vec<(u64, f32)>, ShardError> {
let mut pending = Vec::with_capacity(self.inner.n);
for sender in &self.inner.senders {
let (tx, rx) = oneshot::channel();
let req = ShardReq::Vsearch {
name: name.to_owned(),
query: query.clone(),
k,
l_search,
};
sender
.send(ShardMsg { req, reply: tx })
.await
.map_err(|_| ShardError::Unavailable)?;
pending.push(rx);
}
let mut merged: Vec<(u64, f32)> = Vec::new();
let mut first_err = None;
for rx in pending {
match rx.await.map_err(|_| ShardError::Unavailable)? {
ShardResp::Vsearch(hits) => merged.extend(hits),
ShardResp::Err(e) => {
first_err.get_or_insert(e);
}
_ => return Err(ShardError::Unavailable),
}
}
if merged.is_empty()
&& let Some(e) = first_err
{
return Err(ShardError::Storage(e));
}
merged.sort_unstable_by(|a, b| b.1.total_cmp(&a.1));
merged.truncate(k);
Ok(merged)
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
#[test]
fn test_shard_routing_deterministic() {
for n in [1usize, 2, 4, 7, 16] {
for key in [b"alpha".as_slice(), b"beta", b"", b"\x00\xFF\x01"] {
let a = shard_for(key, n);
let b = shard_for(key, n);
assert_eq!(a, b, "same key must route to same shard");
assert!(a < n, "shard index in range");
}
}
}
#[test]
fn test_shard_routing_distribution() {
let n = 4usize;
let mut counts = vec![0usize; n];
let total = 1_000_000usize;
for i in 0..total {
let key = format!("key_{i}");
counts[shard_for(key.as_bytes(), n)] += 1;
}
let expected = total / n;
for (s, &c) in counts.iter().enumerate() {
let lo = expected * 9 / 10;
let hi = expected * 11 / 10;
assert!(
c >= lo && c <= hi,
"shard {s} got {c}, expected ~{expected} (±10%)"
);
}
}
#[tokio::test]
async fn test_cross_shard_set_get() {
let dir = TempDir::new().unwrap();
let shards = ShardSet::open(dir.path(), 4).unwrap();
for i in 0u32..50 {
let key = format!("ck{i}");
shards
.set(
key.as_bytes(),
format!("v{i}").as_bytes(),
Durability::Kernel,
)
.await
.unwrap();
}
for i in 0u32..50 {
let key = format!("ck{i}");
let got = shards.get(key.as_bytes()).await.unwrap();
assert_eq!(got.as_deref(), Some(format!("v{i}").as_bytes()));
}
}
#[tokio::test]
async fn test_shard_isolation() {
let dir = TempDir::new().unwrap();
let base = dir.path().to_owned();
{
let shards = ShardSet::open(&base, 2).unwrap();
for i in 0u32..40 {
let key = format!("iso{i}");
shards
.set(key.as_bytes(), b"x", Durability::Kernel)
.await
.unwrap();
}
}
let s0 = VLog::open(&base.join("shard-0")).await.unwrap();
let s1 = VLog::open(&base.join("shard-1")).await.unwrap();
for i in 0u32..40 {
let key = format!("iso{i}");
let in0 = s0.get(key.as_bytes()).await.unwrap().is_some();
let in1 = s1.get(key.as_bytes()).await.unwrap().is_some();
let expect = shard_for(key.as_bytes(), 2);
assert_eq!(in0, expect == 0, "key {key} shard-0 membership");
assert_eq!(in1, expect == 1, "key {key} shard-1 membership");
assert!(in0 ^ in1, "key {key} must live in exactly one shard");
}
}
#[tokio::test]
async fn test_mget_cross_shard() {
let dir = TempDir::new().unwrap();
let shards = ShardSet::open(dir.path(), 4).unwrap();
shards.set(b"a", b"va", Durability::Kernel).await.unwrap();
shards.set(b"b", b"vb", Durability::Kernel).await.unwrap();
shards.set(b"c", b"vc", Durability::Kernel).await.unwrap();
let keys = [
Bytes::from_static(b"a"),
Bytes::from_static(b"missing"),
Bytes::from_static(b"c"),
Bytes::from_static(b"b"),
];
let res = shards.mget(&keys).await.unwrap();
assert_eq!(res[0].as_deref(), Some(b"va".as_slice()));
assert!(res[1].is_none());
assert_eq!(res[2].as_deref(), Some(b"vc".as_slice()));
assert_eq!(res[3].as_deref(), Some(b"vb".as_slice()));
}
#[tokio::test]
async fn test_n_clients_concurrent() {
let dir = TempDir::new().unwrap();
let shards = ShardSet::open(dir.path(), 4).unwrap();
let mut handles = Vec::new();
for i in 0u32..100 {
let shards = shards.clone();
handles.push(tokio::spawn(async move {
let key = format!("cc{i}");
let val = format!("vv{i}");
shards
.set(key.as_bytes(), val.as_bytes(), Durability::Kernel)
.await
.unwrap();
let got = shards.get(key.as_bytes()).await.unwrap();
assert_eq!(got.as_deref(), Some(val.as_bytes()));
}));
}
for h in handles {
h.await.unwrap();
}
}
#[tokio::test]
async fn test_get_latency_flat_across_payload_size() {
use std::time::Instant;
let dir = TempDir::new().unwrap();
let shards = ShardSet::open(dir.path(), 1).unwrap();
shards
.set(b"small", &vec![1u8; 4096], Durability::Kernel)
.await
.unwrap();
shards
.set(b"large", &vec![1u8; 1024 * 1024], Durability::Kernel)
.await
.unwrap();
for _ in 0..64 {
let _ = shards.get(b"small").await.unwrap();
let _ = shards.get(b"large").await.unwrap();
}
let n = 2000;
let t0 = Instant::now();
for _ in 0..n {
let _ = shards.get(b"small").await.unwrap();
}
let small = t0.elapsed();
let t1 = Instant::now();
for _ in 0..n {
let _ = shards.get(b"large").await.unwrap();
}
let large = t1.elapsed();
assert!(
large < small * 8,
"GET latency scales with payload size: small={small:?} large={large:?} \
- the zero-copy read path regressed (a memcpy was introduced)"
);
}
#[tokio::test]
async fn test_del_routes_correctly() {
let dir = TempDir::new().unwrap();
let shards = ShardSet::open(dir.path(), 4).unwrap();
shards.set(b"dk", b"v", Durability::Kernel).await.unwrap();
assert!(shards.del(b"dk", Durability::Kernel).await.unwrap());
assert!(!shards.del(b"dk", Durability::Kernel).await.unwrap());
assert!(shards.get(b"dk").await.unwrap().is_none());
}
#[allow(clippy::cast_precision_loss)]
fn tvec(seed: u64) -> Vec<f32> {
let mut s = (seed << 1) | 1;
(0..64)
.map(|_| {
s ^= s << 13;
s ^= s >> 7;
s ^= s << 17;
((s & 0xFFFF) as f32 / 32768.0) - 1.0
})
.collect()
}
#[tokio::test]
async fn test_vindex_disk_survives_restart() {
let dir = TempDir::new().unwrap();
let base = dir.path().to_owned();
{
let shards = ShardSet::open(&base, 4).unwrap();
shards.vindex_create("persist", 64, 0, 1).await.unwrap();
for id in 0u64..150 {
shards.vset("persist", id, tvec(id + 1)).await.unwrap();
}
}
let shards = ShardSet::open(&base, 4).unwrap();
let hits = shards.vsearch("persist", tvec(89), 5, 0).await.unwrap();
assert_eq!(hits[0].0, 88, "disk VINDEX must be recovered after restart");
assert!(shards.vget("persist", 42).await.unwrap().is_some());
}
}