Skip to main content

ember_core/
engine.rs

1//! The engine: coordinator for the sharded keyspace.
2//!
3//! Routes single-key operations to the correct shard based on a hash
4//! of the key. Each shard is an independent tokio task — no locks on
5//! the hot path.
6
7use tokio::sync::broadcast;
8
9use crate::dropper::DropHandle;
10use crate::error::ShardError;
11use crate::keyspace::ShardConfig;
12use crate::shard::{
13    self, PreparedShard, ReplicationEvent, ShardHandle, ShardPersistenceConfig, ShardRequest,
14    ShardResponse,
15};
16
17/// Default channel buffer size per shard.
18///
19/// With batch dispatch, a single pipeline of 16 commands targeting 8 shards
20/// consumes ~8 channel slots (one batch per shard) instead of 16. 4096 gives
21/// generous headroom even under pathological key distribution, at ~400KB per
22/// shard on 64-bit.
23const DEFAULT_SHARD_BUFFER: usize = 4096;
24
25/// Configuration for the engine, passed down to each shard.
26#[derive(Debug, Clone, Default)]
27pub struct EngineConfig {
28    /// Per-shard configuration (memory limits, eviction policy).
29    pub shard: ShardConfig,
30    /// Optional persistence configuration. When set, each shard gets
31    /// its own AOF and snapshot files under this directory.
32    pub persistence: Option<ShardPersistenceConfig>,
33    /// Optional broadcast sender for replication events.
34    ///
35    /// When set, every successful mutation is published as a
36    /// [`ReplicationEvent`] so replication clients can stream it to
37    /// replicas.
38    pub replication_tx: Option<broadcast::Sender<ReplicationEvent>>,
39    /// Optional channel to receive expired key names. Used for keyspace notifications.
40    pub expired_tx: Option<broadcast::Sender<String>>,
41    /// Optional schema registry for protobuf value validation.
42    /// When set, enables PROTO.* commands.
43    #[cfg(feature = "protobuf")]
44    pub schema_registry: Option<crate::schema::SharedSchemaRegistry>,
45    /// Channel buffer size per shard. 0 means use the default (256).
46    pub shard_channel_buffer: usize,
47}
48
49/// The sharded engine. Owns handles to all shard tasks and routes
50/// requests by key hash.
51///
52/// `Clone` is cheap — it just clones the `Vec<ShardHandle>` (which are
53/// mpsc senders under the hood).
54#[derive(Debug, Clone)]
55pub struct Engine {
56    shards: Vec<ShardHandle>,
57    replication_tx: Option<broadcast::Sender<ReplicationEvent>>,
58    expired_tx: Option<broadcast::Sender<String>>,
59    #[cfg(feature = "protobuf")]
60    schema_registry: Option<crate::schema::SharedSchemaRegistry>,
61}
62
63impl Engine {
64    /// Creates an engine with `shard_count` shards using default config.
65    ///
66    /// Each shard is spawned as a tokio task immediately.
67    /// Panics if `shard_count` is zero.
68    pub fn new(shard_count: usize) -> Self {
69        Self::with_config(shard_count, EngineConfig::default())
70    }
71
72    /// Creates an engine with `shard_count` shards and the given config.
73    ///
74    /// Spawns a single background drop thread shared by all shards for
75    /// lazy-freeing large values.
76    ///
77    /// Panics if `shard_count` is zero.
78    pub fn with_config(shard_count: usize, config: EngineConfig) -> Self {
79        assert!(shard_count > 0, "shard count must be at least 1");
80        assert!(
81            shard_count <= u16::MAX as usize,
82            "shard count must fit in u16"
83        );
84
85        let drop_handle = DropHandle::spawn();
86        let buffer = if config.shard_channel_buffer == 0 {
87            DEFAULT_SHARD_BUFFER
88        } else {
89            config.shard_channel_buffer
90        };
91
92        let shards = (0..shard_count)
93            .map(|i| {
94                let mut shard_config = config.shard.clone();
95                shard_config.shard_id = i as u16;
96                shard::spawn_shard(
97                    buffer,
98                    shard_config,
99                    config.persistence.clone(),
100                    Some(drop_handle.clone()),
101                    config.replication_tx.clone(),
102                    config.expired_tx.clone(),
103                    #[cfg(feature = "protobuf")]
104                    config.schema_registry.clone(),
105                )
106            })
107            .collect();
108
109        Self {
110            shards,
111            replication_tx: config.replication_tx,
112            expired_tx: config.expired_tx,
113            #[cfg(feature = "protobuf")]
114            schema_registry: config.schema_registry,
115        }
116    }
117
118    /// Creates the engine and prepared shards without spawning any tasks.
119    ///
120    /// The caller is responsible for running each [`PreparedShard`] on the
121    /// desired runtime via [`shard::run_prepared`]. This is the entry
122    /// point for thread-per-core deployment where each OS thread runs its
123    /// own single-threaded tokio runtime and one shard.
124    ///
125    /// Panics if `shard_count` is zero.
126    pub fn prepare(shard_count: usize, config: EngineConfig) -> (Self, Vec<PreparedShard>) {
127        assert!(shard_count > 0, "shard count must be at least 1");
128        assert!(
129            shard_count <= u16::MAX as usize,
130            "shard count must fit in u16"
131        );
132
133        let drop_handle = DropHandle::spawn();
134        let buffer = if config.shard_channel_buffer == 0 {
135            DEFAULT_SHARD_BUFFER
136        } else {
137            config.shard_channel_buffer
138        };
139
140        let mut handles = Vec::with_capacity(shard_count);
141        let mut prepared = Vec::with_capacity(shard_count);
142
143        for i in 0..shard_count {
144            let mut shard_config = config.shard.clone();
145            shard_config.shard_id = i as u16;
146            let (handle, shard) = shard::prepare_shard(
147                buffer,
148                shard_config,
149                config.persistence.clone(),
150                Some(drop_handle.clone()),
151                config.replication_tx.clone(),
152                config.expired_tx.clone(),
153                #[cfg(feature = "protobuf")]
154                config.schema_registry.clone(),
155            );
156            handles.push(handle);
157            prepared.push(shard);
158        }
159
160        let engine = Self {
161            shards: handles,
162            replication_tx: config.replication_tx,
163            expired_tx: config.expired_tx,
164            #[cfg(feature = "protobuf")]
165            schema_registry: config.schema_registry,
166        };
167
168        (engine, prepared)
169    }
170
171    /// Creates an engine with one shard per available CPU core.
172    ///
173    /// Falls back to a single shard if the core count can't be determined.
174    pub fn with_available_cores() -> Self {
175        Self::with_available_cores_config(EngineConfig::default())
176    }
177
178    /// Creates an engine with one shard per available CPU core and the
179    /// given config.
180    pub fn with_available_cores_config(config: EngineConfig) -> Self {
181        let cores = std::thread::available_parallelism()
182            .map(|n| n.get())
183            .unwrap_or(1);
184        Self::with_config(cores, config)
185    }
186
187    /// Returns a reference to the schema registry, if protobuf is enabled.
188    #[cfg(feature = "protobuf")]
189    pub fn schema_registry(&self) -> Option<&crate::schema::SharedSchemaRegistry> {
190        self.schema_registry.as_ref()
191    }
192
193    /// Returns the number of shards.
194    pub fn shard_count(&self) -> usize {
195        self.shards.len()
196    }
197
198    /// Creates a new broadcast receiver for replication events.
199    ///
200    /// Returns `None` if no replication channel was configured. Each
201    /// caller gets an independent receiver starting from the current
202    /// broadcast position — not from the beginning of the stream.
203    pub fn subscribe_replication(&self) -> Option<broadcast::Receiver<ReplicationEvent>> {
204        self.replication_tx.as_ref().map(|tx| tx.subscribe())
205    }
206
207    /// Creates a new broadcast receiver for expired key names.
208    ///
209    /// Returns `None` if no expired-key channel was configured. Used by the
210    /// server to subscribe a background task that fires keyspace notifications.
211    pub fn subscribe_expired(&self) -> Option<broadcast::Receiver<String>> {
212        self.expired_tx.as_ref().map(|tx| tx.subscribe())
213    }
214
215    /// Sends a request to a specific shard by index.
216    ///
217    /// Used by SCAN to iterate through shards sequentially.
218    pub async fn send_to_shard(
219        &self,
220        shard_idx: usize,
221        request: ShardRequest,
222    ) -> Result<ShardResponse, ShardError> {
223        if shard_idx >= self.shards.len() {
224            return Err(ShardError::Unavailable);
225        }
226        self.shards[shard_idx].send(request).await
227    }
228
229    /// Routes a request to the shard that owns `key`.
230    pub async fn route(
231        &self,
232        key: &str,
233        request: ShardRequest,
234    ) -> Result<ShardResponse, ShardError> {
235        let idx = self.shard_for_key(key);
236        self.shards[idx].send(request).await
237    }
238
239    /// Sends a request to every shard and collects all responses.
240    ///
241    /// Dispatches to all shards first (so they start processing in
242    /// parallel), then collects the replies. Used for commands like
243    /// DBSIZE and INFO that need data from all shards.
244    pub async fn broadcast<F>(&self, make_req: F) -> Result<Vec<ShardResponse>, ShardError>
245    where
246        F: Fn() -> ShardRequest,
247    {
248        // dispatch to all shards without waiting for responses
249        let mut receivers = Vec::with_capacity(self.shards.len());
250        for shard in &self.shards {
251            receivers.push(shard.dispatch(make_req()).await?);
252        }
253
254        // now collect all responses
255        let mut results = Vec::with_capacity(receivers.len());
256        for rx in receivers {
257            results.push(rx.await.map_err(|_| ShardError::Unavailable)?);
258        }
259        Ok(results)
260    }
261
262    /// Routes requests for multiple keys concurrently.
263    ///
264    /// Dispatches all requests without waiting, then collects responses.
265    /// The response order matches the key order. Used for multi-key
266    /// commands like DEL and EXISTS.
267    pub async fn route_multi<F>(
268        &self,
269        keys: &[String],
270        make_req: F,
271    ) -> Result<Vec<ShardResponse>, ShardError>
272    where
273        F: Fn(String) -> ShardRequest,
274    {
275        let mut receivers = Vec::with_capacity(keys.len());
276        for key in keys {
277            let idx = self.shard_for_key(key);
278            let rx = self.shards[idx].dispatch(make_req(key.clone())).await?;
279            receivers.push(rx);
280        }
281
282        let mut results = Vec::with_capacity(receivers.len());
283        for rx in receivers {
284            results.push(rx.await.map_err(|_| ShardError::Unavailable)?);
285        }
286        Ok(results)
287    }
288
289    /// Returns true if both keys are owned by the same shard.
290    pub fn same_shard(&self, key1: &str, key2: &str) -> bool {
291        self.shard_for_key(key1) == self.shard_for_key(key2)
292    }
293
294    /// Determines which shard owns a given key.
295    pub fn shard_for_key(&self, key: &str) -> usize {
296        shard_index(key, self.shards.len())
297    }
298
299    /// Sends a request to a shard and returns the reply channel without
300    /// waiting for the response. Used by the connection handler to
301    /// dispatch commands and collect responses separately.
302    pub async fn dispatch_to_shard(
303        &self,
304        shard_idx: usize,
305        request: ShardRequest,
306    ) -> Result<tokio::sync::oneshot::Receiver<ShardResponse>, ShardError> {
307        if shard_idx >= self.shards.len() {
308            return Err(ShardError::Unavailable);
309        }
310        self.shards[shard_idx].dispatch(request).await
311    }
312
313    /// Sends a request to a shard using a caller-owned mpsc reply channel.
314    ///
315    /// Avoids the per-command oneshot allocation on the P=1 path.
316    pub async fn dispatch_reusable_to_shard(
317        &self,
318        shard_idx: usize,
319        request: ShardRequest,
320        reply: tokio::sync::mpsc::Sender<ShardResponse>,
321    ) -> Result<(), ShardError> {
322        if shard_idx >= self.shards.len() {
323            return Err(ShardError::Unavailable);
324        }
325        self.shards[shard_idx]
326            .dispatch_reusable(request, reply)
327            .await
328    }
329
330    /// Sends a batch of requests to a single shard as one channel message.
331    ///
332    /// Returns one receiver per request, preserving order. This is the
333    /// pipeline batching optimization: N commands targeting the same shard
334    /// consume 1 channel slot instead of N, eliminating head-of-line
335    /// blocking under high pipeline depths.
336    pub async fn dispatch_batch_to_shard(
337        &self,
338        shard_idx: usize,
339        requests: Vec<ShardRequest>,
340    ) -> Result<Vec<tokio::sync::oneshot::Receiver<ShardResponse>>, ShardError> {
341        if shard_idx >= self.shards.len() {
342            return Err(ShardError::Unavailable);
343        }
344        self.shards[shard_idx].dispatch_batch(requests).await
345    }
346}
347
348/// Pure function: maps a key to a shard index.
349///
350/// Uses FNV-1a hashing for deterministic shard routing across restarts.
351/// This is critical for AOF/snapshot recovery — keys must hash to the
352/// same shard on every startup, otherwise recovered data lands in the
353/// wrong shard.
354///
355/// FNV-1a is simple, fast for short keys, and completely deterministic
356/// (no per-process randomization). Shard routing is trusted internal
357/// logic so DoS-resistant hashing is unnecessary here.
358fn shard_index(key: &str, shard_count: usize) -> usize {
359    // FNV-1a 64-bit
360    const FNV_OFFSET: u64 = 0xcbf29ce484222325;
361    const FNV_PRIME: u64 = 0x100000001b3;
362
363    let mut hash = FNV_OFFSET;
364    for byte in key.as_bytes() {
365        hash ^= *byte as u64;
366        hash = hash.wrapping_mul(FNV_PRIME);
367    }
368    (hash as usize) % shard_count
369}
370
371#[cfg(test)]
372mod tests {
373    use super::*;
374    use crate::types::Value;
375    use bytes::Bytes;
376
377    #[test]
378    fn same_key_same_shard() {
379        let idx1 = shard_index("foo", 8);
380        let idx2 = shard_index("foo", 8);
381        assert_eq!(idx1, idx2);
382    }
383
384    #[test]
385    fn keys_spread_across_shards() {
386        let mut seen = std::collections::HashSet::new();
387        // with enough keys, we should hit more than one shard
388        for i in 0..100 {
389            let key = format!("key:{i}");
390            seen.insert(shard_index(&key, 4));
391        }
392        assert!(seen.len() > 1, "expected keys to spread across shards");
393    }
394
395    #[test]
396    fn single_shard_always_zero() {
397        assert_eq!(shard_index("anything", 1), 0);
398        assert_eq!(shard_index("other", 1), 0);
399    }
400
401    #[tokio::test]
402    async fn engine_round_trip() {
403        let engine = Engine::new(4);
404
405        let resp = engine
406            .route(
407                "greeting",
408                ShardRequest::Set {
409                    key: "greeting".into(),
410                    value: Bytes::from("hello"),
411                    expire: None,
412                    nx: false,
413                    xx: false,
414                },
415            )
416            .await
417            .unwrap();
418        assert!(matches!(resp, ShardResponse::Ok));
419
420        let resp = engine
421            .route(
422                "greeting",
423                ShardRequest::Get {
424                    key: "greeting".into(),
425                },
426            )
427            .await
428            .unwrap();
429        match resp {
430            ShardResponse::Value(Some(Value::String(data))) => {
431                assert_eq!(data, Bytes::from("hello"));
432            }
433            other => panic!("expected Value(Some(String)), got {other:?}"),
434        }
435    }
436
437    #[tokio::test]
438    async fn multi_shard_del() {
439        let engine = Engine::new(4);
440
441        // set several keys (likely landing on different shards)
442        for key in &["a", "b", "c", "d"] {
443            engine
444                .route(
445                    key,
446                    ShardRequest::Set {
447                        key: key.to_string(),
448                        value: Bytes::from("v"),
449                        expire: None,
450                        nx: false,
451                        xx: false,
452                    },
453                )
454                .await
455                .unwrap();
456        }
457
458        // delete them all and count successes
459        let mut count = 0i64;
460        for key in &["a", "b", "c", "d", "missing"] {
461            let resp = engine
462                .route(
463                    key,
464                    ShardRequest::Del {
465                        key: key.to_string(),
466                    },
467                )
468                .await
469                .unwrap();
470            if let ShardResponse::Bool(true) = resp {
471                count += 1;
472            }
473        }
474        assert_eq!(count, 4);
475    }
476
477    #[test]
478    #[should_panic(expected = "shard count must be at least 1")]
479    fn zero_shards_panics() {
480        Engine::new(0);
481    }
482}