mod aof;
mod blocking;
mod persistence;
use std::collections::{HashMap, VecDeque};
use std::path::PathBuf;
use std::time::Duration;
use bytes::Bytes;
use ember_persistence::aof::{AofRecord, AofWriter, FsyncPolicy};
use ember_persistence::recovery::{self, RecoveredValue};
use ember_persistence::snapshot::{self, SnapEntry, SnapValue, SnapshotWriter};
use smallvec::{smallvec, SmallVec};
use tokio::sync::{broadcast, mpsc, oneshot};
use tracing::{debug, error, info, warn};
use crate::dropper::DropHandle;
use crate::error::ShardError;
use crate::expiry;
use crate::keyspace::{
IncrError, IncrFloatError, Keyspace, KeyspaceStats, LsetError, SetResult, ShardConfig,
TtlResult, WriteError,
};
use crate::types::sorted_set::{ScoreBound, ZAddFlags};
use crate::types::Value;
const EXPIRY_TICK: Duration = Duration::from_millis(100);
const FSYNC_INTERVAL: Duration = Duration::from_secs(1);
#[derive(Debug, Clone)]
pub struct ReplicationEvent {
pub shard_id: u16,
pub offset: u64,
pub record: AofRecord,
}
#[derive(Debug, Clone)]
pub struct ShardPersistenceConfig {
pub data_dir: PathBuf,
pub append_only: bool,
pub fsync_policy: FsyncPolicy,
#[cfg(feature = "encryption")]
pub encryption_key: Option<ember_persistence::encryption::EncryptionKey>,
}
#[derive(Debug)]
pub enum ShardRequest {
Get {
key: String,
},
Set {
key: String,
value: Bytes,
expire: Option<Duration>,
nx: bool,
xx: bool,
},
Incr {
key: String,
},
Decr {
key: String,
},
IncrBy {
key: String,
delta: i64,
},
DecrBy {
key: String,
delta: i64,
},
IncrByFloat {
key: String,
delta: f64,
},
Append {
key: String,
value: Bytes,
},
Strlen {
key: String,
},
GetRange {
key: String,
start: i64,
end: i64,
},
SetRange {
key: String,
offset: usize,
value: Bytes,
},
Keys {
pattern: String,
},
Rename {
key: String,
newkey: String,
},
Copy {
source: String,
destination: String,
replace: bool,
},
ObjectEncoding {
key: String,
},
Del {
key: String,
},
Unlink {
key: String,
},
Exists {
key: String,
},
RandomKey,
Touch {
key: String,
},
Sort {
key: String,
desc: bool,
alpha: bool,
limit: Option<(i64, i64)>,
},
Expire {
key: String,
seconds: u64,
},
Ttl {
key: String,
},
Persist {
key: String,
},
Pttl {
key: String,
},
Pexpire {
key: String,
milliseconds: u64,
},
LPush {
key: String,
values: Vec<Bytes>,
},
RPush {
key: String,
values: Vec<Bytes>,
},
LPop {
key: String,
},
RPop {
key: String,
},
BLPop {
key: String,
waiter: mpsc::Sender<(String, Bytes)>,
},
BRPop {
key: String,
waiter: mpsc::Sender<(String, Bytes)>,
},
LRange {
key: String,
start: i64,
stop: i64,
},
LLen {
key: String,
},
LIndex {
key: String,
index: i64,
},
LSet {
key: String,
index: i64,
value: Bytes,
},
LTrim {
key: String,
start: i64,
stop: i64,
},
LInsert {
key: String,
before: bool,
pivot: Bytes,
value: Bytes,
},
LRem {
key: String,
count: i64,
value: Bytes,
},
LPos {
key: String,
element: Bytes,
rank: i64,
count: usize,
maxlen: usize,
},
Type {
key: String,
},
ZAdd {
key: String,
members: Vec<(f64, String)>,
nx: bool,
xx: bool,
gt: bool,
lt: bool,
ch: bool,
},
ZRem {
key: String,
members: Vec<String>,
},
ZScore {
key: String,
member: String,
},
ZRank {
key: String,
member: String,
},
ZRevRank {
key: String,
member: String,
},
ZCard {
key: String,
},
ZRange {
key: String,
start: i64,
stop: i64,
with_scores: bool,
},
ZRevRange {
key: String,
start: i64,
stop: i64,
with_scores: bool,
},
ZCount {
key: String,
min: ScoreBound,
max: ScoreBound,
},
ZIncrBy {
key: String,
increment: f64,
member: String,
},
ZRangeByScore {
key: String,
min: ScoreBound,
max: ScoreBound,
offset: usize,
count: Option<usize>,
},
ZRevRangeByScore {
key: String,
min: ScoreBound,
max: ScoreBound,
offset: usize,
count: Option<usize>,
},
ZPopMin {
key: String,
count: usize,
},
ZPopMax {
key: String,
count: usize,
},
HSet {
key: String,
fields: Vec<(String, Bytes)>,
},
HGet {
key: String,
field: String,
},
HGetAll {
key: String,
},
HDel {
key: String,
fields: Vec<String>,
},
HExists {
key: String,
field: String,
},
HLen {
key: String,
},
HIncrBy {
key: String,
field: String,
delta: i64,
},
HKeys {
key: String,
},
HVals {
key: String,
},
HMGet {
key: String,
fields: Vec<String>,
},
SAdd {
key: String,
members: Vec<String>,
},
SRem {
key: String,
members: Vec<String>,
},
SMembers {
key: String,
},
SIsMember {
key: String,
member: String,
},
SCard {
key: String,
},
SUnion {
keys: Vec<String>,
},
SInter {
keys: Vec<String>,
},
SDiff {
keys: Vec<String>,
},
SUnionStore {
dest: String,
keys: Vec<String>,
},
SInterStore {
dest: String,
keys: Vec<String>,
},
SDiffStore {
dest: String,
keys: Vec<String>,
},
SRandMember {
key: String,
count: i64,
},
SPop {
key: String,
count: usize,
},
SMisMember {
key: String,
members: Vec<String>,
},
DbSize,
Stats,
KeyVersion {
key: String,
},
Snapshot,
SerializeSnapshot,
RewriteAof,
FlushDb,
FlushDbAsync,
Scan {
cursor: u64,
count: usize,
pattern: Option<String>,
},
SScan {
key: String,
cursor: u64,
count: usize,
pattern: Option<String>,
},
HScan {
key: String,
cursor: u64,
count: usize,
pattern: Option<String>,
},
ZScan {
key: String,
cursor: u64,
count: usize,
pattern: Option<String>,
},
CountKeysInSlot {
slot: u16,
},
GetKeysInSlot {
slot: u16,
count: usize,
},
DumpKey {
key: String,
},
RestoreKey {
key: String,
ttl_ms: u64,
data: bytes::Bytes,
replace: bool,
},
#[cfg(feature = "vector")]
VAdd {
key: String,
element: String,
vector: Vec<f32>,
metric: u8,
quantization: u8,
connectivity: u32,
expansion_add: u32,
},
#[cfg(feature = "vector")]
VAddBatch {
key: String,
entries: Vec<(String, Vec<f32>)>,
dim: usize,
metric: u8,
quantization: u8,
connectivity: u32,
expansion_add: u32,
},
#[cfg(feature = "vector")]
VSim {
key: String,
query: Vec<f32>,
count: usize,
ef_search: usize,
},
#[cfg(feature = "vector")]
VRem {
key: String,
element: String,
},
#[cfg(feature = "vector")]
VGet {
key: String,
element: String,
},
#[cfg(feature = "vector")]
VCard {
key: String,
},
#[cfg(feature = "vector")]
VDim {
key: String,
},
#[cfg(feature = "vector")]
VInfo {
key: String,
},
#[cfg(feature = "protobuf")]
ProtoSet {
key: String,
type_name: String,
data: Bytes,
expire: Option<Duration>,
nx: bool,
xx: bool,
},
#[cfg(feature = "protobuf")]
ProtoGet {
key: String,
},
#[cfg(feature = "protobuf")]
ProtoType {
key: String,
},
#[cfg(feature = "protobuf")]
ProtoRegisterAof {
name: String,
descriptor: Bytes,
},
#[cfg(feature = "protobuf")]
ProtoSetField {
key: String,
field_path: String,
value: String,
},
#[cfg(feature = "protobuf")]
ProtoDelField {
key: String,
field_path: String,
},
}
impl ShardRequest {
fn is_write(&self) -> bool {
#[allow(unreachable_patterns)]
match self {
ShardRequest::Set { .. }
| ShardRequest::Incr { .. }
| ShardRequest::Decr { .. }
| ShardRequest::IncrBy { .. }
| ShardRequest::DecrBy { .. }
| ShardRequest::IncrByFloat { .. }
| ShardRequest::Append { .. }
| ShardRequest::Del { .. }
| ShardRequest::Unlink { .. }
| ShardRequest::Rename { .. }
| ShardRequest::Copy { .. }
| ShardRequest::Expire { .. }
| ShardRequest::Persist { .. }
| ShardRequest::Pexpire { .. }
| ShardRequest::LPush { .. }
| ShardRequest::RPush { .. }
| ShardRequest::LPop { .. }
| ShardRequest::RPop { .. }
| ShardRequest::LSet { .. }
| ShardRequest::LTrim { .. }
| ShardRequest::LInsert { .. }
| ShardRequest::LRem { .. }
| ShardRequest::BLPop { .. }
| ShardRequest::BRPop { .. }
| ShardRequest::ZAdd { .. }
| ShardRequest::ZRem { .. }
| ShardRequest::ZIncrBy { .. }
| ShardRequest::ZPopMin { .. }
| ShardRequest::ZPopMax { .. }
| ShardRequest::HSet { .. }
| ShardRequest::HDel { .. }
| ShardRequest::HIncrBy { .. }
| ShardRequest::SAdd { .. }
| ShardRequest::SRem { .. }
| ShardRequest::SPop { .. }
| ShardRequest::SUnionStore { .. }
| ShardRequest::SInterStore { .. }
| ShardRequest::SDiffStore { .. }
| ShardRequest::FlushDb
| ShardRequest::FlushDbAsync
| ShardRequest::RestoreKey { .. } => true,
#[cfg(feature = "protobuf")]
ShardRequest::ProtoSet { .. }
| ShardRequest::ProtoRegisterAof { .. }
| ShardRequest::ProtoSetField { .. }
| ShardRequest::ProtoDelField { .. } => true,
#[cfg(feature = "vector")]
ShardRequest::VAdd { .. }
| ShardRequest::VAddBatch { .. }
| ShardRequest::VRem { .. } => true,
_ => false,
}
}
}
#[derive(Debug)]
pub enum ShardResponse {
Value(Option<Value>),
Ok,
Integer(i64),
Bool(bool),
Ttl(TtlResult),
OutOfMemory,
KeyCount(usize),
Stats(KeyspaceStats),
Len(usize),
Array(Vec<Bytes>),
TypeName(&'static str),
EncodingName(Option<&'static str>),
ZAddLen {
count: usize,
applied: Vec<(f64, String)>,
},
ZRemLen { count: usize, removed: Vec<String> },
Score(Option<f64>),
Rank(Option<usize>),
ScoredArray(Vec<(String, f64)>),
ZIncrByResult { new_score: f64, member: String },
ZPopResult(Vec<(String, f64)>),
BulkString(String),
WrongType,
Err(String),
Scan { cursor: u64, keys: Vec<String> },
CollectionScan { cursor: u64, items: Vec<Bytes> },
HashFields(Vec<(String, Bytes)>),
HDelLen { count: usize, removed: Vec<String> },
StringArray(Vec<String>),
IntegerArray(Vec<i64>),
BoolArray(Vec<bool>),
SetStoreResult { count: usize, members: Vec<String> },
KeyDump { data: Vec<u8>, ttl_ms: i64 },
SnapshotData { shard_id: u16, data: Vec<u8> },
OptionalArray(Vec<Option<Bytes>>),
#[cfg(feature = "vector")]
VAddResult {
element: String,
vector: Vec<f32>,
added: bool,
},
#[cfg(feature = "vector")]
VAddBatchResult {
added_count: usize,
applied: Vec<(String, Vec<f32>)>,
},
#[cfg(feature = "vector")]
VSimResult(Vec<(String, f32)>),
#[cfg(feature = "vector")]
VectorData(Option<Vec<f32>>),
#[cfg(feature = "vector")]
VectorInfo(Option<Vec<(String, String)>>),
#[cfg(feature = "protobuf")]
ProtoValue(Option<(String, Bytes, Option<Duration>)>),
#[cfg(feature = "protobuf")]
ProtoTypeName(Option<String>),
#[cfg(feature = "protobuf")]
ProtoFieldUpdated {
type_name: String,
data: Bytes,
expire: Option<Duration>,
},
Version(Option<u64>),
}
#[derive(Debug)]
pub enum ShardMessage {
Single {
request: ShardRequest,
reply: oneshot::Sender<ShardResponse>,
},
SingleReusable {
request: ShardRequest,
reply: mpsc::Sender<ShardResponse>,
},
Batch(Vec<(ShardRequest, oneshot::Sender<ShardResponse>)>),
}
#[derive(Debug, Clone)]
pub struct ShardHandle {
tx: mpsc::Sender<ShardMessage>,
}
impl ShardHandle {
pub async fn send(&self, request: ShardRequest) -> Result<ShardResponse, ShardError> {
let rx = self.dispatch(request).await?;
rx.await.map_err(|_| ShardError::Unavailable)
}
pub async fn dispatch(
&self,
request: ShardRequest,
) -> Result<oneshot::Receiver<ShardResponse>, ShardError> {
let (reply_tx, reply_rx) = oneshot::channel();
let msg = ShardMessage::Single {
request,
reply: reply_tx,
};
self.tx
.send(msg)
.await
.map_err(|_| ShardError::Unavailable)?;
Ok(reply_rx)
}
pub async fn dispatch_reusable(
&self,
request: ShardRequest,
reply: mpsc::Sender<ShardResponse>,
) -> Result<(), ShardError> {
self.tx
.send(ShardMessage::SingleReusable { request, reply })
.await
.map_err(|_| ShardError::Unavailable)
}
pub async fn dispatch_batch(
&self,
requests: Vec<ShardRequest>,
) -> Result<Vec<oneshot::Receiver<ShardResponse>>, ShardError> {
if requests.len() == 1 {
let rx = self
.dispatch(requests.into_iter().next().expect("len == 1"))
.await?;
return Ok(vec![rx]);
}
let mut receivers = Vec::with_capacity(requests.len());
let mut entries = Vec::with_capacity(requests.len());
for request in requests {
let (tx, rx) = oneshot::channel();
entries.push((request, tx));
receivers.push(rx);
}
self.tx
.send(ShardMessage::Batch(entries))
.await
.map_err(|_| ShardError::Unavailable)?;
Ok(receivers)
}
}
pub struct PreparedShard {
rx: mpsc::Receiver<ShardMessage>,
config: ShardConfig,
persistence: Option<ShardPersistenceConfig>,
drop_handle: Option<DropHandle>,
replication_tx: Option<broadcast::Sender<ReplicationEvent>>,
#[cfg(feature = "protobuf")]
schema_registry: Option<crate::schema::SharedSchemaRegistry>,
}
pub fn prepare_shard(
buffer: usize,
config: ShardConfig,
persistence: Option<ShardPersistenceConfig>,
drop_handle: Option<DropHandle>,
replication_tx: Option<broadcast::Sender<ReplicationEvent>>,
#[cfg(feature = "protobuf")] schema_registry: Option<crate::schema::SharedSchemaRegistry>,
) -> (ShardHandle, PreparedShard) {
let (tx, rx) = mpsc::channel(buffer);
let prepared = PreparedShard {
rx,
config,
persistence,
drop_handle,
replication_tx,
#[cfg(feature = "protobuf")]
schema_registry,
};
(ShardHandle { tx }, prepared)
}
pub async fn run_prepared(prepared: PreparedShard) {
run_shard(
prepared.rx,
prepared.config,
prepared.persistence,
prepared.drop_handle,
prepared.replication_tx,
#[cfg(feature = "protobuf")]
prepared.schema_registry,
)
.await
}
pub fn spawn_shard(
buffer: usize,
config: ShardConfig,
persistence: Option<ShardPersistenceConfig>,
drop_handle: Option<DropHandle>,
replication_tx: Option<broadcast::Sender<ReplicationEvent>>,
#[cfg(feature = "protobuf")] schema_registry: Option<crate::schema::SharedSchemaRegistry>,
) -> ShardHandle {
let (handle, prepared) = prepare_shard(
buffer,
config,
persistence,
drop_handle,
replication_tx,
#[cfg(feature = "protobuf")]
schema_registry,
);
tokio::spawn(run_prepared(prepared));
handle
}
async fn run_shard(
mut rx: mpsc::Receiver<ShardMessage>,
config: ShardConfig,
persistence: Option<ShardPersistenceConfig>,
drop_handle: Option<DropHandle>,
replication_tx: Option<broadcast::Sender<ReplicationEvent>>,
#[cfg(feature = "protobuf")] schema_registry: Option<crate::schema::SharedSchemaRegistry>,
) {
let shard_id = config.shard_id;
let mut keyspace = Keyspace::with_config(config);
if let Some(handle) = drop_handle.clone() {
keyspace.set_drop_handle(handle);
}
if let Some(ref pcfg) = persistence {
#[cfg(feature = "encryption")]
let result = if let Some(ref key) = pcfg.encryption_key {
recovery::recover_shard_encrypted(&pcfg.data_dir, shard_id, key.clone())
} else {
recovery::recover_shard(&pcfg.data_dir, shard_id)
};
#[cfg(not(feature = "encryption"))]
let result = recovery::recover_shard(&pcfg.data_dir, shard_id);
let count = result.entries.len();
for entry in result.entries {
let value = match entry.value {
RecoveredValue::String(data) => Value::String(data),
RecoveredValue::List(deque) => Value::List(deque),
RecoveredValue::SortedSet(members) => {
let mut ss = crate::types::sorted_set::SortedSet::new();
for (score, member) in members {
ss.add(&member, score);
}
Value::SortedSet(Box::new(ss))
}
RecoveredValue::Hash(map) => {
Value::Hash(Box::new(crate::types::hash::HashValue::from(map)))
}
RecoveredValue::Set(set) => Value::Set(Box::new(set)),
#[cfg(feature = "vector")]
RecoveredValue::Vector {
metric,
quantization,
connectivity,
expansion_add,
elements,
} => {
use crate::types::vector::{DistanceMetric, QuantizationType, VectorSet};
let dim = elements.first().map(|(_, v)| v.len()).unwrap_or(0);
match VectorSet::new(
dim,
DistanceMetric::from_u8(metric),
QuantizationType::from_u8(quantization),
connectivity as usize,
expansion_add as usize,
) {
Ok(mut vs) => {
for (element, vector) in elements {
if let Err(e) = vs.add(element, &vector) {
warn!("vector recovery: failed to add element: {e}");
}
}
Value::Vector(vs)
}
Err(e) => {
warn!("vector recovery: failed to create index: {e}");
continue;
}
}
}
#[cfg(feature = "protobuf")]
RecoveredValue::Proto { type_name, data } => Value::Proto { type_name, data },
};
keyspace.restore(entry.key, value, entry.ttl);
}
if count > 0 {
info!(
shard_id,
recovered_keys = count,
snapshot = result.loaded_snapshot,
aof = result.replayed_aof,
"recovered shard state"
);
}
#[cfg(feature = "protobuf")]
if let Some(ref registry) = schema_registry {
if !result.schemas.is_empty() {
if let Ok(mut reg) = registry.write() {
let schema_count = result.schemas.len();
for (name, descriptor) in result.schemas {
reg.restore(name, descriptor);
}
info!(
shard_id,
schemas = schema_count,
"restored schemas from AOF"
);
}
}
}
}
let mut aof_writer: Option<AofWriter> = match &persistence {
Some(pcfg) if pcfg.append_only => {
let path = ember_persistence::aof::aof_path(&pcfg.data_dir, shard_id);
#[cfg(feature = "encryption")]
let result = if let Some(ref key) = pcfg.encryption_key {
AofWriter::open_encrypted(path, key.clone())
} else {
AofWriter::open(path)
};
#[cfg(not(feature = "encryption"))]
let result = AofWriter::open(path);
match result {
Ok(w) => Some(w),
Err(e) => {
warn!(shard_id, "failed to open AOF writer: {e}");
None
}
}
}
_ => None,
};
let fsync_policy = persistence
.as_ref()
.map(|p| p.fsync_policy)
.unwrap_or(FsyncPolicy::No);
let mut replication_offset: u64 = 0;
let mut lpop_waiters: HashMap<String, VecDeque<mpsc::Sender<(String, Bytes)>>> = HashMap::new();
let mut rpop_waiters: HashMap<String, VecDeque<mpsc::Sender<(String, Bytes)>>> = HashMap::new();
let mut aof_errors: u32 = 0;
let mut disk_full: bool = false;
let mut expiry_tick = tokio::time::interval(EXPIRY_TICK);
expiry_tick.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
let mut fsync_tick = tokio::time::interval(FSYNC_INTERVAL);
fsync_tick.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
loop {
tokio::select! {
msg = rx.recv() => {
match msg {
Some(msg) => {
let mut ctx = ProcessCtx {
keyspace: &mut keyspace,
aof_writer: &mut aof_writer,
fsync_policy,
persistence: &persistence,
drop_handle: &drop_handle,
shard_id,
replication_tx: &replication_tx,
replication_offset: &mut replication_offset,
lpop_waiters: &mut lpop_waiters,
rpop_waiters: &mut rpop_waiters,
aof_errors: &mut aof_errors,
disk_full: &mut disk_full,
#[cfg(feature = "protobuf")]
schema_registry: &schema_registry,
};
process_message(msg, &mut ctx);
while let Ok(msg) = rx.try_recv() {
process_message(msg, &mut ctx);
}
}
None => break, }
}
_ = expiry_tick.tick() => {
expiry::run_expiration_cycle(&mut keyspace);
}
_ = fsync_tick.tick(), if fsync_policy == FsyncPolicy::EverySec => {
if let Some(ref mut writer) = aof_writer {
if let Err(e) = writer.sync() {
if aof::log_aof_error(shard_id, &mut aof_errors, "sync", &e) {
disk_full = true;
}
} else if aof_errors > 0 {
let missed = aof_errors;
aof_errors = 0;
if disk_full {
disk_full = false;
info!(shard_id, missed_errors = missed, "aof sync recovered, accepting writes again");
} else {
info!(shard_id, missed_errors = missed, "aof sync recovered");
}
}
}
}
}
}
if let Some(ref mut writer) = aof_writer {
let _ = writer.sync();
}
}
struct ProcessCtx<'a> {
keyspace: &'a mut Keyspace,
aof_writer: &'a mut Option<AofWriter>,
fsync_policy: FsyncPolicy,
persistence: &'a Option<ShardPersistenceConfig>,
drop_handle: &'a Option<DropHandle>,
shard_id: u16,
replication_tx: &'a Option<broadcast::Sender<ReplicationEvent>>,
replication_offset: &'a mut u64,
lpop_waiters: &'a mut HashMap<String, VecDeque<mpsc::Sender<(String, Bytes)>>>,
rpop_waiters: &'a mut HashMap<String, VecDeque<mpsc::Sender<(String, Bytes)>>>,
aof_errors: &'a mut u32,
disk_full: &'a mut bool,
#[cfg(feature = "protobuf")]
schema_registry: &'a Option<crate::schema::SharedSchemaRegistry>,
}
fn process_message(msg: ShardMessage, ctx: &mut ProcessCtx<'_>) {
match msg {
ShardMessage::Single { request, reply } => {
process_single(request, ReplySender::Oneshot(reply), ctx);
}
ShardMessage::SingleReusable { request, reply } => {
process_single(request, ReplySender::Reusable(reply), ctx);
}
ShardMessage::Batch(entries) => {
for (request, reply) in entries {
process_single(request, ReplySender::Oneshot(reply), ctx);
}
}
}
}
enum ReplySender {
Oneshot(oneshot::Sender<ShardResponse>),
Reusable(mpsc::Sender<ShardResponse>),
}
impl ReplySender {
fn send(self, response: ShardResponse) {
match self {
ReplySender::Oneshot(tx) => {
let _ = tx.send(response);
}
ReplySender::Reusable(tx) => {
if let Err(e) = tx.try_send(response) {
debug!("reusable reply channel full or closed: {e}");
}
}
}
}
}
fn process_single(mut request: ShardRequest, reply: ReplySender, ctx: &mut ProcessCtx<'_>) {
let fsync_policy = ctx.fsync_policy;
let shard_id = ctx.shard_id;
if *ctx.disk_full && ctx.aof_writer.is_some() && request.is_write() {
reply.send(ShardResponse::Err(
"ERR disk full, write rejected — free disk space to resume writes".into(),
));
return;
}
match request {
ShardRequest::BLPop { key, waiter } => {
blocking::handle_blocking_pop(&key, waiter, true, reply, ctx);
return;
}
ShardRequest::BRPop { key, waiter } => {
blocking::handle_blocking_pop(&key, waiter, false, reply, ctx);
return;
}
_ => {}
}
let request_kind = describe_request(&request);
let mut response = dispatch(
ctx.keyspace,
&mut request,
#[cfg(feature = "protobuf")]
ctx.schema_registry,
);
if let ShardRequest::LPush { ref key, .. } | ShardRequest::RPush { ref key, .. } = request {
if matches!(response, ShardResponse::Len(_)) {
blocking::wake_blocked_waiters(key, ctx);
}
}
let records = aof::to_aof_records(request, &mut response);
if let Some(ref mut writer) = *ctx.aof_writer {
let mut batch_ok = true;
for record in &records {
if let Err(e) = writer.write_record(record) {
if aof::log_aof_error(shard_id, ctx.aof_errors, "write", &e) {
*ctx.disk_full = true;
}
batch_ok = false;
}
}
if !records.is_empty() && fsync_policy == FsyncPolicy::Always {
if let Err(e) = writer.sync() {
if aof::log_aof_error(shard_id, ctx.aof_errors, "sync", &e) {
*ctx.disk_full = true;
}
batch_ok = false;
}
}
if batch_ok && *ctx.aof_errors > 0 {
let missed = *ctx.aof_errors;
*ctx.aof_errors = 0;
*ctx.disk_full = false;
info!(shard_id, missed_errors = missed, "aof writes recovered");
}
}
if let Some(ref tx) = *ctx.replication_tx {
for record in records {
*ctx.replication_offset += 1;
let _ = tx.send(ReplicationEvent {
shard_id,
offset: *ctx.replication_offset,
record,
});
}
}
match request_kind {
RequestKind::Snapshot => {
let resp = persistence::handle_snapshot(ctx.keyspace, ctx.persistence, shard_id);
reply.send(resp);
return;
}
RequestKind::SerializeSnapshot => {
let resp = persistence::handle_serialize_snapshot(ctx.keyspace, shard_id);
reply.send(resp);
return;
}
RequestKind::RewriteAof => {
let resp = persistence::handle_rewrite(
ctx.keyspace,
ctx.persistence,
ctx.aof_writer,
shard_id,
#[cfg(feature = "protobuf")]
ctx.schema_registry,
);
reply.send(resp);
return;
}
RequestKind::FlushDbAsync => {
let old_entries = ctx.keyspace.flush_async();
if let Some(ref handle) = *ctx.drop_handle {
handle.defer_entries(old_entries);
}
reply.send(ShardResponse::Ok);
return;
}
RequestKind::Other => {}
}
reply.send(response);
}
enum RequestKind {
Snapshot,
SerializeSnapshot,
RewriteAof,
FlushDbAsync,
Other,
}
fn describe_request(req: &ShardRequest) -> RequestKind {
match req {
ShardRequest::Snapshot => RequestKind::Snapshot,
ShardRequest::SerializeSnapshot => RequestKind::SerializeSnapshot,
ShardRequest::RewriteAof => RequestKind::RewriteAof,
ShardRequest::FlushDbAsync => RequestKind::FlushDbAsync,
_ => RequestKind::Other,
}
}
fn incr_result(result: Result<i64, IncrError>) -> ShardResponse {
match result {
Ok(val) => ShardResponse::Integer(val),
Err(IncrError::WrongType) => ShardResponse::WrongType,
Err(IncrError::OutOfMemory) => ShardResponse::OutOfMemory,
Err(e) => ShardResponse::Err(e.to_string()),
}
}
fn write_result_len(result: Result<usize, WriteError>) -> ShardResponse {
match result {
Ok(len) => ShardResponse::Len(len),
Err(WriteError::WrongType) => ShardResponse::WrongType,
Err(WriteError::OutOfMemory) => ShardResponse::OutOfMemory,
}
}
fn store_set_response(result: Result<(usize, Vec<String>), WriteError>) -> ShardResponse {
match result {
Ok((count, members)) => ShardResponse::SetStoreResult { count, members },
Err(WriteError::WrongType) => ShardResponse::WrongType,
Err(WriteError::OutOfMemory) => ShardResponse::OutOfMemory,
}
}
fn dispatch(
ks: &mut Keyspace,
req: &mut ShardRequest,
#[cfg(feature = "protobuf")] schema_registry: &Option<crate::schema::SharedSchemaRegistry>,
) -> ShardResponse {
match req {
ShardRequest::Get { key } => match ks.get_string(key) {
Ok(val) => ShardResponse::Value(val.map(Value::String)),
Err(_) => ShardResponse::WrongType,
},
ShardRequest::Set {
key,
value,
expire,
nx,
xx,
} => match ks.set(key.clone(), value.clone(), *expire, *nx, *xx) {
SetResult::Ok => ShardResponse::Ok,
SetResult::Blocked => ShardResponse::Value(None),
SetResult::OutOfMemory => ShardResponse::OutOfMemory,
},
ShardRequest::Incr { key } => incr_result(ks.incr(key)),
ShardRequest::Decr { key } => incr_result(ks.decr(key)),
ShardRequest::IncrBy { key, delta } => incr_result(ks.incr_by(key, *delta)),
ShardRequest::DecrBy { key, delta } => match delta.checked_neg() {
Some(neg) => incr_result(ks.incr_by(key, neg)),
None => ShardResponse::Err("ERR increment or decrement would overflow".into()),
},
ShardRequest::IncrByFloat { key, delta } => match ks.incr_by_float(key, *delta) {
Ok(val) => ShardResponse::BulkString(val),
Err(IncrFloatError::WrongType) => ShardResponse::WrongType,
Err(IncrFloatError::OutOfMemory) => ShardResponse::OutOfMemory,
Err(e) => ShardResponse::Err(e.to_string()),
},
ShardRequest::Append { key, value } => write_result_len(ks.append(key, value)),
ShardRequest::Strlen { key } => match ks.strlen(key) {
Ok(len) => ShardResponse::Len(len),
Err(_) => ShardResponse::WrongType,
},
ShardRequest::GetRange { key, start, end } => match ks.getrange(key, *start, *end) {
Ok(data) => ShardResponse::Value(Some(Value::String(data))),
Err(_) => ShardResponse::WrongType,
},
ShardRequest::SetRange { key, offset, value } => {
write_result_len(ks.setrange(key, *offset, value))
}
ShardRequest::Keys { pattern } => {
let keys = ks.keys(pattern);
ShardResponse::StringArray(keys)
}
ShardRequest::Rename { key, newkey } => {
use crate::keyspace::RenameError;
match ks.rename(key, newkey) {
Ok(()) => ShardResponse::Ok,
Err(RenameError::NoSuchKey) => ShardResponse::Err("ERR no such key".into()),
}
}
ShardRequest::Copy {
source,
destination,
replace,
} => {
use crate::keyspace::CopyError;
match ks.copy(source, destination, *replace) {
Ok(copied) => ShardResponse::Bool(copied),
Err(CopyError::NoSuchKey) => ShardResponse::Err("ERR no such key".into()),
Err(CopyError::OutOfMemory) => ShardResponse::OutOfMemory,
}
}
ShardRequest::ObjectEncoding { key } => ShardResponse::EncodingName(ks.encoding(key)),
ShardRequest::Del { key } => ShardResponse::Bool(ks.del(key)),
ShardRequest::Unlink { key } => ShardResponse::Bool(ks.unlink(key)),
ShardRequest::Exists { key } => ShardResponse::Bool(ks.exists(key)),
ShardRequest::RandomKey => match ks.random_key() {
Some(k) => ShardResponse::StringArray(vec![k]),
None => ShardResponse::StringArray(vec![]),
},
ShardRequest::Touch { key } => ShardResponse::Bool(ks.touch(key)),
ShardRequest::Sort {
key,
desc,
alpha,
limit,
} => match ks.sort(key, *desc, *alpha, *limit) {
Ok(items) => ShardResponse::Array(items),
Err(_) => ShardResponse::WrongType,
},
ShardRequest::Expire { key, seconds } => ShardResponse::Bool(ks.expire(key, *seconds)),
ShardRequest::Ttl { key } => ShardResponse::Ttl(ks.ttl(key)),
ShardRequest::Persist { key } => ShardResponse::Bool(ks.persist(key)),
ShardRequest::Pttl { key } => ShardResponse::Ttl(ks.pttl(key)),
ShardRequest::Pexpire { key, milliseconds } => {
ShardResponse::Bool(ks.pexpire(key, *milliseconds))
}
ShardRequest::LPush { key, values } => write_result_len(ks.lpush(key, values)),
ShardRequest::RPush { key, values } => write_result_len(ks.rpush(key, values)),
ShardRequest::LPop { key } => match ks.lpop(key) {
Ok(val) => ShardResponse::Value(val.map(Value::String)),
Err(_) => ShardResponse::WrongType,
},
ShardRequest::RPop { key } => match ks.rpop(key) {
Ok(val) => ShardResponse::Value(val.map(Value::String)),
Err(_) => ShardResponse::WrongType,
},
ShardRequest::LRange { key, start, stop } => match ks.lrange(key, *start, *stop) {
Ok(items) => ShardResponse::Array(items),
Err(_) => ShardResponse::WrongType,
},
ShardRequest::LLen { key } => match ks.llen(key) {
Ok(len) => ShardResponse::Len(len),
Err(_) => ShardResponse::WrongType,
},
ShardRequest::LIndex { key, index } => match ks.lindex(key, *index) {
Ok(val) => ShardResponse::Value(val.map(Value::String)),
Err(_) => ShardResponse::WrongType,
},
ShardRequest::LSet { key, index, value } => match ks.lset(key, *index, value.clone()) {
Ok(()) => ShardResponse::Ok,
Err(e) => match e {
LsetError::WrongType => ShardResponse::WrongType,
LsetError::NoSuchKey => ShardResponse::Err("ERR no such key".into()),
LsetError::IndexOutOfRange => ShardResponse::Err("ERR index out of range".into()),
},
},
ShardRequest::LTrim { key, start, stop } => match ks.ltrim(key, *start, *stop) {
Ok(()) => ShardResponse::Ok,
Err(_) => ShardResponse::WrongType,
},
ShardRequest::LInsert {
key,
before,
pivot,
value,
} => match ks.linsert(key, *before, pivot, value.clone()) {
Ok(n) => ShardResponse::Integer(n),
Err(WriteError::WrongType) => ShardResponse::WrongType,
Err(WriteError::OutOfMemory) => ShardResponse::OutOfMemory,
},
ShardRequest::LRem { key, count, value } => match ks.lrem(key, *count, value) {
Ok(n) => ShardResponse::Len(n),
Err(_) => ShardResponse::WrongType,
},
ShardRequest::LPos {
key,
element,
rank,
count,
maxlen,
} => match ks.lpos(key, element, *rank, *count, *maxlen) {
Ok(positions) => ShardResponse::IntegerArray(positions),
Err(_) => ShardResponse::WrongType,
},
ShardRequest::Type { key } => ShardResponse::TypeName(ks.value_type(key)),
ShardRequest::ZAdd {
key,
members,
nx,
xx,
gt,
lt,
ch,
} => {
let flags = ZAddFlags {
nx: *nx,
xx: *xx,
gt: *gt,
lt: *lt,
ch: *ch,
};
match ks.zadd(key, members, &flags) {
Ok(result) => ShardResponse::ZAddLen {
count: result.count,
applied: result.applied,
},
Err(WriteError::WrongType) => ShardResponse::WrongType,
Err(WriteError::OutOfMemory) => ShardResponse::OutOfMemory,
}
}
ShardRequest::ZRem { key, members } => match ks.zrem(key, members) {
Ok(removed) => ShardResponse::ZRemLen {
count: removed.len(),
removed,
},
Err(_) => ShardResponse::WrongType,
},
ShardRequest::ZScore { key, member } => match ks.zscore(key, member) {
Ok(score) => ShardResponse::Score(score),
Err(_) => ShardResponse::WrongType,
},
ShardRequest::ZRank { key, member } => match ks.zrank(key, member) {
Ok(rank) => ShardResponse::Rank(rank),
Err(_) => ShardResponse::WrongType,
},
ShardRequest::ZCard { key } => match ks.zcard(key) {
Ok(len) => ShardResponse::Len(len),
Err(_) => ShardResponse::WrongType,
},
ShardRequest::ZRevRank { key, member } => match ks.zrevrank(key, member) {
Ok(rank) => ShardResponse::Rank(rank),
Err(_) => ShardResponse::WrongType,
},
ShardRequest::ZRange {
key, start, stop, ..
} => match ks.zrange(key, *start, *stop) {
Ok(items) => ShardResponse::ScoredArray(items),
Err(_) => ShardResponse::WrongType,
},
ShardRequest::ZRevRange {
key, start, stop, ..
} => match ks.zrevrange(key, *start, *stop) {
Ok(items) => ShardResponse::ScoredArray(items),
Err(_) => ShardResponse::WrongType,
},
ShardRequest::ZCount { key, min, max } => match ks.zcount(key, *min, *max) {
Ok(count) => ShardResponse::Len(count),
Err(_) => ShardResponse::WrongType,
},
ShardRequest::ZIncrBy {
key,
increment,
member,
} => match ks.zincrby(key, *increment, member) {
Ok(new_score) => ShardResponse::ZIncrByResult {
new_score,
member: member.clone(),
},
Err(WriteError::WrongType) => ShardResponse::WrongType,
Err(WriteError::OutOfMemory) => ShardResponse::OutOfMemory,
},
ShardRequest::ZRangeByScore {
key,
min,
max,
offset,
count,
} => match ks.zrangebyscore(key, *min, *max, *offset, *count) {
Ok(items) => ShardResponse::ScoredArray(items),
Err(_) => ShardResponse::WrongType,
},
ShardRequest::ZRevRangeByScore {
key,
min,
max,
offset,
count,
} => match ks.zrevrangebyscore(key, *min, *max, *offset, *count) {
Ok(items) => ShardResponse::ScoredArray(items),
Err(_) => ShardResponse::WrongType,
},
ShardRequest::ZPopMin { key, count } => match ks.zpopmin(key, *count) {
Ok(items) => ShardResponse::ZPopResult(items),
Err(_) => ShardResponse::WrongType,
},
ShardRequest::ZPopMax { key, count } => match ks.zpopmax(key, *count) {
Ok(items) => ShardResponse::ZPopResult(items),
Err(_) => ShardResponse::WrongType,
},
ShardRequest::DbSize => ShardResponse::KeyCount(ks.len()),
ShardRequest::Stats => ShardResponse::Stats(ks.stats()),
ShardRequest::KeyVersion { ref key } => ShardResponse::Version(ks.key_version(key)),
ShardRequest::FlushDb => {
ks.clear();
ShardResponse::Ok
}
ShardRequest::Scan {
cursor,
count,
pattern,
} => {
let (next_cursor, keys) = ks.scan_keys(*cursor, *count, pattern.as_deref());
ShardResponse::Scan {
cursor: next_cursor,
keys,
}
}
ShardRequest::HSet { key, fields } => write_result_len(ks.hset(key, fields)),
ShardRequest::HGet { key, field } => match ks.hget(key, field) {
Ok(val) => ShardResponse::Value(val.map(Value::String)),
Err(_) => ShardResponse::WrongType,
},
ShardRequest::HGetAll { key } => match ks.hgetall(key) {
Ok(fields) => ShardResponse::HashFields(fields),
Err(_) => ShardResponse::WrongType,
},
ShardRequest::HDel { key, fields } => match ks.hdel(key, fields) {
Ok(removed) => ShardResponse::HDelLen {
count: removed.len(),
removed,
},
Err(_) => ShardResponse::WrongType,
},
ShardRequest::HExists { key, field } => match ks.hexists(key, field) {
Ok(exists) => ShardResponse::Bool(exists),
Err(_) => ShardResponse::WrongType,
},
ShardRequest::HLen { key } => match ks.hlen(key) {
Ok(len) => ShardResponse::Len(len),
Err(_) => ShardResponse::WrongType,
},
ShardRequest::HIncrBy { key, field, delta } => incr_result(ks.hincrby(key, field, *delta)),
ShardRequest::HKeys { key } => match ks.hkeys(key) {
Ok(keys) => ShardResponse::StringArray(keys),
Err(_) => ShardResponse::WrongType,
},
ShardRequest::HVals { key } => match ks.hvals(key) {
Ok(vals) => ShardResponse::Array(vals),
Err(_) => ShardResponse::WrongType,
},
ShardRequest::HMGet { key, fields } => match ks.hmget(key, fields) {
Ok(vals) => ShardResponse::OptionalArray(vals),
Err(_) => ShardResponse::WrongType,
},
ShardRequest::SAdd { key, members } => write_result_len(ks.sadd(key, members)),
ShardRequest::SRem { key, members } => match ks.srem(key, members) {
Ok(count) => ShardResponse::Len(count),
Err(_) => ShardResponse::WrongType,
},
ShardRequest::SMembers { key } => match ks.smembers(key) {
Ok(members) => ShardResponse::StringArray(members),
Err(_) => ShardResponse::WrongType,
},
ShardRequest::SIsMember { key, member } => match ks.sismember(key, member) {
Ok(exists) => ShardResponse::Bool(exists),
Err(_) => ShardResponse::WrongType,
},
ShardRequest::SCard { key } => match ks.scard(key) {
Ok(count) => ShardResponse::Len(count),
Err(_) => ShardResponse::WrongType,
},
ShardRequest::SUnion { keys } => match ks.sunion(keys) {
Ok(members) => ShardResponse::StringArray(members),
Err(_) => ShardResponse::WrongType,
},
ShardRequest::SInter { keys } => match ks.sinter(keys) {
Ok(members) => ShardResponse::StringArray(members),
Err(_) => ShardResponse::WrongType,
},
ShardRequest::SDiff { keys } => match ks.sdiff(keys) {
Ok(members) => ShardResponse::StringArray(members),
Err(_) => ShardResponse::WrongType,
},
ShardRequest::SUnionStore { dest, keys } => store_set_response(ks.sunionstore(dest, keys)),
ShardRequest::SInterStore { dest, keys } => store_set_response(ks.sinterstore(dest, keys)),
ShardRequest::SDiffStore { dest, keys } => store_set_response(ks.sdiffstore(dest, keys)),
ShardRequest::SRandMember { key, count } => match ks.srandmember(key, *count) {
Ok(members) => ShardResponse::StringArray(members),
Err(_) => ShardResponse::WrongType,
},
ShardRequest::SPop { key, count } => match ks.spop(key, *count) {
Ok(members) => ShardResponse::StringArray(members),
Err(_) => ShardResponse::WrongType,
},
ShardRequest::SMisMember { key, members } => match ks.smismember(key, members) {
Ok(results) => ShardResponse::BoolArray(results),
Err(_) => ShardResponse::WrongType,
},
ShardRequest::SScan {
key,
cursor,
count,
pattern,
} => match ks.scan_set(key, *cursor, *count, pattern.as_deref()) {
Ok((next, members)) => {
let items = members.into_iter().map(Bytes::from).collect();
ShardResponse::CollectionScan {
cursor: next,
items,
}
}
Err(_) => ShardResponse::WrongType,
},
ShardRequest::HScan {
key,
cursor,
count,
pattern,
} => match ks.scan_hash(key, *cursor, *count, pattern.as_deref()) {
Ok((next, fields)) => {
let mut items = Vec::with_capacity(fields.len() * 2);
for (field, value) in fields {
items.push(Bytes::from(field));
items.push(value);
}
ShardResponse::CollectionScan {
cursor: next,
items,
}
}
Err(_) => ShardResponse::WrongType,
},
ShardRequest::ZScan {
key,
cursor,
count,
pattern,
} => match ks.scan_sorted_set(key, *cursor, *count, pattern.as_deref()) {
Ok((next, members)) => {
let mut items = Vec::with_capacity(members.len() * 2);
for (member, score) in members {
items.push(Bytes::from(member));
items.push(Bytes::from(score.to_string()));
}
ShardResponse::CollectionScan {
cursor: next,
items,
}
}
Err(_) => ShardResponse::WrongType,
},
ShardRequest::CountKeysInSlot { slot } => {
ShardResponse::KeyCount(ks.count_keys_in_slot(*slot))
}
ShardRequest::GetKeysInSlot { slot, count } => {
ShardResponse::StringArray(ks.get_keys_in_slot(*slot, *count))
}
ShardRequest::DumpKey { key } => match ks.dump(key) {
Some((value, ttl_ms)) => {
let snap = persistence::value_to_snap(value);
match snapshot::serialize_snap_value(&snap) {
Ok(data) => ShardResponse::KeyDump { data, ttl_ms },
Err(e) => ShardResponse::Err(format!("ERR snapshot serialization failed: {e}")),
}
}
None => ShardResponse::Value(None),
},
ShardRequest::RestoreKey {
key,
ttl_ms,
data,
replace,
} => match snapshot::deserialize_snap_value(data) {
Ok(snap) => {
let exists = ks.exists(key);
if exists && !*replace {
ShardResponse::Err("ERR Target key name already exists".into())
} else {
let value = persistence::snap_to_value(snap);
let ttl = if *ttl_ms == 0 {
None
} else {
Some(Duration::from_millis(*ttl_ms))
};
ks.restore(key.clone(), value, ttl);
ShardResponse::Ok
}
}
Err(e) => ShardResponse::Err(format!("ERR DUMP payload corrupted: {e}")),
},
#[cfg(feature = "vector")]
ShardRequest::VAdd {
key,
element,
vector,
metric,
quantization,
connectivity,
expansion_add,
} => {
use crate::types::vector::{DistanceMetric, QuantizationType};
match ks.vadd(
key,
element.clone(),
vector.clone(),
DistanceMetric::from_u8(*metric),
QuantizationType::from_u8(*quantization),
*connectivity as usize,
*expansion_add as usize,
) {
Ok(result) => ShardResponse::VAddResult {
element: result.element,
vector: result.vector,
added: result.added,
},
Err(crate::keyspace::VectorWriteError::WrongType) => ShardResponse::WrongType,
Err(crate::keyspace::VectorWriteError::OutOfMemory) => ShardResponse::OutOfMemory,
Err(crate::keyspace::VectorWriteError::IndexError(e))
| Err(crate::keyspace::VectorWriteError::PartialBatch { message: e, .. }) => {
ShardResponse::Err(format!("ERR vector index: {e}"))
}
}
}
#[cfg(feature = "vector")]
ShardRequest::VAddBatch {
key,
entries,
metric,
quantization,
connectivity,
expansion_add,
..
} => {
use crate::types::vector::{DistanceMetric, QuantizationType};
let owned_entries = std::mem::take(entries);
match ks.vadd_batch(
key,
owned_entries,
DistanceMetric::from_u8(*metric),
QuantizationType::from_u8(*quantization),
*connectivity as usize,
*expansion_add as usize,
) {
Ok(result) => ShardResponse::VAddBatchResult {
added_count: result.added_count,
applied: result.applied,
},
Err(crate::keyspace::VectorWriteError::WrongType) => ShardResponse::WrongType,
Err(crate::keyspace::VectorWriteError::OutOfMemory) => ShardResponse::OutOfMemory,
Err(crate::keyspace::VectorWriteError::IndexError(e)) => {
ShardResponse::Err(format!("ERR vector index: {e}"))
}
Err(crate::keyspace::VectorWriteError::PartialBatch { applied, .. }) => {
ShardResponse::VAddBatchResult {
added_count: applied.len(),
applied,
}
}
}
}
#[cfg(feature = "vector")]
ShardRequest::VSim {
key,
query,
count,
ef_search,
} => match ks.vsim(key, query, *count, *ef_search) {
Ok(results) => ShardResponse::VSimResult(
results
.into_iter()
.map(|r| (r.element, r.distance))
.collect(),
),
Err(_) => ShardResponse::WrongType,
},
#[cfg(feature = "vector")]
ShardRequest::VRem { key, element } => match ks.vrem(key, element) {
Ok(removed) => ShardResponse::Bool(removed),
Err(_) => ShardResponse::WrongType,
},
#[cfg(feature = "vector")]
ShardRequest::VGet { key, element } => match ks.vget(key, element) {
Ok(data) => ShardResponse::VectorData(data),
Err(_) => ShardResponse::WrongType,
},
#[cfg(feature = "vector")]
ShardRequest::VCard { key } => match ks.vcard(key) {
Ok(count) => ShardResponse::Integer(count as i64),
Err(_) => ShardResponse::WrongType,
},
#[cfg(feature = "vector")]
ShardRequest::VDim { key } => match ks.vdim(key) {
Ok(dim) => ShardResponse::Integer(dim as i64),
Err(_) => ShardResponse::WrongType,
},
#[cfg(feature = "vector")]
ShardRequest::VInfo { key } => match ks.vinfo(key) {
Ok(Some(info)) => {
let fields = vec![
("dim".to_owned(), info.dim.to_string()),
("count".to_owned(), info.count.to_string()),
("metric".to_owned(), info.metric.to_string()),
("quantization".to_owned(), info.quantization.to_string()),
("connectivity".to_owned(), info.connectivity.to_string()),
("expansion_add".to_owned(), info.expansion_add.to_string()),
];
ShardResponse::VectorInfo(Some(fields))
}
Ok(None) => ShardResponse::VectorInfo(None),
Err(_) => ShardResponse::WrongType,
},
#[cfg(feature = "protobuf")]
ShardRequest::ProtoSet {
key,
type_name,
data,
expire,
nx,
xx,
} => {
if *nx && ks.exists(key) {
return ShardResponse::Value(None);
}
if *xx && !ks.exists(key) {
return ShardResponse::Value(None);
}
match ks.proto_set(key.clone(), type_name.clone(), data.clone(), *expire) {
SetResult::Ok | SetResult::Blocked => ShardResponse::Ok,
SetResult::OutOfMemory => ShardResponse::OutOfMemory,
}
}
#[cfg(feature = "protobuf")]
ShardRequest::ProtoGet { key } => match ks.proto_get(key) {
Ok(val) => ShardResponse::ProtoValue(val),
Err(_) => ShardResponse::WrongType,
},
#[cfg(feature = "protobuf")]
ShardRequest::ProtoType { key } => match ks.proto_type(key) {
Ok(name) => ShardResponse::ProtoTypeName(name),
Err(_) => ShardResponse::WrongType,
},
#[cfg(feature = "protobuf")]
ShardRequest::ProtoRegisterAof { .. } => ShardResponse::Ok,
#[cfg(feature = "protobuf")]
ShardRequest::ProtoSetField {
key,
field_path,
value,
} => dispatch_proto_field_op(ks, schema_registry, key, |reg, type_name, data, ttl| {
let new_data = reg.set_field(type_name, data, field_path, value)?;
Ok(ShardResponse::ProtoFieldUpdated {
type_name: type_name.to_owned(),
data: new_data,
expire: ttl,
})
}),
#[cfg(feature = "protobuf")]
ShardRequest::ProtoDelField { key, field_path } => {
dispatch_proto_field_op(ks, schema_registry, key, |reg, type_name, data, ttl| {
let new_data = reg.clear_field(type_name, data, field_path)?;
Ok(ShardResponse::ProtoFieldUpdated {
type_name: type_name.to_owned(),
data: new_data,
expire: ttl,
})
})
}
ShardRequest::Snapshot
| ShardRequest::SerializeSnapshot
| ShardRequest::RewriteAof
| ShardRequest::FlushDbAsync
| ShardRequest::BLPop { .. }
| ShardRequest::BRPop { .. } => ShardResponse::Ok,
}
}
#[cfg(feature = "protobuf")]
fn dispatch_proto_field_op<F>(
ks: &mut Keyspace,
schema_registry: &Option<crate::schema::SharedSchemaRegistry>,
key: &str,
mutate: F,
) -> ShardResponse
where
F: FnOnce(
&crate::schema::SchemaRegistry,
&str,
&[u8],
Option<Duration>,
) -> Result<ShardResponse, crate::schema::SchemaError>,
{
let registry = match schema_registry {
Some(r) => r,
None => return ShardResponse::Err("protobuf support is not enabled".into()),
};
let (type_name, data, remaining_ttl) = match ks.proto_get(key) {
Ok(Some(tuple)) => tuple,
Ok(None) => return ShardResponse::Value(None),
Err(_) => return ShardResponse::WrongType,
};
let reg = match registry.read() {
Ok(r) => r,
Err(_) => return ShardResponse::Err("schema registry lock poisoned".into()),
};
let resp = match mutate(®, &type_name, &data, remaining_ttl) {
Ok(r) => r,
Err(e) => return ShardResponse::Err(e.to_string()),
};
if let ShardResponse::ProtoFieldUpdated {
ref type_name,
ref data,
expire,
} = resp
{
ks.proto_set(key.to_owned(), type_name.clone(), data.clone(), expire);
}
resp
}
#[cfg(test)]
mod tests {
use super::*;
fn test_dispatch(ks: &mut Keyspace, mut req: ShardRequest) -> ShardResponse {
dispatch(
ks,
&mut req,
#[cfg(feature = "protobuf")]
&None,
)
}
#[test]
fn dispatch_set_and_get() {
let mut ks = Keyspace::new();
let resp = test_dispatch(
&mut ks,
ShardRequest::Set {
key: "k".into(),
value: Bytes::from("v"),
expire: None,
nx: false,
xx: false,
},
);
assert!(matches!(resp, ShardResponse::Ok));
let resp = test_dispatch(&mut ks, ShardRequest::Get { key: "k".into() });
match resp {
ShardResponse::Value(Some(Value::String(data))) => {
assert_eq!(data, Bytes::from("v"));
}
other => panic!("expected Value(Some(String)), got {other:?}"),
}
}
#[test]
fn dispatch_get_missing() {
let mut ks = Keyspace::new();
let resp = test_dispatch(&mut ks, ShardRequest::Get { key: "nope".into() });
assert!(matches!(resp, ShardResponse::Value(None)));
}
#[test]
fn dispatch_del() {
let mut ks = Keyspace::new();
ks.set("key".into(), Bytes::from("val"), None, false, false);
let resp = test_dispatch(&mut ks, ShardRequest::Del { key: "key".into() });
assert!(matches!(resp, ShardResponse::Bool(true)));
let resp = test_dispatch(&mut ks, ShardRequest::Del { key: "key".into() });
assert!(matches!(resp, ShardResponse::Bool(false)));
}
#[test]
fn dispatch_exists() {
let mut ks = Keyspace::new();
ks.set("yes".into(), Bytes::from("here"), None, false, false);
let resp = test_dispatch(&mut ks, ShardRequest::Exists { key: "yes".into() });
assert!(matches!(resp, ShardResponse::Bool(true)));
let resp = test_dispatch(&mut ks, ShardRequest::Exists { key: "no".into() });
assert!(matches!(resp, ShardResponse::Bool(false)));
}
#[test]
fn dispatch_expire_and_ttl() {
let mut ks = Keyspace::new();
ks.set("key".into(), Bytes::from("val"), None, false, false);
let resp = test_dispatch(
&mut ks,
ShardRequest::Expire {
key: "key".into(),
seconds: 60,
},
);
assert!(matches!(resp, ShardResponse::Bool(true)));
let resp = test_dispatch(&mut ks, ShardRequest::Ttl { key: "key".into() });
match resp {
ShardResponse::Ttl(TtlResult::Seconds(s)) => assert!((58..=60).contains(&s)),
other => panic!("expected Ttl(Seconds), got {other:?}"),
}
}
#[test]
fn dispatch_ttl_missing() {
let mut ks = Keyspace::new();
let resp = test_dispatch(&mut ks, ShardRequest::Ttl { key: "gone".into() });
assert!(matches!(resp, ShardResponse::Ttl(TtlResult::NotFound)));
}
#[test]
fn dispatch_incr_new_key() {
let mut ks = Keyspace::new();
let resp = test_dispatch(&mut ks, ShardRequest::Incr { key: "c".into() });
assert!(matches!(resp, ShardResponse::Integer(1)));
}
#[test]
fn dispatch_decr_existing() {
let mut ks = Keyspace::new();
ks.set("n".into(), Bytes::from("10"), None, false, false);
let resp = test_dispatch(&mut ks, ShardRequest::Decr { key: "n".into() });
assert!(matches!(resp, ShardResponse::Integer(9)));
}
#[test]
fn dispatch_incr_non_integer() {
let mut ks = Keyspace::new();
ks.set("s".into(), Bytes::from("hello"), None, false, false);
let resp = test_dispatch(&mut ks, ShardRequest::Incr { key: "s".into() });
assert!(matches!(resp, ShardResponse::Err(_)));
}
#[test]
fn dispatch_incrby() {
let mut ks = Keyspace::new();
ks.set("n".into(), Bytes::from("10"), None, false, false);
let resp = test_dispatch(
&mut ks,
ShardRequest::IncrBy {
key: "n".into(),
delta: 5,
},
);
assert!(matches!(resp, ShardResponse::Integer(15)));
}
#[test]
fn dispatch_decrby() {
let mut ks = Keyspace::new();
ks.set("n".into(), Bytes::from("10"), None, false, false);
let resp = test_dispatch(
&mut ks,
ShardRequest::DecrBy {
key: "n".into(),
delta: 3,
},
);
assert!(matches!(resp, ShardResponse::Integer(7)));
}
#[test]
fn dispatch_incrby_new_key() {
let mut ks = Keyspace::new();
let resp = test_dispatch(
&mut ks,
ShardRequest::IncrBy {
key: "new".into(),
delta: 42,
},
);
assert!(matches!(resp, ShardResponse::Integer(42)));
}
#[test]
fn dispatch_incrbyfloat() {
let mut ks = Keyspace::new();
ks.set("n".into(), Bytes::from("10.5"), None, false, false);
let resp = test_dispatch(
&mut ks,
ShardRequest::IncrByFloat {
key: "n".into(),
delta: 2.3,
},
);
match resp {
ShardResponse::BulkString(val) => {
let f: f64 = val.parse().unwrap();
assert!((f - 12.8).abs() < 0.001);
}
other => panic!("expected BulkString, got {other:?}"),
}
}
#[test]
fn dispatch_append() {
let mut ks = Keyspace::new();
ks.set("k".into(), Bytes::from("hello"), None, false, false);
let resp = test_dispatch(
&mut ks,
ShardRequest::Append {
key: "k".into(),
value: Bytes::from(" world"),
},
);
assert!(matches!(resp, ShardResponse::Len(11)));
}
#[test]
fn dispatch_strlen() {
let mut ks = Keyspace::new();
ks.set("k".into(), Bytes::from("hello"), None, false, false);
let resp = test_dispatch(&mut ks, ShardRequest::Strlen { key: "k".into() });
assert!(matches!(resp, ShardResponse::Len(5)));
}
#[test]
fn dispatch_strlen_missing() {
let mut ks = Keyspace::new();
let resp = test_dispatch(&mut ks, ShardRequest::Strlen { key: "nope".into() });
assert!(matches!(resp, ShardResponse::Len(0)));
}
#[test]
fn dispatch_incrbyfloat_new_key() {
let mut ks = Keyspace::new();
let resp = test_dispatch(
&mut ks,
ShardRequest::IncrByFloat {
key: "new".into(),
delta: 2.72,
},
);
match resp {
ShardResponse::BulkString(val) => {
let f: f64 = val.parse().unwrap();
assert!((f - 2.72).abs() < 0.001);
}
other => panic!("expected BulkString, got {other:?}"),
}
}
#[test]
fn dispatch_persist_removes_ttl() {
let mut ks = Keyspace::new();
ks.set(
"key".into(),
Bytes::from("val"),
Some(Duration::from_secs(60)),
false,
false,
);
let resp = test_dispatch(&mut ks, ShardRequest::Persist { key: "key".into() });
assert!(matches!(resp, ShardResponse::Bool(true)));
let resp = test_dispatch(&mut ks, ShardRequest::Ttl { key: "key".into() });
assert!(matches!(resp, ShardResponse::Ttl(TtlResult::NoExpiry)));
}
#[test]
fn dispatch_persist_missing_key() {
let mut ks = Keyspace::new();
let resp = test_dispatch(&mut ks, ShardRequest::Persist { key: "nope".into() });
assert!(matches!(resp, ShardResponse::Bool(false)));
}
#[test]
fn dispatch_pttl() {
let mut ks = Keyspace::new();
ks.set(
"key".into(),
Bytes::from("val"),
Some(Duration::from_secs(60)),
false,
false,
);
let resp = test_dispatch(&mut ks, ShardRequest::Pttl { key: "key".into() });
match resp {
ShardResponse::Ttl(TtlResult::Milliseconds(ms)) => {
assert!(ms > 59_000 && ms <= 60_000);
}
other => panic!("expected Ttl(Milliseconds), got {other:?}"),
}
}
#[test]
fn dispatch_pttl_missing() {
let mut ks = Keyspace::new();
let resp = test_dispatch(&mut ks, ShardRequest::Pttl { key: "nope".into() });
assert!(matches!(resp, ShardResponse::Ttl(TtlResult::NotFound)));
}
#[test]
fn dispatch_pexpire() {
let mut ks = Keyspace::new();
ks.set("key".into(), Bytes::from("val"), None, false, false);
let resp = test_dispatch(
&mut ks,
ShardRequest::Pexpire {
key: "key".into(),
milliseconds: 5000,
},
);
assert!(matches!(resp, ShardResponse::Bool(true)));
let resp = test_dispatch(&mut ks, ShardRequest::Pttl { key: "key".into() });
match resp {
ShardResponse::Ttl(TtlResult::Milliseconds(ms)) => {
assert!(ms > 4000 && ms <= 5000);
}
other => panic!("expected Ttl(Milliseconds), got {other:?}"),
}
}
#[test]
fn dispatch_set_nx_when_key_missing() {
let mut ks = Keyspace::new();
let resp = test_dispatch(
&mut ks,
ShardRequest::Set {
key: "k".into(),
value: Bytes::from("v"),
expire: None,
nx: true,
xx: false,
},
);
assert!(matches!(resp, ShardResponse::Ok));
assert!(ks.exists("k"));
}
#[test]
fn dispatch_set_nx_when_key_exists() {
let mut ks = Keyspace::new();
ks.set("k".into(), Bytes::from("old"), None, false, false);
let resp = test_dispatch(
&mut ks,
ShardRequest::Set {
key: "k".into(),
value: Bytes::from("new"),
expire: None,
nx: true,
xx: false,
},
);
assert!(matches!(resp, ShardResponse::Value(None)));
match ks.get("k").unwrap() {
Some(Value::String(data)) => assert_eq!(data, Bytes::from("old")),
other => panic!("expected old value, got {other:?}"),
}
}
#[test]
fn dispatch_set_xx_when_key_exists() {
let mut ks = Keyspace::new();
ks.set("k".into(), Bytes::from("old"), None, false, false);
let resp = test_dispatch(
&mut ks,
ShardRequest::Set {
key: "k".into(),
value: Bytes::from("new"),
expire: None,
nx: false,
xx: true,
},
);
assert!(matches!(resp, ShardResponse::Ok));
match ks.get("k").unwrap() {
Some(Value::String(data)) => assert_eq!(data, Bytes::from("new")),
other => panic!("expected new value, got {other:?}"),
}
}
#[test]
fn dispatch_set_xx_when_key_missing() {
let mut ks = Keyspace::new();
let resp = test_dispatch(
&mut ks,
ShardRequest::Set {
key: "k".into(),
value: Bytes::from("v"),
expire: None,
nx: false,
xx: true,
},
);
assert!(matches!(resp, ShardResponse::Value(None)));
assert!(!ks.exists("k"));
}
#[test]
fn dispatch_flushdb_clears_all_keys() {
let mut ks = Keyspace::new();
ks.set("a".into(), Bytes::from("1"), None, false, false);
ks.set("b".into(), Bytes::from("2"), None, false, false);
assert_eq!(ks.len(), 2);
let resp = test_dispatch(&mut ks, ShardRequest::FlushDb);
assert!(matches!(resp, ShardResponse::Ok));
assert_eq!(ks.len(), 0);
}
#[test]
fn dispatch_scan_returns_keys() {
let mut ks = Keyspace::new();
ks.set("user:1".into(), Bytes::from("a"), None, false, false);
ks.set("user:2".into(), Bytes::from("b"), None, false, false);
ks.set("item:1".into(), Bytes::from("c"), None, false, false);
let resp = test_dispatch(
&mut ks,
ShardRequest::Scan {
cursor: 0,
count: 10,
pattern: None,
},
);
match resp {
ShardResponse::Scan { cursor, keys } => {
assert_eq!(cursor, 0); assert_eq!(keys.len(), 3);
}
_ => panic!("expected Scan response"),
}
}
#[test]
fn dispatch_scan_with_pattern() {
let mut ks = Keyspace::new();
ks.set("user:1".into(), Bytes::from("a"), None, false, false);
ks.set("user:2".into(), Bytes::from("b"), None, false, false);
ks.set("item:1".into(), Bytes::from("c"), None, false, false);
let resp = test_dispatch(
&mut ks,
ShardRequest::Scan {
cursor: 0,
count: 10,
pattern: Some("user:*".into()),
},
);
match resp {
ShardResponse::Scan { cursor, keys } => {
assert_eq!(cursor, 0);
assert_eq!(keys.len(), 2);
for k in &keys {
assert!(k.starts_with("user:"));
}
}
_ => panic!("expected Scan response"),
}
}
#[test]
fn dispatch_keys() {
let mut ks = Keyspace::new();
ks.set("user:1".into(), Bytes::from("a"), None, false, false);
ks.set("user:2".into(), Bytes::from("b"), None, false, false);
ks.set("item:1".into(), Bytes::from("c"), None, false, false);
let resp = test_dispatch(
&mut ks,
ShardRequest::Keys {
pattern: "user:*".into(),
},
);
match resp {
ShardResponse::StringArray(mut keys) => {
keys.sort();
assert_eq!(keys, vec!["user:1", "user:2"]);
}
other => panic!("expected StringArray, got {other:?}"),
}
}
#[test]
fn dispatch_rename() {
let mut ks = Keyspace::new();
ks.set("old".into(), Bytes::from("value"), None, false, false);
let resp = test_dispatch(
&mut ks,
ShardRequest::Rename {
key: "old".into(),
newkey: "new".into(),
},
);
assert!(matches!(resp, ShardResponse::Ok));
assert!(!ks.exists("old"));
assert!(ks.exists("new"));
}
#[test]
fn dispatch_rename_missing_key() {
let mut ks = Keyspace::new();
let resp = test_dispatch(
&mut ks,
ShardRequest::Rename {
key: "missing".into(),
newkey: "new".into(),
},
);
assert!(matches!(resp, ShardResponse::Err(_)));
}
#[test]
fn dump_key_returns_serialized_value() {
let mut ks = Keyspace::new();
ks.set(
"greeting".into(),
Bytes::from("hello"),
Some(Duration::from_secs(60)),
false,
false,
);
let resp = test_dispatch(
&mut ks,
ShardRequest::DumpKey {
key: "greeting".into(),
},
);
match resp {
ShardResponse::KeyDump { data, ttl_ms } => {
assert!(!data.is_empty());
assert!(ttl_ms > 0);
let snap = snapshot::deserialize_snap_value(&data).unwrap();
assert!(matches!(snap, SnapValue::String(ref b) if b == &Bytes::from("hello")));
}
other => panic!("expected KeyDump, got {other:?}"),
}
}
#[test]
fn dump_key_missing_returns_none() {
let mut ks = Keyspace::new();
let resp = test_dispatch(&mut ks, ShardRequest::DumpKey { key: "nope".into() });
assert!(matches!(resp, ShardResponse::Value(None)));
}
#[test]
fn restore_key_inserts_value() {
let mut ks = Keyspace::new();
let snap = SnapValue::String(Bytes::from("restored"));
let data = snapshot::serialize_snap_value(&snap).unwrap();
let resp = test_dispatch(
&mut ks,
ShardRequest::RestoreKey {
key: "mykey".into(),
ttl_ms: 0,
data: Bytes::from(data),
replace: false,
},
);
assert!(matches!(resp, ShardResponse::Ok));
assert_eq!(
ks.get("mykey").unwrap(),
Some(Value::String(Bytes::from("restored")))
);
}
#[test]
fn restore_key_with_ttl() {
let mut ks = Keyspace::new();
let snap = SnapValue::String(Bytes::from("temp"));
let data = snapshot::serialize_snap_value(&snap).unwrap();
let resp = test_dispatch(
&mut ks,
ShardRequest::RestoreKey {
key: "ttlkey".into(),
ttl_ms: 30_000,
data: Bytes::from(data),
replace: false,
},
);
assert!(matches!(resp, ShardResponse::Ok));
match ks.pttl("ttlkey") {
TtlResult::Milliseconds(ms) => assert!(ms > 29_000 && ms <= 30_000),
other => panic!("expected Milliseconds, got {other:?}"),
}
}
#[test]
fn restore_key_rejects_duplicate_without_replace() {
let mut ks = Keyspace::new();
ks.set("existing".into(), Bytes::from("old"), None, false, false);
let snap = SnapValue::String(Bytes::from("new"));
let data = snapshot::serialize_snap_value(&snap).unwrap();
let resp = test_dispatch(
&mut ks,
ShardRequest::RestoreKey {
key: "existing".into(),
ttl_ms: 0,
data: Bytes::from(data),
replace: false,
},
);
assert!(matches!(resp, ShardResponse::Err(_)));
assert_eq!(
ks.get("existing").unwrap(),
Some(Value::String(Bytes::from("old")))
);
}
#[test]
fn restore_key_replace_overwrites() {
let mut ks = Keyspace::new();
ks.set("existing".into(), Bytes::from("old"), None, false, false);
let snap = SnapValue::String(Bytes::from("new"));
let data = snapshot::serialize_snap_value(&snap).unwrap();
let resp = test_dispatch(
&mut ks,
ShardRequest::RestoreKey {
key: "existing".into(),
ttl_ms: 0,
data: Bytes::from(data),
replace: true,
},
);
assert!(matches!(resp, ShardResponse::Ok));
assert_eq!(
ks.get("existing").unwrap(),
Some(Value::String(Bytes::from("new")))
);
}
#[test]
fn dump_and_restore_hash_roundtrip() {
let mut ks = Keyspace::new();
ks.hset(
"myhash",
&[
("f1".into(), Bytes::from("v1")),
("f2".into(), Bytes::from("v2")),
],
)
.unwrap();
let resp = test_dispatch(
&mut ks,
ShardRequest::DumpKey {
key: "myhash".into(),
},
);
let (data, _ttl) = match resp {
ShardResponse::KeyDump { data, ttl_ms } => (data, ttl_ms),
other => panic!("expected KeyDump, got {other:?}"),
};
let resp = test_dispatch(
&mut ks,
ShardRequest::RestoreKey {
key: "myhash2".into(),
ttl_ms: 0,
data: Bytes::from(data),
replace: false,
},
);
assert!(matches!(resp, ShardResponse::Ok));
assert_eq!(ks.hget("myhash2", "f1").unwrap(), Some(Bytes::from("v1")));
assert_eq!(ks.hget("myhash2", "f2").unwrap(), Some(Bytes::from("v2")));
}
#[test]
fn is_write_classifies_correctly() {
assert!(ShardRequest::Set {
key: "k".into(),
value: Bytes::from("v"),
expire: None,
nx: false,
xx: false,
}
.is_write());
assert!(ShardRequest::Del { key: "k".into() }.is_write());
assert!(ShardRequest::Incr { key: "k".into() }.is_write());
assert!(ShardRequest::LPush {
key: "k".into(),
values: vec![],
}
.is_write());
assert!(ShardRequest::HSet {
key: "k".into(),
fields: vec![],
}
.is_write());
assert!(ShardRequest::SAdd {
key: "k".into(),
members: vec![],
}
.is_write());
assert!(ShardRequest::FlushDb.is_write());
assert!(!ShardRequest::Get { key: "k".into() }.is_write());
assert!(!ShardRequest::Exists { key: "k".into() }.is_write());
assert!(!ShardRequest::Ttl { key: "k".into() }.is_write());
assert!(!ShardRequest::DbSize.is_write());
assert!(!ShardRequest::Stats.is_write());
assert!(!ShardRequest::LLen { key: "k".into() }.is_write());
assert!(!ShardRequest::HGet {
key: "k".into(),
field: "f".into(),
}
.is_write());
assert!(!ShardRequest::SMembers { key: "k".into() }.is_write());
}
}