1use std::collections::hash_map::DefaultHasher;
8use std::hash::{Hash, Hasher};
9
10use crate::error::ShardError;
11use crate::keyspace::ShardConfig;
12use crate::shard::{self, ShardHandle, ShardPersistenceConfig, ShardRequest, ShardResponse};
13
14const SHARD_BUFFER: usize = 256;
17
18#[derive(Debug, Clone, Default)]
20pub struct EngineConfig {
21 pub shard: ShardConfig,
23 pub persistence: Option<ShardPersistenceConfig>,
26}
27
28#[derive(Debug, Clone)]
34pub struct Engine {
35 shards: Vec<ShardHandle>,
36}
37
38impl Engine {
39 pub fn new(shard_count: usize) -> Self {
44 Self::with_config(shard_count, EngineConfig::default())
45 }
46
47 pub fn with_config(shard_count: usize, config: EngineConfig) -> Self {
51 assert!(shard_count > 0, "shard count must be at least 1");
52
53 let shards = (0..shard_count)
54 .map(|i| {
55 let mut shard_config = config.shard.clone();
56 shard_config.shard_id = i as u16;
57 shard::spawn_shard(SHARD_BUFFER, shard_config, config.persistence.clone())
58 })
59 .collect();
60
61 Self { shards }
62 }
63
64 pub fn with_available_cores() -> Self {
68 Self::with_available_cores_config(EngineConfig::default())
69 }
70
71 pub fn with_available_cores_config(config: EngineConfig) -> Self {
74 let cores = std::thread::available_parallelism()
75 .map(|n| n.get())
76 .unwrap_or(1);
77 Self::with_config(cores, config)
78 }
79
80 pub fn shard_count(&self) -> usize {
82 self.shards.len()
83 }
84
85 pub async fn send_to_shard(
89 &self,
90 shard_idx: usize,
91 request: ShardRequest,
92 ) -> Result<ShardResponse, ShardError> {
93 if shard_idx >= self.shards.len() {
94 return Err(ShardError::Unavailable);
95 }
96 self.shards[shard_idx].send(request).await
97 }
98
99 pub async fn route(
101 &self,
102 key: &str,
103 request: ShardRequest,
104 ) -> Result<ShardResponse, ShardError> {
105 let idx = self.shard_for_key(key);
106 self.shards[idx].send(request).await
107 }
108
109 pub async fn broadcast<F>(&self, make_req: F) -> Result<Vec<ShardResponse>, ShardError>
115 where
116 F: Fn() -> ShardRequest,
117 {
118 let mut receivers = Vec::with_capacity(self.shards.len());
120 for shard in &self.shards {
121 receivers.push(shard.dispatch(make_req()).await?);
122 }
123
124 let mut results = Vec::with_capacity(receivers.len());
126 for rx in receivers {
127 results.push(rx.await.map_err(|_| ShardError::Unavailable)?);
128 }
129 Ok(results)
130 }
131
132 pub async fn route_multi<F>(
138 &self,
139 keys: &[String],
140 make_req: F,
141 ) -> Result<Vec<ShardResponse>, ShardError>
142 where
143 F: Fn(String) -> ShardRequest,
144 {
145 let mut receivers = Vec::with_capacity(keys.len());
146 for key in keys {
147 let idx = self.shard_for_key(key);
148 let rx = self.shards[idx].dispatch(make_req(key.clone())).await?;
149 receivers.push(rx);
150 }
151
152 let mut results = Vec::with_capacity(receivers.len());
153 for rx in receivers {
154 results.push(rx.await.map_err(|_| ShardError::Unavailable)?);
155 }
156 Ok(results)
157 }
158
159 fn shard_for_key(&self, key: &str) -> usize {
161 shard_index(key, self.shards.len())
162 }
163}
164
165fn shard_index(key: &str, shard_count: usize) -> usize {
171 let mut hasher = DefaultHasher::new();
172 key.hash(&mut hasher);
173 (hasher.finish() as usize) % shard_count
174}
175
176#[cfg(test)]
177mod tests {
178 use super::*;
179 use crate::types::Value;
180 use bytes::Bytes;
181
182 #[test]
183 fn same_key_same_shard() {
184 let idx1 = shard_index("foo", 8);
185 let idx2 = shard_index("foo", 8);
186 assert_eq!(idx1, idx2);
187 }
188
189 #[test]
190 fn keys_spread_across_shards() {
191 let mut seen = std::collections::HashSet::new();
192 for i in 0..100 {
194 let key = format!("key:{i}");
195 seen.insert(shard_index(&key, 4));
196 }
197 assert!(seen.len() > 1, "expected keys to spread across shards");
198 }
199
200 #[test]
201 fn single_shard_always_zero() {
202 assert_eq!(shard_index("anything", 1), 0);
203 assert_eq!(shard_index("other", 1), 0);
204 }
205
206 #[tokio::test]
207 async fn engine_round_trip() {
208 let engine = Engine::new(4);
209
210 let resp = engine
211 .route(
212 "greeting",
213 ShardRequest::Set {
214 key: "greeting".into(),
215 value: Bytes::from("hello"),
216 expire: None,
217 nx: false,
218 xx: false,
219 },
220 )
221 .await
222 .unwrap();
223 assert!(matches!(resp, ShardResponse::Ok));
224
225 let resp = engine
226 .route(
227 "greeting",
228 ShardRequest::Get {
229 key: "greeting".into(),
230 },
231 )
232 .await
233 .unwrap();
234 match resp {
235 ShardResponse::Value(Some(Value::String(data))) => {
236 assert_eq!(data, Bytes::from("hello"));
237 }
238 other => panic!("expected Value(Some(String)), got {other:?}"),
239 }
240 }
241
242 #[tokio::test]
243 async fn multi_shard_del() {
244 let engine = Engine::new(4);
245
246 for key in &["a", "b", "c", "d"] {
248 engine
249 .route(
250 key,
251 ShardRequest::Set {
252 key: key.to_string(),
253 value: Bytes::from("v"),
254 expire: None,
255 nx: false,
256 xx: false,
257 },
258 )
259 .await
260 .unwrap();
261 }
262
263 let mut count = 0i64;
265 for key in &["a", "b", "c", "d", "missing"] {
266 let resp = engine
267 .route(
268 key,
269 ShardRequest::Del {
270 key: key.to_string(),
271 },
272 )
273 .await
274 .unwrap();
275 if let ShardResponse::Bool(true) = resp {
276 count += 1;
277 }
278 }
279 assert_eq!(count, 4);
280 }
281
282 #[test]
283 #[should_panic(expected = "shard count must be at least 1")]
284 fn zero_shards_panics() {
285 Engine::new(0);
286 }
287}