Skip to main content

ember_core/
shard.rs

1//! Shard: an independent partition of the keyspace.
2//!
3//! Each shard runs as its own tokio task, owning a `Keyspace` with no
4//! internal locking. Commands arrive over an mpsc channel and responses
5//! go back on a per-request oneshot. A background tick drives active
6//! expiration of TTL'd keys.
7
8use std::path::PathBuf;
9use std::time::Duration;
10
11use bytes::Bytes;
12use ember_persistence::aof::{AofRecord, AofWriter, FsyncPolicy};
13use ember_persistence::recovery::{self, RecoveredValue};
14use ember_persistence::snapshot::{self, SnapEntry, SnapValue, SnapshotWriter};
15use tokio::sync::{mpsc, oneshot};
16use tracing::{info, warn};
17
18use crate::error::ShardError;
19use crate::expiry;
20use crate::keyspace::{
21    IncrError, Keyspace, KeyspaceStats, SetResult, ShardConfig, TtlResult, WriteError,
22};
23use crate::types::sorted_set::ZAddFlags;
24use crate::types::Value;
25
26/// How often the shard runs active expiration. 100ms matches
27/// Redis's hz=10 default and keeps CPU overhead negligible.
28const EXPIRY_TICK: Duration = Duration::from_millis(100);
29
30/// How often to fsync when using the `EverySec` policy.
31const FSYNC_INTERVAL: Duration = Duration::from_secs(1);
32
33/// Optional persistence configuration for a shard.
34#[derive(Debug, Clone)]
35pub struct ShardPersistenceConfig {
36    /// Directory where AOF and snapshot files live.
37    pub data_dir: PathBuf,
38    /// Whether to write an AOF log of mutations.
39    pub append_only: bool,
40    /// When to fsync the AOF file.
41    pub fsync_policy: FsyncPolicy,
42}
43
44/// A protocol-agnostic command sent to a shard.
45#[derive(Debug)]
46pub enum ShardRequest {
47    Get {
48        key: String,
49    },
50    Set {
51        key: String,
52        value: Bytes,
53        expire: Option<Duration>,
54        /// Only set the key if it does not already exist.
55        nx: bool,
56        /// Only set the key if it already exists.
57        xx: bool,
58    },
59    Incr {
60        key: String,
61    },
62    Decr {
63        key: String,
64    },
65    Del {
66        key: String,
67    },
68    Exists {
69        key: String,
70    },
71    Expire {
72        key: String,
73        seconds: u64,
74    },
75    Ttl {
76        key: String,
77    },
78    Persist {
79        key: String,
80    },
81    Pttl {
82        key: String,
83    },
84    Pexpire {
85        key: String,
86        milliseconds: u64,
87    },
88    LPush {
89        key: String,
90        values: Vec<Bytes>,
91    },
92    RPush {
93        key: String,
94        values: Vec<Bytes>,
95    },
96    LPop {
97        key: String,
98    },
99    RPop {
100        key: String,
101    },
102    LRange {
103        key: String,
104        start: i64,
105        stop: i64,
106    },
107    LLen {
108        key: String,
109    },
110    Type {
111        key: String,
112    },
113    ZAdd {
114        key: String,
115        members: Vec<(f64, String)>,
116        nx: bool,
117        xx: bool,
118        gt: bool,
119        lt: bool,
120        ch: bool,
121    },
122    ZRem {
123        key: String,
124        members: Vec<String>,
125    },
126    ZScore {
127        key: String,
128        member: String,
129    },
130    ZRank {
131        key: String,
132        member: String,
133    },
134    ZCard {
135        key: String,
136    },
137    ZRange {
138        key: String,
139        start: i64,
140        stop: i64,
141        with_scores: bool,
142    },
143    HSet {
144        key: String,
145        fields: Vec<(String, Bytes)>,
146    },
147    HGet {
148        key: String,
149        field: String,
150    },
151    HGetAll {
152        key: String,
153    },
154    HDel {
155        key: String,
156        fields: Vec<String>,
157    },
158    HExists {
159        key: String,
160        field: String,
161    },
162    HLen {
163        key: String,
164    },
165    HIncrBy {
166        key: String,
167        field: String,
168        delta: i64,
169    },
170    HKeys {
171        key: String,
172    },
173    HVals {
174        key: String,
175    },
176    HMGet {
177        key: String,
178        fields: Vec<String>,
179    },
180    SAdd {
181        key: String,
182        members: Vec<String>,
183    },
184    SRem {
185        key: String,
186        members: Vec<String>,
187    },
188    SMembers {
189        key: String,
190    },
191    SIsMember {
192        key: String,
193        member: String,
194    },
195    SCard {
196        key: String,
197    },
198    /// Returns the key count for this shard.
199    DbSize,
200    /// Returns keyspace stats for this shard.
201    Stats,
202    /// Triggers a snapshot write.
203    Snapshot,
204    /// Triggers an AOF rewrite (snapshot + truncate AOF).
205    RewriteAof,
206    /// Clears all keys from the keyspace.
207    FlushDb,
208    /// Scans keys in the keyspace.
209    Scan {
210        cursor: u64,
211        count: usize,
212        pattern: Option<String>,
213    },
214}
215
216/// The shard's response to a request.
217#[derive(Debug)]
218pub enum ShardResponse {
219    /// A value (or None for a cache miss).
220    Value(Option<Value>),
221    /// Simple acknowledgement (e.g. SET).
222    Ok,
223    /// Integer result (e.g. INCR, DECR).
224    Integer(i64),
225    /// Boolean result (e.g. DEL, EXISTS, EXPIRE).
226    Bool(bool),
227    /// TTL query result.
228    Ttl(TtlResult),
229    /// Memory limit reached and eviction policy is NoEviction.
230    OutOfMemory,
231    /// Key count for a shard (DBSIZE).
232    KeyCount(usize),
233    /// Full stats for a shard (INFO).
234    Stats(KeyspaceStats),
235    /// Integer length result (e.g. LPUSH, RPUSH, LLEN).
236    Len(usize),
237    /// Array of bulk values (e.g. LRANGE).
238    Array(Vec<Bytes>),
239    /// The type name of a stored value.
240    TypeName(&'static str),
241    /// ZADD result: count for the client + actually applied members for AOF.
242    ZAddLen {
243        count: usize,
244        applied: Vec<(f64, String)>,
245    },
246    /// ZREM result: count for the client + actually removed members for AOF.
247    ZRemLen { count: usize, removed: Vec<String> },
248    /// Float score result (e.g. ZSCORE).
249    Score(Option<f64>),
250    /// Rank result (e.g. ZRANK).
251    Rank(Option<usize>),
252    /// Scored array of (member, score) pairs (e.g. ZRANGE).
253    ScoredArray(Vec<(String, f64)>),
254    /// Command used against a key holding the wrong kind of value.
255    WrongType,
256    /// An error message.
257    Err(String),
258    /// Scan result: next cursor and list of keys.
259    Scan { cursor: u64, keys: Vec<String> },
260    /// HGETALL result: all field-value pairs.
261    HashFields(Vec<(String, Bytes)>),
262    /// HDEL result: removed count + field names for AOF.
263    HDelLen { count: usize, removed: Vec<String> },
264    /// Array of strings (e.g. HKEYS).
265    StringArray(Vec<String>),
266    /// HMGET result: array of optional values.
267    OptionalArray(Vec<Option<Bytes>>),
268}
269
270/// A request bundled with its reply channel.
271#[derive(Debug)]
272pub struct ShardMessage {
273    pub request: ShardRequest,
274    pub reply: oneshot::Sender<ShardResponse>,
275}
276
277/// A cloneable handle for sending commands to a shard task.
278///
279/// Wraps the mpsc sender so callers don't need to manage oneshot
280/// channels directly.
281#[derive(Debug, Clone)]
282pub struct ShardHandle {
283    tx: mpsc::Sender<ShardMessage>,
284}
285
286impl ShardHandle {
287    /// Sends a request and waits for the response.
288    ///
289    /// Returns `ShardError::Unavailable` if the shard task has stopped.
290    pub async fn send(&self, request: ShardRequest) -> Result<ShardResponse, ShardError> {
291        let rx = self.dispatch(request).await?;
292        rx.await.map_err(|_| ShardError::Unavailable)
293    }
294
295    /// Sends a request and returns the reply channel without waiting
296    /// for the response. Used by `Engine::broadcast` to fan out to
297    /// all shards before collecting results.
298    pub(crate) async fn dispatch(
299        &self,
300        request: ShardRequest,
301    ) -> Result<oneshot::Receiver<ShardResponse>, ShardError> {
302        let (reply_tx, reply_rx) = oneshot::channel();
303        let msg = ShardMessage {
304            request,
305            reply: reply_tx,
306        };
307        self.tx
308            .send(msg)
309            .await
310            .map_err(|_| ShardError::Unavailable)?;
311        Ok(reply_rx)
312    }
313}
314
315/// Spawns a shard task and returns the handle for communicating with it.
316///
317/// `buffer` controls the mpsc channel capacity — higher values absorb
318/// burst traffic at the cost of memory.
319pub fn spawn_shard(
320    buffer: usize,
321    config: ShardConfig,
322    persistence: Option<ShardPersistenceConfig>,
323) -> ShardHandle {
324    let (tx, rx) = mpsc::channel(buffer);
325    tokio::spawn(run_shard(rx, config, persistence));
326    ShardHandle { tx }
327}
328
329/// The shard's main loop. Processes messages and runs periodic
330/// active expiration until the channel closes.
331async fn run_shard(
332    mut rx: mpsc::Receiver<ShardMessage>,
333    config: ShardConfig,
334    persistence: Option<ShardPersistenceConfig>,
335) {
336    let shard_id = config.shard_id;
337    let mut keyspace = Keyspace::with_config(config);
338
339    // -- recovery --
340    if let Some(ref pcfg) = persistence {
341        let result = recovery::recover_shard(&pcfg.data_dir, shard_id);
342        let count = result.entries.len();
343        for entry in result.entries {
344            let value = match entry.value {
345                RecoveredValue::String(data) => Value::String(data),
346                RecoveredValue::List(deque) => Value::List(deque),
347                RecoveredValue::SortedSet(members) => {
348                    let mut ss = crate::types::sorted_set::SortedSet::new();
349                    for (score, member) in members {
350                        ss.add(member, score);
351                    }
352                    Value::SortedSet(ss)
353                }
354                RecoveredValue::Hash(map) => Value::Hash(map),
355                RecoveredValue::Set(set) => Value::Set(set),
356            };
357            keyspace.restore(entry.key, value, entry.ttl);
358        }
359        if count > 0 {
360            info!(
361                shard_id,
362                recovered_keys = count,
363                snapshot = result.loaded_snapshot,
364                aof = result.replayed_aof,
365                "recovered shard state"
366            );
367        }
368    }
369
370    // -- AOF writer --
371    let mut aof_writer: Option<AofWriter> = match &persistence {
372        Some(pcfg) if pcfg.append_only => {
373            let path = ember_persistence::aof::aof_path(&pcfg.data_dir, shard_id);
374            match AofWriter::open(path) {
375                Ok(w) => Some(w),
376                Err(e) => {
377                    warn!(shard_id, "failed to open AOF writer: {e}");
378                    None
379                }
380            }
381        }
382        _ => None,
383    };
384
385    let fsync_policy = persistence
386        .as_ref()
387        .map(|p| p.fsync_policy)
388        .unwrap_or(FsyncPolicy::No);
389
390    // -- tickers --
391    let mut expiry_tick = tokio::time::interval(EXPIRY_TICK);
392    expiry_tick.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
393
394    let mut fsync_tick = tokio::time::interval(FSYNC_INTERVAL);
395    fsync_tick.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
396
397    loop {
398        tokio::select! {
399            msg = rx.recv() => {
400                match msg {
401                    Some(msg) => {
402                        let request_kind = describe_request(&msg.request);
403                        let response = dispatch(&mut keyspace, &msg.request);
404
405                        // write AOF record for successful mutations
406                        if let Some(ref mut writer) = aof_writer {
407                            if let Some(record) = to_aof_record(&msg.request, &response) {
408                                if let Err(e) = writer.write_record(&record) {
409                                    warn!(shard_id, "aof write failed: {e}");
410                                }
411                                if fsync_policy == FsyncPolicy::Always {
412                                    if let Err(e) = writer.sync() {
413                                        warn!(shard_id, "aof sync failed: {e}");
414                                    }
415                                }
416                            }
417                        }
418
419                        // handle snapshot/rewrite (these need mutable access
420                        // to both keyspace and aof_writer)
421                        match request_kind {
422                            RequestKind::Snapshot => {
423                                let resp = handle_snapshot(
424                                    &keyspace, &persistence, shard_id,
425                                );
426                                let _ = msg.reply.send(resp);
427                                continue;
428                            }
429                            RequestKind::RewriteAof => {
430                                let resp = handle_rewrite(
431                                    &keyspace,
432                                    &persistence,
433                                    &mut aof_writer,
434                                    shard_id,
435                                );
436                                let _ = msg.reply.send(resp);
437                                continue;
438                            }
439                            RequestKind::Other => {}
440                        }
441
442                        let _ = msg.reply.send(response);
443                    }
444                    None => break, // channel closed, shard shutting down
445                }
446            }
447            _ = expiry_tick.tick() => {
448                expiry::run_expiration_cycle(&mut keyspace);
449            }
450            _ = fsync_tick.tick(), if fsync_policy == FsyncPolicy::EverySec => {
451                if let Some(ref mut writer) = aof_writer {
452                    if let Err(e) = writer.sync() {
453                        warn!(shard_id, "periodic aof sync failed: {e}");
454                    }
455                }
456            }
457        }
458    }
459
460    // flush AOF on clean shutdown
461    if let Some(ref mut writer) = aof_writer {
462        let _ = writer.sync();
463    }
464}
465
466/// Lightweight tag so we can identify snapshot/rewrite requests after
467/// dispatch without borrowing the request again.
468enum RequestKind {
469    Snapshot,
470    RewriteAof,
471    Other,
472}
473
474fn describe_request(req: &ShardRequest) -> RequestKind {
475    match req {
476        ShardRequest::Snapshot => RequestKind::Snapshot,
477        ShardRequest::RewriteAof => RequestKind::RewriteAof,
478        _ => RequestKind::Other,
479    }
480}
481
482/// Executes a single request against the keyspace.
483fn dispatch(ks: &mut Keyspace, req: &ShardRequest) -> ShardResponse {
484    match req {
485        ShardRequest::Get { key } => match ks.get(key) {
486            Ok(val) => ShardResponse::Value(val),
487            Err(_) => ShardResponse::WrongType,
488        },
489        ShardRequest::Set {
490            key,
491            value,
492            expire,
493            nx,
494            xx,
495        } => {
496            // NX: only set if key does NOT already exist
497            if *nx && ks.exists(key) {
498                return ShardResponse::Value(None);
499            }
500            // XX: only set if key DOES already exist
501            if *xx && !ks.exists(key) {
502                return ShardResponse::Value(None);
503            }
504            match ks.set(key.clone(), value.clone(), *expire) {
505                SetResult::Ok => ShardResponse::Ok,
506                SetResult::OutOfMemory => ShardResponse::OutOfMemory,
507            }
508        }
509        ShardRequest::Incr { key } => match ks.incr(key) {
510            Ok(val) => ShardResponse::Integer(val),
511            Err(IncrError::WrongType) => ShardResponse::WrongType,
512            Err(IncrError::OutOfMemory) => ShardResponse::OutOfMemory,
513            Err(e) => ShardResponse::Err(e.to_string()),
514        },
515        ShardRequest::Decr { key } => match ks.decr(key) {
516            Ok(val) => ShardResponse::Integer(val),
517            Err(IncrError::WrongType) => ShardResponse::WrongType,
518            Err(IncrError::OutOfMemory) => ShardResponse::OutOfMemory,
519            Err(e) => ShardResponse::Err(e.to_string()),
520        },
521        ShardRequest::Del { key } => ShardResponse::Bool(ks.del(key)),
522        ShardRequest::Exists { key } => ShardResponse::Bool(ks.exists(key)),
523        ShardRequest::Expire { key, seconds } => ShardResponse::Bool(ks.expire(key, *seconds)),
524        ShardRequest::Ttl { key } => ShardResponse::Ttl(ks.ttl(key)),
525        ShardRequest::Persist { key } => ShardResponse::Bool(ks.persist(key)),
526        ShardRequest::Pttl { key } => ShardResponse::Ttl(ks.pttl(key)),
527        ShardRequest::Pexpire { key, milliseconds } => {
528            ShardResponse::Bool(ks.pexpire(key, *milliseconds))
529        }
530        ShardRequest::LPush { key, values } => match ks.lpush(key, values) {
531            Ok(len) => ShardResponse::Len(len),
532            Err(WriteError::WrongType) => ShardResponse::WrongType,
533            Err(WriteError::OutOfMemory) => ShardResponse::OutOfMemory,
534        },
535        ShardRequest::RPush { key, values } => match ks.rpush(key, values) {
536            Ok(len) => ShardResponse::Len(len),
537            Err(WriteError::WrongType) => ShardResponse::WrongType,
538            Err(WriteError::OutOfMemory) => ShardResponse::OutOfMemory,
539        },
540        ShardRequest::LPop { key } => match ks.lpop(key) {
541            Ok(val) => ShardResponse::Value(val.map(Value::String)),
542            Err(_) => ShardResponse::WrongType,
543        },
544        ShardRequest::RPop { key } => match ks.rpop(key) {
545            Ok(val) => ShardResponse::Value(val.map(Value::String)),
546            Err(_) => ShardResponse::WrongType,
547        },
548        ShardRequest::LRange { key, start, stop } => match ks.lrange(key, *start, *stop) {
549            Ok(items) => ShardResponse::Array(items),
550            Err(_) => ShardResponse::WrongType,
551        },
552        ShardRequest::LLen { key } => match ks.llen(key) {
553            Ok(len) => ShardResponse::Len(len),
554            Err(_) => ShardResponse::WrongType,
555        },
556        ShardRequest::Type { key } => ShardResponse::TypeName(ks.value_type(key)),
557        ShardRequest::ZAdd {
558            key,
559            members,
560            nx,
561            xx,
562            gt,
563            lt,
564            ch,
565        } => {
566            let flags = ZAddFlags {
567                nx: *nx,
568                xx: *xx,
569                gt: *gt,
570                lt: *lt,
571                ch: *ch,
572            };
573            match ks.zadd(key, members, &flags) {
574                Ok(result) => ShardResponse::ZAddLen {
575                    count: result.count,
576                    applied: result.applied,
577                },
578                Err(WriteError::WrongType) => ShardResponse::WrongType,
579                Err(WriteError::OutOfMemory) => ShardResponse::OutOfMemory,
580            }
581        }
582        ShardRequest::ZRem { key, members } => match ks.zrem(key, members) {
583            Ok(removed) => ShardResponse::ZRemLen {
584                count: removed.len(),
585                removed,
586            },
587            Err(_) => ShardResponse::WrongType,
588        },
589        ShardRequest::ZScore { key, member } => match ks.zscore(key, member) {
590            Ok(score) => ShardResponse::Score(score),
591            Err(_) => ShardResponse::WrongType,
592        },
593        ShardRequest::ZRank { key, member } => match ks.zrank(key, member) {
594            Ok(rank) => ShardResponse::Rank(rank),
595            Err(_) => ShardResponse::WrongType,
596        },
597        ShardRequest::ZCard { key } => match ks.zcard(key) {
598            Ok(len) => ShardResponse::Len(len),
599            Err(_) => ShardResponse::WrongType,
600        },
601        ShardRequest::ZRange {
602            key, start, stop, ..
603        } => match ks.zrange(key, *start, *stop) {
604            Ok(items) => ShardResponse::ScoredArray(items),
605            Err(_) => ShardResponse::WrongType,
606        },
607        ShardRequest::DbSize => ShardResponse::KeyCount(ks.len()),
608        ShardRequest::Stats => ShardResponse::Stats(ks.stats()),
609        ShardRequest::FlushDb => {
610            ks.clear();
611            ShardResponse::Ok
612        }
613        ShardRequest::Scan {
614            cursor,
615            count,
616            pattern,
617        } => {
618            let (next_cursor, keys) = ks.scan_keys(*cursor, *count, pattern.as_deref());
619            ShardResponse::Scan {
620                cursor: next_cursor,
621                keys,
622            }
623        }
624        ShardRequest::HSet { key, fields } => match ks.hset(key, fields) {
625            Ok(count) => ShardResponse::Len(count),
626            Err(WriteError::WrongType) => ShardResponse::WrongType,
627            Err(WriteError::OutOfMemory) => ShardResponse::OutOfMemory,
628        },
629        ShardRequest::HGet { key, field } => match ks.hget(key, field) {
630            Ok(val) => ShardResponse::Value(val.map(Value::String)),
631            Err(_) => ShardResponse::WrongType,
632        },
633        ShardRequest::HGetAll { key } => match ks.hgetall(key) {
634            Ok(fields) => ShardResponse::HashFields(fields),
635            Err(_) => ShardResponse::WrongType,
636        },
637        ShardRequest::HDel { key, fields } => match ks.hdel(key, fields) {
638            Ok(removed) => ShardResponse::HDelLen {
639                count: removed.len(),
640                removed,
641            },
642            Err(_) => ShardResponse::WrongType,
643        },
644        ShardRequest::HExists { key, field } => match ks.hexists(key, field) {
645            Ok(exists) => ShardResponse::Bool(exists),
646            Err(_) => ShardResponse::WrongType,
647        },
648        ShardRequest::HLen { key } => match ks.hlen(key) {
649            Ok(len) => ShardResponse::Len(len),
650            Err(_) => ShardResponse::WrongType,
651        },
652        ShardRequest::HIncrBy { key, field, delta } => match ks.hincrby(key, field, *delta) {
653            Ok(val) => ShardResponse::Integer(val),
654            Err(IncrError::WrongType) => ShardResponse::WrongType,
655            Err(IncrError::OutOfMemory) => ShardResponse::OutOfMemory,
656            Err(e) => ShardResponse::Err(e.to_string()),
657        },
658        ShardRequest::HKeys { key } => match ks.hkeys(key) {
659            Ok(keys) => ShardResponse::StringArray(keys),
660            Err(_) => ShardResponse::WrongType,
661        },
662        ShardRequest::HVals { key } => match ks.hvals(key) {
663            Ok(vals) => ShardResponse::Array(vals),
664            Err(_) => ShardResponse::WrongType,
665        },
666        ShardRequest::HMGet { key, fields } => match ks.hmget(key, fields) {
667            Ok(vals) => ShardResponse::OptionalArray(vals),
668            Err(_) => ShardResponse::WrongType,
669        },
670        ShardRequest::SAdd { key, members } => match ks.sadd(key, members) {
671            Ok(count) => ShardResponse::Len(count),
672            Err(WriteError::WrongType) => ShardResponse::WrongType,
673            Err(WriteError::OutOfMemory) => ShardResponse::OutOfMemory,
674        },
675        ShardRequest::SRem { key, members } => match ks.srem(key, members) {
676            Ok(count) => ShardResponse::Len(count),
677            Err(_) => ShardResponse::WrongType,
678        },
679        ShardRequest::SMembers { key } => match ks.smembers(key) {
680            Ok(members) => ShardResponse::StringArray(members),
681            Err(_) => ShardResponse::WrongType,
682        },
683        ShardRequest::SIsMember { key, member } => match ks.sismember(key, member) {
684            Ok(exists) => ShardResponse::Bool(exists),
685            Err(_) => ShardResponse::WrongType,
686        },
687        ShardRequest::SCard { key } => match ks.scard(key) {
688            Ok(count) => ShardResponse::Len(count),
689            Err(_) => ShardResponse::WrongType,
690        },
691        // snapshot/rewrite are handled in the main loop, not here
692        ShardRequest::Snapshot | ShardRequest::RewriteAof => ShardResponse::Ok,
693    }
694}
695
696/// Converts a successful mutation request+response pair into an AOF record.
697/// Returns None for non-mutation requests or failed mutations.
698fn to_aof_record(req: &ShardRequest, resp: &ShardResponse) -> Option<AofRecord> {
699    match (req, resp) {
700        (
701            ShardRequest::Set {
702                key, value, expire, ..
703            },
704            ShardResponse::Ok,
705        ) => {
706            let expire_ms = expire.map(|d| d.as_millis() as i64).unwrap_or(-1);
707            Some(AofRecord::Set {
708                key: key.clone(),
709                value: value.clone(),
710                expire_ms,
711            })
712        }
713        (ShardRequest::Del { key }, ShardResponse::Bool(true)) => {
714            Some(AofRecord::Del { key: key.clone() })
715        }
716        (ShardRequest::Expire { key, seconds }, ShardResponse::Bool(true)) => {
717            Some(AofRecord::Expire {
718                key: key.clone(),
719                seconds: *seconds,
720            })
721        }
722        (ShardRequest::LPush { key, values }, ShardResponse::Len(_)) => Some(AofRecord::LPush {
723            key: key.clone(),
724            values: values.clone(),
725        }),
726        (ShardRequest::RPush { key, values }, ShardResponse::Len(_)) => Some(AofRecord::RPush {
727            key: key.clone(),
728            values: values.clone(),
729        }),
730        (ShardRequest::LPop { key }, ShardResponse::Value(Some(_))) => {
731            Some(AofRecord::LPop { key: key.clone() })
732        }
733        (ShardRequest::RPop { key }, ShardResponse::Value(Some(_))) => {
734            Some(AofRecord::RPop { key: key.clone() })
735        }
736        (ShardRequest::ZAdd { key, .. }, ShardResponse::ZAddLen { applied, .. })
737            if !applied.is_empty() =>
738        {
739            Some(AofRecord::ZAdd {
740                key: key.clone(),
741                members: applied.clone(),
742            })
743        }
744        (ShardRequest::ZRem { key, .. }, ShardResponse::ZRemLen { removed, .. })
745            if !removed.is_empty() =>
746        {
747            Some(AofRecord::ZRem {
748                key: key.clone(),
749                members: removed.clone(),
750            })
751        }
752        (ShardRequest::Incr { key }, ShardResponse::Integer(_)) => {
753            Some(AofRecord::Incr { key: key.clone() })
754        }
755        (ShardRequest::Decr { key }, ShardResponse::Integer(_)) => {
756            Some(AofRecord::Decr { key: key.clone() })
757        }
758        (ShardRequest::Persist { key }, ShardResponse::Bool(true)) => {
759            Some(AofRecord::Persist { key: key.clone() })
760        }
761        (ShardRequest::Pexpire { key, milliseconds }, ShardResponse::Bool(true)) => {
762            Some(AofRecord::Pexpire {
763                key: key.clone(),
764                milliseconds: *milliseconds,
765            })
766        }
767        // Hash commands
768        (ShardRequest::HSet { key, fields }, ShardResponse::Len(_)) => Some(AofRecord::HSet {
769            key: key.clone(),
770            fields: fields.clone(),
771        }),
772        (ShardRequest::HDel { key, .. }, ShardResponse::HDelLen { removed, .. })
773            if !removed.is_empty() =>
774        {
775            Some(AofRecord::HDel {
776                key: key.clone(),
777                fields: removed.clone(),
778            })
779        }
780        (ShardRequest::HIncrBy { key, field, delta }, ShardResponse::Integer(_)) => {
781            Some(AofRecord::HIncrBy {
782                key: key.clone(),
783                field: field.clone(),
784                delta: *delta,
785            })
786        }
787        // Set commands
788        (ShardRequest::SAdd { key, members }, ShardResponse::Len(count)) if *count > 0 => {
789            Some(AofRecord::SAdd {
790                key: key.clone(),
791                members: members.clone(),
792            })
793        }
794        (ShardRequest::SRem { key, members }, ShardResponse::Len(count)) if *count > 0 => {
795            Some(AofRecord::SRem {
796                key: key.clone(),
797                members: members.clone(),
798            })
799        }
800        _ => None,
801    }
802}
803
804/// Writes a snapshot of the current keyspace.
805fn handle_snapshot(
806    keyspace: &Keyspace,
807    persistence: &Option<ShardPersistenceConfig>,
808    shard_id: u16,
809) -> ShardResponse {
810    let pcfg = match persistence {
811        Some(p) => p,
812        None => return ShardResponse::Err("persistence not configured".into()),
813    };
814
815    let path = snapshot::snapshot_path(&pcfg.data_dir, shard_id);
816    match write_snapshot(keyspace, &path, shard_id) {
817        Ok(count) => {
818            info!(shard_id, entries = count, "snapshot written");
819            ShardResponse::Ok
820        }
821        Err(e) => {
822            warn!(shard_id, "snapshot failed: {e}");
823            ShardResponse::Err(format!("snapshot failed: {e}"))
824        }
825    }
826}
827
828/// Writes a snapshot and then truncates the AOF.
829fn handle_rewrite(
830    keyspace: &Keyspace,
831    persistence: &Option<ShardPersistenceConfig>,
832    aof_writer: &mut Option<AofWriter>,
833    shard_id: u16,
834) -> ShardResponse {
835    let pcfg = match persistence {
836        Some(p) => p,
837        None => return ShardResponse::Err("persistence not configured".into()),
838    };
839
840    let path = snapshot::snapshot_path(&pcfg.data_dir, shard_id);
841    match write_snapshot(keyspace, &path, shard_id) {
842        Ok(count) => {
843            // truncate AOF after successful snapshot
844            if let Some(ref mut writer) = aof_writer {
845                if let Err(e) = writer.truncate() {
846                    warn!(shard_id, "aof truncate after rewrite failed: {e}");
847                }
848            }
849            info!(shard_id, entries = count, "aof rewrite complete");
850            ShardResponse::Ok
851        }
852        Err(e) => {
853            warn!(shard_id, "aof rewrite failed: {e}");
854            ShardResponse::Err(format!("rewrite failed: {e}"))
855        }
856    }
857}
858
859/// Iterates the keyspace and writes all live entries to a snapshot file.
860fn write_snapshot(
861    keyspace: &Keyspace,
862    path: &std::path::Path,
863    shard_id: u16,
864) -> Result<u32, ember_persistence::format::FormatError> {
865    let mut writer = SnapshotWriter::create(path, shard_id)?;
866    let mut count = 0u32;
867
868    for (key, value, ttl_ms) in keyspace.iter_entries() {
869        let snap_value = match value {
870            Value::String(data) => SnapValue::String(data.clone()),
871            Value::List(deque) => SnapValue::List(deque.clone()),
872            Value::SortedSet(ss) => {
873                let members: Vec<(f64, String)> = ss
874                    .iter()
875                    .map(|(member, score)| (score, member.to_owned()))
876                    .collect();
877                SnapValue::SortedSet(members)
878            }
879            Value::Hash(map) => SnapValue::Hash(map.clone()),
880            Value::Set(set) => SnapValue::Set(set.clone()),
881        };
882        writer.write_entry(&SnapEntry {
883            key: key.to_owned(),
884            value: snap_value,
885            expire_ms: ttl_ms,
886        })?;
887        count += 1;
888    }
889
890    writer.finish()?;
891    Ok(count)
892}
893
894#[cfg(test)]
895mod tests {
896    use super::*;
897
898    #[test]
899    fn dispatch_set_and_get() {
900        let mut ks = Keyspace::new();
901
902        let resp = dispatch(
903            &mut ks,
904            &ShardRequest::Set {
905                key: "k".into(),
906                value: Bytes::from("v"),
907                expire: None,
908                nx: false,
909                xx: false,
910            },
911        );
912        assert!(matches!(resp, ShardResponse::Ok));
913
914        let resp = dispatch(&mut ks, &ShardRequest::Get { key: "k".into() });
915        match resp {
916            ShardResponse::Value(Some(Value::String(data))) => {
917                assert_eq!(data, Bytes::from("v"));
918            }
919            other => panic!("expected Value(Some(String)), got {other:?}"),
920        }
921    }
922
923    #[test]
924    fn dispatch_get_missing() {
925        let mut ks = Keyspace::new();
926        let resp = dispatch(&mut ks, &ShardRequest::Get { key: "nope".into() });
927        assert!(matches!(resp, ShardResponse::Value(None)));
928    }
929
930    #[test]
931    fn dispatch_del() {
932        let mut ks = Keyspace::new();
933        ks.set("key".into(), Bytes::from("val"), None);
934
935        let resp = dispatch(&mut ks, &ShardRequest::Del { key: "key".into() });
936        assert!(matches!(resp, ShardResponse::Bool(true)));
937
938        let resp = dispatch(&mut ks, &ShardRequest::Del { key: "key".into() });
939        assert!(matches!(resp, ShardResponse::Bool(false)));
940    }
941
942    #[test]
943    fn dispatch_exists() {
944        let mut ks = Keyspace::new();
945        ks.set("yes".into(), Bytes::from("here"), None);
946
947        let resp = dispatch(&mut ks, &ShardRequest::Exists { key: "yes".into() });
948        assert!(matches!(resp, ShardResponse::Bool(true)));
949
950        let resp = dispatch(&mut ks, &ShardRequest::Exists { key: "no".into() });
951        assert!(matches!(resp, ShardResponse::Bool(false)));
952    }
953
954    #[test]
955    fn dispatch_expire_and_ttl() {
956        let mut ks = Keyspace::new();
957        ks.set("key".into(), Bytes::from("val"), None);
958
959        let resp = dispatch(
960            &mut ks,
961            &ShardRequest::Expire {
962                key: "key".into(),
963                seconds: 60,
964            },
965        );
966        assert!(matches!(resp, ShardResponse::Bool(true)));
967
968        let resp = dispatch(&mut ks, &ShardRequest::Ttl { key: "key".into() });
969        match resp {
970            ShardResponse::Ttl(TtlResult::Seconds(s)) => assert!((58..=60).contains(&s)),
971            other => panic!("expected Ttl(Seconds), got {other:?}"),
972        }
973    }
974
975    #[test]
976    fn dispatch_ttl_missing() {
977        let mut ks = Keyspace::new();
978        let resp = dispatch(&mut ks, &ShardRequest::Ttl { key: "gone".into() });
979        assert!(matches!(resp, ShardResponse::Ttl(TtlResult::NotFound)));
980    }
981
982    #[tokio::test]
983    async fn shard_round_trip() {
984        let handle = spawn_shard(16, ShardConfig::default(), None);
985
986        let resp = handle
987            .send(ShardRequest::Set {
988                key: "hello".into(),
989                value: Bytes::from("world"),
990                expire: None,
991                nx: false,
992                xx: false,
993            })
994            .await
995            .unwrap();
996        assert!(matches!(resp, ShardResponse::Ok));
997
998        let resp = handle
999            .send(ShardRequest::Get {
1000                key: "hello".into(),
1001            })
1002            .await
1003            .unwrap();
1004        match resp {
1005            ShardResponse::Value(Some(Value::String(data))) => {
1006                assert_eq!(data, Bytes::from("world"));
1007            }
1008            other => panic!("expected Value(Some(String)), got {other:?}"),
1009        }
1010    }
1011
1012    #[tokio::test]
1013    async fn expired_key_through_shard() {
1014        let handle = spawn_shard(16, ShardConfig::default(), None);
1015
1016        handle
1017            .send(ShardRequest::Set {
1018                key: "temp".into(),
1019                value: Bytes::from("gone"),
1020                expire: Some(Duration::from_millis(10)),
1021                nx: false,
1022                xx: false,
1023            })
1024            .await
1025            .unwrap();
1026
1027        tokio::time::sleep(Duration::from_millis(30)).await;
1028
1029        let resp = handle
1030            .send(ShardRequest::Get { key: "temp".into() })
1031            .await
1032            .unwrap();
1033        assert!(matches!(resp, ShardResponse::Value(None)));
1034    }
1035
1036    #[tokio::test]
1037    async fn active_expiration_cleans_up_without_access() {
1038        let handle = spawn_shard(16, ShardConfig::default(), None);
1039
1040        // set a key with a short TTL
1041        handle
1042            .send(ShardRequest::Set {
1043                key: "ephemeral".into(),
1044                value: Bytes::from("temp"),
1045                expire: Some(Duration::from_millis(10)),
1046                nx: false,
1047                xx: false,
1048            })
1049            .await
1050            .unwrap();
1051
1052        // also set a persistent key
1053        handle
1054            .send(ShardRequest::Set {
1055                key: "persistent".into(),
1056                value: Bytes::from("stays"),
1057                expire: None,
1058                nx: false,
1059                xx: false,
1060            })
1061            .await
1062            .unwrap();
1063
1064        // wait long enough for the TTL to expire AND for the background
1065        // tick to fire (100ms interval + some slack)
1066        tokio::time::sleep(Duration::from_millis(250)).await;
1067
1068        // the ephemeral key should be gone even though we never accessed it
1069        let resp = handle
1070            .send(ShardRequest::Exists {
1071                key: "ephemeral".into(),
1072            })
1073            .await
1074            .unwrap();
1075        assert!(matches!(resp, ShardResponse::Bool(false)));
1076
1077        // the persistent key should still be there
1078        let resp = handle
1079            .send(ShardRequest::Exists {
1080                key: "persistent".into(),
1081            })
1082            .await
1083            .unwrap();
1084        assert!(matches!(resp, ShardResponse::Bool(true)));
1085    }
1086
1087    #[tokio::test]
1088    async fn shard_with_persistence_snapshot_and_recovery() {
1089        let dir = tempfile::tempdir().unwrap();
1090        let pcfg = ShardPersistenceConfig {
1091            data_dir: dir.path().to_owned(),
1092            append_only: true,
1093            fsync_policy: FsyncPolicy::Always,
1094        };
1095        let config = ShardConfig {
1096            shard_id: 0,
1097            ..ShardConfig::default()
1098        };
1099
1100        // write some keys then trigger a snapshot
1101        {
1102            let handle = spawn_shard(16, config.clone(), Some(pcfg.clone()));
1103            handle
1104                .send(ShardRequest::Set {
1105                    key: "a".into(),
1106                    value: Bytes::from("1"),
1107                    expire: None,
1108                    nx: false,
1109                    xx: false,
1110                })
1111                .await
1112                .unwrap();
1113            handle
1114                .send(ShardRequest::Set {
1115                    key: "b".into(),
1116                    value: Bytes::from("2"),
1117                    expire: Some(Duration::from_secs(300)),
1118                    nx: false,
1119                    xx: false,
1120                })
1121                .await
1122                .unwrap();
1123            handle.send(ShardRequest::Snapshot).await.unwrap();
1124            // write one more key that goes only to AOF
1125            handle
1126                .send(ShardRequest::Set {
1127                    key: "c".into(),
1128                    value: Bytes::from("3"),
1129                    expire: None,
1130                    nx: false,
1131                    xx: false,
1132                })
1133                .await
1134                .unwrap();
1135            // drop handle to shut down shard
1136        }
1137
1138        // give it a moment to flush
1139        tokio::time::sleep(Duration::from_millis(50)).await;
1140
1141        // start a new shard with the same config — should recover
1142        {
1143            let handle = spawn_shard(16, config, Some(pcfg));
1144            // give it a moment to recover
1145            tokio::time::sleep(Duration::from_millis(50)).await;
1146
1147            let resp = handle
1148                .send(ShardRequest::Get { key: "a".into() })
1149                .await
1150                .unwrap();
1151            match resp {
1152                ShardResponse::Value(Some(Value::String(data))) => {
1153                    assert_eq!(data, Bytes::from("1"));
1154                }
1155                other => panic!("expected a=1, got {other:?}"),
1156            }
1157
1158            let resp = handle
1159                .send(ShardRequest::Get { key: "b".into() })
1160                .await
1161                .unwrap();
1162            assert!(matches!(resp, ShardResponse::Value(Some(_))));
1163
1164            let resp = handle
1165                .send(ShardRequest::Get { key: "c".into() })
1166                .await
1167                .unwrap();
1168            match resp {
1169                ShardResponse::Value(Some(Value::String(data))) => {
1170                    assert_eq!(data, Bytes::from("3"));
1171                }
1172                other => panic!("expected c=3, got {other:?}"),
1173            }
1174        }
1175    }
1176
1177    #[test]
1178    fn to_aof_record_for_set() {
1179        let req = ShardRequest::Set {
1180            key: "k".into(),
1181            value: Bytes::from("v"),
1182            expire: Some(Duration::from_secs(60)),
1183            nx: false,
1184            xx: false,
1185        };
1186        let resp = ShardResponse::Ok;
1187        let record = to_aof_record(&req, &resp).unwrap();
1188        match record {
1189            AofRecord::Set { key, expire_ms, .. } => {
1190                assert_eq!(key, "k");
1191                assert_eq!(expire_ms, 60_000);
1192            }
1193            other => panic!("expected Set, got {other:?}"),
1194        }
1195    }
1196
1197    #[test]
1198    fn to_aof_record_skips_failed_set() {
1199        let req = ShardRequest::Set {
1200            key: "k".into(),
1201            value: Bytes::from("v"),
1202            expire: None,
1203            nx: false,
1204            xx: false,
1205        };
1206        let resp = ShardResponse::OutOfMemory;
1207        assert!(to_aof_record(&req, &resp).is_none());
1208    }
1209
1210    #[test]
1211    fn to_aof_record_for_del() {
1212        let req = ShardRequest::Del { key: "k".into() };
1213        let resp = ShardResponse::Bool(true);
1214        let record = to_aof_record(&req, &resp).unwrap();
1215        assert!(matches!(record, AofRecord::Del { .. }));
1216    }
1217
1218    #[test]
1219    fn to_aof_record_skips_failed_del() {
1220        let req = ShardRequest::Del { key: "k".into() };
1221        let resp = ShardResponse::Bool(false);
1222        assert!(to_aof_record(&req, &resp).is_none());
1223    }
1224
1225    #[test]
1226    fn dispatch_incr_new_key() {
1227        let mut ks = Keyspace::new();
1228        let resp = dispatch(&mut ks, &ShardRequest::Incr { key: "c".into() });
1229        assert!(matches!(resp, ShardResponse::Integer(1)));
1230    }
1231
1232    #[test]
1233    fn dispatch_decr_existing() {
1234        let mut ks = Keyspace::new();
1235        ks.set("n".into(), Bytes::from("10"), None);
1236        let resp = dispatch(&mut ks, &ShardRequest::Decr { key: "n".into() });
1237        assert!(matches!(resp, ShardResponse::Integer(9)));
1238    }
1239
1240    #[test]
1241    fn dispatch_incr_non_integer() {
1242        let mut ks = Keyspace::new();
1243        ks.set("s".into(), Bytes::from("hello"), None);
1244        let resp = dispatch(&mut ks, &ShardRequest::Incr { key: "s".into() });
1245        assert!(matches!(resp, ShardResponse::Err(_)));
1246    }
1247
1248    #[test]
1249    fn to_aof_record_for_incr() {
1250        let req = ShardRequest::Incr { key: "c".into() };
1251        let resp = ShardResponse::Integer(1);
1252        let record = to_aof_record(&req, &resp).unwrap();
1253        assert!(matches!(record, AofRecord::Incr { .. }));
1254    }
1255
1256    #[test]
1257    fn to_aof_record_for_decr() {
1258        let req = ShardRequest::Decr { key: "c".into() };
1259        let resp = ShardResponse::Integer(-1);
1260        let record = to_aof_record(&req, &resp).unwrap();
1261        assert!(matches!(record, AofRecord::Decr { .. }));
1262    }
1263
1264    #[test]
1265    fn dispatch_persist_removes_ttl() {
1266        let mut ks = Keyspace::new();
1267        ks.set(
1268            "key".into(),
1269            Bytes::from("val"),
1270            Some(Duration::from_secs(60)),
1271        );
1272
1273        let resp = dispatch(&mut ks, &ShardRequest::Persist { key: "key".into() });
1274        assert!(matches!(resp, ShardResponse::Bool(true)));
1275
1276        let resp = dispatch(&mut ks, &ShardRequest::Ttl { key: "key".into() });
1277        assert!(matches!(resp, ShardResponse::Ttl(TtlResult::NoExpiry)));
1278    }
1279
1280    #[test]
1281    fn dispatch_persist_missing_key() {
1282        let mut ks = Keyspace::new();
1283        let resp = dispatch(&mut ks, &ShardRequest::Persist { key: "nope".into() });
1284        assert!(matches!(resp, ShardResponse::Bool(false)));
1285    }
1286
1287    #[test]
1288    fn dispatch_pttl() {
1289        let mut ks = Keyspace::new();
1290        ks.set(
1291            "key".into(),
1292            Bytes::from("val"),
1293            Some(Duration::from_secs(60)),
1294        );
1295
1296        let resp = dispatch(&mut ks, &ShardRequest::Pttl { key: "key".into() });
1297        match resp {
1298            ShardResponse::Ttl(TtlResult::Milliseconds(ms)) => {
1299                assert!(ms > 59_000 && ms <= 60_000);
1300            }
1301            other => panic!("expected Ttl(Milliseconds), got {other:?}"),
1302        }
1303    }
1304
1305    #[test]
1306    fn dispatch_pttl_missing() {
1307        let mut ks = Keyspace::new();
1308        let resp = dispatch(&mut ks, &ShardRequest::Pttl { key: "nope".into() });
1309        assert!(matches!(resp, ShardResponse::Ttl(TtlResult::NotFound)));
1310    }
1311
1312    #[test]
1313    fn dispatch_pexpire() {
1314        let mut ks = Keyspace::new();
1315        ks.set("key".into(), Bytes::from("val"), None);
1316
1317        let resp = dispatch(
1318            &mut ks,
1319            &ShardRequest::Pexpire {
1320                key: "key".into(),
1321                milliseconds: 5000,
1322            },
1323        );
1324        assert!(matches!(resp, ShardResponse::Bool(true)));
1325
1326        let resp = dispatch(&mut ks, &ShardRequest::Pttl { key: "key".into() });
1327        match resp {
1328            ShardResponse::Ttl(TtlResult::Milliseconds(ms)) => {
1329                assert!(ms > 4000 && ms <= 5000);
1330            }
1331            other => panic!("expected Ttl(Milliseconds), got {other:?}"),
1332        }
1333    }
1334
1335    #[test]
1336    fn to_aof_record_for_persist() {
1337        let req = ShardRequest::Persist { key: "k".into() };
1338        let resp = ShardResponse::Bool(true);
1339        let record = to_aof_record(&req, &resp).unwrap();
1340        assert!(matches!(record, AofRecord::Persist { .. }));
1341    }
1342
1343    #[test]
1344    fn to_aof_record_skips_failed_persist() {
1345        let req = ShardRequest::Persist { key: "k".into() };
1346        let resp = ShardResponse::Bool(false);
1347        assert!(to_aof_record(&req, &resp).is_none());
1348    }
1349
1350    #[test]
1351    fn to_aof_record_for_pexpire() {
1352        let req = ShardRequest::Pexpire {
1353            key: "k".into(),
1354            milliseconds: 5000,
1355        };
1356        let resp = ShardResponse::Bool(true);
1357        let record = to_aof_record(&req, &resp).unwrap();
1358        match record {
1359            AofRecord::Pexpire { key, milliseconds } => {
1360                assert_eq!(key, "k");
1361                assert_eq!(milliseconds, 5000);
1362            }
1363            other => panic!("expected Pexpire, got {other:?}"),
1364        }
1365    }
1366
1367    #[test]
1368    fn to_aof_record_skips_failed_pexpire() {
1369        let req = ShardRequest::Pexpire {
1370            key: "k".into(),
1371            milliseconds: 5000,
1372        };
1373        let resp = ShardResponse::Bool(false);
1374        assert!(to_aof_record(&req, &resp).is_none());
1375    }
1376
1377    #[test]
1378    fn dispatch_set_nx_when_key_missing() {
1379        let mut ks = Keyspace::new();
1380        let resp = dispatch(
1381            &mut ks,
1382            &ShardRequest::Set {
1383                key: "k".into(),
1384                value: Bytes::from("v"),
1385                expire: None,
1386                nx: true,
1387                xx: false,
1388            },
1389        );
1390        assert!(matches!(resp, ShardResponse::Ok));
1391        assert!(ks.exists("k"));
1392    }
1393
1394    #[test]
1395    fn dispatch_set_nx_when_key_exists() {
1396        let mut ks = Keyspace::new();
1397        ks.set("k".into(), Bytes::from("old"), None);
1398
1399        let resp = dispatch(
1400            &mut ks,
1401            &ShardRequest::Set {
1402                key: "k".into(),
1403                value: Bytes::from("new"),
1404                expire: None,
1405                nx: true,
1406                xx: false,
1407            },
1408        );
1409        // NX should block — returns nil
1410        assert!(matches!(resp, ShardResponse::Value(None)));
1411        // original value should remain
1412        match ks.get("k").unwrap() {
1413            Some(Value::String(data)) => assert_eq!(data, Bytes::from("old")),
1414            other => panic!("expected old value, got {other:?}"),
1415        }
1416    }
1417
1418    #[test]
1419    fn dispatch_set_xx_when_key_exists() {
1420        let mut ks = Keyspace::new();
1421        ks.set("k".into(), Bytes::from("old"), None);
1422
1423        let resp = dispatch(
1424            &mut ks,
1425            &ShardRequest::Set {
1426                key: "k".into(),
1427                value: Bytes::from("new"),
1428                expire: None,
1429                nx: false,
1430                xx: true,
1431            },
1432        );
1433        assert!(matches!(resp, ShardResponse::Ok));
1434        match ks.get("k").unwrap() {
1435            Some(Value::String(data)) => assert_eq!(data, Bytes::from("new")),
1436            other => panic!("expected new value, got {other:?}"),
1437        }
1438    }
1439
1440    #[test]
1441    fn dispatch_set_xx_when_key_missing() {
1442        let mut ks = Keyspace::new();
1443        let resp = dispatch(
1444            &mut ks,
1445            &ShardRequest::Set {
1446                key: "k".into(),
1447                value: Bytes::from("v"),
1448                expire: None,
1449                nx: false,
1450                xx: true,
1451            },
1452        );
1453        // XX should block — returns nil
1454        assert!(matches!(resp, ShardResponse::Value(None)));
1455        assert!(!ks.exists("k"));
1456    }
1457
1458    #[test]
1459    fn to_aof_record_skips_nx_blocked_set() {
1460        let req = ShardRequest::Set {
1461            key: "k".into(),
1462            value: Bytes::from("v"),
1463            expire: None,
1464            nx: true,
1465            xx: false,
1466        };
1467        // when NX blocks, the shard returns Value(None), not Ok
1468        let resp = ShardResponse::Value(None);
1469        assert!(to_aof_record(&req, &resp).is_none());
1470    }
1471
1472    #[test]
1473    fn dispatch_flushdb_clears_all_keys() {
1474        let mut ks = Keyspace::new();
1475        ks.set("a".into(), Bytes::from("1"), None);
1476        ks.set("b".into(), Bytes::from("2"), None);
1477
1478        assert_eq!(ks.len(), 2);
1479
1480        let resp = dispatch(&mut ks, &ShardRequest::FlushDb);
1481        assert!(matches!(resp, ShardResponse::Ok));
1482        assert_eq!(ks.len(), 0);
1483    }
1484
1485    #[test]
1486    fn dispatch_scan_returns_keys() {
1487        let mut ks = Keyspace::new();
1488        ks.set("user:1".into(), Bytes::from("a"), None);
1489        ks.set("user:2".into(), Bytes::from("b"), None);
1490        ks.set("item:1".into(), Bytes::from("c"), None);
1491
1492        let resp = dispatch(
1493            &mut ks,
1494            &ShardRequest::Scan {
1495                cursor: 0,
1496                count: 10,
1497                pattern: None,
1498            },
1499        );
1500
1501        match resp {
1502            ShardResponse::Scan { cursor, keys } => {
1503                assert_eq!(cursor, 0); // complete in one pass
1504                assert_eq!(keys.len(), 3);
1505            }
1506            _ => panic!("expected Scan response"),
1507        }
1508    }
1509
1510    #[test]
1511    fn dispatch_scan_with_pattern() {
1512        let mut ks = Keyspace::new();
1513        ks.set("user:1".into(), Bytes::from("a"), None);
1514        ks.set("user:2".into(), Bytes::from("b"), None);
1515        ks.set("item:1".into(), Bytes::from("c"), None);
1516
1517        let resp = dispatch(
1518            &mut ks,
1519            &ShardRequest::Scan {
1520                cursor: 0,
1521                count: 10,
1522                pattern: Some("user:*".into()),
1523            },
1524        );
1525
1526        match resp {
1527            ShardResponse::Scan { cursor, keys } => {
1528                assert_eq!(cursor, 0);
1529                assert_eq!(keys.len(), 2);
1530                for k in &keys {
1531                    assert!(k.starts_with("user:"));
1532                }
1533            }
1534            _ => panic!("expected Scan response"),
1535        }
1536    }
1537
1538    #[test]
1539    fn to_aof_record_for_hset() {
1540        let req = ShardRequest::HSet {
1541            key: "h".into(),
1542            fields: vec![("f1".into(), Bytes::from("v1"))],
1543        };
1544        let resp = ShardResponse::Len(1);
1545        let record = to_aof_record(&req, &resp).unwrap();
1546        match record {
1547            AofRecord::HSet { key, fields } => {
1548                assert_eq!(key, "h");
1549                assert_eq!(fields.len(), 1);
1550            }
1551            _ => panic!("expected HSet record"),
1552        }
1553    }
1554
1555    #[test]
1556    fn to_aof_record_for_hdel() {
1557        let req = ShardRequest::HDel {
1558            key: "h".into(),
1559            fields: vec!["f1".into(), "f2".into()],
1560        };
1561        let resp = ShardResponse::HDelLen {
1562            count: 2,
1563            removed: vec!["f1".into(), "f2".into()],
1564        };
1565        let record = to_aof_record(&req, &resp).unwrap();
1566        match record {
1567            AofRecord::HDel { key, fields } => {
1568                assert_eq!(key, "h");
1569                assert_eq!(fields.len(), 2);
1570            }
1571            _ => panic!("expected HDel record"),
1572        }
1573    }
1574
1575    #[test]
1576    fn to_aof_record_skips_hdel_when_none_removed() {
1577        let req = ShardRequest::HDel {
1578            key: "h".into(),
1579            fields: vec!["f1".into()],
1580        };
1581        let resp = ShardResponse::HDelLen {
1582            count: 0,
1583            removed: vec![],
1584        };
1585        assert!(to_aof_record(&req, &resp).is_none());
1586    }
1587
1588    #[test]
1589    fn to_aof_record_for_hincrby() {
1590        let req = ShardRequest::HIncrBy {
1591            key: "h".into(),
1592            field: "counter".into(),
1593            delta: 5,
1594        };
1595        let resp = ShardResponse::Integer(10);
1596        let record = to_aof_record(&req, &resp).unwrap();
1597        match record {
1598            AofRecord::HIncrBy { key, field, delta } => {
1599                assert_eq!(key, "h");
1600                assert_eq!(field, "counter");
1601                assert_eq!(delta, 5);
1602            }
1603            _ => panic!("expected HIncrBy record"),
1604        }
1605    }
1606
1607    #[test]
1608    fn to_aof_record_for_sadd() {
1609        let req = ShardRequest::SAdd {
1610            key: "s".into(),
1611            members: vec!["m1".into(), "m2".into()],
1612        };
1613        let resp = ShardResponse::Len(2);
1614        let record = to_aof_record(&req, &resp).unwrap();
1615        match record {
1616            AofRecord::SAdd { key, members } => {
1617                assert_eq!(key, "s");
1618                assert_eq!(members.len(), 2);
1619            }
1620            _ => panic!("expected SAdd record"),
1621        }
1622    }
1623
1624    #[test]
1625    fn to_aof_record_skips_sadd_when_none_added() {
1626        let req = ShardRequest::SAdd {
1627            key: "s".into(),
1628            members: vec!["m1".into()],
1629        };
1630        let resp = ShardResponse::Len(0);
1631        assert!(to_aof_record(&req, &resp).is_none());
1632    }
1633
1634    #[test]
1635    fn to_aof_record_for_srem() {
1636        let req = ShardRequest::SRem {
1637            key: "s".into(),
1638            members: vec!["m1".into()],
1639        };
1640        let resp = ShardResponse::Len(1);
1641        let record = to_aof_record(&req, &resp).unwrap();
1642        match record {
1643            AofRecord::SRem { key, members } => {
1644                assert_eq!(key, "s");
1645                assert_eq!(members.len(), 1);
1646            }
1647            _ => panic!("expected SRem record"),
1648        }
1649    }
1650
1651    #[test]
1652    fn to_aof_record_skips_srem_when_none_removed() {
1653        let req = ShardRequest::SRem {
1654            key: "s".into(),
1655            members: vec!["m1".into()],
1656        };
1657        let resp = ShardResponse::Len(0);
1658        assert!(to_aof_record(&req, &resp).is_none());
1659    }
1660}