1use std::collections::hash_map::DefaultHasher;
8use std::hash::{Hash, Hasher};
9
10use crate::error::ShardError;
11use crate::keyspace::ShardConfig;
12use crate::shard::{self, ShardHandle, ShardRequest, ShardResponse};
13
14const SHARD_BUFFER: usize = 256;
17
18#[derive(Debug, Clone, Default)]
20pub struct EngineConfig {
21 pub shard: ShardConfig,
23}
24
25#[derive(Debug, Clone)]
31pub struct Engine {
32 shards: Vec<ShardHandle>,
33}
34
35impl Engine {
36 pub fn new(shard_count: usize) -> Self {
41 Self::with_config(shard_count, EngineConfig::default())
42 }
43
44 pub fn with_config(shard_count: usize, config: EngineConfig) -> Self {
48 assert!(shard_count > 0, "shard count must be at least 1");
49
50 let shards = (0..shard_count)
51 .map(|_| shard::spawn_shard(SHARD_BUFFER, config.shard.clone()))
52 .collect();
53
54 Self { shards }
55 }
56
57 pub fn with_available_cores() -> Self {
61 Self::with_available_cores_config(EngineConfig::default())
62 }
63
64 pub fn with_available_cores_config(config: EngineConfig) -> Self {
67 let cores = std::thread::available_parallelism()
68 .map(|n| n.get())
69 .unwrap_or(1);
70 Self::with_config(cores, config)
71 }
72
73 pub fn shard_count(&self) -> usize {
75 self.shards.len()
76 }
77
78 pub async fn route(
80 &self,
81 key: &str,
82 request: ShardRequest,
83 ) -> Result<ShardResponse, ShardError> {
84 let idx = self.shard_for_key(key);
85 self.shards[idx].send(request).await
86 }
87
88 pub async fn broadcast<F>(&self, make_req: F) -> Result<Vec<ShardResponse>, ShardError>
94 where
95 F: Fn() -> ShardRequest,
96 {
97 let mut receivers = Vec::with_capacity(self.shards.len());
99 for shard in &self.shards {
100 receivers.push(shard.dispatch(make_req()).await?);
101 }
102
103 let mut results = Vec::with_capacity(receivers.len());
105 for rx in receivers {
106 results.push(rx.await.map_err(|_| ShardError::Unavailable)?);
107 }
108 Ok(results)
109 }
110
111 pub async fn route_multi<F>(
117 &self,
118 keys: &[String],
119 make_req: F,
120 ) -> Result<Vec<ShardResponse>, ShardError>
121 where
122 F: Fn(String) -> ShardRequest,
123 {
124 let mut receivers = Vec::with_capacity(keys.len());
125 for key in keys {
126 let idx = self.shard_for_key(key);
127 let rx = self.shards[idx].dispatch(make_req(key.clone())).await?;
128 receivers.push(rx);
129 }
130
131 let mut results = Vec::with_capacity(receivers.len());
132 for rx in receivers {
133 results.push(rx.await.map_err(|_| ShardError::Unavailable)?);
134 }
135 Ok(results)
136 }
137
138 fn shard_for_key(&self, key: &str) -> usize {
140 shard_index(key, self.shards.len())
141 }
142}
143
144fn shard_index(key: &str, shard_count: usize) -> usize {
150 let mut hasher = DefaultHasher::new();
151 key.hash(&mut hasher);
152 (hasher.finish() as usize) % shard_count
153}
154
155#[cfg(test)]
156mod tests {
157 use super::*;
158 use crate::types::Value;
159 use bytes::Bytes;
160
161 #[test]
162 fn same_key_same_shard() {
163 let idx1 = shard_index("foo", 8);
164 let idx2 = shard_index("foo", 8);
165 assert_eq!(idx1, idx2);
166 }
167
168 #[test]
169 fn keys_spread_across_shards() {
170 let mut seen = std::collections::HashSet::new();
171 for i in 0..100 {
173 let key = format!("key:{i}");
174 seen.insert(shard_index(&key, 4));
175 }
176 assert!(seen.len() > 1, "expected keys to spread across shards");
177 }
178
179 #[test]
180 fn single_shard_always_zero() {
181 assert_eq!(shard_index("anything", 1), 0);
182 assert_eq!(shard_index("other", 1), 0);
183 }
184
185 #[tokio::test]
186 async fn engine_round_trip() {
187 let engine = Engine::new(4);
188
189 let resp = engine
190 .route(
191 "greeting",
192 ShardRequest::Set {
193 key: "greeting".into(),
194 value: Bytes::from("hello"),
195 expire: None,
196 },
197 )
198 .await
199 .unwrap();
200 assert!(matches!(resp, ShardResponse::Ok));
201
202 let resp = engine
203 .route(
204 "greeting",
205 ShardRequest::Get {
206 key: "greeting".into(),
207 },
208 )
209 .await
210 .unwrap();
211 match resp {
212 ShardResponse::Value(Some(Value::String(data))) => {
213 assert_eq!(data, Bytes::from("hello"));
214 }
215 other => panic!("expected Value(Some(String)), got {other:?}"),
216 }
217 }
218
219 #[tokio::test]
220 async fn multi_shard_del() {
221 let engine = Engine::new(4);
222
223 for key in &["a", "b", "c", "d"] {
225 engine
226 .route(
227 key,
228 ShardRequest::Set {
229 key: key.to_string(),
230 value: Bytes::from("v"),
231 expire: None,
232 },
233 )
234 .await
235 .unwrap();
236 }
237
238 let mut count = 0i64;
240 for key in &["a", "b", "c", "d", "missing"] {
241 let resp = engine
242 .route(
243 key,
244 ShardRequest::Del {
245 key: key.to_string(),
246 },
247 )
248 .await
249 .unwrap();
250 if let ShardResponse::Bool(true) = resp {
251 count += 1;
252 }
253 }
254 assert_eq!(count, 4);
255 }
256
257 #[test]
258 #[should_panic(expected = "shard count must be at least 1")]
259 fn zero_shards_panics() {
260 Engine::new(0);
261 }
262}