Skip to main content

kora_core/shard/
engine.rs

1//! Shard engine — coordinates worker threads and routes commands.
2//!
3//! `ShardEngine` spawns N OS threads, each running a tight event loop over a
4//! crossbeam channel. Incoming commands are routed by key hash to the owning
5//! shard worker. Multi-key commands (MGET, MSET, DEL, EXISTS, etc.) are fanned
6//! out to all relevant shards and their responses merged. Keyless commands
7//! (PING, DBSIZE, FLUSHDB, KEYS, SCAN) are either broadcast or delegated to
8//! shard 0.
9//!
10//! Each worker loop periodically samples expired keys (lazy + sweep) and,
11//! when a WAL writer is attached, appends mutation records before execution.
12
13use std::sync::Arc;
14use std::thread;
15
16use crossbeam_channel::{Receiver, Sender};
17
18use crate::command::{Command, CommandResponse};
19use crate::hash::shard_for_key;
20use crate::shard::store::ShardStore;
21use crate::shard::wal_trait::{WalRecord, WalWriter};
22
23/// Message sent from the dispatcher to a shard worker.
24pub enum ShardMessage {
25    /// Execute a single command.
26    Single {
27        /// The command to execute.
28        command: Command,
29        /// Channel to send the response back.
30        response_tx: ResponseSender,
31    },
32    /// Execute multiple commands on the same shard in order.
33    Batch {
34        /// Commands tagged with their original position in the batch.
35        commands: Vec<(usize, Command)>,
36        /// Channel to send all responses back.
37        response_tx: BatchResponseSender,
38    },
39}
40
41/// A oneshot-like sender for command responses. Uses a crossbeam channel with capacity 1.
42pub type ResponseSender = Sender<CommandResponse>;
43/// A oneshot-like receiver for command responses.
44pub type ResponseReceiver = Receiver<CommandResponse>;
45/// Sender for batched command responses.
46type BatchResponseSender = Sender<Vec<(usize, CommandResponse)>>;
47/// Receiver for batched command responses.
48type BatchResponseReceiver = Receiver<Vec<(usize, CommandResponse)>>;
49
50/// Create a response channel pair.
51pub fn response_channel() -> (ResponseSender, ResponseReceiver) {
52    crossbeam_channel::bounded(1)
53}
54
55/// Create a response channel pair for batched command responses.
56fn batch_response_channel() -> (BatchResponseSender, BatchResponseReceiver) {
57    crossbeam_channel::bounded(1)
58}
59
60struct WorkerHandle {
61    tx: Sender<ShardMessage>,
62    thread: Option<thread::JoinHandle<()>>,
63}
64
65/// The shard engine coordinates N worker threads, each owning a `ShardStore`.
66///
67/// Commands are routed to the appropriate worker based on key hash.
68pub struct ShardEngine {
69    workers: Vec<WorkerHandle>,
70    shard_count: usize,
71}
72
73impl ShardEngine {
74    /// Create a new ShardEngine with the given number of worker threads.
75    pub fn new(shard_count: usize) -> Self {
76        let wal_writers: Vec<Option<Box<dyn WalWriter>>> = (0..shard_count).map(|_| None).collect();
77        Self::new_with_storage(shard_count, wal_writers)
78    }
79
80    /// Create a new ShardEngine with per-shard WAL writers.
81    ///
82    /// Each element in `wal_writers` corresponds to one shard. The WAL writer
83    /// is moved into the shard's worker thread and called after every mutation.
84    pub fn new_with_storage(
85        shard_count: usize,
86        wal_writers: Vec<Option<Box<dyn WalWriter>>>,
87    ) -> Self {
88        Self::new_with_recovery(shard_count, wal_writers, None)
89    }
90
91    /// Create a new ShardEngine with per-shard WAL writers and optional
92    /// per-shard recovery callbacks.
93    ///
94    /// Each recovery callback receives the shard index and a mutable reference
95    /// to the freshly created `ShardStore`, allowing callers to replay RDB
96    /// snapshots and WAL entries before the worker starts accepting commands.
97    #[allow(clippy::type_complexity)]
98    pub fn new_with_recovery(
99        shard_count: usize,
100        wal_writers: Vec<Option<Box<dyn WalWriter>>>,
101        recovery_fns: Option<Vec<Box<dyn FnOnce(usize, &mut ShardStore) + Send>>>,
102    ) -> Self {
103        assert_eq!(
104            wal_writers.len(),
105            shard_count,
106            "wal_writers length must match shard_count"
107        );
108
109        let mut recovery_iter: Vec<Option<Box<dyn FnOnce(usize, &mut ShardStore) + Send>>> =
110            match recovery_fns {
111                Some(fns) => {
112                    assert_eq!(fns.len(), shard_count);
113                    fns.into_iter().map(Some).collect()
114                }
115                None => (0..shard_count).map(|_| None).collect(),
116            };
117
118        let mut workers = Vec::with_capacity(shard_count);
119
120        for (i, wal_writer) in wal_writers.into_iter().enumerate() {
121            let (tx, rx) = crossbeam_channel::unbounded::<ShardMessage>();
122            let recovery_fn = recovery_iter[i].take();
123            let handle = thread::Builder::new()
124                .name(format!("kora-shard-{}", i))
125                .spawn(move || {
126                    let mut store = ShardStore::new(i as u16);
127                    if let Some(recover) = recovery_fn {
128                        recover(i, &mut store);
129                    }
130                    worker_loop(&mut store, &rx, wal_writer);
131                })
132                .expect("failed to spawn shard worker thread");
133
134            workers.push(WorkerHandle {
135                tx,
136                thread: Some(handle),
137            });
138        }
139
140        ShardEngine {
141            workers,
142            shard_count,
143        }
144    }
145
146    /// Get the number of shards.
147    pub fn shard_count(&self) -> usize {
148        self.shard_count
149    }
150
151    /// Dispatch a command and return a receiver for the response.
152    pub fn dispatch(&self, cmd: Command) -> ResponseReceiver {
153        let (tx, rx) = response_channel();
154
155        if let Some(key) = cmd.key() {
156            let shard_id = shard_for_key(key, self.shard_count) as usize;
157            let _ = self.workers[shard_id].tx.send(ShardMessage::Single {
158                command: cmd,
159                response_tx: tx,
160            });
161        } else if cmd.is_multi_key() {
162            self.dispatch_multi_key(cmd, tx);
163        } else {
164            self.dispatch_keyless(cmd, tx);
165        }
166
167        rx
168    }
169
170    /// Dispatch a command and block until the response is received.
171    pub fn dispatch_blocking(&self, cmd: Command) -> CommandResponse {
172        let rx = self.dispatch(cmd);
173        rx.recv()
174            .unwrap_or(CommandResponse::Error("ERR internal error".into()))
175    }
176
177    /// Dispatch a batch of commands and block until all responses are received.
178    ///
179    /// Commands that target a single key are grouped per shard and executed in-order,
180    /// reducing response-channel overhead for pipelined workloads.
181    pub fn dispatch_batch_blocking(&self, commands: Vec<Command>) -> Vec<CommandResponse> {
182        let total = commands.len();
183        if total == 0 {
184            return Vec::new();
185        }
186
187        let mut responses = vec![None; total];
188        let mut segment = Vec::new();
189
190        for (idx, command) in commands.into_iter().enumerate() {
191            if command.key().is_some() {
192                segment.push((idx, command));
193            } else {
194                if !segment.is_empty() {
195                    self.execute_shard_batch(std::mem::take(&mut segment), &mut responses);
196                }
197                responses[idx] = Some(self.dispatch_blocking(command));
198            }
199        }
200
201        if !segment.is_empty() {
202            self.execute_shard_batch(segment, &mut responses);
203        }
204
205        responses
206            .into_iter()
207            .map(|resp| resp.unwrap_or(CommandResponse::Error("ERR internal error".into())))
208            .collect()
209    }
210
211    fn execute_shard_batch(
212        &self,
213        commands: Vec<(usize, Command)>,
214        responses: &mut [Option<CommandResponse>],
215    ) {
216        if commands.is_empty() {
217            return;
218        }
219
220        let mut shard_batches: Vec<Vec<(usize, Command)>> = vec![Vec::new(); self.shard_count];
221        for (idx, command) in commands {
222            let Some(key) = command.key() else {
223                responses[idx] = Some(self.dispatch_blocking(command));
224                continue;
225            };
226            let shard_id = shard_for_key(key, self.shard_count) as usize;
227            shard_batches[shard_id].push((idx, command));
228        }
229
230        let mut receivers = Vec::new();
231        for (shard_id, commands) in shard_batches.into_iter().enumerate() {
232            if commands.is_empty() {
233                continue;
234            }
235
236            let (resp_tx, resp_rx) = batch_response_channel();
237            let _ = self.workers[shard_id].tx.send(ShardMessage::Batch {
238                commands,
239                response_tx: resp_tx,
240            });
241            receivers.push(resp_rx);
242        }
243
244        for rx in receivers {
245            if let Ok(items) = rx.recv() {
246                for (idx, response) in items {
247                    if let Some(slot) = responses.get_mut(idx) {
248                        *slot = Some(response);
249                    }
250                }
251            }
252        }
253    }
254
255    fn dispatch_multi_key(&self, cmd: Command, tx: ResponseSender) {
256        match cmd {
257            Command::MGet { keys } => {
258                let mut results = vec![CommandResponse::Nil; keys.len()];
259                let mut shard_requests: Vec<Vec<(usize, Vec<u8>)>> = vec![vec![]; self.shard_count];
260                for (i, key) in keys.iter().enumerate() {
261                    let shard_id = shard_for_key(key, self.shard_count) as usize;
262                    shard_requests[shard_id].push((i, key.clone()));
263                }
264
265                let mut receivers = Vec::new();
266                for (shard_id, reqs) in shard_requests.into_iter().enumerate() {
267                    if reqs.is_empty() {
268                        continue;
269                    }
270                    let shard_keys: Vec<Vec<u8>> = reqs.iter().map(|(_, k)| k.clone()).collect();
271                    let indices: Vec<usize> = reqs.iter().map(|(i, _)| *i).collect();
272                    let (resp_tx, resp_rx) = response_channel();
273                    let _ = self.workers[shard_id].tx.send(ShardMessage::Single {
274                        command: Command::MGet { keys: shard_keys },
275                        response_tx: resp_tx,
276                    });
277                    receivers.push((indices, resp_rx));
278                }
279
280                for (indices, rx) in receivers {
281                    if let Ok(CommandResponse::Array(values)) = rx.recv() {
282                        for (idx, val) in indices.into_iter().zip(values) {
283                            results[idx] = val;
284                        }
285                    }
286                }
287                let _ = tx.send(CommandResponse::Array(results));
288            }
289            Command::MSet { entries } => {
290                let mut shard_entries: Vec<Vec<(Vec<u8>, Vec<u8>)>> =
291                    vec![vec![]; self.shard_count];
292                for (key, value) in entries {
293                    let shard_id = shard_for_key(&key, self.shard_count) as usize;
294                    shard_entries[shard_id].push((key, value));
295                }
296
297                let mut receivers = Vec::new();
298                for (shard_id, entries) in shard_entries.into_iter().enumerate() {
299                    if entries.is_empty() {
300                        continue;
301                    }
302                    let (resp_tx, resp_rx) = response_channel();
303                    let _ = self.workers[shard_id].tx.send(ShardMessage::Single {
304                        command: Command::MSet { entries },
305                        response_tx: resp_tx,
306                    });
307                    receivers.push(resp_rx);
308                }
309                for rx in receivers {
310                    let _ = rx.recv();
311                }
312                let _ = tx.send(CommandResponse::Ok);
313            }
314            Command::Del { keys } => {
315                let mut shard_keys: Vec<Vec<Vec<u8>>> = vec![vec![]; self.shard_count];
316                for key in keys {
317                    let shard_id = shard_for_key(&key, self.shard_count) as usize;
318                    shard_keys[shard_id].push(key);
319                }
320
321                let mut total = 0i64;
322                let mut receivers = Vec::new();
323                for (shard_id, keys) in shard_keys.into_iter().enumerate() {
324                    if keys.is_empty() {
325                        continue;
326                    }
327                    let (resp_tx, resp_rx) = response_channel();
328                    let _ = self.workers[shard_id].tx.send(ShardMessage::Single {
329                        command: Command::Del { keys },
330                        response_tx: resp_tx,
331                    });
332                    receivers.push(resp_rx);
333                }
334                for rx in receivers {
335                    if let Ok(CommandResponse::Integer(n)) = rx.recv() {
336                        total += n;
337                    }
338                }
339                let _ = tx.send(CommandResponse::Integer(total));
340            }
341            Command::Exists { keys } => {
342                let mut shard_keys: Vec<Vec<Vec<u8>>> = vec![vec![]; self.shard_count];
343                for key in keys {
344                    let shard_id = shard_for_key(&key, self.shard_count) as usize;
345                    shard_keys[shard_id].push(key);
346                }
347
348                let mut total = 0i64;
349                let mut receivers = Vec::new();
350                for (shard_id, keys) in shard_keys.into_iter().enumerate() {
351                    if keys.is_empty() {
352                        continue;
353                    }
354                    let (resp_tx, resp_rx) = response_channel();
355                    let _ = self.workers[shard_id].tx.send(ShardMessage::Single {
356                        command: Command::Exists { keys },
357                        response_tx: resp_tx,
358                    });
359                    receivers.push(resp_rx);
360                }
361                for rx in receivers {
362                    if let Ok(CommandResponse::Integer(n)) = rx.recv() {
363                        total += n;
364                    }
365                }
366                let _ = tx.send(CommandResponse::Integer(total));
367            }
368            Command::Unlink { keys } => {
369                let mut shard_keys: Vec<Vec<Vec<u8>>> = vec![vec![]; self.shard_count];
370                for key in keys {
371                    let shard_id = shard_for_key(&key, self.shard_count) as usize;
372                    shard_keys[shard_id].push(key);
373                }
374
375                let mut total = 0i64;
376                let mut receivers = Vec::new();
377                for (shard_id, keys) in shard_keys.into_iter().enumerate() {
378                    if keys.is_empty() {
379                        continue;
380                    }
381                    let (resp_tx, resp_rx) = response_channel();
382                    let _ = self.workers[shard_id].tx.send(ShardMessage::Single {
383                        command: Command::Unlink { keys },
384                        response_tx: resp_tx,
385                    });
386                    receivers.push(resp_rx);
387                }
388                for rx in receivers {
389                    if let Ok(CommandResponse::Integer(n)) = rx.recv() {
390                        total += n;
391                    }
392                }
393                let _ = tx.send(CommandResponse::Integer(total));
394            }
395            Command::Touch { keys } => {
396                let mut shard_keys: Vec<Vec<Vec<u8>>> = vec![vec![]; self.shard_count];
397                for key in keys {
398                    let shard_id = shard_for_key(&key, self.shard_count) as usize;
399                    shard_keys[shard_id].push(key);
400                }
401
402                let mut total = 0i64;
403                let mut receivers = Vec::new();
404                for (shard_id, keys) in shard_keys.into_iter().enumerate() {
405                    if keys.is_empty() {
406                        continue;
407                    }
408                    let (resp_tx, resp_rx) = response_channel();
409                    let _ = self.workers[shard_id].tx.send(ShardMessage::Single {
410                        command: Command::Touch { keys },
411                        response_tx: resp_tx,
412                    });
413                    receivers.push(resp_rx);
414                }
415                for rx in receivers {
416                    if let Ok(CommandResponse::Integer(n)) = rx.recv() {
417                        total += n;
418                    }
419                }
420                let _ = tx.send(CommandResponse::Integer(total));
421            }
422            Command::MSetNx { entries } => {
423                let mut shard_entries: Vec<Vec<(Vec<u8>, Vec<u8>)>> =
424                    vec![vec![]; self.shard_count];
425                for (key, value) in entries {
426                    let shard_id = shard_for_key(&key, self.shard_count) as usize;
427                    shard_entries[shard_id].push((key, value));
428                }
429
430                let mut all_set = true;
431                let mut receivers = Vec::new();
432                for (shard_id, entries) in shard_entries.into_iter().enumerate() {
433                    if entries.is_empty() {
434                        continue;
435                    }
436                    let (resp_tx, resp_rx) = response_channel();
437                    let _ = self.workers[shard_id].tx.send(ShardMessage::Single {
438                        command: Command::MSetNx { entries },
439                        response_tx: resp_tx,
440                    });
441                    receivers.push(resp_rx);
442                }
443                for rx in receivers {
444                    if let Ok(CommandResponse::Integer(0)) = rx.recv() {
445                        all_set = false;
446                    }
447                }
448                let _ = tx.send(CommandResponse::Integer(if all_set { 1 } else { 0 }));
449            }
450            _ => {
451                let _ = tx.send(CommandResponse::Error(
452                    "ERR unsupported multi-key command".into(),
453                ));
454            }
455        }
456    }
457
458    fn dispatch_vec_query(&self, cmd: Command, tx: ResponseSender) {
459        let (key, k, vector) = match cmd {
460            Command::VecQuery { key, k, vector } => (key, k, vector),
461            _ => {
462                let _ = tx.send(CommandResponse::Error(
463                    "ERR internal: not a VecQuery".into(),
464                ));
465                return;
466            }
467        };
468
469        let mut receivers = Vec::with_capacity(self.shard_count);
470        for worker in &self.workers {
471            let (resp_tx, resp_rx) = response_channel();
472            let _ = worker.tx.send(ShardMessage::Single {
473                command: Command::VecQuery {
474                    key: key.clone(),
475                    k,
476                    vector: vector.clone(),
477                },
478                response_tx: resp_tx,
479            });
480            receivers.push(resp_rx);
481        }
482
483        let mut all_results: Vec<(i64, Vec<u8>)> = Vec::new();
484        for rx in receivers {
485            if let Ok(CommandResponse::Array(items)) = rx.recv() {
486                for item in items {
487                    if let CommandResponse::Array(pair) = item {
488                        if pair.len() == 2 {
489                            if let (
490                                CommandResponse::Integer(id),
491                                CommandResponse::BulkString(dist),
492                            ) = (&pair[0], &pair[1])
493                            {
494                                all_results.push((*id, dist.clone()));
495                            }
496                        }
497                    }
498                }
499            }
500        }
501
502        all_results.sort_by(|a, b| {
503            let da: f64 = String::from_utf8_lossy(&a.1).parse().unwrap_or(f64::MAX);
504            let db: f64 = String::from_utf8_lossy(&b.1).parse().unwrap_or(f64::MAX);
505            da.partial_cmp(&db).unwrap_or(std::cmp::Ordering::Equal)
506        });
507        all_results.truncate(k);
508
509        let items: Vec<CommandResponse> = all_results
510            .into_iter()
511            .map(|(id, dist)| {
512                CommandResponse::Array(vec![
513                    CommandResponse::Integer(id),
514                    CommandResponse::BulkString(dist),
515                ])
516            })
517            .collect();
518        let _ = tx.send(CommandResponse::Array(items));
519    }
520
521    fn dispatch_keyless(&self, cmd: Command, tx: ResponseSender) {
522        match cmd {
523            Command::VecQuery { .. } => self.dispatch_vec_query(cmd, tx),
524            Command::DbSize => {
525                let mut total = 0i64;
526                let mut receivers = Vec::new();
527                for worker in &self.workers {
528                    let (resp_tx, resp_rx) = response_channel();
529                    let _ = worker.tx.send(ShardMessage::Single {
530                        command: Command::DbSize,
531                        response_tx: resp_tx,
532                    });
533                    receivers.push(resp_rx);
534                }
535                for rx in receivers {
536                    if let Ok(CommandResponse::Integer(n)) = rx.recv() {
537                        total += n;
538                    }
539                }
540                let _ = tx.send(CommandResponse::Integer(total));
541            }
542            Command::FlushDb | Command::FlushAll => {
543                let mut receivers = Vec::new();
544                for worker in &self.workers {
545                    let (resp_tx, resp_rx) = response_channel();
546                    let _ = worker.tx.send(ShardMessage::Single {
547                        command: Command::FlushDb,
548                        response_tx: resp_tx,
549                    });
550                    receivers.push(resp_rx);
551                }
552                for rx in receivers {
553                    let _ = rx.recv();
554                }
555                let _ = tx.send(CommandResponse::Ok);
556            }
557            Command::Dump => {
558                let mut all_entries = Vec::new();
559                let mut receivers = Vec::new();
560                for worker in &self.workers {
561                    let (resp_tx, resp_rx) = response_channel();
562                    let _ = worker.tx.send(ShardMessage::Single {
563                        command: Command::Dump,
564                        response_tx: resp_tx,
565                    });
566                    receivers.push(resp_rx);
567                }
568                for rx in receivers {
569                    if let Ok(CommandResponse::Array(entries)) = rx.recv() {
570                        all_entries.extend(entries);
571                    }
572                }
573                let _ = tx.send(CommandResponse::Array(all_entries));
574            }
575            Command::Keys { pattern } => {
576                let mut all_keys = Vec::new();
577                let mut receivers = Vec::new();
578                for worker in &self.workers {
579                    let (resp_tx, resp_rx) = response_channel();
580                    let _ = worker.tx.send(ShardMessage::Single {
581                        command: Command::Keys {
582                            pattern: pattern.clone(),
583                        },
584                        response_tx: resp_tx,
585                    });
586                    receivers.push(resp_rx);
587                }
588                for rx in receivers {
589                    if let Ok(CommandResponse::Array(keys)) = rx.recv() {
590                        all_keys.extend(keys);
591                    }
592                }
593                let _ = tx.send(CommandResponse::Array(all_keys));
594            }
595            Command::Scan {
596                cursor,
597                pattern,
598                count,
599            } => {
600                let pattern = pattern.unwrap_or_else(|| "*".to_string());
601                let mut merged_keys: Vec<Vec<u8>> = Vec::new();
602                let mut receivers = Vec::new();
603                for worker in &self.workers {
604                    let (resp_tx, resp_rx) = response_channel();
605                    let _ = worker.tx.send(ShardMessage::Single {
606                        command: Command::Keys {
607                            pattern: pattern.clone(),
608                        },
609                        response_tx: resp_tx,
610                    });
611                    receivers.push(resp_rx);
612                }
613                for rx in receivers {
614                    if let Ok(CommandResponse::Array(keys)) = rx.recv() {
615                        for key in keys {
616                            if let CommandResponse::BulkString(raw) = key {
617                                merged_keys.push(raw);
618                            }
619                        }
620                    }
621                }
622
623                merged_keys.sort();
624                let start = cursor as usize;
625                let limit = count.unwrap_or(10).max(1);
626                if start >= merged_keys.len() {
627                    let _ = tx.send(CommandResponse::Array(vec![
628                        CommandResponse::BulkString(b"0".to_vec()),
629                        CommandResponse::Array(vec![]),
630                    ]));
631                    return;
632                }
633                let end = start.saturating_add(limit).min(merged_keys.len());
634                let result_keys: Vec<CommandResponse> = merged_keys[start..end]
635                    .iter()
636                    .map(|k| CommandResponse::BulkString(k.clone()))
637                    .collect();
638                let next_cursor = if end >= merged_keys.len() { 0 } else { end };
639                let _ = tx.send(CommandResponse::Array(vec![
640                    CommandResponse::BulkString(next_cursor.to_string().into_bytes()),
641                    CommandResponse::Array(result_keys),
642                ]));
643            }
644            Command::StatsHotkeys { count } => {
645                let mut all_hot: Vec<(Vec<u8>, i64)> = Vec::new();
646                let mut receivers = Vec::new();
647                for worker in &self.workers {
648                    let (resp_tx, resp_rx) = response_channel();
649                    let _ = worker.tx.send(ShardMessage::Single {
650                        command: Command::StatsHotkeys { count },
651                        response_tx: resp_tx,
652                    });
653                    receivers.push(resp_rx);
654                }
655                for rx in receivers {
656                    if let Ok(CommandResponse::Array(hotkeys)) = rx.recv() {
657                        for hotkey in hotkeys {
658                            if let CommandResponse::Array(pair) = hotkey {
659                                if pair.len() != 2 {
660                                    continue;
661                                }
662                                if let (
663                                    CommandResponse::BulkString(key),
664                                    CommandResponse::Integer(freq),
665                                ) = (&pair[0], &pair[1])
666                                {
667                                    all_hot.push((key.clone(), *freq));
668                                }
669                            }
670                        }
671                    }
672                }
673                all_hot.sort_by(|a, b| b.1.cmp(&a.1).then_with(|| a.0.cmp(&b.0)));
674                all_hot.truncate(count);
675                let items = all_hot
676                    .into_iter()
677                    .map(|(key, freq)| {
678                        CommandResponse::Array(vec![
679                            CommandResponse::BulkString(key),
680                            CommandResponse::Integer(freq),
681                        ])
682                    })
683                    .collect();
684                let _ = tx.send(CommandResponse::Array(items));
685            }
686            Command::StatsMemory { prefixes } => {
687                let mut totals = vec![0i64; prefixes.len()];
688                let mut receivers = Vec::new();
689                for worker in &self.workers {
690                    let (resp_tx, resp_rx) = response_channel();
691                    let _ = worker.tx.send(ShardMessage::Single {
692                        command: Command::StatsMemory {
693                            prefixes: prefixes.clone(),
694                        },
695                        response_tx: resp_tx,
696                    });
697                    receivers.push(resp_rx);
698                }
699                for rx in receivers {
700                    if let Ok(CommandResponse::Array(items)) = rx.recv() {
701                        for (idx, item) in items.into_iter().enumerate().take(totals.len()) {
702                            if let CommandResponse::Integer(val) = item {
703                                totals[idx] = totals[idx].saturating_add(val);
704                            }
705                        }
706                    }
707                }
708                let merged = totals.into_iter().map(CommandResponse::Integer).collect();
709                let _ = tx.send(CommandResponse::Array(merged));
710            }
711            Command::RandomKey => {
712                for worker in &self.workers {
713                    let (resp_tx, resp_rx) = response_channel();
714                    let _ = worker.tx.send(ShardMessage::Single {
715                        command: Command::RandomKey,
716                        response_tx: resp_tx,
717                    });
718                    if let Ok(resp) = resp_rx.recv() {
719                        if !matches!(resp, CommandResponse::Nil) {
720                            let _ = tx.send(resp);
721                            return;
722                        }
723                    }
724                }
725                let _ = tx.send(CommandResponse::Nil);
726            }
727            _ => {
728                let _ = self.workers[0].tx.send(ShardMessage::Single {
729                    command: cmd,
730                    response_tx: tx,
731                });
732            }
733        }
734    }
735}
736
737impl Drop for ShardEngine {
738    fn drop(&mut self) {
739        for worker in &mut self.workers {
740            let (dummy, _) = crossbeam_channel::unbounded();
741            let old_tx = std::mem::replace(&mut worker.tx, dummy);
742            drop(old_tx);
743        }
744        for worker in &mut self.workers {
745            if let Some(handle) = worker.thread.take() {
746                let _ = handle.join();
747            }
748        }
749    }
750}
751
752/// Worker thread event loop.
753fn worker_loop(
754    store: &mut ShardStore,
755    rx: &Receiver<ShardMessage>,
756    mut wal_writer: Option<Box<dyn WalWriter>>,
757) {
758    const EXPIRE_SWEEP_INTERVAL_OPS: u32 = 4096;
759    const EXPIRE_SWEEP_SAMPLE_SIZE: usize = 64;
760
761    fn maybe_sweep(store: &mut ShardStore, ops_since_expire: &mut u32) {
762        *ops_since_expire += 1;
763        if *ops_since_expire >= EXPIRE_SWEEP_INTERVAL_OPS {
764            let _ = store.evict_expired_sample(EXPIRE_SWEEP_SAMPLE_SIZE);
765            *ops_since_expire = 0;
766        }
767    }
768
769    fn execute_with_wal(
770        store: &mut ShardStore,
771        wal_writer: Option<&mut Box<dyn WalWriter>>,
772        command: Command,
773    ) -> CommandResponse {
774        if let Some(writer) = wal_writer {
775            if command.is_mutation() {
776                if let Some(record) = command_to_wal_record(&command) {
777                    writer.append(&record);
778                }
779            }
780        }
781        store.execute(command)
782    }
783
784    let mut ops_since_expire = 0u32;
785    while let Ok(msg) = rx.recv() {
786        match msg {
787            ShardMessage::Single {
788                command,
789                response_tx,
790            } => {
791                maybe_sweep(store, &mut ops_since_expire);
792                let response = execute_with_wal(store, wal_writer.as_mut(), command);
793                let _ = response_tx.send(response);
794            }
795            ShardMessage::Batch {
796                commands,
797                response_tx,
798            } => {
799                let mut responses = Vec::with_capacity(commands.len());
800                for (idx, command) in commands {
801                    maybe_sweep(store, &mut ops_since_expire);
802                    let response = execute_with_wal(store, wal_writer.as_mut(), command);
803                    responses.push((idx, response));
804                }
805                let _ = response_tx.send(responses);
806            }
807        }
808    }
809}
810
811/// Convert a `Command` to a `WalRecord` for WAL logging.
812pub fn command_to_wal_record(cmd: &Command) -> Option<WalRecord> {
813    match cmd {
814        Command::Set {
815            key, value, ex, px, ..
816        } => {
817            let ttl_ms = if let Some(s) = ex {
818                Some(s * 1000)
819            } else {
820                *px
821            };
822            Some(WalRecord::Set {
823                key: key.clone(),
824                value: value.clone(),
825                ttl_ms,
826            })
827        }
828        Command::SetNx { key, value } => Some(WalRecord::Set {
829            key: key.clone(),
830            value: value.clone(),
831            ttl_ms: None,
832        }),
833        Command::GetSet { key, value } => Some(WalRecord::Set {
834            key: key.clone(),
835            value: value.clone(),
836            ttl_ms: None,
837        }),
838        Command::Append { key, value } => Some(WalRecord::Set {
839            key: key.clone(),
840            value: value.clone(),
841            ttl_ms: None,
842        }),
843        Command::Del { keys } => keys.first().map(|key| WalRecord::Del { key: key.clone() }),
844        Command::Expire { key, seconds } => Some(WalRecord::Expire {
845            key: key.clone(),
846            ttl_ms: seconds * 1000,
847        }),
848        Command::PExpire { key, millis } => Some(WalRecord::Expire {
849            key: key.clone(),
850            ttl_ms: *millis,
851        }),
852        Command::LPush { key, values } => Some(WalRecord::LPush {
853            key: key.clone(),
854            values: values.clone(),
855        }),
856        Command::RPush { key, values } => Some(WalRecord::RPush {
857            key: key.clone(),
858            values: values.clone(),
859        }),
860        Command::HSet { key, fields } => Some(WalRecord::HSet {
861            key: key.clone(),
862            fields: fields.clone(),
863        }),
864        Command::SAdd { key, members } => Some(WalRecord::SAdd {
865            key: key.clone(),
866            members: members.clone(),
867        }),
868        Command::FlushDb | Command::FlushAll => Some(WalRecord::FlushDb),
869        Command::DocSet {
870            collection,
871            doc_id,
872            json,
873        } => Some(WalRecord::DocSet {
874            collection: collection.clone(),
875            doc_id: doc_id.clone(),
876            json: json.clone(),
877        }),
878        Command::DocInsert { .. } => None,
879        Command::DocDel { collection, doc_id } => Some(WalRecord::DocDel {
880            collection: collection.clone(),
881            doc_id: doc_id.clone(),
882        }),
883        Command::DocMSet { .. } => None,
884        Command::VecSet {
885            key,
886            dimensions,
887            vector,
888        } => {
889            let mut vec_bytes = Vec::with_capacity(vector.len() * 4);
890            for &f in vector {
891                vec_bytes.extend_from_slice(&f.to_le_bytes());
892            }
893            Some(WalRecord::VecSet {
894                key: key.clone(),
895                dimensions: *dimensions,
896                vector: vec_bytes,
897            })
898        }
899        Command::VecDel { key } => Some(WalRecord::VecDel { key: key.clone() }),
900        _ => None,
901    }
902}
903
904/// A handle to the engine that can be shared across async tasks.
905pub type SharedEngine = Arc<ShardEngine>;
906
907#[cfg(test)]
908mod tests {
909    use super::*;
910
911    #[test]
912    fn test_engine_basic() {
913        let engine = ShardEngine::new(4);
914        let resp = engine.dispatch_blocking(Command::Ping { message: None });
915        assert!(matches!(resp, CommandResponse::SimpleString(s) if s == "PONG"));
916    }
917
918    #[test]
919    fn test_engine_set_get() {
920        let engine = ShardEngine::new(4);
921        engine.dispatch_blocking(Command::Set {
922            key: b"hello".to_vec(),
923            value: b"world".to_vec(),
924            ex: None,
925            px: None,
926            nx: false,
927            xx: false,
928        });
929        match engine.dispatch_blocking(Command::Get {
930            key: b"hello".to_vec(),
931        }) {
932            CommandResponse::BulkString(v) => assert_eq!(v, b"world"),
933            other => panic!("Expected 'world', got {:?}", other),
934        }
935    }
936
937    #[test]
938    fn test_engine_mget_across_shards() {
939        let engine = ShardEngine::new(4);
940        // Set multiple keys that will hash to different shards
941        for i in 0..20 {
942            engine.dispatch_blocking(Command::Set {
943                key: format!("key:{}", i).into_bytes(),
944                value: format!("val:{}", i).into_bytes(),
945                ex: None,
946                px: None,
947                nx: false,
948                xx: false,
949            });
950        }
951        let keys: Vec<Vec<u8>> = (0..20).map(|i| format!("key:{}", i).into_bytes()).collect();
952        match engine.dispatch_blocking(Command::MGet { keys }) {
953            CommandResponse::Array(values) => {
954                assert_eq!(values.len(), 20);
955                for (i, v) in values.iter().enumerate() {
956                    match v {
957                        CommandResponse::BulkString(b) => {
958                            assert_eq!(*b, format!("val:{}", i).into_bytes());
959                        }
960                        other => panic!("Expected BulkString for key:{}, got {:?}", i, other),
961                    }
962                }
963            }
964            other => panic!("Expected Array, got {:?}", other),
965        }
966    }
967
968    #[test]
969    fn test_engine_dbsize() {
970        let engine = ShardEngine::new(4);
971        for i in 0..10 {
972            engine.dispatch_blocking(Command::Set {
973                key: format!("k{}", i).into_bytes(),
974                value: b"v".to_vec(),
975                ex: None,
976                px: None,
977                nx: false,
978                xx: false,
979            });
980        }
981        match engine.dispatch_blocking(Command::DbSize) {
982            CommandResponse::Integer(10) => {}
983            other => panic!("Expected 10, got {:?}", other),
984        }
985    }
986
987    #[test]
988    fn test_engine_concurrent_access() {
989        let engine = Arc::new(ShardEngine::new(4));
990        let mut handles = Vec::new();
991
992        for t in 0..8 {
993            let eng = engine.clone();
994            handles.push(thread::spawn(move || {
995                for i in 0..100 {
996                    let key = format!("t{}:k{}", t, i).into_bytes();
997                    let val = format!("v{}", i).into_bytes();
998                    eng.dispatch_blocking(Command::Set {
999                        key: key.clone(),
1000                        value: val.clone(),
1001                        ex: None,
1002                        px: None,
1003                        nx: false,
1004                        xx: false,
1005                    });
1006                    match eng.dispatch_blocking(Command::Get { key }) {
1007                        CommandResponse::BulkString(v) => assert_eq!(v, val),
1008                        CommandResponse::Nil => {} // race with other thread is ok
1009                        other => panic!("Unexpected: {:?}", other),
1010                    }
1011                }
1012            }));
1013        }
1014
1015        for h in handles {
1016            h.join().unwrap();
1017        }
1018    }
1019
1020    #[test]
1021    fn test_dispatch_batch_blocking_preserves_barrier_order() {
1022        let engine = ShardEngine::new(4);
1023        let responses = engine.dispatch_batch_blocking(vec![
1024            Command::Set {
1025                key: b"k".to_vec(),
1026                value: b"v".to_vec(),
1027                ex: None,
1028                px: None,
1029                nx: false,
1030                xx: false,
1031            },
1032            Command::FlushDb,
1033            Command::Get { key: b"k".to_vec() },
1034        ]);
1035
1036        assert_eq!(responses.len(), 3);
1037        assert!(matches!(responses[0], CommandResponse::Ok));
1038        assert!(matches!(responses[1], CommandResponse::Ok));
1039        assert!(matches!(responses[2], CommandResponse::Nil));
1040    }
1041
1042    #[test]
1043    fn test_engine_shutdown() {
1044        let engine = ShardEngine::new(2);
1045        engine.dispatch_blocking(Command::Set {
1046            key: b"k".to_vec(),
1047            value: b"v".to_vec(),
1048            ex: None,
1049            px: None,
1050            nx: false,
1051            xx: false,
1052        });
1053        drop(engine); // should not hang
1054    }
1055
1056    #[test]
1057    fn test_vec_set_and_query_per_shard() {
1058        let engine = ShardEngine::new(4);
1059
1060        let vector1 = vec![1.0f32, 0.0, 0.0, 0.0];
1061        let vector2 = vec![0.0f32, 1.0, 0.0, 0.0];
1062        let vector3 = vec![1.0f32, 1.0, 0.0, 0.0];
1063
1064        let resp1 = engine.dispatch_blocking(Command::VecSet {
1065            key: b"idx".to_vec(),
1066            dimensions: 4,
1067            vector: vector1.clone(),
1068        });
1069        assert!(matches!(resp1, CommandResponse::Integer(_)));
1070
1071        let resp2 = engine.dispatch_blocking(Command::VecSet {
1072            key: b"idx".to_vec(),
1073            dimensions: 4,
1074            vector: vector2,
1075        });
1076        assert!(matches!(resp2, CommandResponse::Integer(_)));
1077
1078        let resp3 = engine.dispatch_blocking(Command::VecSet {
1079            key: b"idx".to_vec(),
1080            dimensions: 4,
1081            vector: vector3,
1082        });
1083        assert!(matches!(resp3, CommandResponse::Integer(_)));
1084
1085        let query_resp = engine.dispatch_blocking(Command::VecQuery {
1086            key: b"idx".to_vec(),
1087            k: 3,
1088            vector: vector1,
1089        });
1090        match query_resp {
1091            CommandResponse::Array(results) => {
1092                assert!(!results.is_empty(), "VecQuery should return results");
1093                assert!(results.len() <= 3);
1094            }
1095            other => panic!("Expected Array, got {:?}", other),
1096        }
1097    }
1098
1099    #[test]
1100    fn test_vec_del() {
1101        let engine = ShardEngine::new(2);
1102
1103        engine.dispatch_blocking(Command::VecSet {
1104            key: b"myidx".to_vec(),
1105            dimensions: 3,
1106            vector: vec![1.0, 2.0, 3.0],
1107        });
1108
1109        let del_resp = engine.dispatch_blocking(Command::VecDel {
1110            key: b"myidx".to_vec(),
1111        });
1112        assert!(matches!(del_resp, CommandResponse::Integer(1)));
1113
1114        let del_again = engine.dispatch_blocking(Command::VecDel {
1115            key: b"myidx".to_vec(),
1116        });
1117        assert!(matches!(del_again, CommandResponse::Integer(0)));
1118
1119        let query_resp = engine.dispatch_blocking(Command::VecQuery {
1120            key: b"myidx".to_vec(),
1121            k: 5,
1122            vector: vec![1.0, 2.0, 3.0],
1123        });
1124        match query_resp {
1125            CommandResponse::Array(results) => assert!(results.is_empty()),
1126            other => panic!("Expected empty Array, got {:?}", other),
1127        }
1128    }
1129
1130    #[test]
1131    fn test_vec_query_fan_out() {
1132        let engine = ShardEngine::new(4);
1133
1134        for i in 0..10 {
1135            let v: Vec<f32> = (0..8).map(|d| (i * 8 + d) as f32 * 0.1).collect();
1136            engine.dispatch_blocking(Command::VecSet {
1137                key: b"fanout-idx".to_vec(),
1138                dimensions: 8,
1139                vector: v,
1140            });
1141        }
1142
1143        let query = vec![0.0f32; 8];
1144        let resp = engine.dispatch_blocking(Command::VecQuery {
1145            key: b"fanout-idx".to_vec(),
1146            k: 5,
1147            vector: query,
1148        });
1149        match resp {
1150            CommandResponse::Array(results) => {
1151                assert!(results.len() <= 5);
1152            }
1153            other => panic!("Expected Array, got {:?}", other),
1154        }
1155    }
1156}