Skip to main content

ember_core/
engine.rs

1//! The engine: coordinator for the sharded keyspace.
2//!
3//! Routes single-key operations to the correct shard based on a hash
4//! of the key. Each shard is an independent tokio task — no locks on
5//! the hot path.
6
7use crate::dropper::DropHandle;
8use crate::error::ShardError;
9use crate::keyspace::ShardConfig;
10use crate::shard::{self, ShardHandle, ShardPersistenceConfig, ShardRequest, ShardResponse};
11
12/// Channel buffer size per shard. 256 is large enough to absorb
13/// bursts without putting meaningful back-pressure on connections.
14const SHARD_BUFFER: usize = 256;
15
16/// Configuration for the engine, passed down to each shard.
17#[derive(Debug, Clone, Default)]
18pub struct EngineConfig {
19    /// Per-shard configuration (memory limits, eviction policy).
20    pub shard: ShardConfig,
21    /// Optional persistence configuration. When set, each shard gets
22    /// its own AOF and snapshot files under this directory.
23    pub persistence: Option<ShardPersistenceConfig>,
24    /// Optional schema registry for protobuf value validation.
25    /// When set, enables PROTO.* commands.
26    #[cfg(feature = "protobuf")]
27    pub schema_registry: Option<crate::schema::SharedSchemaRegistry>,
28}
29
30/// The sharded engine. Owns handles to all shard tasks and routes
31/// requests by key hash.
32///
33/// `Clone` is cheap — it just clones the `Vec<ShardHandle>` (which are
34/// mpsc senders under the hood).
35#[derive(Debug, Clone)]
36pub struct Engine {
37    shards: Vec<ShardHandle>,
38    #[cfg(feature = "protobuf")]
39    schema_registry: Option<crate::schema::SharedSchemaRegistry>,
40}
41
42impl Engine {
43    /// Creates an engine with `shard_count` shards using default config.
44    ///
45    /// Each shard is spawned as a tokio task immediately.
46    /// Panics if `shard_count` is zero.
47    pub fn new(shard_count: usize) -> Self {
48        Self::with_config(shard_count, EngineConfig::default())
49    }
50
51    /// Creates an engine with `shard_count` shards and the given config.
52    ///
53    /// Spawns a single background drop thread shared by all shards for
54    /// lazy-freeing large values.
55    ///
56    /// Panics if `shard_count` is zero.
57    pub fn with_config(shard_count: usize, config: EngineConfig) -> Self {
58        assert!(shard_count > 0, "shard count must be at least 1");
59        assert!(
60            shard_count <= u16::MAX as usize,
61            "shard count must fit in u16"
62        );
63
64        let drop_handle = DropHandle::spawn();
65
66        let shards = (0..shard_count)
67            .map(|i| {
68                let mut shard_config = config.shard.clone();
69                shard_config.shard_id = i as u16;
70                shard::spawn_shard(
71                    SHARD_BUFFER,
72                    shard_config,
73                    config.persistence.clone(),
74                    Some(drop_handle.clone()),
75                    #[cfg(feature = "protobuf")]
76                    config.schema_registry.clone(),
77                )
78            })
79            .collect();
80
81        Self {
82            shards,
83            #[cfg(feature = "protobuf")]
84            schema_registry: config.schema_registry,
85        }
86    }
87
88    /// Creates an engine with one shard per available CPU core.
89    ///
90    /// Falls back to a single shard if the core count can't be determined.
91    pub fn with_available_cores() -> Self {
92        Self::with_available_cores_config(EngineConfig::default())
93    }
94
95    /// Creates an engine with one shard per available CPU core and the
96    /// given config.
97    pub fn with_available_cores_config(config: EngineConfig) -> Self {
98        let cores = std::thread::available_parallelism()
99            .map(|n| n.get())
100            .unwrap_or(1);
101        Self::with_config(cores, config)
102    }
103
104    /// Returns a reference to the schema registry, if protobuf is enabled.
105    #[cfg(feature = "protobuf")]
106    pub fn schema_registry(&self) -> Option<&crate::schema::SharedSchemaRegistry> {
107        self.schema_registry.as_ref()
108    }
109
110    /// Returns the number of shards.
111    pub fn shard_count(&self) -> usize {
112        self.shards.len()
113    }
114
115    /// Sends a request to a specific shard by index.
116    ///
117    /// Used by SCAN to iterate through shards sequentially.
118    pub async fn send_to_shard(
119        &self,
120        shard_idx: usize,
121        request: ShardRequest,
122    ) -> Result<ShardResponse, ShardError> {
123        if shard_idx >= self.shards.len() {
124            return Err(ShardError::Unavailable);
125        }
126        self.shards[shard_idx].send(request).await
127    }
128
129    /// Routes a request to the shard that owns `key`.
130    pub async fn route(
131        &self,
132        key: &str,
133        request: ShardRequest,
134    ) -> Result<ShardResponse, ShardError> {
135        let idx = self.shard_for_key(key);
136        self.shards[idx].send(request).await
137    }
138
139    /// Sends a request to every shard and collects all responses.
140    ///
141    /// Dispatches to all shards first (so they start processing in
142    /// parallel), then collects the replies. Used for commands like
143    /// DBSIZE and INFO that need data from all shards.
144    pub async fn broadcast<F>(&self, make_req: F) -> Result<Vec<ShardResponse>, ShardError>
145    where
146        F: Fn() -> ShardRequest,
147    {
148        // dispatch to all shards without waiting for responses
149        let mut receivers = Vec::with_capacity(self.shards.len());
150        for shard in &self.shards {
151            receivers.push(shard.dispatch(make_req()).await?);
152        }
153
154        // now collect all responses
155        let mut results = Vec::with_capacity(receivers.len());
156        for rx in receivers {
157            results.push(rx.await.map_err(|_| ShardError::Unavailable)?);
158        }
159        Ok(results)
160    }
161
162    /// Routes requests for multiple keys concurrently.
163    ///
164    /// Dispatches all requests without waiting, then collects responses.
165    /// The response order matches the key order. Used for multi-key
166    /// commands like DEL and EXISTS.
167    pub async fn route_multi<F>(
168        &self,
169        keys: &[String],
170        make_req: F,
171    ) -> Result<Vec<ShardResponse>, ShardError>
172    where
173        F: Fn(String) -> ShardRequest,
174    {
175        let mut receivers = Vec::with_capacity(keys.len());
176        for key in keys {
177            let idx = self.shard_for_key(key);
178            let rx = self.shards[idx].dispatch(make_req(key.clone())).await?;
179            receivers.push(rx);
180        }
181
182        let mut results = Vec::with_capacity(receivers.len());
183        for rx in receivers {
184            results.push(rx.await.map_err(|_| ShardError::Unavailable)?);
185        }
186        Ok(results)
187    }
188
189    /// Returns true if both keys are owned by the same shard.
190    pub fn same_shard(&self, key1: &str, key2: &str) -> bool {
191        self.shard_for_key(key1) == self.shard_for_key(key2)
192    }
193
194    /// Determines which shard owns a given key.
195    pub fn shard_for_key(&self, key: &str) -> usize {
196        shard_index(key, self.shards.len())
197    }
198
199    /// Sends a request to a shard and returns the reply channel without
200    /// waiting for the response. Used by the connection handler to
201    /// dispatch commands and collect responses separately.
202    pub async fn dispatch_to_shard(
203        &self,
204        shard_idx: usize,
205        request: ShardRequest,
206    ) -> Result<tokio::sync::oneshot::Receiver<ShardResponse>, ShardError> {
207        if shard_idx >= self.shards.len() {
208            return Err(ShardError::Unavailable);
209        }
210        self.shards[shard_idx].dispatch(request).await
211    }
212}
213
214/// Pure function: maps a key to a shard index.
215///
216/// Uses FNV-1a hashing for deterministic shard routing across restarts.
217/// This is critical for AOF/snapshot recovery — keys must hash to the
218/// same shard on every startup, otherwise recovered data lands in the
219/// wrong shard.
220///
221/// FNV-1a is simple, fast for short keys, and completely deterministic
222/// (no per-process randomization). Shard routing is trusted internal
223/// logic so DoS-resistant hashing is unnecessary here.
224fn shard_index(key: &str, shard_count: usize) -> usize {
225    // FNV-1a 64-bit
226    const FNV_OFFSET: u64 = 0xcbf29ce484222325;
227    const FNV_PRIME: u64 = 0x100000001b3;
228
229    let mut hash = FNV_OFFSET;
230    for byte in key.as_bytes() {
231        hash ^= *byte as u64;
232        hash = hash.wrapping_mul(FNV_PRIME);
233    }
234    (hash as usize) % shard_count
235}
236
237#[cfg(test)]
238mod tests {
239    use super::*;
240    use crate::types::Value;
241    use bytes::Bytes;
242
243    #[test]
244    fn same_key_same_shard() {
245        let idx1 = shard_index("foo", 8);
246        let idx2 = shard_index("foo", 8);
247        assert_eq!(idx1, idx2);
248    }
249
250    #[test]
251    fn keys_spread_across_shards() {
252        let mut seen = std::collections::HashSet::new();
253        // with enough keys, we should hit more than one shard
254        for i in 0..100 {
255            let key = format!("key:{i}");
256            seen.insert(shard_index(&key, 4));
257        }
258        assert!(seen.len() > 1, "expected keys to spread across shards");
259    }
260
261    #[test]
262    fn single_shard_always_zero() {
263        assert_eq!(shard_index("anything", 1), 0);
264        assert_eq!(shard_index("other", 1), 0);
265    }
266
267    #[tokio::test]
268    async fn engine_round_trip() {
269        let engine = Engine::new(4);
270
271        let resp = engine
272            .route(
273                "greeting",
274                ShardRequest::Set {
275                    key: "greeting".into(),
276                    value: Bytes::from("hello"),
277                    expire: None,
278                    nx: false,
279                    xx: false,
280                },
281            )
282            .await
283            .unwrap();
284        assert!(matches!(resp, ShardResponse::Ok));
285
286        let resp = engine
287            .route(
288                "greeting",
289                ShardRequest::Get {
290                    key: "greeting".into(),
291                },
292            )
293            .await
294            .unwrap();
295        match resp {
296            ShardResponse::Value(Some(Value::String(data))) => {
297                assert_eq!(data, Bytes::from("hello"));
298            }
299            other => panic!("expected Value(Some(String)), got {other:?}"),
300        }
301    }
302
303    #[tokio::test]
304    async fn multi_shard_del() {
305        let engine = Engine::new(4);
306
307        // set several keys (likely landing on different shards)
308        for key in &["a", "b", "c", "d"] {
309            engine
310                .route(
311                    key,
312                    ShardRequest::Set {
313                        key: key.to_string(),
314                        value: Bytes::from("v"),
315                        expire: None,
316                        nx: false,
317                        xx: false,
318                    },
319                )
320                .await
321                .unwrap();
322        }
323
324        // delete them all and count successes
325        let mut count = 0i64;
326        for key in &["a", "b", "c", "d", "missing"] {
327            let resp = engine
328                .route(
329                    key,
330                    ShardRequest::Del {
331                        key: key.to_string(),
332                    },
333                )
334                .await
335                .unwrap();
336            if let ShardResponse::Bool(true) = resp {
337                count += 1;
338            }
339        }
340        assert_eq!(count, 4);
341    }
342
343    #[test]
344    #[should_panic(expected = "shard count must be at least 1")]
345    fn zero_shards_panics() {
346        Engine::new(0);
347    }
348}