1use tokio::sync::broadcast;
8
9use crate::dropper::DropHandle;
10use crate::error::ShardError;
11use crate::keyspace::ShardConfig;
12use crate::shard::{
13 self, PreparedShard, ReplicationEvent, ShardHandle, ShardPersistenceConfig, ShardRequest,
14 ShardResponse,
15};
16
17const DEFAULT_SHARD_BUFFER: usize = 4096;
24
25#[derive(Debug, Clone, Default)]
27pub struct EngineConfig {
28 pub shard: ShardConfig,
30 pub persistence: Option<ShardPersistenceConfig>,
33 pub replication_tx: Option<broadcast::Sender<ReplicationEvent>>,
39 #[cfg(feature = "protobuf")]
42 pub schema_registry: Option<crate::schema::SharedSchemaRegistry>,
43 pub shard_channel_buffer: usize,
45}
46
47#[derive(Debug, Clone)]
53pub struct Engine {
54 shards: Vec<ShardHandle>,
55 replication_tx: Option<broadcast::Sender<ReplicationEvent>>,
56 #[cfg(feature = "protobuf")]
57 schema_registry: Option<crate::schema::SharedSchemaRegistry>,
58}
59
60impl Engine {
61 pub fn new(shard_count: usize) -> Self {
66 Self::with_config(shard_count, EngineConfig::default())
67 }
68
69 pub fn with_config(shard_count: usize, config: EngineConfig) -> Self {
76 assert!(shard_count > 0, "shard count must be at least 1");
77 assert!(
78 shard_count <= u16::MAX as usize,
79 "shard count must fit in u16"
80 );
81
82 let drop_handle = DropHandle::spawn();
83 let buffer = if config.shard_channel_buffer == 0 {
84 DEFAULT_SHARD_BUFFER
85 } else {
86 config.shard_channel_buffer
87 };
88
89 let shards = (0..shard_count)
90 .map(|i| {
91 let mut shard_config = config.shard.clone();
92 shard_config.shard_id = i as u16;
93 shard::spawn_shard(
94 buffer,
95 shard_config,
96 config.persistence.clone(),
97 Some(drop_handle.clone()),
98 config.replication_tx.clone(),
99 #[cfg(feature = "protobuf")]
100 config.schema_registry.clone(),
101 )
102 })
103 .collect();
104
105 Self {
106 shards,
107 replication_tx: config.replication_tx,
108 #[cfg(feature = "protobuf")]
109 schema_registry: config.schema_registry,
110 }
111 }
112
113 pub fn prepare(shard_count: usize, config: EngineConfig) -> (Self, Vec<PreparedShard>) {
122 assert!(shard_count > 0, "shard count must be at least 1");
123 assert!(
124 shard_count <= u16::MAX as usize,
125 "shard count must fit in u16"
126 );
127
128 let drop_handle = DropHandle::spawn();
129 let buffer = if config.shard_channel_buffer == 0 {
130 DEFAULT_SHARD_BUFFER
131 } else {
132 config.shard_channel_buffer
133 };
134
135 let mut handles = Vec::with_capacity(shard_count);
136 let mut prepared = Vec::with_capacity(shard_count);
137
138 for i in 0..shard_count {
139 let mut shard_config = config.shard.clone();
140 shard_config.shard_id = i as u16;
141 let (handle, shard) = shard::prepare_shard(
142 buffer,
143 shard_config,
144 config.persistence.clone(),
145 Some(drop_handle.clone()),
146 config.replication_tx.clone(),
147 #[cfg(feature = "protobuf")]
148 config.schema_registry.clone(),
149 );
150 handles.push(handle);
151 prepared.push(shard);
152 }
153
154 let engine = Self {
155 shards: handles,
156 replication_tx: config.replication_tx,
157 #[cfg(feature = "protobuf")]
158 schema_registry: config.schema_registry,
159 };
160
161 (engine, prepared)
162 }
163
164 pub fn with_available_cores() -> Self {
168 Self::with_available_cores_config(EngineConfig::default())
169 }
170
171 pub fn with_available_cores_config(config: EngineConfig) -> Self {
174 let cores = std::thread::available_parallelism()
175 .map(|n| n.get())
176 .unwrap_or(1);
177 Self::with_config(cores, config)
178 }
179
180 #[cfg(feature = "protobuf")]
182 pub fn schema_registry(&self) -> Option<&crate::schema::SharedSchemaRegistry> {
183 self.schema_registry.as_ref()
184 }
185
186 pub fn shard_count(&self) -> usize {
188 self.shards.len()
189 }
190
191 pub fn subscribe_replication(&self) -> Option<broadcast::Receiver<ReplicationEvent>> {
197 self.replication_tx.as_ref().map(|tx| tx.subscribe())
198 }
199
200 pub async fn send_to_shard(
204 &self,
205 shard_idx: usize,
206 request: ShardRequest,
207 ) -> Result<ShardResponse, ShardError> {
208 if shard_idx >= self.shards.len() {
209 return Err(ShardError::Unavailable);
210 }
211 self.shards[shard_idx].send(request).await
212 }
213
214 pub async fn route(
216 &self,
217 key: &str,
218 request: ShardRequest,
219 ) -> Result<ShardResponse, ShardError> {
220 let idx = self.shard_for_key(key);
221 self.shards[idx].send(request).await
222 }
223
224 pub async fn broadcast<F>(&self, make_req: F) -> Result<Vec<ShardResponse>, ShardError>
230 where
231 F: Fn() -> ShardRequest,
232 {
233 let mut receivers = Vec::with_capacity(self.shards.len());
235 for shard in &self.shards {
236 receivers.push(shard.dispatch(make_req()).await?);
237 }
238
239 let mut results = Vec::with_capacity(receivers.len());
241 for rx in receivers {
242 results.push(rx.await.map_err(|_| ShardError::Unavailable)?);
243 }
244 Ok(results)
245 }
246
247 pub async fn route_multi<F>(
253 &self,
254 keys: &[String],
255 make_req: F,
256 ) -> Result<Vec<ShardResponse>, ShardError>
257 where
258 F: Fn(String) -> ShardRequest,
259 {
260 let mut receivers = Vec::with_capacity(keys.len());
261 for key in keys {
262 let idx = self.shard_for_key(key);
263 let rx = self.shards[idx].dispatch(make_req(key.clone())).await?;
264 receivers.push(rx);
265 }
266
267 let mut results = Vec::with_capacity(receivers.len());
268 for rx in receivers {
269 results.push(rx.await.map_err(|_| ShardError::Unavailable)?);
270 }
271 Ok(results)
272 }
273
274 pub fn same_shard(&self, key1: &str, key2: &str) -> bool {
276 self.shard_for_key(key1) == self.shard_for_key(key2)
277 }
278
279 pub fn shard_for_key(&self, key: &str) -> usize {
281 shard_index(key, self.shards.len())
282 }
283
284 pub async fn dispatch_to_shard(
288 &self,
289 shard_idx: usize,
290 request: ShardRequest,
291 ) -> Result<tokio::sync::oneshot::Receiver<ShardResponse>, ShardError> {
292 if shard_idx >= self.shards.len() {
293 return Err(ShardError::Unavailable);
294 }
295 self.shards[shard_idx].dispatch(request).await
296 }
297
298 pub async fn dispatch_reusable_to_shard(
302 &self,
303 shard_idx: usize,
304 request: ShardRequest,
305 reply: tokio::sync::mpsc::Sender<ShardResponse>,
306 ) -> Result<(), ShardError> {
307 if shard_idx >= self.shards.len() {
308 return Err(ShardError::Unavailable);
309 }
310 self.shards[shard_idx]
311 .dispatch_reusable(request, reply)
312 .await
313 }
314
315 pub async fn dispatch_batch_to_shard(
322 &self,
323 shard_idx: usize,
324 requests: Vec<ShardRequest>,
325 ) -> Result<Vec<tokio::sync::oneshot::Receiver<ShardResponse>>, ShardError> {
326 if shard_idx >= self.shards.len() {
327 return Err(ShardError::Unavailable);
328 }
329 self.shards[shard_idx].dispatch_batch(requests).await
330 }
331}
332
333fn shard_index(key: &str, shard_count: usize) -> usize {
344 const FNV_OFFSET: u64 = 0xcbf29ce484222325;
346 const FNV_PRIME: u64 = 0x100000001b3;
347
348 let mut hash = FNV_OFFSET;
349 for byte in key.as_bytes() {
350 hash ^= *byte as u64;
351 hash = hash.wrapping_mul(FNV_PRIME);
352 }
353 (hash as usize) % shard_count
354}
355
356#[cfg(test)]
357mod tests {
358 use super::*;
359 use crate::types::Value;
360 use bytes::Bytes;
361
362 #[test]
363 fn same_key_same_shard() {
364 let idx1 = shard_index("foo", 8);
365 let idx2 = shard_index("foo", 8);
366 assert_eq!(idx1, idx2);
367 }
368
369 #[test]
370 fn keys_spread_across_shards() {
371 let mut seen = std::collections::HashSet::new();
372 for i in 0..100 {
374 let key = format!("key:{i}");
375 seen.insert(shard_index(&key, 4));
376 }
377 assert!(seen.len() > 1, "expected keys to spread across shards");
378 }
379
380 #[test]
381 fn single_shard_always_zero() {
382 assert_eq!(shard_index("anything", 1), 0);
383 assert_eq!(shard_index("other", 1), 0);
384 }
385
386 #[tokio::test]
387 async fn engine_round_trip() {
388 let engine = Engine::new(4);
389
390 let resp = engine
391 .route(
392 "greeting",
393 ShardRequest::Set {
394 key: "greeting".into(),
395 value: Bytes::from("hello"),
396 expire: None,
397 nx: false,
398 xx: false,
399 },
400 )
401 .await
402 .unwrap();
403 assert!(matches!(resp, ShardResponse::Ok));
404
405 let resp = engine
406 .route(
407 "greeting",
408 ShardRequest::Get {
409 key: "greeting".into(),
410 },
411 )
412 .await
413 .unwrap();
414 match resp {
415 ShardResponse::Value(Some(Value::String(data))) => {
416 assert_eq!(data, Bytes::from("hello"));
417 }
418 other => panic!("expected Value(Some(String)), got {other:?}"),
419 }
420 }
421
422 #[tokio::test]
423 async fn multi_shard_del() {
424 let engine = Engine::new(4);
425
426 for key in &["a", "b", "c", "d"] {
428 engine
429 .route(
430 key,
431 ShardRequest::Set {
432 key: key.to_string(),
433 value: Bytes::from("v"),
434 expire: None,
435 nx: false,
436 xx: false,
437 },
438 )
439 .await
440 .unwrap();
441 }
442
443 let mut count = 0i64;
445 for key in &["a", "b", "c", "d", "missing"] {
446 let resp = engine
447 .route(
448 key,
449 ShardRequest::Del {
450 key: key.to_string(),
451 },
452 )
453 .await
454 .unwrap();
455 if let ShardResponse::Bool(true) = resp {
456 count += 1;
457 }
458 }
459 assert_eq!(count, 4);
460 }
461
462 #[test]
463 #[should_panic(expected = "shard count must be at least 1")]
464 fn zero_shards_panics() {
465 Engine::new(0);
466 }
467}