1use tokio::sync::broadcast;
8
9use crate::dropper::DropHandle;
10use crate::error::ShardError;
11use crate::keyspace::ShardConfig;
12use crate::shard::{
13 self, PreparedShard, ReplicationEvent, ShardHandle, ShardPersistenceConfig, ShardRequest,
14 ShardResponse,
15};
16
17const DEFAULT_SHARD_BUFFER: usize = 4096;
24
25#[derive(Debug, Clone, Default)]
27pub struct EngineConfig {
28 pub shard: ShardConfig,
30 pub persistence: Option<ShardPersistenceConfig>,
33 pub replication_tx: Option<broadcast::Sender<ReplicationEvent>>,
39 pub expired_tx: Option<broadcast::Sender<String>>,
41 #[cfg(feature = "protobuf")]
44 pub schema_registry: Option<crate::schema::SharedSchemaRegistry>,
45 pub shard_channel_buffer: usize,
47}
48
49#[derive(Debug, Clone)]
55pub struct Engine {
56 shards: Vec<ShardHandle>,
57 replication_tx: Option<broadcast::Sender<ReplicationEvent>>,
58 expired_tx: Option<broadcast::Sender<String>>,
59 #[cfg(feature = "protobuf")]
60 schema_registry: Option<crate::schema::SharedSchemaRegistry>,
61}
62
63impl Engine {
64 pub fn new(shard_count: usize) -> Self {
69 Self::with_config(shard_count, EngineConfig::default())
70 }
71
72 pub fn with_config(shard_count: usize, config: EngineConfig) -> Self {
79 assert!(shard_count > 0, "shard count must be at least 1");
80 assert!(
81 shard_count <= u16::MAX as usize,
82 "shard count must fit in u16"
83 );
84
85 let drop_handle = DropHandle::spawn();
86 let buffer = if config.shard_channel_buffer == 0 {
87 DEFAULT_SHARD_BUFFER
88 } else {
89 config.shard_channel_buffer
90 };
91
92 let shards = (0..shard_count)
93 .map(|i| {
94 let mut shard_config = config.shard.clone();
95 shard_config.shard_id = i as u16;
96 shard::spawn_shard(
97 buffer,
98 shard_config,
99 config.persistence.clone(),
100 Some(drop_handle.clone()),
101 config.replication_tx.clone(),
102 config.expired_tx.clone(),
103 #[cfg(feature = "protobuf")]
104 config.schema_registry.clone(),
105 )
106 })
107 .collect();
108
109 Self {
110 shards,
111 replication_tx: config.replication_tx,
112 expired_tx: config.expired_tx,
113 #[cfg(feature = "protobuf")]
114 schema_registry: config.schema_registry,
115 }
116 }
117
118 pub fn prepare(shard_count: usize, config: EngineConfig) -> (Self, Vec<PreparedShard>) {
127 assert!(shard_count > 0, "shard count must be at least 1");
128 assert!(
129 shard_count <= u16::MAX as usize,
130 "shard count must fit in u16"
131 );
132
133 let drop_handle = DropHandle::spawn();
134 let buffer = if config.shard_channel_buffer == 0 {
135 DEFAULT_SHARD_BUFFER
136 } else {
137 config.shard_channel_buffer
138 };
139
140 let mut handles = Vec::with_capacity(shard_count);
141 let mut prepared = Vec::with_capacity(shard_count);
142
143 for i in 0..shard_count {
144 let mut shard_config = config.shard.clone();
145 shard_config.shard_id = i as u16;
146 let (handle, shard) = shard::prepare_shard(
147 buffer,
148 shard_config,
149 config.persistence.clone(),
150 Some(drop_handle.clone()),
151 config.replication_tx.clone(),
152 config.expired_tx.clone(),
153 #[cfg(feature = "protobuf")]
154 config.schema_registry.clone(),
155 );
156 handles.push(handle);
157 prepared.push(shard);
158 }
159
160 let engine = Self {
161 shards: handles,
162 replication_tx: config.replication_tx,
163 expired_tx: config.expired_tx,
164 #[cfg(feature = "protobuf")]
165 schema_registry: config.schema_registry,
166 };
167
168 (engine, prepared)
169 }
170
171 pub fn with_available_cores() -> Self {
175 Self::with_available_cores_config(EngineConfig::default())
176 }
177
178 pub fn with_available_cores_config(config: EngineConfig) -> Self {
181 let cores = std::thread::available_parallelism()
182 .map(|n| n.get())
183 .unwrap_or(1);
184 Self::with_config(cores, config)
185 }
186
187 #[cfg(feature = "protobuf")]
189 pub fn schema_registry(&self) -> Option<&crate::schema::SharedSchemaRegistry> {
190 self.schema_registry.as_ref()
191 }
192
193 pub fn shard_count(&self) -> usize {
195 self.shards.len()
196 }
197
198 pub fn subscribe_replication(&self) -> Option<broadcast::Receiver<ReplicationEvent>> {
204 self.replication_tx.as_ref().map(|tx| tx.subscribe())
205 }
206
207 pub fn subscribe_expired(&self) -> Option<broadcast::Receiver<String>> {
212 self.expired_tx.as_ref().map(|tx| tx.subscribe())
213 }
214
215 pub async fn send_to_shard(
219 &self,
220 shard_idx: usize,
221 request: ShardRequest,
222 ) -> Result<ShardResponse, ShardError> {
223 if shard_idx >= self.shards.len() {
224 return Err(ShardError::Unavailable);
225 }
226 self.shards[shard_idx].send(request).await
227 }
228
229 pub async fn route(
231 &self,
232 key: &str,
233 request: ShardRequest,
234 ) -> Result<ShardResponse, ShardError> {
235 let idx = self.shard_for_key(key);
236 self.shards[idx].send(request).await
237 }
238
239 pub async fn broadcast<F>(&self, make_req: F) -> Result<Vec<ShardResponse>, ShardError>
245 where
246 F: Fn() -> ShardRequest,
247 {
248 let mut receivers = Vec::with_capacity(self.shards.len());
250 for shard in &self.shards {
251 receivers.push(shard.dispatch(make_req()).await?);
252 }
253
254 let mut results = Vec::with_capacity(receivers.len());
256 for rx in receivers {
257 results.push(rx.await.map_err(|_| ShardError::Unavailable)?);
258 }
259 Ok(results)
260 }
261
262 pub async fn route_multi<F>(
268 &self,
269 keys: &[String],
270 make_req: F,
271 ) -> Result<Vec<ShardResponse>, ShardError>
272 where
273 F: Fn(String) -> ShardRequest,
274 {
275 let mut receivers = Vec::with_capacity(keys.len());
276 for key in keys {
277 let idx = self.shard_for_key(key);
278 let rx = self.shards[idx].dispatch(make_req(key.clone())).await?;
279 receivers.push(rx);
280 }
281
282 let mut results = Vec::with_capacity(receivers.len());
283 for rx in receivers {
284 results.push(rx.await.map_err(|_| ShardError::Unavailable)?);
285 }
286 Ok(results)
287 }
288
289 pub fn same_shard(&self, key1: &str, key2: &str) -> bool {
291 self.shard_for_key(key1) == self.shard_for_key(key2)
292 }
293
294 pub fn shard_for_key(&self, key: &str) -> usize {
296 shard_index(key, self.shards.len())
297 }
298
299 pub async fn dispatch_to_shard(
303 &self,
304 shard_idx: usize,
305 request: ShardRequest,
306 ) -> Result<tokio::sync::oneshot::Receiver<ShardResponse>, ShardError> {
307 if shard_idx >= self.shards.len() {
308 return Err(ShardError::Unavailable);
309 }
310 self.shards[shard_idx].dispatch(request).await
311 }
312
313 pub async fn dispatch_reusable_to_shard(
317 &self,
318 shard_idx: usize,
319 request: ShardRequest,
320 reply: tokio::sync::mpsc::Sender<ShardResponse>,
321 ) -> Result<(), ShardError> {
322 if shard_idx >= self.shards.len() {
323 return Err(ShardError::Unavailable);
324 }
325 self.shards[shard_idx]
326 .dispatch_reusable(request, reply)
327 .await
328 }
329
330 pub async fn dispatch_batch_to_shard(
337 &self,
338 shard_idx: usize,
339 requests: Vec<ShardRequest>,
340 ) -> Result<Vec<tokio::sync::oneshot::Receiver<ShardResponse>>, ShardError> {
341 if shard_idx >= self.shards.len() {
342 return Err(ShardError::Unavailable);
343 }
344 self.shards[shard_idx].dispatch_batch(requests).await
345 }
346}
347
348fn shard_index(key: &str, shard_count: usize) -> usize {
359 const FNV_OFFSET: u64 = 0xcbf29ce484222325;
361 const FNV_PRIME: u64 = 0x100000001b3;
362
363 let mut hash = FNV_OFFSET;
364 for byte in key.as_bytes() {
365 hash ^= *byte as u64;
366 hash = hash.wrapping_mul(FNV_PRIME);
367 }
368 (hash as usize) % shard_count
369}
370
371#[cfg(test)]
372mod tests {
373 use super::*;
374 use crate::types::Value;
375 use bytes::Bytes;
376
377 #[test]
378 fn same_key_same_shard() {
379 let idx1 = shard_index("foo", 8);
380 let idx2 = shard_index("foo", 8);
381 assert_eq!(idx1, idx2);
382 }
383
384 #[test]
385 fn keys_spread_across_shards() {
386 let mut seen = std::collections::HashSet::new();
387 for i in 0..100 {
389 let key = format!("key:{i}");
390 seen.insert(shard_index(&key, 4));
391 }
392 assert!(seen.len() > 1, "expected keys to spread across shards");
393 }
394
395 #[test]
396 fn single_shard_always_zero() {
397 assert_eq!(shard_index("anything", 1), 0);
398 assert_eq!(shard_index("other", 1), 0);
399 }
400
401 #[tokio::test]
402 async fn engine_round_trip() {
403 let engine = Engine::new(4);
404
405 let resp = engine
406 .route(
407 "greeting",
408 ShardRequest::Set {
409 key: "greeting".into(),
410 value: Bytes::from("hello"),
411 expire: None,
412 nx: false,
413 xx: false,
414 },
415 )
416 .await
417 .unwrap();
418 assert!(matches!(resp, ShardResponse::Ok));
419
420 let resp = engine
421 .route(
422 "greeting",
423 ShardRequest::Get {
424 key: "greeting".into(),
425 },
426 )
427 .await
428 .unwrap();
429 match resp {
430 ShardResponse::Value(Some(Value::String(data))) => {
431 assert_eq!(data, Bytes::from("hello"));
432 }
433 other => panic!("expected Value(Some(String)), got {other:?}"),
434 }
435 }
436
437 #[tokio::test]
438 async fn multi_shard_del() {
439 let engine = Engine::new(4);
440
441 for key in &["a", "b", "c", "d"] {
443 engine
444 .route(
445 key,
446 ShardRequest::Set {
447 key: key.to_string(),
448 value: Bytes::from("v"),
449 expire: None,
450 nx: false,
451 xx: false,
452 },
453 )
454 .await
455 .unwrap();
456 }
457
458 let mut count = 0i64;
460 for key in &["a", "b", "c", "d", "missing"] {
461 let resp = engine
462 .route(
463 key,
464 ShardRequest::Del {
465 key: key.to_string(),
466 },
467 )
468 .await
469 .unwrap();
470 if let ShardResponse::Bool(true) = resp {
471 count += 1;
472 }
473 }
474 assert_eq!(count, 4);
475 }
476
477 #[test]
478 #[should_panic(expected = "shard count must be at least 1")]
479 fn zero_shards_panics() {
480 Engine::new(0);
481 }
482}