1use crate::dropper::DropHandle;
8use crate::error::ShardError;
9use crate::keyspace::ShardConfig;
10use crate::shard::{self, ShardHandle, ShardPersistenceConfig, ShardRequest, ShardResponse};
11
12const SHARD_BUFFER: usize = 256;
15
16#[derive(Debug, Clone, Default)]
18pub struct EngineConfig {
19 pub shard: ShardConfig,
21 pub persistence: Option<ShardPersistenceConfig>,
24 #[cfg(feature = "protobuf")]
27 pub schema_registry: Option<crate::schema::SharedSchemaRegistry>,
28}
29
30#[derive(Debug, Clone)]
36pub struct Engine {
37 shards: Vec<ShardHandle>,
38 #[cfg(feature = "protobuf")]
39 schema_registry: Option<crate::schema::SharedSchemaRegistry>,
40}
41
42impl Engine {
43 pub fn new(shard_count: usize) -> Self {
48 Self::with_config(shard_count, EngineConfig::default())
49 }
50
51 pub fn with_config(shard_count: usize, config: EngineConfig) -> Self {
58 assert!(shard_count > 0, "shard count must be at least 1");
59 assert!(
60 shard_count <= u16::MAX as usize,
61 "shard count must fit in u16"
62 );
63
64 let drop_handle = DropHandle::spawn();
65
66 let shards = (0..shard_count)
67 .map(|i| {
68 let mut shard_config = config.shard.clone();
69 shard_config.shard_id = i as u16;
70 shard::spawn_shard(
71 SHARD_BUFFER,
72 shard_config,
73 config.persistence.clone(),
74 Some(drop_handle.clone()),
75 #[cfg(feature = "protobuf")]
76 config.schema_registry.clone(),
77 )
78 })
79 .collect();
80
81 Self {
82 shards,
83 #[cfg(feature = "protobuf")]
84 schema_registry: config.schema_registry,
85 }
86 }
87
88 pub fn with_available_cores() -> Self {
92 Self::with_available_cores_config(EngineConfig::default())
93 }
94
95 pub fn with_available_cores_config(config: EngineConfig) -> Self {
98 let cores = std::thread::available_parallelism()
99 .map(|n| n.get())
100 .unwrap_or(1);
101 Self::with_config(cores, config)
102 }
103
104 #[cfg(feature = "protobuf")]
106 pub fn schema_registry(&self) -> Option<&crate::schema::SharedSchemaRegistry> {
107 self.schema_registry.as_ref()
108 }
109
110 pub fn shard_count(&self) -> usize {
112 self.shards.len()
113 }
114
115 pub async fn send_to_shard(
119 &self,
120 shard_idx: usize,
121 request: ShardRequest,
122 ) -> Result<ShardResponse, ShardError> {
123 if shard_idx >= self.shards.len() {
124 return Err(ShardError::Unavailable);
125 }
126 self.shards[shard_idx].send(request).await
127 }
128
129 pub async fn route(
131 &self,
132 key: &str,
133 request: ShardRequest,
134 ) -> Result<ShardResponse, ShardError> {
135 let idx = self.shard_for_key(key);
136 self.shards[idx].send(request).await
137 }
138
139 pub async fn broadcast<F>(&self, make_req: F) -> Result<Vec<ShardResponse>, ShardError>
145 where
146 F: Fn() -> ShardRequest,
147 {
148 let mut receivers = Vec::with_capacity(self.shards.len());
150 for shard in &self.shards {
151 receivers.push(shard.dispatch(make_req()).await?);
152 }
153
154 let mut results = Vec::with_capacity(receivers.len());
156 for rx in receivers {
157 results.push(rx.await.map_err(|_| ShardError::Unavailable)?);
158 }
159 Ok(results)
160 }
161
162 pub async fn route_multi<F>(
168 &self,
169 keys: &[String],
170 make_req: F,
171 ) -> Result<Vec<ShardResponse>, ShardError>
172 where
173 F: Fn(String) -> ShardRequest,
174 {
175 let mut receivers = Vec::with_capacity(keys.len());
176 for key in keys {
177 let idx = self.shard_for_key(key);
178 let rx = self.shards[idx].dispatch(make_req(key.clone())).await?;
179 receivers.push(rx);
180 }
181
182 let mut results = Vec::with_capacity(receivers.len());
183 for rx in receivers {
184 results.push(rx.await.map_err(|_| ShardError::Unavailable)?);
185 }
186 Ok(results)
187 }
188
189 pub fn same_shard(&self, key1: &str, key2: &str) -> bool {
191 self.shard_for_key(key1) == self.shard_for_key(key2)
192 }
193
194 pub fn shard_for_key(&self, key: &str) -> usize {
196 shard_index(key, self.shards.len())
197 }
198
199 pub async fn dispatch_to_shard(
203 &self,
204 shard_idx: usize,
205 request: ShardRequest,
206 ) -> Result<tokio::sync::oneshot::Receiver<ShardResponse>, ShardError> {
207 if shard_idx >= self.shards.len() {
208 return Err(ShardError::Unavailable);
209 }
210 self.shards[shard_idx].dispatch(request).await
211 }
212}
213
214fn shard_index(key: &str, shard_count: usize) -> usize {
225 const FNV_OFFSET: u64 = 0xcbf29ce484222325;
227 const FNV_PRIME: u64 = 0x100000001b3;
228
229 let mut hash = FNV_OFFSET;
230 for byte in key.as_bytes() {
231 hash ^= *byte as u64;
232 hash = hash.wrapping_mul(FNV_PRIME);
233 }
234 (hash as usize) % shard_count
235}
236
237#[cfg(test)]
238mod tests {
239 use super::*;
240 use crate::types::Value;
241 use bytes::Bytes;
242
243 #[test]
244 fn same_key_same_shard() {
245 let idx1 = shard_index("foo", 8);
246 let idx2 = shard_index("foo", 8);
247 assert_eq!(idx1, idx2);
248 }
249
250 #[test]
251 fn keys_spread_across_shards() {
252 let mut seen = std::collections::HashSet::new();
253 for i in 0..100 {
255 let key = format!("key:{i}");
256 seen.insert(shard_index(&key, 4));
257 }
258 assert!(seen.len() > 1, "expected keys to spread across shards");
259 }
260
261 #[test]
262 fn single_shard_always_zero() {
263 assert_eq!(shard_index("anything", 1), 0);
264 assert_eq!(shard_index("other", 1), 0);
265 }
266
267 #[tokio::test]
268 async fn engine_round_trip() {
269 let engine = Engine::new(4);
270
271 let resp = engine
272 .route(
273 "greeting",
274 ShardRequest::Set {
275 key: "greeting".into(),
276 value: Bytes::from("hello"),
277 expire: None,
278 nx: false,
279 xx: false,
280 },
281 )
282 .await
283 .unwrap();
284 assert!(matches!(resp, ShardResponse::Ok));
285
286 let resp = engine
287 .route(
288 "greeting",
289 ShardRequest::Get {
290 key: "greeting".into(),
291 },
292 )
293 .await
294 .unwrap();
295 match resp {
296 ShardResponse::Value(Some(Value::String(data))) => {
297 assert_eq!(data, Bytes::from("hello"));
298 }
299 other => panic!("expected Value(Some(String)), got {other:?}"),
300 }
301 }
302
303 #[tokio::test]
304 async fn multi_shard_del() {
305 let engine = Engine::new(4);
306
307 for key in &["a", "b", "c", "d"] {
309 engine
310 .route(
311 key,
312 ShardRequest::Set {
313 key: key.to_string(),
314 value: Bytes::from("v"),
315 expire: None,
316 nx: false,
317 xx: false,
318 },
319 )
320 .await
321 .unwrap();
322 }
323
324 let mut count = 0i64;
326 for key in &["a", "b", "c", "d", "missing"] {
327 let resp = engine
328 .route(
329 key,
330 ShardRequest::Del {
331 key: key.to_string(),
332 },
333 )
334 .await
335 .unwrap();
336 if let ShardResponse::Bool(true) = resp {
337 count += 1;
338 }
339 }
340 assert_eq!(count, 4);
341 }
342
343 #[test]
344 #[should_panic(expected = "shard count must be at least 1")]
345 fn zero_shards_panics() {
346 Engine::new(0);
347 }
348}