Skip to main content

ember_core/
engine.rs

1//! The engine: coordinator for the sharded keyspace.
2//!
3//! Routes single-key operations to the correct shard based on a hash
4//! of the key. Each shard is an independent tokio task — no locks on
5//! the hot path.
6
7use std::collections::hash_map::DefaultHasher;
8use std::hash::{Hash, Hasher};
9
10use crate::error::ShardError;
11use crate::keyspace::ShardConfig;
12use crate::shard::{self, ShardHandle, ShardPersistenceConfig, ShardRequest, ShardResponse};
13
14/// Channel buffer size per shard. 256 is large enough to absorb
15/// bursts without putting meaningful back-pressure on connections.
16const SHARD_BUFFER: usize = 256;
17
18/// Configuration for the engine, passed down to each shard.
19#[derive(Debug, Clone, Default)]
20pub struct EngineConfig {
21    /// Per-shard configuration (memory limits, eviction policy).
22    pub shard: ShardConfig,
23    /// Optional persistence configuration. When set, each shard gets
24    /// its own AOF and snapshot files under this directory.
25    pub persistence: Option<ShardPersistenceConfig>,
26}
27
28/// The sharded engine. Owns handles to all shard tasks and routes
29/// requests by key hash.
30///
31/// `Clone` is cheap — it just clones the `Vec<ShardHandle>` (which are
32/// mpsc senders under the hood).
33#[derive(Debug, Clone)]
34pub struct Engine {
35    shards: Vec<ShardHandle>,
36}
37
38impl Engine {
39    /// Creates an engine with `shard_count` shards using default config.
40    ///
41    /// Each shard is spawned as a tokio task immediately.
42    /// Panics if `shard_count` is zero.
43    pub fn new(shard_count: usize) -> Self {
44        Self::with_config(shard_count, EngineConfig::default())
45    }
46
47    /// Creates an engine with `shard_count` shards and the given config.
48    ///
49    /// Panics if `shard_count` is zero.
50    pub fn with_config(shard_count: usize, config: EngineConfig) -> Self {
51        assert!(shard_count > 0, "shard count must be at least 1");
52
53        let shards = (0..shard_count)
54            .map(|i| {
55                let mut shard_config = config.shard.clone();
56                shard_config.shard_id = i as u16;
57                shard::spawn_shard(SHARD_BUFFER, shard_config, config.persistence.clone())
58            })
59            .collect();
60
61        Self { shards }
62    }
63
64    /// Creates an engine with one shard per available CPU core.
65    ///
66    /// Falls back to a single shard if the core count can't be determined.
67    pub fn with_available_cores() -> Self {
68        Self::with_available_cores_config(EngineConfig::default())
69    }
70
71    /// Creates an engine with one shard per available CPU core and the
72    /// given config.
73    pub fn with_available_cores_config(config: EngineConfig) -> Self {
74        let cores = std::thread::available_parallelism()
75            .map(|n| n.get())
76            .unwrap_or(1);
77        Self::with_config(cores, config)
78    }
79
80    /// Returns the number of shards.
81    pub fn shard_count(&self) -> usize {
82        self.shards.len()
83    }
84
85    /// Sends a request to a specific shard by index.
86    ///
87    /// Used by SCAN to iterate through shards sequentially.
88    pub async fn send_to_shard(
89        &self,
90        shard_idx: usize,
91        request: ShardRequest,
92    ) -> Result<ShardResponse, ShardError> {
93        if shard_idx >= self.shards.len() {
94            return Err(ShardError::Unavailable);
95        }
96        self.shards[shard_idx].send(request).await
97    }
98
99    /// Routes a request to the shard that owns `key`.
100    pub async fn route(
101        &self,
102        key: &str,
103        request: ShardRequest,
104    ) -> Result<ShardResponse, ShardError> {
105        let idx = self.shard_for_key(key);
106        self.shards[idx].send(request).await
107    }
108
109    /// Sends a request to every shard and collects all responses.
110    ///
111    /// Dispatches to all shards first (so they start processing in
112    /// parallel), then collects the replies. Used for commands like
113    /// DBSIZE and INFO that need data from all shards.
114    pub async fn broadcast<F>(&self, make_req: F) -> Result<Vec<ShardResponse>, ShardError>
115    where
116        F: Fn() -> ShardRequest,
117    {
118        // dispatch to all shards without waiting for responses
119        let mut receivers = Vec::with_capacity(self.shards.len());
120        for shard in &self.shards {
121            receivers.push(shard.dispatch(make_req()).await?);
122        }
123
124        // now collect all responses
125        let mut results = Vec::with_capacity(receivers.len());
126        for rx in receivers {
127            results.push(rx.await.map_err(|_| ShardError::Unavailable)?);
128        }
129        Ok(results)
130    }
131
132    /// Routes requests for multiple keys concurrently.
133    ///
134    /// Dispatches all requests without waiting, then collects responses.
135    /// The response order matches the key order. Used for multi-key
136    /// commands like DEL and EXISTS.
137    pub async fn route_multi<F>(
138        &self,
139        keys: &[String],
140        make_req: F,
141    ) -> Result<Vec<ShardResponse>, ShardError>
142    where
143        F: Fn(String) -> ShardRequest,
144    {
145        let mut receivers = Vec::with_capacity(keys.len());
146        for key in keys {
147            let idx = self.shard_for_key(key);
148            let rx = self.shards[idx].dispatch(make_req(key.clone())).await?;
149            receivers.push(rx);
150        }
151
152        let mut results = Vec::with_capacity(receivers.len());
153        for rx in receivers {
154            results.push(rx.await.map_err(|_| ShardError::Unavailable)?);
155        }
156        Ok(results)
157    }
158
159    /// Determines which shard owns a given key.
160    fn shard_for_key(&self, key: &str) -> usize {
161        shard_index(key, self.shards.len())
162    }
163}
164
165/// Pure function: maps a key to a shard index.
166///
167/// Uses `DefaultHasher` (SipHash) and modulo. Deterministic within a
168/// single process — that's all we need for local sharding. CRC16 will
169/// replace this when cluster-level slot assignment arrives.
170fn shard_index(key: &str, shard_count: usize) -> usize {
171    let mut hasher = DefaultHasher::new();
172    key.hash(&mut hasher);
173    (hasher.finish() as usize) % shard_count
174}
175
176#[cfg(test)]
177mod tests {
178    use super::*;
179    use crate::types::Value;
180    use bytes::Bytes;
181
182    #[test]
183    fn same_key_same_shard() {
184        let idx1 = shard_index("foo", 8);
185        let idx2 = shard_index("foo", 8);
186        assert_eq!(idx1, idx2);
187    }
188
189    #[test]
190    fn keys_spread_across_shards() {
191        let mut seen = std::collections::HashSet::new();
192        // with enough keys, we should hit more than one shard
193        for i in 0..100 {
194            let key = format!("key:{i}");
195            seen.insert(shard_index(&key, 4));
196        }
197        assert!(seen.len() > 1, "expected keys to spread across shards");
198    }
199
200    #[test]
201    fn single_shard_always_zero() {
202        assert_eq!(shard_index("anything", 1), 0);
203        assert_eq!(shard_index("other", 1), 0);
204    }
205
206    #[tokio::test]
207    async fn engine_round_trip() {
208        let engine = Engine::new(4);
209
210        let resp = engine
211            .route(
212                "greeting",
213                ShardRequest::Set {
214                    key: "greeting".into(),
215                    value: Bytes::from("hello"),
216                    expire: None,
217                    nx: false,
218                    xx: false,
219                },
220            )
221            .await
222            .unwrap();
223        assert!(matches!(resp, ShardResponse::Ok));
224
225        let resp = engine
226            .route(
227                "greeting",
228                ShardRequest::Get {
229                    key: "greeting".into(),
230                },
231            )
232            .await
233            .unwrap();
234        match resp {
235            ShardResponse::Value(Some(Value::String(data))) => {
236                assert_eq!(data, Bytes::from("hello"));
237            }
238            other => panic!("expected Value(Some(String)), got {other:?}"),
239        }
240    }
241
242    #[tokio::test]
243    async fn multi_shard_del() {
244        let engine = Engine::new(4);
245
246        // set several keys (likely landing on different shards)
247        for key in &["a", "b", "c", "d"] {
248            engine
249                .route(
250                    key,
251                    ShardRequest::Set {
252                        key: key.to_string(),
253                        value: Bytes::from("v"),
254                        expire: None,
255                        nx: false,
256                        xx: false,
257                    },
258                )
259                .await
260                .unwrap();
261        }
262
263        // delete them all and count successes
264        let mut count = 0i64;
265        for key in &["a", "b", "c", "d", "missing"] {
266            let resp = engine
267                .route(
268                    key,
269                    ShardRequest::Del {
270                        key: key.to_string(),
271                    },
272                )
273                .await
274                .unwrap();
275            if let ShardResponse::Bool(true) = resp {
276                count += 1;
277            }
278        }
279        assert_eq!(count, 4);
280    }
281
282    #[test]
283    #[should_panic(expected = "shard count must be at least 1")]
284    fn zero_shards_panics() {
285        Engine::new(0);
286    }
287}