Skip to main content

ember_core/
shard.rs

1//! Shard: an independent partition of the keyspace.
2//!
3//! Each shard runs as its own tokio task, owning a `Keyspace` with no
4//! internal locking. Commands arrive over an mpsc channel and responses
5//! go back on a per-request oneshot. A background tick drives active
6//! expiration of TTL'd keys.
7
8use std::path::PathBuf;
9use std::time::Duration;
10
11use bytes::Bytes;
12use ember_persistence::aof::{AofRecord, AofWriter, FsyncPolicy};
13use ember_persistence::recovery::{self, RecoveredValue};
14use ember_persistence::snapshot::{self, SnapEntry, SnapValue, SnapshotWriter};
15use tokio::sync::{mpsc, oneshot};
16use tracing::{info, warn};
17
18use crate::error::ShardError;
19use crate::expiry;
20use crate::keyspace::{
21    IncrError, Keyspace, KeyspaceStats, SetResult, ShardConfig, TtlResult, WriteError,
22};
23use crate::types::sorted_set::ZAddFlags;
24use crate::types::Value;
25
26/// How often the shard runs active expiration. 100ms matches
27/// Redis's hz=10 default and keeps CPU overhead negligible.
28const EXPIRY_TICK: Duration = Duration::from_millis(100);
29
30/// How often to fsync when using the `EverySec` policy.
31const FSYNC_INTERVAL: Duration = Duration::from_secs(1);
32
33/// Optional persistence configuration for a shard.
34#[derive(Debug, Clone)]
35pub struct ShardPersistenceConfig {
36    /// Directory where AOF and snapshot files live.
37    pub data_dir: PathBuf,
38    /// Whether to write an AOF log of mutations.
39    pub append_only: bool,
40    /// When to fsync the AOF file.
41    pub fsync_policy: FsyncPolicy,
42}
43
44/// A protocol-agnostic command sent to a shard.
45#[derive(Debug)]
46pub enum ShardRequest {
47    Get {
48        key: String,
49    },
50    Set {
51        key: String,
52        value: Bytes,
53        expire: Option<Duration>,
54        /// Only set the key if it does not already exist.
55        nx: bool,
56        /// Only set the key if it already exists.
57        xx: bool,
58    },
59    Incr {
60        key: String,
61    },
62    Decr {
63        key: String,
64    },
65    Del {
66        key: String,
67    },
68    Exists {
69        key: String,
70    },
71    Expire {
72        key: String,
73        seconds: u64,
74    },
75    Ttl {
76        key: String,
77    },
78    Persist {
79        key: String,
80    },
81    Pttl {
82        key: String,
83    },
84    Pexpire {
85        key: String,
86        milliseconds: u64,
87    },
88    LPush {
89        key: String,
90        values: Vec<Bytes>,
91    },
92    RPush {
93        key: String,
94        values: Vec<Bytes>,
95    },
96    LPop {
97        key: String,
98    },
99    RPop {
100        key: String,
101    },
102    LRange {
103        key: String,
104        start: i64,
105        stop: i64,
106    },
107    LLen {
108        key: String,
109    },
110    Type {
111        key: String,
112    },
113    ZAdd {
114        key: String,
115        members: Vec<(f64, String)>,
116        nx: bool,
117        xx: bool,
118        gt: bool,
119        lt: bool,
120        ch: bool,
121    },
122    ZRem {
123        key: String,
124        members: Vec<String>,
125    },
126    ZScore {
127        key: String,
128        member: String,
129    },
130    ZRank {
131        key: String,
132        member: String,
133    },
134    ZCard {
135        key: String,
136    },
137    ZRange {
138        key: String,
139        start: i64,
140        stop: i64,
141        with_scores: bool,
142    },
143    /// Returns the key count for this shard.
144    DbSize,
145    /// Returns keyspace stats for this shard.
146    Stats,
147    /// Triggers a snapshot write.
148    Snapshot,
149    /// Triggers an AOF rewrite (snapshot + truncate AOF).
150    RewriteAof,
151    /// Clears all keys from the keyspace.
152    FlushDb,
153    /// Scans keys in the keyspace.
154    Scan {
155        cursor: u64,
156        count: usize,
157        pattern: Option<String>,
158    },
159}
160
161/// The shard's response to a request.
162#[derive(Debug)]
163pub enum ShardResponse {
164    /// A value (or None for a cache miss).
165    Value(Option<Value>),
166    /// Simple acknowledgement (e.g. SET).
167    Ok,
168    /// Integer result (e.g. INCR, DECR).
169    Integer(i64),
170    /// Boolean result (e.g. DEL, EXISTS, EXPIRE).
171    Bool(bool),
172    /// TTL query result.
173    Ttl(TtlResult),
174    /// Memory limit reached and eviction policy is NoEviction.
175    OutOfMemory,
176    /// Key count for a shard (DBSIZE).
177    KeyCount(usize),
178    /// Full stats for a shard (INFO).
179    Stats(KeyspaceStats),
180    /// Integer length result (e.g. LPUSH, RPUSH, LLEN).
181    Len(usize),
182    /// Array of bulk values (e.g. LRANGE).
183    Array(Vec<Bytes>),
184    /// The type name of a stored value.
185    TypeName(&'static str),
186    /// ZADD result: count for the client + actually applied members for AOF.
187    ZAddLen {
188        count: usize,
189        applied: Vec<(f64, String)>,
190    },
191    /// ZREM result: count for the client + actually removed members for AOF.
192    ZRemLen { count: usize, removed: Vec<String> },
193    /// Float score result (e.g. ZSCORE).
194    Score(Option<f64>),
195    /// Rank result (e.g. ZRANK).
196    Rank(Option<usize>),
197    /// Scored array of (member, score) pairs (e.g. ZRANGE).
198    ScoredArray(Vec<(String, f64)>),
199    /// Command used against a key holding the wrong kind of value.
200    WrongType,
201    /// An error message.
202    Err(String),
203    /// Scan result: next cursor and list of keys.
204    Scan { cursor: u64, keys: Vec<String> },
205}
206
207/// A request bundled with its reply channel.
208#[derive(Debug)]
209pub struct ShardMessage {
210    pub request: ShardRequest,
211    pub reply: oneshot::Sender<ShardResponse>,
212}
213
214/// A cloneable handle for sending commands to a shard task.
215///
216/// Wraps the mpsc sender so callers don't need to manage oneshot
217/// channels directly.
218#[derive(Debug, Clone)]
219pub struct ShardHandle {
220    tx: mpsc::Sender<ShardMessage>,
221}
222
223impl ShardHandle {
224    /// Sends a request and waits for the response.
225    ///
226    /// Returns `ShardError::Unavailable` if the shard task has stopped.
227    pub async fn send(&self, request: ShardRequest) -> Result<ShardResponse, ShardError> {
228        let rx = self.dispatch(request).await?;
229        rx.await.map_err(|_| ShardError::Unavailable)
230    }
231
232    /// Sends a request and returns the reply channel without waiting
233    /// for the response. Used by `Engine::broadcast` to fan out to
234    /// all shards before collecting results.
235    pub(crate) async fn dispatch(
236        &self,
237        request: ShardRequest,
238    ) -> Result<oneshot::Receiver<ShardResponse>, ShardError> {
239        let (reply_tx, reply_rx) = oneshot::channel();
240        let msg = ShardMessage {
241            request,
242            reply: reply_tx,
243        };
244        self.tx
245            .send(msg)
246            .await
247            .map_err(|_| ShardError::Unavailable)?;
248        Ok(reply_rx)
249    }
250}
251
252/// Spawns a shard task and returns the handle for communicating with it.
253///
254/// `buffer` controls the mpsc channel capacity — higher values absorb
255/// burst traffic at the cost of memory.
256pub fn spawn_shard(
257    buffer: usize,
258    config: ShardConfig,
259    persistence: Option<ShardPersistenceConfig>,
260) -> ShardHandle {
261    let (tx, rx) = mpsc::channel(buffer);
262    tokio::spawn(run_shard(rx, config, persistence));
263    ShardHandle { tx }
264}
265
266/// The shard's main loop. Processes messages and runs periodic
267/// active expiration until the channel closes.
268async fn run_shard(
269    mut rx: mpsc::Receiver<ShardMessage>,
270    config: ShardConfig,
271    persistence: Option<ShardPersistenceConfig>,
272) {
273    let shard_id = config.shard_id;
274    let mut keyspace = Keyspace::with_config(config);
275
276    // -- recovery --
277    if let Some(ref pcfg) = persistence {
278        let result = recovery::recover_shard(&pcfg.data_dir, shard_id);
279        let count = result.entries.len();
280        for entry in result.entries {
281            let value = match entry.value {
282                RecoveredValue::String(data) => Value::String(data),
283                RecoveredValue::List(deque) => Value::List(deque),
284                RecoveredValue::SortedSet(members) => {
285                    let mut ss = crate::types::sorted_set::SortedSet::new();
286                    for (score, member) in members {
287                        ss.add(member, score);
288                    }
289                    Value::SortedSet(ss)
290                }
291            };
292            keyspace.restore(entry.key, value, entry.expires_at);
293        }
294        if count > 0 {
295            info!(
296                shard_id,
297                recovered_keys = count,
298                snapshot = result.loaded_snapshot,
299                aof = result.replayed_aof,
300                "recovered shard state"
301            );
302        }
303    }
304
305    // -- AOF writer --
306    let mut aof_writer: Option<AofWriter> = match &persistence {
307        Some(pcfg) if pcfg.append_only => {
308            let path = ember_persistence::aof::aof_path(&pcfg.data_dir, shard_id);
309            match AofWriter::open(path) {
310                Ok(w) => Some(w),
311                Err(e) => {
312                    warn!(shard_id, "failed to open AOF writer: {e}");
313                    None
314                }
315            }
316        }
317        _ => None,
318    };
319
320    let fsync_policy = persistence
321        .as_ref()
322        .map(|p| p.fsync_policy)
323        .unwrap_or(FsyncPolicy::No);
324
325    // -- tickers --
326    let mut expiry_tick = tokio::time::interval(EXPIRY_TICK);
327    expiry_tick.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
328
329    let mut fsync_tick = tokio::time::interval(FSYNC_INTERVAL);
330    fsync_tick.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
331
332    loop {
333        tokio::select! {
334            msg = rx.recv() => {
335                match msg {
336                    Some(msg) => {
337                        let request_kind = describe_request(&msg.request);
338                        let response = dispatch(&mut keyspace, &msg.request);
339
340                        // write AOF record for successful mutations
341                        if let Some(ref mut writer) = aof_writer {
342                            if let Some(record) = to_aof_record(&msg.request, &response) {
343                                if let Err(e) = writer.write_record(&record) {
344                                    warn!(shard_id, "aof write failed: {e}");
345                                }
346                                if fsync_policy == FsyncPolicy::Always {
347                                    if let Err(e) = writer.sync() {
348                                        warn!(shard_id, "aof sync failed: {e}");
349                                    }
350                                }
351                            }
352                        }
353
354                        // handle snapshot/rewrite (these need mutable access
355                        // to both keyspace and aof_writer)
356                        match request_kind {
357                            RequestKind::Snapshot => {
358                                let resp = handle_snapshot(
359                                    &keyspace, &persistence, shard_id,
360                                );
361                                let _ = msg.reply.send(resp);
362                                continue;
363                            }
364                            RequestKind::RewriteAof => {
365                                let resp = handle_rewrite(
366                                    &keyspace,
367                                    &persistence,
368                                    &mut aof_writer,
369                                    shard_id,
370                                );
371                                let _ = msg.reply.send(resp);
372                                continue;
373                            }
374                            RequestKind::Other => {}
375                        }
376
377                        let _ = msg.reply.send(response);
378                    }
379                    None => break, // channel closed, shard shutting down
380                }
381            }
382            _ = expiry_tick.tick() => {
383                expiry::run_expiration_cycle(&mut keyspace);
384            }
385            _ = fsync_tick.tick(), if fsync_policy == FsyncPolicy::EverySec => {
386                if let Some(ref mut writer) = aof_writer {
387                    if let Err(e) = writer.sync() {
388                        warn!(shard_id, "periodic aof sync failed: {e}");
389                    }
390                }
391            }
392        }
393    }
394
395    // flush AOF on clean shutdown
396    if let Some(ref mut writer) = aof_writer {
397        let _ = writer.sync();
398    }
399}
400
401/// Lightweight tag so we can identify snapshot/rewrite requests after
402/// dispatch without borrowing the request again.
403enum RequestKind {
404    Snapshot,
405    RewriteAof,
406    Other,
407}
408
409fn describe_request(req: &ShardRequest) -> RequestKind {
410    match req {
411        ShardRequest::Snapshot => RequestKind::Snapshot,
412        ShardRequest::RewriteAof => RequestKind::RewriteAof,
413        _ => RequestKind::Other,
414    }
415}
416
417/// Executes a single request against the keyspace.
418fn dispatch(ks: &mut Keyspace, req: &ShardRequest) -> ShardResponse {
419    match req {
420        ShardRequest::Get { key } => match ks.get(key) {
421            Ok(val) => ShardResponse::Value(val),
422            Err(_) => ShardResponse::WrongType,
423        },
424        ShardRequest::Set {
425            key,
426            value,
427            expire,
428            nx,
429            xx,
430        } => {
431            // NX: only set if key does NOT already exist
432            if *nx && ks.exists(key) {
433                return ShardResponse::Value(None);
434            }
435            // XX: only set if key DOES already exist
436            if *xx && !ks.exists(key) {
437                return ShardResponse::Value(None);
438            }
439            match ks.set(key.clone(), value.clone(), *expire) {
440                SetResult::Ok => ShardResponse::Ok,
441                SetResult::OutOfMemory => ShardResponse::OutOfMemory,
442            }
443        }
444        ShardRequest::Incr { key } => match ks.incr(key) {
445            Ok(val) => ShardResponse::Integer(val),
446            Err(IncrError::WrongType) => ShardResponse::WrongType,
447            Err(IncrError::OutOfMemory) => ShardResponse::OutOfMemory,
448            Err(e) => ShardResponse::Err(e.to_string()),
449        },
450        ShardRequest::Decr { key } => match ks.decr(key) {
451            Ok(val) => ShardResponse::Integer(val),
452            Err(IncrError::WrongType) => ShardResponse::WrongType,
453            Err(IncrError::OutOfMemory) => ShardResponse::OutOfMemory,
454            Err(e) => ShardResponse::Err(e.to_string()),
455        },
456        ShardRequest::Del { key } => ShardResponse::Bool(ks.del(key)),
457        ShardRequest::Exists { key } => ShardResponse::Bool(ks.exists(key)),
458        ShardRequest::Expire { key, seconds } => ShardResponse::Bool(ks.expire(key, *seconds)),
459        ShardRequest::Ttl { key } => ShardResponse::Ttl(ks.ttl(key)),
460        ShardRequest::Persist { key } => ShardResponse::Bool(ks.persist(key)),
461        ShardRequest::Pttl { key } => ShardResponse::Ttl(ks.pttl(key)),
462        ShardRequest::Pexpire { key, milliseconds } => {
463            ShardResponse::Bool(ks.pexpire(key, *milliseconds))
464        }
465        ShardRequest::LPush { key, values } => match ks.lpush(key, values) {
466            Ok(len) => ShardResponse::Len(len),
467            Err(WriteError::WrongType) => ShardResponse::WrongType,
468            Err(WriteError::OutOfMemory) => ShardResponse::OutOfMemory,
469        },
470        ShardRequest::RPush { key, values } => match ks.rpush(key, values) {
471            Ok(len) => ShardResponse::Len(len),
472            Err(WriteError::WrongType) => ShardResponse::WrongType,
473            Err(WriteError::OutOfMemory) => ShardResponse::OutOfMemory,
474        },
475        ShardRequest::LPop { key } => match ks.lpop(key) {
476            Ok(val) => ShardResponse::Value(val.map(Value::String)),
477            Err(_) => ShardResponse::WrongType,
478        },
479        ShardRequest::RPop { key } => match ks.rpop(key) {
480            Ok(val) => ShardResponse::Value(val.map(Value::String)),
481            Err(_) => ShardResponse::WrongType,
482        },
483        ShardRequest::LRange { key, start, stop } => match ks.lrange(key, *start, *stop) {
484            Ok(items) => ShardResponse::Array(items),
485            Err(_) => ShardResponse::WrongType,
486        },
487        ShardRequest::LLen { key } => match ks.llen(key) {
488            Ok(len) => ShardResponse::Len(len),
489            Err(_) => ShardResponse::WrongType,
490        },
491        ShardRequest::Type { key } => ShardResponse::TypeName(ks.value_type(key)),
492        ShardRequest::ZAdd {
493            key,
494            members,
495            nx,
496            xx,
497            gt,
498            lt,
499            ch,
500        } => {
501            let flags = ZAddFlags {
502                nx: *nx,
503                xx: *xx,
504                gt: *gt,
505                lt: *lt,
506                ch: *ch,
507            };
508            match ks.zadd(key, members, &flags) {
509                Ok(result) => ShardResponse::ZAddLen {
510                    count: result.count,
511                    applied: result.applied,
512                },
513                Err(WriteError::WrongType) => ShardResponse::WrongType,
514                Err(WriteError::OutOfMemory) => ShardResponse::OutOfMemory,
515            }
516        }
517        ShardRequest::ZRem { key, members } => match ks.zrem(key, members) {
518            Ok(removed) => ShardResponse::ZRemLen {
519                count: removed.len(),
520                removed,
521            },
522            Err(_) => ShardResponse::WrongType,
523        },
524        ShardRequest::ZScore { key, member } => match ks.zscore(key, member) {
525            Ok(score) => ShardResponse::Score(score),
526            Err(_) => ShardResponse::WrongType,
527        },
528        ShardRequest::ZRank { key, member } => match ks.zrank(key, member) {
529            Ok(rank) => ShardResponse::Rank(rank),
530            Err(_) => ShardResponse::WrongType,
531        },
532        ShardRequest::ZCard { key } => match ks.zcard(key) {
533            Ok(len) => ShardResponse::Len(len),
534            Err(_) => ShardResponse::WrongType,
535        },
536        ShardRequest::ZRange {
537            key, start, stop, ..
538        } => match ks.zrange(key, *start, *stop) {
539            Ok(items) => ShardResponse::ScoredArray(items),
540            Err(_) => ShardResponse::WrongType,
541        },
542        ShardRequest::DbSize => ShardResponse::KeyCount(ks.len()),
543        ShardRequest::Stats => ShardResponse::Stats(ks.stats()),
544        ShardRequest::FlushDb => {
545            ks.clear();
546            ShardResponse::Ok
547        }
548        ShardRequest::Scan {
549            cursor,
550            count,
551            pattern,
552        } => {
553            let (next_cursor, keys) = ks.scan_keys(*cursor, *count, pattern.as_deref());
554            ShardResponse::Scan {
555                cursor: next_cursor,
556                keys,
557            }
558        }
559        // snapshot/rewrite are handled in the main loop, not here
560        ShardRequest::Snapshot | ShardRequest::RewriteAof => ShardResponse::Ok,
561    }
562}
563
564/// Converts a successful mutation request+response pair into an AOF record.
565/// Returns None for non-mutation requests or failed mutations.
566fn to_aof_record(req: &ShardRequest, resp: &ShardResponse) -> Option<AofRecord> {
567    match (req, resp) {
568        (
569            ShardRequest::Set {
570                key, value, expire, ..
571            },
572            ShardResponse::Ok,
573        ) => {
574            let expire_ms = expire.map(|d| d.as_millis() as i64).unwrap_or(-1);
575            Some(AofRecord::Set {
576                key: key.clone(),
577                value: value.clone(),
578                expire_ms,
579            })
580        }
581        (ShardRequest::Del { key }, ShardResponse::Bool(true)) => {
582            Some(AofRecord::Del { key: key.clone() })
583        }
584        (ShardRequest::Expire { key, seconds }, ShardResponse::Bool(true)) => {
585            Some(AofRecord::Expire {
586                key: key.clone(),
587                seconds: *seconds,
588            })
589        }
590        (ShardRequest::LPush { key, values }, ShardResponse::Len(_)) => Some(AofRecord::LPush {
591            key: key.clone(),
592            values: values.clone(),
593        }),
594        (ShardRequest::RPush { key, values }, ShardResponse::Len(_)) => Some(AofRecord::RPush {
595            key: key.clone(),
596            values: values.clone(),
597        }),
598        (ShardRequest::LPop { key }, ShardResponse::Value(Some(_))) => {
599            Some(AofRecord::LPop { key: key.clone() })
600        }
601        (ShardRequest::RPop { key }, ShardResponse::Value(Some(_))) => {
602            Some(AofRecord::RPop { key: key.clone() })
603        }
604        (ShardRequest::ZAdd { key, .. }, ShardResponse::ZAddLen { applied, .. })
605            if !applied.is_empty() =>
606        {
607            Some(AofRecord::ZAdd {
608                key: key.clone(),
609                members: applied.clone(),
610            })
611        }
612        (ShardRequest::ZRem { key, .. }, ShardResponse::ZRemLen { removed, .. })
613            if !removed.is_empty() =>
614        {
615            Some(AofRecord::ZRem {
616                key: key.clone(),
617                members: removed.clone(),
618            })
619        }
620        (ShardRequest::Incr { key }, ShardResponse::Integer(_)) => {
621            Some(AofRecord::Incr { key: key.clone() })
622        }
623        (ShardRequest::Decr { key }, ShardResponse::Integer(_)) => {
624            Some(AofRecord::Decr { key: key.clone() })
625        }
626        (ShardRequest::Persist { key }, ShardResponse::Bool(true)) => {
627            Some(AofRecord::Persist { key: key.clone() })
628        }
629        (ShardRequest::Pexpire { key, milliseconds }, ShardResponse::Bool(true)) => {
630            Some(AofRecord::Pexpire {
631                key: key.clone(),
632                milliseconds: *milliseconds,
633            })
634        }
635        _ => None,
636    }
637}
638
639/// Writes a snapshot of the current keyspace.
640fn handle_snapshot(
641    keyspace: &Keyspace,
642    persistence: &Option<ShardPersistenceConfig>,
643    shard_id: u16,
644) -> ShardResponse {
645    let pcfg = match persistence {
646        Some(p) => p,
647        None => return ShardResponse::Err("persistence not configured".into()),
648    };
649
650    let path = snapshot::snapshot_path(&pcfg.data_dir, shard_id);
651    match write_snapshot(keyspace, &path, shard_id) {
652        Ok(count) => {
653            info!(shard_id, entries = count, "snapshot written");
654            ShardResponse::Ok
655        }
656        Err(e) => {
657            warn!(shard_id, "snapshot failed: {e}");
658            ShardResponse::Err(format!("snapshot failed: {e}"))
659        }
660    }
661}
662
663/// Writes a snapshot and then truncates the AOF.
664fn handle_rewrite(
665    keyspace: &Keyspace,
666    persistence: &Option<ShardPersistenceConfig>,
667    aof_writer: &mut Option<AofWriter>,
668    shard_id: u16,
669) -> ShardResponse {
670    let pcfg = match persistence {
671        Some(p) => p,
672        None => return ShardResponse::Err("persistence not configured".into()),
673    };
674
675    let path = snapshot::snapshot_path(&pcfg.data_dir, shard_id);
676    match write_snapshot(keyspace, &path, shard_id) {
677        Ok(count) => {
678            // truncate AOF after successful snapshot
679            if let Some(ref mut writer) = aof_writer {
680                if let Err(e) = writer.truncate() {
681                    warn!(shard_id, "aof truncate after rewrite failed: {e}");
682                }
683            }
684            info!(shard_id, entries = count, "aof rewrite complete");
685            ShardResponse::Ok
686        }
687        Err(e) => {
688            warn!(shard_id, "aof rewrite failed: {e}");
689            ShardResponse::Err(format!("rewrite failed: {e}"))
690        }
691    }
692}
693
694/// Iterates the keyspace and writes all live entries to a snapshot file.
695fn write_snapshot(
696    keyspace: &Keyspace,
697    path: &std::path::Path,
698    shard_id: u16,
699) -> Result<u32, ember_persistence::format::FormatError> {
700    let mut writer = SnapshotWriter::create(path, shard_id)?;
701    let mut count = 0u32;
702
703    for (key, value, ttl_ms) in keyspace.iter_entries() {
704        let snap_value = match value {
705            Value::String(data) => SnapValue::String(data.clone()),
706            Value::List(deque) => SnapValue::List(deque.clone()),
707            Value::SortedSet(ss) => {
708                let members: Vec<(f64, String)> = ss
709                    .iter()
710                    .map(|(member, score)| (score, member.to_owned()))
711                    .collect();
712                SnapValue::SortedSet(members)
713            }
714        };
715        writer.write_entry(&SnapEntry {
716            key: key.to_owned(),
717            value: snap_value,
718            expire_ms: ttl_ms,
719        })?;
720        count += 1;
721    }
722
723    writer.finish()?;
724    Ok(count)
725}
726
727#[cfg(test)]
728mod tests {
729    use super::*;
730
731    #[test]
732    fn dispatch_set_and_get() {
733        let mut ks = Keyspace::new();
734
735        let resp = dispatch(
736            &mut ks,
737            &ShardRequest::Set {
738                key: "k".into(),
739                value: Bytes::from("v"),
740                expire: None,
741                nx: false,
742                xx: false,
743            },
744        );
745        assert!(matches!(resp, ShardResponse::Ok));
746
747        let resp = dispatch(&mut ks, &ShardRequest::Get { key: "k".into() });
748        match resp {
749            ShardResponse::Value(Some(Value::String(data))) => {
750                assert_eq!(data, Bytes::from("v"));
751            }
752            other => panic!("expected Value(Some(String)), got {other:?}"),
753        }
754    }
755
756    #[test]
757    fn dispatch_get_missing() {
758        let mut ks = Keyspace::new();
759        let resp = dispatch(&mut ks, &ShardRequest::Get { key: "nope".into() });
760        assert!(matches!(resp, ShardResponse::Value(None)));
761    }
762
763    #[test]
764    fn dispatch_del() {
765        let mut ks = Keyspace::new();
766        ks.set("key".into(), Bytes::from("val"), None);
767
768        let resp = dispatch(&mut ks, &ShardRequest::Del { key: "key".into() });
769        assert!(matches!(resp, ShardResponse::Bool(true)));
770
771        let resp = dispatch(&mut ks, &ShardRequest::Del { key: "key".into() });
772        assert!(matches!(resp, ShardResponse::Bool(false)));
773    }
774
775    #[test]
776    fn dispatch_exists() {
777        let mut ks = Keyspace::new();
778        ks.set("yes".into(), Bytes::from("here"), None);
779
780        let resp = dispatch(&mut ks, &ShardRequest::Exists { key: "yes".into() });
781        assert!(matches!(resp, ShardResponse::Bool(true)));
782
783        let resp = dispatch(&mut ks, &ShardRequest::Exists { key: "no".into() });
784        assert!(matches!(resp, ShardResponse::Bool(false)));
785    }
786
787    #[test]
788    fn dispatch_expire_and_ttl() {
789        let mut ks = Keyspace::new();
790        ks.set("key".into(), Bytes::from("val"), None);
791
792        let resp = dispatch(
793            &mut ks,
794            &ShardRequest::Expire {
795                key: "key".into(),
796                seconds: 60,
797            },
798        );
799        assert!(matches!(resp, ShardResponse::Bool(true)));
800
801        let resp = dispatch(&mut ks, &ShardRequest::Ttl { key: "key".into() });
802        match resp {
803            ShardResponse::Ttl(TtlResult::Seconds(s)) => assert!(s >= 58 && s <= 60),
804            other => panic!("expected Ttl(Seconds), got {other:?}"),
805        }
806    }
807
808    #[test]
809    fn dispatch_ttl_missing() {
810        let mut ks = Keyspace::new();
811        let resp = dispatch(&mut ks, &ShardRequest::Ttl { key: "gone".into() });
812        assert!(matches!(resp, ShardResponse::Ttl(TtlResult::NotFound)));
813    }
814
815    #[tokio::test]
816    async fn shard_round_trip() {
817        let handle = spawn_shard(16, ShardConfig::default(), None);
818
819        let resp = handle
820            .send(ShardRequest::Set {
821                key: "hello".into(),
822                value: Bytes::from("world"),
823                expire: None,
824                nx: false,
825                xx: false,
826            })
827            .await
828            .unwrap();
829        assert!(matches!(resp, ShardResponse::Ok));
830
831        let resp = handle
832            .send(ShardRequest::Get {
833                key: "hello".into(),
834            })
835            .await
836            .unwrap();
837        match resp {
838            ShardResponse::Value(Some(Value::String(data))) => {
839                assert_eq!(data, Bytes::from("world"));
840            }
841            other => panic!("expected Value(Some(String)), got {other:?}"),
842        }
843    }
844
845    #[tokio::test]
846    async fn expired_key_through_shard() {
847        let handle = spawn_shard(16, ShardConfig::default(), None);
848
849        handle
850            .send(ShardRequest::Set {
851                key: "temp".into(),
852                value: Bytes::from("gone"),
853                expire: Some(Duration::from_millis(10)),
854                nx: false,
855                xx: false,
856            })
857            .await
858            .unwrap();
859
860        tokio::time::sleep(Duration::from_millis(30)).await;
861
862        let resp = handle
863            .send(ShardRequest::Get { key: "temp".into() })
864            .await
865            .unwrap();
866        assert!(matches!(resp, ShardResponse::Value(None)));
867    }
868
869    #[tokio::test]
870    async fn active_expiration_cleans_up_without_access() {
871        let handle = spawn_shard(16, ShardConfig::default(), None);
872
873        // set a key with a short TTL
874        handle
875            .send(ShardRequest::Set {
876                key: "ephemeral".into(),
877                value: Bytes::from("temp"),
878                expire: Some(Duration::from_millis(10)),
879                nx: false,
880                xx: false,
881            })
882            .await
883            .unwrap();
884
885        // also set a persistent key
886        handle
887            .send(ShardRequest::Set {
888                key: "persistent".into(),
889                value: Bytes::from("stays"),
890                expire: None,
891                nx: false,
892                xx: false,
893            })
894            .await
895            .unwrap();
896
897        // wait long enough for the TTL to expire AND for the background
898        // tick to fire (100ms interval + some slack)
899        tokio::time::sleep(Duration::from_millis(250)).await;
900
901        // the ephemeral key should be gone even though we never accessed it
902        let resp = handle
903            .send(ShardRequest::Exists {
904                key: "ephemeral".into(),
905            })
906            .await
907            .unwrap();
908        assert!(matches!(resp, ShardResponse::Bool(false)));
909
910        // the persistent key should still be there
911        let resp = handle
912            .send(ShardRequest::Exists {
913                key: "persistent".into(),
914            })
915            .await
916            .unwrap();
917        assert!(matches!(resp, ShardResponse::Bool(true)));
918    }
919
920    #[tokio::test]
921    async fn shard_with_persistence_snapshot_and_recovery() {
922        let dir = tempfile::tempdir().unwrap();
923        let pcfg = ShardPersistenceConfig {
924            data_dir: dir.path().to_owned(),
925            append_only: true,
926            fsync_policy: FsyncPolicy::Always,
927        };
928        let config = ShardConfig {
929            shard_id: 0,
930            ..ShardConfig::default()
931        };
932
933        // write some keys then trigger a snapshot
934        {
935            let handle = spawn_shard(16, config.clone(), Some(pcfg.clone()));
936            handle
937                .send(ShardRequest::Set {
938                    key: "a".into(),
939                    value: Bytes::from("1"),
940                    expire: None,
941                    nx: false,
942                    xx: false,
943                })
944                .await
945                .unwrap();
946            handle
947                .send(ShardRequest::Set {
948                    key: "b".into(),
949                    value: Bytes::from("2"),
950                    expire: Some(Duration::from_secs(300)),
951                    nx: false,
952                    xx: false,
953                })
954                .await
955                .unwrap();
956            handle.send(ShardRequest::Snapshot).await.unwrap();
957            // write one more key that goes only to AOF
958            handle
959                .send(ShardRequest::Set {
960                    key: "c".into(),
961                    value: Bytes::from("3"),
962                    expire: None,
963                    nx: false,
964                    xx: false,
965                })
966                .await
967                .unwrap();
968            // drop handle to shut down shard
969        }
970
971        // give it a moment to flush
972        tokio::time::sleep(Duration::from_millis(50)).await;
973
974        // start a new shard with the same config — should recover
975        {
976            let handle = spawn_shard(16, config, Some(pcfg));
977            // give it a moment to recover
978            tokio::time::sleep(Duration::from_millis(50)).await;
979
980            let resp = handle
981                .send(ShardRequest::Get { key: "a".into() })
982                .await
983                .unwrap();
984            match resp {
985                ShardResponse::Value(Some(Value::String(data))) => {
986                    assert_eq!(data, Bytes::from("1"));
987                }
988                other => panic!("expected a=1, got {other:?}"),
989            }
990
991            let resp = handle
992                .send(ShardRequest::Get { key: "b".into() })
993                .await
994                .unwrap();
995            assert!(matches!(resp, ShardResponse::Value(Some(_))));
996
997            let resp = handle
998                .send(ShardRequest::Get { key: "c".into() })
999                .await
1000                .unwrap();
1001            match resp {
1002                ShardResponse::Value(Some(Value::String(data))) => {
1003                    assert_eq!(data, Bytes::from("3"));
1004                }
1005                other => panic!("expected c=3, got {other:?}"),
1006            }
1007        }
1008    }
1009
1010    #[test]
1011    fn to_aof_record_for_set() {
1012        let req = ShardRequest::Set {
1013            key: "k".into(),
1014            value: Bytes::from("v"),
1015            expire: Some(Duration::from_secs(60)),
1016            nx: false,
1017            xx: false,
1018        };
1019        let resp = ShardResponse::Ok;
1020        let record = to_aof_record(&req, &resp).unwrap();
1021        match record {
1022            AofRecord::Set { key, expire_ms, .. } => {
1023                assert_eq!(key, "k");
1024                assert_eq!(expire_ms, 60_000);
1025            }
1026            other => panic!("expected Set, got {other:?}"),
1027        }
1028    }
1029
1030    #[test]
1031    fn to_aof_record_skips_failed_set() {
1032        let req = ShardRequest::Set {
1033            key: "k".into(),
1034            value: Bytes::from("v"),
1035            expire: None,
1036            nx: false,
1037            xx: false,
1038        };
1039        let resp = ShardResponse::OutOfMemory;
1040        assert!(to_aof_record(&req, &resp).is_none());
1041    }
1042
1043    #[test]
1044    fn to_aof_record_for_del() {
1045        let req = ShardRequest::Del { key: "k".into() };
1046        let resp = ShardResponse::Bool(true);
1047        let record = to_aof_record(&req, &resp).unwrap();
1048        assert!(matches!(record, AofRecord::Del { .. }));
1049    }
1050
1051    #[test]
1052    fn to_aof_record_skips_failed_del() {
1053        let req = ShardRequest::Del { key: "k".into() };
1054        let resp = ShardResponse::Bool(false);
1055        assert!(to_aof_record(&req, &resp).is_none());
1056    }
1057
1058    #[test]
1059    fn dispatch_incr_new_key() {
1060        let mut ks = Keyspace::new();
1061        let resp = dispatch(&mut ks, &ShardRequest::Incr { key: "c".into() });
1062        assert!(matches!(resp, ShardResponse::Integer(1)));
1063    }
1064
1065    #[test]
1066    fn dispatch_decr_existing() {
1067        let mut ks = Keyspace::new();
1068        ks.set("n".into(), Bytes::from("10"), None);
1069        let resp = dispatch(&mut ks, &ShardRequest::Decr { key: "n".into() });
1070        assert!(matches!(resp, ShardResponse::Integer(9)));
1071    }
1072
1073    #[test]
1074    fn dispatch_incr_non_integer() {
1075        let mut ks = Keyspace::new();
1076        ks.set("s".into(), Bytes::from("hello"), None);
1077        let resp = dispatch(&mut ks, &ShardRequest::Incr { key: "s".into() });
1078        assert!(matches!(resp, ShardResponse::Err(_)));
1079    }
1080
1081    #[test]
1082    fn to_aof_record_for_incr() {
1083        let req = ShardRequest::Incr { key: "c".into() };
1084        let resp = ShardResponse::Integer(1);
1085        let record = to_aof_record(&req, &resp).unwrap();
1086        assert!(matches!(record, AofRecord::Incr { .. }));
1087    }
1088
1089    #[test]
1090    fn to_aof_record_for_decr() {
1091        let req = ShardRequest::Decr { key: "c".into() };
1092        let resp = ShardResponse::Integer(-1);
1093        let record = to_aof_record(&req, &resp).unwrap();
1094        assert!(matches!(record, AofRecord::Decr { .. }));
1095    }
1096
1097    #[test]
1098    fn dispatch_persist_removes_ttl() {
1099        let mut ks = Keyspace::new();
1100        ks.set(
1101            "key".into(),
1102            Bytes::from("val"),
1103            Some(Duration::from_secs(60)),
1104        );
1105
1106        let resp = dispatch(&mut ks, &ShardRequest::Persist { key: "key".into() });
1107        assert!(matches!(resp, ShardResponse::Bool(true)));
1108
1109        let resp = dispatch(&mut ks, &ShardRequest::Ttl { key: "key".into() });
1110        assert!(matches!(resp, ShardResponse::Ttl(TtlResult::NoExpiry)));
1111    }
1112
1113    #[test]
1114    fn dispatch_persist_missing_key() {
1115        let mut ks = Keyspace::new();
1116        let resp = dispatch(&mut ks, &ShardRequest::Persist { key: "nope".into() });
1117        assert!(matches!(resp, ShardResponse::Bool(false)));
1118    }
1119
1120    #[test]
1121    fn dispatch_pttl() {
1122        let mut ks = Keyspace::new();
1123        ks.set(
1124            "key".into(),
1125            Bytes::from("val"),
1126            Some(Duration::from_secs(60)),
1127        );
1128
1129        let resp = dispatch(&mut ks, &ShardRequest::Pttl { key: "key".into() });
1130        match resp {
1131            ShardResponse::Ttl(TtlResult::Milliseconds(ms)) => {
1132                assert!(ms > 59_000 && ms <= 60_000);
1133            }
1134            other => panic!("expected Ttl(Milliseconds), got {other:?}"),
1135        }
1136    }
1137
1138    #[test]
1139    fn dispatch_pttl_missing() {
1140        let mut ks = Keyspace::new();
1141        let resp = dispatch(&mut ks, &ShardRequest::Pttl { key: "nope".into() });
1142        assert!(matches!(resp, ShardResponse::Ttl(TtlResult::NotFound)));
1143    }
1144
1145    #[test]
1146    fn dispatch_pexpire() {
1147        let mut ks = Keyspace::new();
1148        ks.set("key".into(), Bytes::from("val"), None);
1149
1150        let resp = dispatch(
1151            &mut ks,
1152            &ShardRequest::Pexpire {
1153                key: "key".into(),
1154                milliseconds: 5000,
1155            },
1156        );
1157        assert!(matches!(resp, ShardResponse::Bool(true)));
1158
1159        let resp = dispatch(&mut ks, &ShardRequest::Pttl { key: "key".into() });
1160        match resp {
1161            ShardResponse::Ttl(TtlResult::Milliseconds(ms)) => {
1162                assert!(ms > 4000 && ms <= 5000);
1163            }
1164            other => panic!("expected Ttl(Milliseconds), got {other:?}"),
1165        }
1166    }
1167
1168    #[test]
1169    fn to_aof_record_for_persist() {
1170        let req = ShardRequest::Persist { key: "k".into() };
1171        let resp = ShardResponse::Bool(true);
1172        let record = to_aof_record(&req, &resp).unwrap();
1173        assert!(matches!(record, AofRecord::Persist { .. }));
1174    }
1175
1176    #[test]
1177    fn to_aof_record_skips_failed_persist() {
1178        let req = ShardRequest::Persist { key: "k".into() };
1179        let resp = ShardResponse::Bool(false);
1180        assert!(to_aof_record(&req, &resp).is_none());
1181    }
1182
1183    #[test]
1184    fn to_aof_record_for_pexpire() {
1185        let req = ShardRequest::Pexpire {
1186            key: "k".into(),
1187            milliseconds: 5000,
1188        };
1189        let resp = ShardResponse::Bool(true);
1190        let record = to_aof_record(&req, &resp).unwrap();
1191        match record {
1192            AofRecord::Pexpire { key, milliseconds } => {
1193                assert_eq!(key, "k");
1194                assert_eq!(milliseconds, 5000);
1195            }
1196            other => panic!("expected Pexpire, got {other:?}"),
1197        }
1198    }
1199
1200    #[test]
1201    fn to_aof_record_skips_failed_pexpire() {
1202        let req = ShardRequest::Pexpire {
1203            key: "k".into(),
1204            milliseconds: 5000,
1205        };
1206        let resp = ShardResponse::Bool(false);
1207        assert!(to_aof_record(&req, &resp).is_none());
1208    }
1209
1210    #[test]
1211    fn dispatch_set_nx_when_key_missing() {
1212        let mut ks = Keyspace::new();
1213        let resp = dispatch(
1214            &mut ks,
1215            &ShardRequest::Set {
1216                key: "k".into(),
1217                value: Bytes::from("v"),
1218                expire: None,
1219                nx: true,
1220                xx: false,
1221            },
1222        );
1223        assert!(matches!(resp, ShardResponse::Ok));
1224        assert!(ks.exists("k"));
1225    }
1226
1227    #[test]
1228    fn dispatch_set_nx_when_key_exists() {
1229        let mut ks = Keyspace::new();
1230        ks.set("k".into(), Bytes::from("old"), None);
1231
1232        let resp = dispatch(
1233            &mut ks,
1234            &ShardRequest::Set {
1235                key: "k".into(),
1236                value: Bytes::from("new"),
1237                expire: None,
1238                nx: true,
1239                xx: false,
1240            },
1241        );
1242        // NX should block — returns nil
1243        assert!(matches!(resp, ShardResponse::Value(None)));
1244        // original value should remain
1245        match ks.get("k").unwrap() {
1246            Some(Value::String(data)) => assert_eq!(data, Bytes::from("old")),
1247            other => panic!("expected old value, got {other:?}"),
1248        }
1249    }
1250
1251    #[test]
1252    fn dispatch_set_xx_when_key_exists() {
1253        let mut ks = Keyspace::new();
1254        ks.set("k".into(), Bytes::from("old"), None);
1255
1256        let resp = dispatch(
1257            &mut ks,
1258            &ShardRequest::Set {
1259                key: "k".into(),
1260                value: Bytes::from("new"),
1261                expire: None,
1262                nx: false,
1263                xx: true,
1264            },
1265        );
1266        assert!(matches!(resp, ShardResponse::Ok));
1267        match ks.get("k").unwrap() {
1268            Some(Value::String(data)) => assert_eq!(data, Bytes::from("new")),
1269            other => panic!("expected new value, got {other:?}"),
1270        }
1271    }
1272
1273    #[test]
1274    fn dispatch_set_xx_when_key_missing() {
1275        let mut ks = Keyspace::new();
1276        let resp = dispatch(
1277            &mut ks,
1278            &ShardRequest::Set {
1279                key: "k".into(),
1280                value: Bytes::from("v"),
1281                expire: None,
1282                nx: false,
1283                xx: true,
1284            },
1285        );
1286        // XX should block — returns nil
1287        assert!(matches!(resp, ShardResponse::Value(None)));
1288        assert!(!ks.exists("k"));
1289    }
1290
1291    #[test]
1292    fn to_aof_record_skips_nx_blocked_set() {
1293        let req = ShardRequest::Set {
1294            key: "k".into(),
1295            value: Bytes::from("v"),
1296            expire: None,
1297            nx: true,
1298            xx: false,
1299        };
1300        // when NX blocks, the shard returns Value(None), not Ok
1301        let resp = ShardResponse::Value(None);
1302        assert!(to_aof_record(&req, &resp).is_none());
1303    }
1304
1305    #[test]
1306    fn dispatch_flushdb_clears_all_keys() {
1307        let mut ks = Keyspace::new();
1308        ks.set("a".into(), Bytes::from("1"), None);
1309        ks.set("b".into(), Bytes::from("2"), None);
1310
1311        assert_eq!(ks.len(), 2);
1312
1313        let resp = dispatch(&mut ks, &ShardRequest::FlushDb);
1314        assert!(matches!(resp, ShardResponse::Ok));
1315        assert_eq!(ks.len(), 0);
1316    }
1317
1318    #[test]
1319    fn dispatch_scan_returns_keys() {
1320        let mut ks = Keyspace::new();
1321        ks.set("user:1".into(), Bytes::from("a"), None);
1322        ks.set("user:2".into(), Bytes::from("b"), None);
1323        ks.set("item:1".into(), Bytes::from("c"), None);
1324
1325        let resp = dispatch(
1326            &mut ks,
1327            &ShardRequest::Scan {
1328                cursor: 0,
1329                count: 10,
1330                pattern: None,
1331            },
1332        );
1333
1334        match resp {
1335            ShardResponse::Scan { cursor, keys } => {
1336                assert_eq!(cursor, 0); // complete in one pass
1337                assert_eq!(keys.len(), 3);
1338            }
1339            _ => panic!("expected Scan response"),
1340        }
1341    }
1342
1343    #[test]
1344    fn dispatch_scan_with_pattern() {
1345        let mut ks = Keyspace::new();
1346        ks.set("user:1".into(), Bytes::from("a"), None);
1347        ks.set("user:2".into(), Bytes::from("b"), None);
1348        ks.set("item:1".into(), Bytes::from("c"), None);
1349
1350        let resp = dispatch(
1351            &mut ks,
1352            &ShardRequest::Scan {
1353                cursor: 0,
1354                count: 10,
1355                pattern: Some("user:*".into()),
1356            },
1357        );
1358
1359        match resp {
1360            ShardResponse::Scan { cursor, keys } => {
1361                assert_eq!(cursor, 0);
1362                assert_eq!(keys.len(), 2);
1363                for k in &keys {
1364                    assert!(k.starts_with("user:"));
1365                }
1366            }
1367            _ => panic!("expected Scan response"),
1368        }
1369    }
1370}