1use std::sync::Arc;
14use std::thread;
15
16use crossbeam_channel::{Receiver, Sender};
17
18use crate::command::{Command, CommandResponse};
19use crate::hash::shard_for_key;
20use crate::shard::store::ShardStore;
21use crate::shard::wal_trait::{WalRecord, WalWriter};
22
23pub enum ShardMessage {
25 Single {
27 command: Command,
29 response_tx: ResponseSender,
31 },
32 Batch {
34 commands: Vec<(usize, Command)>,
36 response_tx: BatchResponseSender,
38 },
39}
40
41pub type ResponseSender = Sender<CommandResponse>;
43pub type ResponseReceiver = Receiver<CommandResponse>;
45type BatchResponseSender = Sender<Vec<(usize, CommandResponse)>>;
47type BatchResponseReceiver = Receiver<Vec<(usize, CommandResponse)>>;
49
50pub fn response_channel() -> (ResponseSender, ResponseReceiver) {
52 crossbeam_channel::bounded(1)
53}
54
55fn batch_response_channel() -> (BatchResponseSender, BatchResponseReceiver) {
57 crossbeam_channel::bounded(1)
58}
59
60struct WorkerHandle {
61 tx: Sender<ShardMessage>,
62 thread: Option<thread::JoinHandle<()>>,
63}
64
65pub struct ShardEngine {
69 workers: Vec<WorkerHandle>,
70 shard_count: usize,
71}
72
73impl ShardEngine {
74 pub fn new(shard_count: usize) -> Self {
76 let wal_writers: Vec<Option<Box<dyn WalWriter>>> = (0..shard_count).map(|_| None).collect();
77 Self::new_with_storage(shard_count, wal_writers)
78 }
79
80 pub fn new_with_storage(
85 shard_count: usize,
86 wal_writers: Vec<Option<Box<dyn WalWriter>>>,
87 ) -> Self {
88 Self::new_with_recovery(shard_count, wal_writers, None)
89 }
90
91 #[allow(clippy::type_complexity)]
98 pub fn new_with_recovery(
99 shard_count: usize,
100 wal_writers: Vec<Option<Box<dyn WalWriter>>>,
101 recovery_fns: Option<Vec<Box<dyn FnOnce(usize, &mut ShardStore) + Send>>>,
102 ) -> Self {
103 assert_eq!(
104 wal_writers.len(),
105 shard_count,
106 "wal_writers length must match shard_count"
107 );
108
109 let mut recovery_iter: Vec<Option<Box<dyn FnOnce(usize, &mut ShardStore) + Send>>> =
110 match recovery_fns {
111 Some(fns) => {
112 assert_eq!(fns.len(), shard_count);
113 fns.into_iter().map(Some).collect()
114 }
115 None => (0..shard_count).map(|_| None).collect(),
116 };
117
118 let mut workers = Vec::with_capacity(shard_count);
119
120 for (i, wal_writer) in wal_writers.into_iter().enumerate() {
121 let (tx, rx) = crossbeam_channel::unbounded::<ShardMessage>();
122 let recovery_fn = recovery_iter[i].take();
123 let handle = thread::Builder::new()
124 .name(format!("kora-shard-{}", i))
125 .spawn(move || {
126 let mut store = ShardStore::new(i as u16);
127 if let Some(recover) = recovery_fn {
128 recover(i, &mut store);
129 }
130 worker_loop(&mut store, &rx, wal_writer);
131 })
132 .expect("failed to spawn shard worker thread");
133
134 workers.push(WorkerHandle {
135 tx,
136 thread: Some(handle),
137 });
138 }
139
140 ShardEngine {
141 workers,
142 shard_count,
143 }
144 }
145
146 pub fn shard_count(&self) -> usize {
148 self.shard_count
149 }
150
151 pub fn dispatch(&self, cmd: Command) -> ResponseReceiver {
153 let (tx, rx) = response_channel();
154
155 if let Some(key) = cmd.key() {
156 let shard_id = shard_for_key(key, self.shard_count) as usize;
157 let _ = self.workers[shard_id].tx.send(ShardMessage::Single {
158 command: cmd,
159 response_tx: tx,
160 });
161 } else if cmd.is_multi_key() {
162 self.dispatch_multi_key(cmd, tx);
163 } else {
164 self.dispatch_keyless(cmd, tx);
165 }
166
167 rx
168 }
169
170 pub fn dispatch_blocking(&self, cmd: Command) -> CommandResponse {
172 let rx = self.dispatch(cmd);
173 rx.recv()
174 .unwrap_or(CommandResponse::Error("ERR internal error".into()))
175 }
176
177 pub fn dispatch_batch_blocking(&self, commands: Vec<Command>) -> Vec<CommandResponse> {
182 let total = commands.len();
183 if total == 0 {
184 return Vec::new();
185 }
186
187 let mut responses = vec![None; total];
188 let mut segment = Vec::new();
189
190 for (idx, command) in commands.into_iter().enumerate() {
191 if command.key().is_some() {
192 segment.push((idx, command));
193 } else {
194 if !segment.is_empty() {
195 self.execute_shard_batch(std::mem::take(&mut segment), &mut responses);
196 }
197 responses[idx] = Some(self.dispatch_blocking(command));
198 }
199 }
200
201 if !segment.is_empty() {
202 self.execute_shard_batch(segment, &mut responses);
203 }
204
205 responses
206 .into_iter()
207 .map(|resp| resp.unwrap_or(CommandResponse::Error("ERR internal error".into())))
208 .collect()
209 }
210
211 fn execute_shard_batch(
212 &self,
213 commands: Vec<(usize, Command)>,
214 responses: &mut [Option<CommandResponse>],
215 ) {
216 if commands.is_empty() {
217 return;
218 }
219
220 let mut shard_batches: Vec<Vec<(usize, Command)>> = vec![Vec::new(); self.shard_count];
221 for (idx, command) in commands {
222 let Some(key) = command.key() else {
223 responses[idx] = Some(self.dispatch_blocking(command));
224 continue;
225 };
226 let shard_id = shard_for_key(key, self.shard_count) as usize;
227 shard_batches[shard_id].push((idx, command));
228 }
229
230 let mut receivers = Vec::new();
231 for (shard_id, commands) in shard_batches.into_iter().enumerate() {
232 if commands.is_empty() {
233 continue;
234 }
235
236 let (resp_tx, resp_rx) = batch_response_channel();
237 let _ = self.workers[shard_id].tx.send(ShardMessage::Batch {
238 commands,
239 response_tx: resp_tx,
240 });
241 receivers.push(resp_rx);
242 }
243
244 for rx in receivers {
245 if let Ok(items) = rx.recv() {
246 for (idx, response) in items {
247 if let Some(slot) = responses.get_mut(idx) {
248 *slot = Some(response);
249 }
250 }
251 }
252 }
253 }
254
255 fn dispatch_multi_key(&self, cmd: Command, tx: ResponseSender) {
256 match cmd {
257 Command::MGet { keys } => {
258 let mut results = vec![CommandResponse::Nil; keys.len()];
259 let mut shard_requests: Vec<Vec<(usize, Vec<u8>)>> = vec![vec![]; self.shard_count];
260 for (i, key) in keys.iter().enumerate() {
261 let shard_id = shard_for_key(key, self.shard_count) as usize;
262 shard_requests[shard_id].push((i, key.clone()));
263 }
264
265 let mut receivers = Vec::new();
266 for (shard_id, reqs) in shard_requests.into_iter().enumerate() {
267 if reqs.is_empty() {
268 continue;
269 }
270 let shard_keys: Vec<Vec<u8>> = reqs.iter().map(|(_, k)| k.clone()).collect();
271 let indices: Vec<usize> = reqs.iter().map(|(i, _)| *i).collect();
272 let (resp_tx, resp_rx) = response_channel();
273 let _ = self.workers[shard_id].tx.send(ShardMessage::Single {
274 command: Command::MGet { keys: shard_keys },
275 response_tx: resp_tx,
276 });
277 receivers.push((indices, resp_rx));
278 }
279
280 for (indices, rx) in receivers {
281 if let Ok(CommandResponse::Array(values)) = rx.recv() {
282 for (idx, val) in indices.into_iter().zip(values) {
283 results[idx] = val;
284 }
285 }
286 }
287 let _ = tx.send(CommandResponse::Array(results));
288 }
289 Command::MSet { entries } => {
290 let mut shard_entries: Vec<Vec<(Vec<u8>, Vec<u8>)>> =
291 vec![vec![]; self.shard_count];
292 for (key, value) in entries {
293 let shard_id = shard_for_key(&key, self.shard_count) as usize;
294 shard_entries[shard_id].push((key, value));
295 }
296
297 let mut receivers = Vec::new();
298 for (shard_id, entries) in shard_entries.into_iter().enumerate() {
299 if entries.is_empty() {
300 continue;
301 }
302 let (resp_tx, resp_rx) = response_channel();
303 let _ = self.workers[shard_id].tx.send(ShardMessage::Single {
304 command: Command::MSet { entries },
305 response_tx: resp_tx,
306 });
307 receivers.push(resp_rx);
308 }
309 for rx in receivers {
310 let _ = rx.recv();
311 }
312 let _ = tx.send(CommandResponse::Ok);
313 }
314 Command::Del { keys } => {
315 let mut shard_keys: Vec<Vec<Vec<u8>>> = vec![vec![]; self.shard_count];
316 for key in keys {
317 let shard_id = shard_for_key(&key, self.shard_count) as usize;
318 shard_keys[shard_id].push(key);
319 }
320
321 let mut total = 0i64;
322 let mut receivers = Vec::new();
323 for (shard_id, keys) in shard_keys.into_iter().enumerate() {
324 if keys.is_empty() {
325 continue;
326 }
327 let (resp_tx, resp_rx) = response_channel();
328 let _ = self.workers[shard_id].tx.send(ShardMessage::Single {
329 command: Command::Del { keys },
330 response_tx: resp_tx,
331 });
332 receivers.push(resp_rx);
333 }
334 for rx in receivers {
335 if let Ok(CommandResponse::Integer(n)) = rx.recv() {
336 total += n;
337 }
338 }
339 let _ = tx.send(CommandResponse::Integer(total));
340 }
341 Command::Exists { keys } => {
342 let mut shard_keys: Vec<Vec<Vec<u8>>> = vec![vec![]; self.shard_count];
343 for key in keys {
344 let shard_id = shard_for_key(&key, self.shard_count) as usize;
345 shard_keys[shard_id].push(key);
346 }
347
348 let mut total = 0i64;
349 let mut receivers = Vec::new();
350 for (shard_id, keys) in shard_keys.into_iter().enumerate() {
351 if keys.is_empty() {
352 continue;
353 }
354 let (resp_tx, resp_rx) = response_channel();
355 let _ = self.workers[shard_id].tx.send(ShardMessage::Single {
356 command: Command::Exists { keys },
357 response_tx: resp_tx,
358 });
359 receivers.push(resp_rx);
360 }
361 for rx in receivers {
362 if let Ok(CommandResponse::Integer(n)) = rx.recv() {
363 total += n;
364 }
365 }
366 let _ = tx.send(CommandResponse::Integer(total));
367 }
368 Command::Unlink { keys } => {
369 let mut shard_keys: Vec<Vec<Vec<u8>>> = vec![vec![]; self.shard_count];
370 for key in keys {
371 let shard_id = shard_for_key(&key, self.shard_count) as usize;
372 shard_keys[shard_id].push(key);
373 }
374
375 let mut total = 0i64;
376 let mut receivers = Vec::new();
377 for (shard_id, keys) in shard_keys.into_iter().enumerate() {
378 if keys.is_empty() {
379 continue;
380 }
381 let (resp_tx, resp_rx) = response_channel();
382 let _ = self.workers[shard_id].tx.send(ShardMessage::Single {
383 command: Command::Unlink { keys },
384 response_tx: resp_tx,
385 });
386 receivers.push(resp_rx);
387 }
388 for rx in receivers {
389 if let Ok(CommandResponse::Integer(n)) = rx.recv() {
390 total += n;
391 }
392 }
393 let _ = tx.send(CommandResponse::Integer(total));
394 }
395 Command::Touch { keys } => {
396 let mut shard_keys: Vec<Vec<Vec<u8>>> = vec![vec![]; self.shard_count];
397 for key in keys {
398 let shard_id = shard_for_key(&key, self.shard_count) as usize;
399 shard_keys[shard_id].push(key);
400 }
401
402 let mut total = 0i64;
403 let mut receivers = Vec::new();
404 for (shard_id, keys) in shard_keys.into_iter().enumerate() {
405 if keys.is_empty() {
406 continue;
407 }
408 let (resp_tx, resp_rx) = response_channel();
409 let _ = self.workers[shard_id].tx.send(ShardMessage::Single {
410 command: Command::Touch { keys },
411 response_tx: resp_tx,
412 });
413 receivers.push(resp_rx);
414 }
415 for rx in receivers {
416 if let Ok(CommandResponse::Integer(n)) = rx.recv() {
417 total += n;
418 }
419 }
420 let _ = tx.send(CommandResponse::Integer(total));
421 }
422 Command::MSetNx { entries } => {
423 let mut shard_entries: Vec<Vec<(Vec<u8>, Vec<u8>)>> =
424 vec![vec![]; self.shard_count];
425 for (key, value) in entries {
426 let shard_id = shard_for_key(&key, self.shard_count) as usize;
427 shard_entries[shard_id].push((key, value));
428 }
429
430 let mut all_set = true;
431 let mut receivers = Vec::new();
432 for (shard_id, entries) in shard_entries.into_iter().enumerate() {
433 if entries.is_empty() {
434 continue;
435 }
436 let (resp_tx, resp_rx) = response_channel();
437 let _ = self.workers[shard_id].tx.send(ShardMessage::Single {
438 command: Command::MSetNx { entries },
439 response_tx: resp_tx,
440 });
441 receivers.push(resp_rx);
442 }
443 for rx in receivers {
444 if let Ok(CommandResponse::Integer(0)) = rx.recv() {
445 all_set = false;
446 }
447 }
448 let _ = tx.send(CommandResponse::Integer(if all_set { 1 } else { 0 }));
449 }
450 _ => {
451 let _ = tx.send(CommandResponse::Error(
452 "ERR unsupported multi-key command".into(),
453 ));
454 }
455 }
456 }
457
458 fn dispatch_vec_query(&self, cmd: Command, tx: ResponseSender) {
459 let (key, k, vector) = match cmd {
460 Command::VecQuery { key, k, vector } => (key, k, vector),
461 _ => {
462 let _ = tx.send(CommandResponse::Error(
463 "ERR internal: not a VecQuery".into(),
464 ));
465 return;
466 }
467 };
468
469 let mut receivers = Vec::with_capacity(self.shard_count);
470 for worker in &self.workers {
471 let (resp_tx, resp_rx) = response_channel();
472 let _ = worker.tx.send(ShardMessage::Single {
473 command: Command::VecQuery {
474 key: key.clone(),
475 k,
476 vector: vector.clone(),
477 },
478 response_tx: resp_tx,
479 });
480 receivers.push(resp_rx);
481 }
482
483 let mut all_results: Vec<(i64, Vec<u8>)> = Vec::new();
484 for rx in receivers {
485 if let Ok(CommandResponse::Array(items)) = rx.recv() {
486 for item in items {
487 if let CommandResponse::Array(pair) = item {
488 if pair.len() == 2 {
489 if let (
490 CommandResponse::Integer(id),
491 CommandResponse::BulkString(dist),
492 ) = (&pair[0], &pair[1])
493 {
494 all_results.push((*id, dist.clone()));
495 }
496 }
497 }
498 }
499 }
500 }
501
502 all_results.sort_by(|a, b| {
503 let da: f64 = String::from_utf8_lossy(&a.1).parse().unwrap_or(f64::MAX);
504 let db: f64 = String::from_utf8_lossy(&b.1).parse().unwrap_or(f64::MAX);
505 da.partial_cmp(&db).unwrap_or(std::cmp::Ordering::Equal)
506 });
507 all_results.truncate(k);
508
509 let items: Vec<CommandResponse> = all_results
510 .into_iter()
511 .map(|(id, dist)| {
512 CommandResponse::Array(vec![
513 CommandResponse::Integer(id),
514 CommandResponse::BulkString(dist),
515 ])
516 })
517 .collect();
518 let _ = tx.send(CommandResponse::Array(items));
519 }
520
521 fn dispatch_keyless(&self, cmd: Command, tx: ResponseSender) {
522 match cmd {
523 Command::VecQuery { .. } => self.dispatch_vec_query(cmd, tx),
524 Command::DbSize => {
525 let mut total = 0i64;
526 let mut receivers = Vec::new();
527 for worker in &self.workers {
528 let (resp_tx, resp_rx) = response_channel();
529 let _ = worker.tx.send(ShardMessage::Single {
530 command: Command::DbSize,
531 response_tx: resp_tx,
532 });
533 receivers.push(resp_rx);
534 }
535 for rx in receivers {
536 if let Ok(CommandResponse::Integer(n)) = rx.recv() {
537 total += n;
538 }
539 }
540 let _ = tx.send(CommandResponse::Integer(total));
541 }
542 Command::FlushDb | Command::FlushAll => {
543 let mut receivers = Vec::new();
544 for worker in &self.workers {
545 let (resp_tx, resp_rx) = response_channel();
546 let _ = worker.tx.send(ShardMessage::Single {
547 command: Command::FlushDb,
548 response_tx: resp_tx,
549 });
550 receivers.push(resp_rx);
551 }
552 for rx in receivers {
553 let _ = rx.recv();
554 }
555 let _ = tx.send(CommandResponse::Ok);
556 }
557 Command::Dump => {
558 let mut all_entries = Vec::new();
559 let mut receivers = Vec::new();
560 for worker in &self.workers {
561 let (resp_tx, resp_rx) = response_channel();
562 let _ = worker.tx.send(ShardMessage::Single {
563 command: Command::Dump,
564 response_tx: resp_tx,
565 });
566 receivers.push(resp_rx);
567 }
568 for rx in receivers {
569 if let Ok(CommandResponse::Array(entries)) = rx.recv() {
570 all_entries.extend(entries);
571 }
572 }
573 let _ = tx.send(CommandResponse::Array(all_entries));
574 }
575 Command::Keys { pattern } => {
576 let mut all_keys = Vec::new();
577 let mut receivers = Vec::new();
578 for worker in &self.workers {
579 let (resp_tx, resp_rx) = response_channel();
580 let _ = worker.tx.send(ShardMessage::Single {
581 command: Command::Keys {
582 pattern: pattern.clone(),
583 },
584 response_tx: resp_tx,
585 });
586 receivers.push(resp_rx);
587 }
588 for rx in receivers {
589 if let Ok(CommandResponse::Array(keys)) = rx.recv() {
590 all_keys.extend(keys);
591 }
592 }
593 let _ = tx.send(CommandResponse::Array(all_keys));
594 }
595 Command::Scan {
596 cursor,
597 pattern,
598 count,
599 } => {
600 let pattern = pattern.unwrap_or_else(|| "*".to_string());
601 let mut merged_keys: Vec<Vec<u8>> = Vec::new();
602 let mut receivers = Vec::new();
603 for worker in &self.workers {
604 let (resp_tx, resp_rx) = response_channel();
605 let _ = worker.tx.send(ShardMessage::Single {
606 command: Command::Keys {
607 pattern: pattern.clone(),
608 },
609 response_tx: resp_tx,
610 });
611 receivers.push(resp_rx);
612 }
613 for rx in receivers {
614 if let Ok(CommandResponse::Array(keys)) = rx.recv() {
615 for key in keys {
616 if let CommandResponse::BulkString(raw) = key {
617 merged_keys.push(raw);
618 }
619 }
620 }
621 }
622
623 merged_keys.sort();
624 let start = cursor as usize;
625 let limit = count.unwrap_or(10).max(1);
626 if start >= merged_keys.len() {
627 let _ = tx.send(CommandResponse::Array(vec![
628 CommandResponse::BulkString(b"0".to_vec()),
629 CommandResponse::Array(vec![]),
630 ]));
631 return;
632 }
633 let end = start.saturating_add(limit).min(merged_keys.len());
634 let result_keys: Vec<CommandResponse> = merged_keys[start..end]
635 .iter()
636 .map(|k| CommandResponse::BulkString(k.clone()))
637 .collect();
638 let next_cursor = if end >= merged_keys.len() { 0 } else { end };
639 let _ = tx.send(CommandResponse::Array(vec![
640 CommandResponse::BulkString(next_cursor.to_string().into_bytes()),
641 CommandResponse::Array(result_keys),
642 ]));
643 }
644 Command::StatsHotkeys { count } => {
645 let mut all_hot: Vec<(Vec<u8>, i64)> = Vec::new();
646 let mut receivers = Vec::new();
647 for worker in &self.workers {
648 let (resp_tx, resp_rx) = response_channel();
649 let _ = worker.tx.send(ShardMessage::Single {
650 command: Command::StatsHotkeys { count },
651 response_tx: resp_tx,
652 });
653 receivers.push(resp_rx);
654 }
655 for rx in receivers {
656 if let Ok(CommandResponse::Array(hotkeys)) = rx.recv() {
657 for hotkey in hotkeys {
658 if let CommandResponse::Array(pair) = hotkey {
659 if pair.len() != 2 {
660 continue;
661 }
662 if let (
663 CommandResponse::BulkString(key),
664 CommandResponse::Integer(freq),
665 ) = (&pair[0], &pair[1])
666 {
667 all_hot.push((key.clone(), *freq));
668 }
669 }
670 }
671 }
672 }
673 all_hot.sort_by(|a, b| b.1.cmp(&a.1).then_with(|| a.0.cmp(&b.0)));
674 all_hot.truncate(count);
675 let items = all_hot
676 .into_iter()
677 .map(|(key, freq)| {
678 CommandResponse::Array(vec![
679 CommandResponse::BulkString(key),
680 CommandResponse::Integer(freq),
681 ])
682 })
683 .collect();
684 let _ = tx.send(CommandResponse::Array(items));
685 }
686 Command::StatsMemory { prefixes } => {
687 let mut totals = vec![0i64; prefixes.len()];
688 let mut receivers = Vec::new();
689 for worker in &self.workers {
690 let (resp_tx, resp_rx) = response_channel();
691 let _ = worker.tx.send(ShardMessage::Single {
692 command: Command::StatsMemory {
693 prefixes: prefixes.clone(),
694 },
695 response_tx: resp_tx,
696 });
697 receivers.push(resp_rx);
698 }
699 for rx in receivers {
700 if let Ok(CommandResponse::Array(items)) = rx.recv() {
701 for (idx, item) in items.into_iter().enumerate().take(totals.len()) {
702 if let CommandResponse::Integer(val) = item {
703 totals[idx] = totals[idx].saturating_add(val);
704 }
705 }
706 }
707 }
708 let merged = totals.into_iter().map(CommandResponse::Integer).collect();
709 let _ = tx.send(CommandResponse::Array(merged));
710 }
711 Command::RandomKey => {
712 for worker in &self.workers {
713 let (resp_tx, resp_rx) = response_channel();
714 let _ = worker.tx.send(ShardMessage::Single {
715 command: Command::RandomKey,
716 response_tx: resp_tx,
717 });
718 if let Ok(resp) = resp_rx.recv() {
719 if !matches!(resp, CommandResponse::Nil) {
720 let _ = tx.send(resp);
721 return;
722 }
723 }
724 }
725 let _ = tx.send(CommandResponse::Nil);
726 }
727 _ => {
728 let _ = self.workers[0].tx.send(ShardMessage::Single {
729 command: cmd,
730 response_tx: tx,
731 });
732 }
733 }
734 }
735}
736
737impl Drop for ShardEngine {
738 fn drop(&mut self) {
739 for worker in &mut self.workers {
740 let (dummy, _) = crossbeam_channel::unbounded();
741 let old_tx = std::mem::replace(&mut worker.tx, dummy);
742 drop(old_tx);
743 }
744 for worker in &mut self.workers {
745 if let Some(handle) = worker.thread.take() {
746 let _ = handle.join();
747 }
748 }
749 }
750}
751
752fn worker_loop(
754 store: &mut ShardStore,
755 rx: &Receiver<ShardMessage>,
756 mut wal_writer: Option<Box<dyn WalWriter>>,
757) {
758 const EXPIRE_SWEEP_INTERVAL_OPS: u32 = 4096;
759 const EXPIRE_SWEEP_SAMPLE_SIZE: usize = 64;
760
761 fn maybe_sweep(store: &mut ShardStore, ops_since_expire: &mut u32) {
762 *ops_since_expire += 1;
763 if *ops_since_expire >= EXPIRE_SWEEP_INTERVAL_OPS {
764 let _ = store.evict_expired_sample(EXPIRE_SWEEP_SAMPLE_SIZE);
765 *ops_since_expire = 0;
766 }
767 }
768
769 fn execute_with_wal(
770 store: &mut ShardStore,
771 wal_writer: Option<&mut Box<dyn WalWriter>>,
772 command: Command,
773 ) -> CommandResponse {
774 if let Some(writer) = wal_writer {
775 if command.is_mutation() {
776 if let Some(record) = command_to_wal_record(&command) {
777 writer.append(&record);
778 }
779 }
780 }
781 store.execute(command)
782 }
783
784 let mut ops_since_expire = 0u32;
785 while let Ok(msg) = rx.recv() {
786 match msg {
787 ShardMessage::Single {
788 command,
789 response_tx,
790 } => {
791 maybe_sweep(store, &mut ops_since_expire);
792 let response = execute_with_wal(store, wal_writer.as_mut(), command);
793 let _ = response_tx.send(response);
794 }
795 ShardMessage::Batch {
796 commands,
797 response_tx,
798 } => {
799 let mut responses = Vec::with_capacity(commands.len());
800 for (idx, command) in commands {
801 maybe_sweep(store, &mut ops_since_expire);
802 let response = execute_with_wal(store, wal_writer.as_mut(), command);
803 responses.push((idx, response));
804 }
805 let _ = response_tx.send(responses);
806 }
807 }
808 }
809}
810
811pub fn command_to_wal_record(cmd: &Command) -> Option<WalRecord> {
813 match cmd {
814 Command::Set {
815 key, value, ex, px, ..
816 } => {
817 let ttl_ms = if let Some(s) = ex {
818 Some(s * 1000)
819 } else {
820 *px
821 };
822 Some(WalRecord::Set {
823 key: key.clone(),
824 value: value.clone(),
825 ttl_ms,
826 })
827 }
828 Command::SetNx { key, value } => Some(WalRecord::Set {
829 key: key.clone(),
830 value: value.clone(),
831 ttl_ms: None,
832 }),
833 Command::GetSet { key, value } => Some(WalRecord::Set {
834 key: key.clone(),
835 value: value.clone(),
836 ttl_ms: None,
837 }),
838 Command::Append { key, value } => Some(WalRecord::Set {
839 key: key.clone(),
840 value: value.clone(),
841 ttl_ms: None,
842 }),
843 Command::Del { keys } => keys.first().map(|key| WalRecord::Del { key: key.clone() }),
844 Command::Expire { key, seconds } => Some(WalRecord::Expire {
845 key: key.clone(),
846 ttl_ms: seconds * 1000,
847 }),
848 Command::PExpire { key, millis } => Some(WalRecord::Expire {
849 key: key.clone(),
850 ttl_ms: *millis,
851 }),
852 Command::LPush { key, values } => Some(WalRecord::LPush {
853 key: key.clone(),
854 values: values.clone(),
855 }),
856 Command::RPush { key, values } => Some(WalRecord::RPush {
857 key: key.clone(),
858 values: values.clone(),
859 }),
860 Command::HSet { key, fields } => Some(WalRecord::HSet {
861 key: key.clone(),
862 fields: fields.clone(),
863 }),
864 Command::SAdd { key, members } => Some(WalRecord::SAdd {
865 key: key.clone(),
866 members: members.clone(),
867 }),
868 Command::FlushDb | Command::FlushAll => Some(WalRecord::FlushDb),
869 Command::DocSet {
870 collection,
871 doc_id,
872 json,
873 } => Some(WalRecord::DocSet {
874 collection: collection.clone(),
875 doc_id: doc_id.clone(),
876 json: json.clone(),
877 }),
878 Command::DocInsert { .. } => None,
879 Command::DocDel { collection, doc_id } => Some(WalRecord::DocDel {
880 collection: collection.clone(),
881 doc_id: doc_id.clone(),
882 }),
883 Command::DocMSet { .. } => None,
884 Command::VecSet {
885 key,
886 dimensions,
887 vector,
888 } => {
889 let mut vec_bytes = Vec::with_capacity(vector.len() * 4);
890 for &f in vector {
891 vec_bytes.extend_from_slice(&f.to_le_bytes());
892 }
893 Some(WalRecord::VecSet {
894 key: key.clone(),
895 dimensions: *dimensions,
896 vector: vec_bytes,
897 })
898 }
899 Command::VecDel { key } => Some(WalRecord::VecDel { key: key.clone() }),
900 _ => None,
901 }
902}
903
904pub type SharedEngine = Arc<ShardEngine>;
906
907#[cfg(test)]
908mod tests {
909 use super::*;
910
911 #[test]
912 fn test_engine_basic() {
913 let engine = ShardEngine::new(4);
914 let resp = engine.dispatch_blocking(Command::Ping { message: None });
915 assert!(matches!(resp, CommandResponse::SimpleString(s) if s == "PONG"));
916 }
917
918 #[test]
919 fn test_engine_set_get() {
920 let engine = ShardEngine::new(4);
921 engine.dispatch_blocking(Command::Set {
922 key: b"hello".to_vec(),
923 value: b"world".to_vec(),
924 ex: None,
925 px: None,
926 nx: false,
927 xx: false,
928 });
929 match engine.dispatch_blocking(Command::Get {
930 key: b"hello".to_vec(),
931 }) {
932 CommandResponse::BulkString(v) => assert_eq!(v, b"world"),
933 other => panic!("Expected 'world', got {:?}", other),
934 }
935 }
936
937 #[test]
938 fn test_engine_mget_across_shards() {
939 let engine = ShardEngine::new(4);
940 for i in 0..20 {
942 engine.dispatch_blocking(Command::Set {
943 key: format!("key:{}", i).into_bytes(),
944 value: format!("val:{}", i).into_bytes(),
945 ex: None,
946 px: None,
947 nx: false,
948 xx: false,
949 });
950 }
951 let keys: Vec<Vec<u8>> = (0..20).map(|i| format!("key:{}", i).into_bytes()).collect();
952 match engine.dispatch_blocking(Command::MGet { keys }) {
953 CommandResponse::Array(values) => {
954 assert_eq!(values.len(), 20);
955 for (i, v) in values.iter().enumerate() {
956 match v {
957 CommandResponse::BulkString(b) => {
958 assert_eq!(*b, format!("val:{}", i).into_bytes());
959 }
960 other => panic!("Expected BulkString for key:{}, got {:?}", i, other),
961 }
962 }
963 }
964 other => panic!("Expected Array, got {:?}", other),
965 }
966 }
967
968 #[test]
969 fn test_engine_dbsize() {
970 let engine = ShardEngine::new(4);
971 for i in 0..10 {
972 engine.dispatch_blocking(Command::Set {
973 key: format!("k{}", i).into_bytes(),
974 value: b"v".to_vec(),
975 ex: None,
976 px: None,
977 nx: false,
978 xx: false,
979 });
980 }
981 match engine.dispatch_blocking(Command::DbSize) {
982 CommandResponse::Integer(10) => {}
983 other => panic!("Expected 10, got {:?}", other),
984 }
985 }
986
987 #[test]
988 fn test_engine_concurrent_access() {
989 let engine = Arc::new(ShardEngine::new(4));
990 let mut handles = Vec::new();
991
992 for t in 0..8 {
993 let eng = engine.clone();
994 handles.push(thread::spawn(move || {
995 for i in 0..100 {
996 let key = format!("t{}:k{}", t, i).into_bytes();
997 let val = format!("v{}", i).into_bytes();
998 eng.dispatch_blocking(Command::Set {
999 key: key.clone(),
1000 value: val.clone(),
1001 ex: None,
1002 px: None,
1003 nx: false,
1004 xx: false,
1005 });
1006 match eng.dispatch_blocking(Command::Get { key }) {
1007 CommandResponse::BulkString(v) => assert_eq!(v, val),
1008 CommandResponse::Nil => {} other => panic!("Unexpected: {:?}", other),
1010 }
1011 }
1012 }));
1013 }
1014
1015 for h in handles {
1016 h.join().unwrap();
1017 }
1018 }
1019
1020 #[test]
1021 fn test_dispatch_batch_blocking_preserves_barrier_order() {
1022 let engine = ShardEngine::new(4);
1023 let responses = engine.dispatch_batch_blocking(vec![
1024 Command::Set {
1025 key: b"k".to_vec(),
1026 value: b"v".to_vec(),
1027 ex: None,
1028 px: None,
1029 nx: false,
1030 xx: false,
1031 },
1032 Command::FlushDb,
1033 Command::Get { key: b"k".to_vec() },
1034 ]);
1035
1036 assert_eq!(responses.len(), 3);
1037 assert!(matches!(responses[0], CommandResponse::Ok));
1038 assert!(matches!(responses[1], CommandResponse::Ok));
1039 assert!(matches!(responses[2], CommandResponse::Nil));
1040 }
1041
1042 #[test]
1043 fn test_engine_shutdown() {
1044 let engine = ShardEngine::new(2);
1045 engine.dispatch_blocking(Command::Set {
1046 key: b"k".to_vec(),
1047 value: b"v".to_vec(),
1048 ex: None,
1049 px: None,
1050 nx: false,
1051 xx: false,
1052 });
1053 drop(engine); }
1055
1056 #[test]
1057 fn test_vec_set_and_query_per_shard() {
1058 let engine = ShardEngine::new(4);
1059
1060 let vector1 = vec![1.0f32, 0.0, 0.0, 0.0];
1061 let vector2 = vec![0.0f32, 1.0, 0.0, 0.0];
1062 let vector3 = vec![1.0f32, 1.0, 0.0, 0.0];
1063
1064 let resp1 = engine.dispatch_blocking(Command::VecSet {
1065 key: b"idx".to_vec(),
1066 dimensions: 4,
1067 vector: vector1.clone(),
1068 });
1069 assert!(matches!(resp1, CommandResponse::Integer(_)));
1070
1071 let resp2 = engine.dispatch_blocking(Command::VecSet {
1072 key: b"idx".to_vec(),
1073 dimensions: 4,
1074 vector: vector2,
1075 });
1076 assert!(matches!(resp2, CommandResponse::Integer(_)));
1077
1078 let resp3 = engine.dispatch_blocking(Command::VecSet {
1079 key: b"idx".to_vec(),
1080 dimensions: 4,
1081 vector: vector3,
1082 });
1083 assert!(matches!(resp3, CommandResponse::Integer(_)));
1084
1085 let query_resp = engine.dispatch_blocking(Command::VecQuery {
1086 key: b"idx".to_vec(),
1087 k: 3,
1088 vector: vector1,
1089 });
1090 match query_resp {
1091 CommandResponse::Array(results) => {
1092 assert!(!results.is_empty(), "VecQuery should return results");
1093 assert!(results.len() <= 3);
1094 }
1095 other => panic!("Expected Array, got {:?}", other),
1096 }
1097 }
1098
1099 #[test]
1100 fn test_vec_del() {
1101 let engine = ShardEngine::new(2);
1102
1103 engine.dispatch_blocking(Command::VecSet {
1104 key: b"myidx".to_vec(),
1105 dimensions: 3,
1106 vector: vec![1.0, 2.0, 3.0],
1107 });
1108
1109 let del_resp = engine.dispatch_blocking(Command::VecDel {
1110 key: b"myidx".to_vec(),
1111 });
1112 assert!(matches!(del_resp, CommandResponse::Integer(1)));
1113
1114 let del_again = engine.dispatch_blocking(Command::VecDel {
1115 key: b"myidx".to_vec(),
1116 });
1117 assert!(matches!(del_again, CommandResponse::Integer(0)));
1118
1119 let query_resp = engine.dispatch_blocking(Command::VecQuery {
1120 key: b"myidx".to_vec(),
1121 k: 5,
1122 vector: vec![1.0, 2.0, 3.0],
1123 });
1124 match query_resp {
1125 CommandResponse::Array(results) => assert!(results.is_empty()),
1126 other => panic!("Expected empty Array, got {:?}", other),
1127 }
1128 }
1129
1130 #[test]
1131 fn test_vec_query_fan_out() {
1132 let engine = ShardEngine::new(4);
1133
1134 for i in 0..10 {
1135 let v: Vec<f32> = (0..8).map(|d| (i * 8 + d) as f32 * 0.1).collect();
1136 engine.dispatch_blocking(Command::VecSet {
1137 key: b"fanout-idx".to_vec(),
1138 dimensions: 8,
1139 vector: v,
1140 });
1141 }
1142
1143 let query = vec![0.0f32; 8];
1144 let resp = engine.dispatch_blocking(Command::VecQuery {
1145 key: b"fanout-idx".to_vec(),
1146 k: 5,
1147 vector: query,
1148 });
1149 match resp {
1150 CommandResponse::Array(results) => {
1151 assert!(results.len() <= 5);
1152 }
1153 other => panic!("Expected Array, got {:?}", other),
1154 }
1155 }
1156}