Skip to main content

ember_core/
engine.rs

1//! The engine: coordinator for the sharded keyspace.
2//!
3//! Routes single-key operations to the correct shard based on a hash
4//! of the key. Each shard is an independent tokio task — no locks on
5//! the hot path.
6
7use tokio::sync::broadcast;
8
9use crate::dropper::DropHandle;
10use crate::error::ShardError;
11use crate::keyspace::ShardConfig;
12use crate::shard::{
13    self, PreparedShard, ReplicationEvent, ShardHandle, ShardPersistenceConfig, ShardRequest,
14    ShardResponse,
15};
16
17/// Default channel buffer size per shard.
18///
19/// With batch dispatch, a single pipeline of 16 commands targeting 8 shards
20/// consumes ~8 channel slots (one batch per shard) instead of 16. 4096 gives
21/// generous headroom even under pathological key distribution, at ~400KB per
22/// shard on 64-bit.
23const DEFAULT_SHARD_BUFFER: usize = 4096;
24
25/// Configuration for the engine, passed down to each shard.
26#[derive(Debug, Clone, Default)]
27pub struct EngineConfig {
28    /// Per-shard configuration (memory limits, eviction policy).
29    pub shard: ShardConfig,
30    /// Optional persistence configuration. When set, each shard gets
31    /// its own AOF and snapshot files under this directory.
32    pub persistence: Option<ShardPersistenceConfig>,
33    /// Optional broadcast sender for replication events.
34    ///
35    /// When set, every successful mutation is published as a
36    /// [`ReplicationEvent`] so replication clients can stream it to
37    /// replicas.
38    pub replication_tx: Option<broadcast::Sender<ReplicationEvent>>,
39    /// Optional schema registry for protobuf value validation.
40    /// When set, enables PROTO.* commands.
41    #[cfg(feature = "protobuf")]
42    pub schema_registry: Option<crate::schema::SharedSchemaRegistry>,
43    /// Channel buffer size per shard. 0 means use the default (256).
44    pub shard_channel_buffer: usize,
45}
46
47/// The sharded engine. Owns handles to all shard tasks and routes
48/// requests by key hash.
49///
50/// `Clone` is cheap — it just clones the `Vec<ShardHandle>` (which are
51/// mpsc senders under the hood).
52#[derive(Debug, Clone)]
53pub struct Engine {
54    shards: Vec<ShardHandle>,
55    replication_tx: Option<broadcast::Sender<ReplicationEvent>>,
56    #[cfg(feature = "protobuf")]
57    schema_registry: Option<crate::schema::SharedSchemaRegistry>,
58}
59
60impl Engine {
61    /// Creates an engine with `shard_count` shards using default config.
62    ///
63    /// Each shard is spawned as a tokio task immediately.
64    /// Panics if `shard_count` is zero.
65    pub fn new(shard_count: usize) -> Self {
66        Self::with_config(shard_count, EngineConfig::default())
67    }
68
69    /// Creates an engine with `shard_count` shards and the given config.
70    ///
71    /// Spawns a single background drop thread shared by all shards for
72    /// lazy-freeing large values.
73    ///
74    /// Panics if `shard_count` is zero.
75    pub fn with_config(shard_count: usize, config: EngineConfig) -> Self {
76        assert!(shard_count > 0, "shard count must be at least 1");
77        assert!(
78            shard_count <= u16::MAX as usize,
79            "shard count must fit in u16"
80        );
81
82        let drop_handle = DropHandle::spawn();
83        let buffer = if config.shard_channel_buffer == 0 {
84            DEFAULT_SHARD_BUFFER
85        } else {
86            config.shard_channel_buffer
87        };
88
89        let shards = (0..shard_count)
90            .map(|i| {
91                let mut shard_config = config.shard.clone();
92                shard_config.shard_id = i as u16;
93                shard::spawn_shard(
94                    buffer,
95                    shard_config,
96                    config.persistence.clone(),
97                    Some(drop_handle.clone()),
98                    config.replication_tx.clone(),
99                    #[cfg(feature = "protobuf")]
100                    config.schema_registry.clone(),
101                )
102            })
103            .collect();
104
105        Self {
106            shards,
107            replication_tx: config.replication_tx,
108            #[cfg(feature = "protobuf")]
109            schema_registry: config.schema_registry,
110        }
111    }
112
113    /// Creates the engine and prepared shards without spawning any tasks.
114    ///
115    /// The caller is responsible for running each [`PreparedShard`] on the
116    /// desired runtime via [`shard::run_prepared`]. This is the entry
117    /// point for thread-per-core deployment where each OS thread runs its
118    /// own single-threaded tokio runtime and one shard.
119    ///
120    /// Panics if `shard_count` is zero.
121    pub fn prepare(shard_count: usize, config: EngineConfig) -> (Self, Vec<PreparedShard>) {
122        assert!(shard_count > 0, "shard count must be at least 1");
123        assert!(
124            shard_count <= u16::MAX as usize,
125            "shard count must fit in u16"
126        );
127
128        let drop_handle = DropHandle::spawn();
129        let buffer = if config.shard_channel_buffer == 0 {
130            DEFAULT_SHARD_BUFFER
131        } else {
132            config.shard_channel_buffer
133        };
134
135        let mut handles = Vec::with_capacity(shard_count);
136        let mut prepared = Vec::with_capacity(shard_count);
137
138        for i in 0..shard_count {
139            let mut shard_config = config.shard.clone();
140            shard_config.shard_id = i as u16;
141            let (handle, shard) = shard::prepare_shard(
142                buffer,
143                shard_config,
144                config.persistence.clone(),
145                Some(drop_handle.clone()),
146                config.replication_tx.clone(),
147                #[cfg(feature = "protobuf")]
148                config.schema_registry.clone(),
149            );
150            handles.push(handle);
151            prepared.push(shard);
152        }
153
154        let engine = Self {
155            shards: handles,
156            replication_tx: config.replication_tx,
157            #[cfg(feature = "protobuf")]
158            schema_registry: config.schema_registry,
159        };
160
161        (engine, prepared)
162    }
163
164    /// Creates an engine with one shard per available CPU core.
165    ///
166    /// Falls back to a single shard if the core count can't be determined.
167    pub fn with_available_cores() -> Self {
168        Self::with_available_cores_config(EngineConfig::default())
169    }
170
171    /// Creates an engine with one shard per available CPU core and the
172    /// given config.
173    pub fn with_available_cores_config(config: EngineConfig) -> Self {
174        let cores = std::thread::available_parallelism()
175            .map(|n| n.get())
176            .unwrap_or(1);
177        Self::with_config(cores, config)
178    }
179
180    /// Returns a reference to the schema registry, if protobuf is enabled.
181    #[cfg(feature = "protobuf")]
182    pub fn schema_registry(&self) -> Option<&crate::schema::SharedSchemaRegistry> {
183        self.schema_registry.as_ref()
184    }
185
186    /// Returns the number of shards.
187    pub fn shard_count(&self) -> usize {
188        self.shards.len()
189    }
190
191    /// Creates a new broadcast receiver for replication events.
192    ///
193    /// Returns `None` if no replication channel was configured. Each
194    /// caller gets an independent receiver starting from the current
195    /// broadcast position — not from the beginning of the stream.
196    pub fn subscribe_replication(&self) -> Option<broadcast::Receiver<ReplicationEvent>> {
197        self.replication_tx.as_ref().map(|tx| tx.subscribe())
198    }
199
200    /// Sends a request to a specific shard by index.
201    ///
202    /// Used by SCAN to iterate through shards sequentially.
203    pub async fn send_to_shard(
204        &self,
205        shard_idx: usize,
206        request: ShardRequest,
207    ) -> Result<ShardResponse, ShardError> {
208        if shard_idx >= self.shards.len() {
209            return Err(ShardError::Unavailable);
210        }
211        self.shards[shard_idx].send(request).await
212    }
213
214    /// Routes a request to the shard that owns `key`.
215    pub async fn route(
216        &self,
217        key: &str,
218        request: ShardRequest,
219    ) -> Result<ShardResponse, ShardError> {
220        let idx = self.shard_for_key(key);
221        self.shards[idx].send(request).await
222    }
223
224    /// Sends a request to every shard and collects all responses.
225    ///
226    /// Dispatches to all shards first (so they start processing in
227    /// parallel), then collects the replies. Used for commands like
228    /// DBSIZE and INFO that need data from all shards.
229    pub async fn broadcast<F>(&self, make_req: F) -> Result<Vec<ShardResponse>, ShardError>
230    where
231        F: Fn() -> ShardRequest,
232    {
233        // dispatch to all shards without waiting for responses
234        let mut receivers = Vec::with_capacity(self.shards.len());
235        for shard in &self.shards {
236            receivers.push(shard.dispatch(make_req()).await?);
237        }
238
239        // now collect all responses
240        let mut results = Vec::with_capacity(receivers.len());
241        for rx in receivers {
242            results.push(rx.await.map_err(|_| ShardError::Unavailable)?);
243        }
244        Ok(results)
245    }
246
247    /// Routes requests for multiple keys concurrently.
248    ///
249    /// Dispatches all requests without waiting, then collects responses.
250    /// The response order matches the key order. Used for multi-key
251    /// commands like DEL and EXISTS.
252    pub async fn route_multi<F>(
253        &self,
254        keys: &[String],
255        make_req: F,
256    ) -> Result<Vec<ShardResponse>, ShardError>
257    where
258        F: Fn(String) -> ShardRequest,
259    {
260        let mut receivers = Vec::with_capacity(keys.len());
261        for key in keys {
262            let idx = self.shard_for_key(key);
263            let rx = self.shards[idx].dispatch(make_req(key.clone())).await?;
264            receivers.push(rx);
265        }
266
267        let mut results = Vec::with_capacity(receivers.len());
268        for rx in receivers {
269            results.push(rx.await.map_err(|_| ShardError::Unavailable)?);
270        }
271        Ok(results)
272    }
273
274    /// Returns true if both keys are owned by the same shard.
275    pub fn same_shard(&self, key1: &str, key2: &str) -> bool {
276        self.shard_for_key(key1) == self.shard_for_key(key2)
277    }
278
279    /// Determines which shard owns a given key.
280    pub fn shard_for_key(&self, key: &str) -> usize {
281        shard_index(key, self.shards.len())
282    }
283
284    /// Sends a request to a shard and returns the reply channel without
285    /// waiting for the response. Used by the connection handler to
286    /// dispatch commands and collect responses separately.
287    pub async fn dispatch_to_shard(
288        &self,
289        shard_idx: usize,
290        request: ShardRequest,
291    ) -> Result<tokio::sync::oneshot::Receiver<ShardResponse>, ShardError> {
292        if shard_idx >= self.shards.len() {
293            return Err(ShardError::Unavailable);
294        }
295        self.shards[shard_idx].dispatch(request).await
296    }
297
298    /// Sends a request to a shard using a caller-owned mpsc reply channel.
299    ///
300    /// Avoids the per-command oneshot allocation on the P=1 path.
301    pub async fn dispatch_reusable_to_shard(
302        &self,
303        shard_idx: usize,
304        request: ShardRequest,
305        reply: tokio::sync::mpsc::Sender<ShardResponse>,
306    ) -> Result<(), ShardError> {
307        if shard_idx >= self.shards.len() {
308            return Err(ShardError::Unavailable);
309        }
310        self.shards[shard_idx]
311            .dispatch_reusable(request, reply)
312            .await
313    }
314
315    /// Sends a batch of requests to a single shard as one channel message.
316    ///
317    /// Returns one receiver per request, preserving order. This is the
318    /// pipeline batching optimization: N commands targeting the same shard
319    /// consume 1 channel slot instead of N, eliminating head-of-line
320    /// blocking under high pipeline depths.
321    pub async fn dispatch_batch_to_shard(
322        &self,
323        shard_idx: usize,
324        requests: Vec<ShardRequest>,
325    ) -> Result<Vec<tokio::sync::oneshot::Receiver<ShardResponse>>, ShardError> {
326        if shard_idx >= self.shards.len() {
327            return Err(ShardError::Unavailable);
328        }
329        self.shards[shard_idx].dispatch_batch(requests).await
330    }
331}
332
333/// Pure function: maps a key to a shard index.
334///
335/// Uses FNV-1a hashing for deterministic shard routing across restarts.
336/// This is critical for AOF/snapshot recovery — keys must hash to the
337/// same shard on every startup, otherwise recovered data lands in the
338/// wrong shard.
339///
340/// FNV-1a is simple, fast for short keys, and completely deterministic
341/// (no per-process randomization). Shard routing is trusted internal
342/// logic so DoS-resistant hashing is unnecessary here.
343fn shard_index(key: &str, shard_count: usize) -> usize {
344    // FNV-1a 64-bit
345    const FNV_OFFSET: u64 = 0xcbf29ce484222325;
346    const FNV_PRIME: u64 = 0x100000001b3;
347
348    let mut hash = FNV_OFFSET;
349    for byte in key.as_bytes() {
350        hash ^= *byte as u64;
351        hash = hash.wrapping_mul(FNV_PRIME);
352    }
353    (hash as usize) % shard_count
354}
355
356#[cfg(test)]
357mod tests {
358    use super::*;
359    use crate::types::Value;
360    use bytes::Bytes;
361
362    #[test]
363    fn same_key_same_shard() {
364        let idx1 = shard_index("foo", 8);
365        let idx2 = shard_index("foo", 8);
366        assert_eq!(idx1, idx2);
367    }
368
369    #[test]
370    fn keys_spread_across_shards() {
371        let mut seen = std::collections::HashSet::new();
372        // with enough keys, we should hit more than one shard
373        for i in 0..100 {
374            let key = format!("key:{i}");
375            seen.insert(shard_index(&key, 4));
376        }
377        assert!(seen.len() > 1, "expected keys to spread across shards");
378    }
379
380    #[test]
381    fn single_shard_always_zero() {
382        assert_eq!(shard_index("anything", 1), 0);
383        assert_eq!(shard_index("other", 1), 0);
384    }
385
386    #[tokio::test]
387    async fn engine_round_trip() {
388        let engine = Engine::new(4);
389
390        let resp = engine
391            .route(
392                "greeting",
393                ShardRequest::Set {
394                    key: "greeting".into(),
395                    value: Bytes::from("hello"),
396                    expire: None,
397                    nx: false,
398                    xx: false,
399                },
400            )
401            .await
402            .unwrap();
403        assert!(matches!(resp, ShardResponse::Ok));
404
405        let resp = engine
406            .route(
407                "greeting",
408                ShardRequest::Get {
409                    key: "greeting".into(),
410                },
411            )
412            .await
413            .unwrap();
414        match resp {
415            ShardResponse::Value(Some(Value::String(data))) => {
416                assert_eq!(data, Bytes::from("hello"));
417            }
418            other => panic!("expected Value(Some(String)), got {other:?}"),
419        }
420    }
421
422    #[tokio::test]
423    async fn multi_shard_del() {
424        let engine = Engine::new(4);
425
426        // set several keys (likely landing on different shards)
427        for key in &["a", "b", "c", "d"] {
428            engine
429                .route(
430                    key,
431                    ShardRequest::Set {
432                        key: key.to_string(),
433                        value: Bytes::from("v"),
434                        expire: None,
435                        nx: false,
436                        xx: false,
437                    },
438                )
439                .await
440                .unwrap();
441        }
442
443        // delete them all and count successes
444        let mut count = 0i64;
445        for key in &["a", "b", "c", "d", "missing"] {
446            let resp = engine
447                .route(
448                    key,
449                    ShardRequest::Del {
450                        key: key.to_string(),
451                    },
452                )
453                .await
454                .unwrap();
455            if let ShardResponse::Bool(true) = resp {
456                count += 1;
457            }
458        }
459        assert_eq!(count, 4);
460    }
461
462    #[test]
463    #[should_panic(expected = "shard count must be at least 1")]
464    fn zero_shards_panics() {
465        Engine::new(0);
466    }
467}