1use std::path::PathBuf;
9use std::time::Duration;
10
11use bytes::Bytes;
12use ember_persistence::aof::{AofRecord, AofWriter, FsyncPolicy};
13use ember_persistence::recovery::{self, RecoveredValue};
14use ember_persistence::snapshot::{self, SnapEntry, SnapValue, SnapshotWriter};
15use tokio::sync::{mpsc, oneshot};
16use tracing::{info, warn};
17
18use crate::dropper::DropHandle;
19use crate::error::ShardError;
20use crate::expiry;
21use crate::keyspace::{
22 IncrError, IncrFloatError, Keyspace, KeyspaceStats, SetResult, ShardConfig, TtlResult,
23 WriteError,
24};
25use crate::types::sorted_set::ZAddFlags;
26use crate::types::Value;
27
28const EXPIRY_TICK: Duration = Duration::from_millis(100);
31
32const FSYNC_INTERVAL: Duration = Duration::from_secs(1);
34
35#[derive(Debug, Clone)]
37pub struct ShardPersistenceConfig {
38 pub data_dir: PathBuf,
40 pub append_only: bool,
42 pub fsync_policy: FsyncPolicy,
44 #[cfg(feature = "encryption")]
47 pub encryption_key: Option<ember_persistence::encryption::EncryptionKey>,
48}
49
50#[derive(Debug)]
52pub enum ShardRequest {
53 Get {
54 key: String,
55 },
56 Set {
57 key: String,
58 value: Bytes,
59 expire: Option<Duration>,
60 nx: bool,
62 xx: bool,
64 },
65 Incr {
66 key: String,
67 },
68 Decr {
69 key: String,
70 },
71 IncrBy {
72 key: String,
73 delta: i64,
74 },
75 DecrBy {
76 key: String,
77 delta: i64,
78 },
79 IncrByFloat {
80 key: String,
81 delta: f64,
82 },
83 Append {
84 key: String,
85 value: Bytes,
86 },
87 Strlen {
88 key: String,
89 },
90 Keys {
92 pattern: String,
93 },
94 Rename {
96 key: String,
97 newkey: String,
98 },
99 Del {
100 key: String,
101 },
102 Unlink {
104 key: String,
105 },
106 Exists {
107 key: String,
108 },
109 Expire {
110 key: String,
111 seconds: u64,
112 },
113 Ttl {
114 key: String,
115 },
116 Persist {
117 key: String,
118 },
119 Pttl {
120 key: String,
121 },
122 Pexpire {
123 key: String,
124 milliseconds: u64,
125 },
126 LPush {
127 key: String,
128 values: Vec<Bytes>,
129 },
130 RPush {
131 key: String,
132 values: Vec<Bytes>,
133 },
134 LPop {
135 key: String,
136 },
137 RPop {
138 key: String,
139 },
140 LRange {
141 key: String,
142 start: i64,
143 stop: i64,
144 },
145 LLen {
146 key: String,
147 },
148 Type {
149 key: String,
150 },
151 ZAdd {
152 key: String,
153 members: Vec<(f64, String)>,
154 nx: bool,
155 xx: bool,
156 gt: bool,
157 lt: bool,
158 ch: bool,
159 },
160 ZRem {
161 key: String,
162 members: Vec<String>,
163 },
164 ZScore {
165 key: String,
166 member: String,
167 },
168 ZRank {
169 key: String,
170 member: String,
171 },
172 ZCard {
173 key: String,
174 },
175 ZRange {
176 key: String,
177 start: i64,
178 stop: i64,
179 with_scores: bool,
180 },
181 HSet {
182 key: String,
183 fields: Vec<(String, Bytes)>,
184 },
185 HGet {
186 key: String,
187 field: String,
188 },
189 HGetAll {
190 key: String,
191 },
192 HDel {
193 key: String,
194 fields: Vec<String>,
195 },
196 HExists {
197 key: String,
198 field: String,
199 },
200 HLen {
201 key: String,
202 },
203 HIncrBy {
204 key: String,
205 field: String,
206 delta: i64,
207 },
208 HKeys {
209 key: String,
210 },
211 HVals {
212 key: String,
213 },
214 HMGet {
215 key: String,
216 fields: Vec<String>,
217 },
218 SAdd {
219 key: String,
220 members: Vec<String>,
221 },
222 SRem {
223 key: String,
224 members: Vec<String>,
225 },
226 SMembers {
227 key: String,
228 },
229 SIsMember {
230 key: String,
231 member: String,
232 },
233 SCard {
234 key: String,
235 },
236 DbSize,
238 Stats,
240 Snapshot,
242 RewriteAof,
244 FlushDb,
246 FlushDbAsync,
248 Scan {
250 cursor: u64,
251 count: usize,
252 pattern: Option<String>,
253 },
254 CountKeysInSlot {
256 slot: u16,
257 },
258 GetKeysInSlot {
260 slot: u16,
261 count: usize,
262 },
263 #[cfg(feature = "vector")]
265 VAdd {
266 key: String,
267 element: String,
268 vector: Vec<f32>,
269 metric: u8,
270 quantization: u8,
271 connectivity: u32,
272 expansion_add: u32,
273 },
274 #[cfg(feature = "vector")]
276 VAddBatch {
277 key: String,
278 entries: Vec<(String, Vec<f32>)>,
279 dim: usize,
280 metric: u8,
281 quantization: u8,
282 connectivity: u32,
283 expansion_add: u32,
284 },
285 #[cfg(feature = "vector")]
287 VSim {
288 key: String,
289 query: Vec<f32>,
290 count: usize,
291 ef_search: usize,
292 },
293 #[cfg(feature = "vector")]
295 VRem {
296 key: String,
297 element: String,
298 },
299 #[cfg(feature = "vector")]
301 VGet {
302 key: String,
303 element: String,
304 },
305 #[cfg(feature = "vector")]
307 VCard {
308 key: String,
309 },
310 #[cfg(feature = "vector")]
312 VDim {
313 key: String,
314 },
315 #[cfg(feature = "vector")]
317 VInfo {
318 key: String,
319 },
320 #[cfg(feature = "protobuf")]
322 ProtoSet {
323 key: String,
324 type_name: String,
325 data: Bytes,
326 expire: Option<Duration>,
327 nx: bool,
328 xx: bool,
329 },
330 #[cfg(feature = "protobuf")]
332 ProtoGet {
333 key: String,
334 },
335 #[cfg(feature = "protobuf")]
337 ProtoType {
338 key: String,
339 },
340 #[cfg(feature = "protobuf")]
344 ProtoRegisterAof {
345 name: String,
346 descriptor: Bytes,
347 },
348 #[cfg(feature = "protobuf")]
351 ProtoSetField {
352 key: String,
353 field_path: String,
354 value: String,
355 },
356 #[cfg(feature = "protobuf")]
359 ProtoDelField {
360 key: String,
361 field_path: String,
362 },
363}
364
365#[derive(Debug)]
367pub enum ShardResponse {
368 Value(Option<Value>),
370 Ok,
372 Integer(i64),
374 Bool(bool),
376 Ttl(TtlResult),
378 OutOfMemory,
380 KeyCount(usize),
382 Stats(KeyspaceStats),
384 Len(usize),
386 Array(Vec<Bytes>),
388 TypeName(&'static str),
390 ZAddLen {
392 count: usize,
393 applied: Vec<(f64, String)>,
394 },
395 ZRemLen { count: usize, removed: Vec<String> },
397 Score(Option<f64>),
399 Rank(Option<usize>),
401 ScoredArray(Vec<(String, f64)>),
403 BulkString(String),
405 WrongType,
407 Err(String),
409 Scan { cursor: u64, keys: Vec<String> },
411 HashFields(Vec<(String, Bytes)>),
413 HDelLen { count: usize, removed: Vec<String> },
415 StringArray(Vec<String>),
417 OptionalArray(Vec<Option<Bytes>>),
419 #[cfg(feature = "vector")]
421 VAddResult {
422 element: String,
423 vector: Vec<f32>,
424 added: bool,
425 },
426 #[cfg(feature = "vector")]
428 VAddBatchResult {
429 added_count: usize,
430 applied: Vec<(String, Vec<f32>)>,
431 },
432 #[cfg(feature = "vector")]
434 VSimResult(Vec<(String, f32)>),
435 #[cfg(feature = "vector")]
437 VectorData(Option<Vec<f32>>),
438 #[cfg(feature = "vector")]
440 VectorInfo(Option<Vec<(String, String)>>),
441 #[cfg(feature = "protobuf")]
443 ProtoValue(Option<(String, Bytes, Option<Duration>)>),
444 #[cfg(feature = "protobuf")]
446 ProtoTypeName(Option<String>),
447 #[cfg(feature = "protobuf")]
450 ProtoFieldUpdated {
451 type_name: String,
452 data: Bytes,
453 expire: Option<Duration>,
454 },
455}
456
457#[derive(Debug)]
459pub struct ShardMessage {
460 pub request: ShardRequest,
461 pub reply: oneshot::Sender<ShardResponse>,
462}
463
464#[derive(Debug, Clone)]
469pub struct ShardHandle {
470 tx: mpsc::Sender<ShardMessage>,
471}
472
473impl ShardHandle {
474 pub async fn send(&self, request: ShardRequest) -> Result<ShardResponse, ShardError> {
478 let rx = self.dispatch(request).await?;
479 rx.await.map_err(|_| ShardError::Unavailable)
480 }
481
482 pub async fn dispatch(
487 &self,
488 request: ShardRequest,
489 ) -> Result<oneshot::Receiver<ShardResponse>, ShardError> {
490 let (reply_tx, reply_rx) = oneshot::channel();
491 let msg = ShardMessage {
492 request,
493 reply: reply_tx,
494 };
495 self.tx
496 .send(msg)
497 .await
498 .map_err(|_| ShardError::Unavailable)?;
499 Ok(reply_rx)
500 }
501}
502
503pub fn spawn_shard(
509 buffer: usize,
510 config: ShardConfig,
511 persistence: Option<ShardPersistenceConfig>,
512 drop_handle: Option<DropHandle>,
513 #[cfg(feature = "protobuf")] schema_registry: Option<crate::schema::SharedSchemaRegistry>,
514) -> ShardHandle {
515 let (tx, rx) = mpsc::channel(buffer);
516 tokio::spawn(run_shard(
517 rx,
518 config,
519 persistence,
520 drop_handle,
521 #[cfg(feature = "protobuf")]
522 schema_registry,
523 ));
524 ShardHandle { tx }
525}
526
527async fn run_shard(
530 mut rx: mpsc::Receiver<ShardMessage>,
531 config: ShardConfig,
532 persistence: Option<ShardPersistenceConfig>,
533 drop_handle: Option<DropHandle>,
534 #[cfg(feature = "protobuf")] schema_registry: Option<crate::schema::SharedSchemaRegistry>,
535) {
536 let shard_id = config.shard_id;
537 let mut keyspace = Keyspace::with_config(config);
538
539 if let Some(handle) = drop_handle.clone() {
540 keyspace.set_drop_handle(handle);
541 }
542
543 if let Some(ref pcfg) = persistence {
545 #[cfg(feature = "encryption")]
546 let result = if let Some(ref key) = pcfg.encryption_key {
547 recovery::recover_shard_encrypted(&pcfg.data_dir, shard_id, key.clone())
548 } else {
549 recovery::recover_shard(&pcfg.data_dir, shard_id)
550 };
551 #[cfg(not(feature = "encryption"))]
552 let result = recovery::recover_shard(&pcfg.data_dir, shard_id);
553 let count = result.entries.len();
554 for entry in result.entries {
555 let value = match entry.value {
556 RecoveredValue::String(data) => Value::String(data),
557 RecoveredValue::List(deque) => Value::List(deque),
558 RecoveredValue::SortedSet(members) => {
559 let mut ss = crate::types::sorted_set::SortedSet::new();
560 for (score, member) in members {
561 ss.add(member, score);
562 }
563 Value::SortedSet(ss)
564 }
565 RecoveredValue::Hash(map) => Value::Hash(map),
566 RecoveredValue::Set(set) => Value::Set(set),
567 #[cfg(feature = "vector")]
568 RecoveredValue::Vector {
569 metric,
570 quantization,
571 connectivity,
572 expansion_add,
573 elements,
574 } => {
575 use crate::types::vector::{DistanceMetric, QuantizationType, VectorSet};
576 let dim = elements.first().map(|(_, v)| v.len()).unwrap_or(0);
577 match VectorSet::new(
578 dim,
579 DistanceMetric::from_u8(metric),
580 QuantizationType::from_u8(quantization),
581 connectivity as usize,
582 expansion_add as usize,
583 ) {
584 Ok(mut vs) => {
585 for (element, vector) in elements {
586 if let Err(e) = vs.add(element, &vector) {
587 warn!("vector recovery: failed to add element: {e}");
588 }
589 }
590 Value::Vector(vs)
591 }
592 Err(e) => {
593 warn!("vector recovery: failed to create index: {e}");
594 continue;
595 }
596 }
597 }
598 #[cfg(feature = "protobuf")]
599 RecoveredValue::Proto { type_name, data } => Value::Proto { type_name, data },
600 };
601 keyspace.restore(entry.key, value, entry.ttl);
602 }
603 if count > 0 {
604 info!(
605 shard_id,
606 recovered_keys = count,
607 snapshot = result.loaded_snapshot,
608 aof = result.replayed_aof,
609 "recovered shard state"
610 );
611 }
612
613 #[cfg(feature = "protobuf")]
615 if let Some(ref registry) = schema_registry {
616 if !result.schemas.is_empty() {
617 if let Ok(mut reg) = registry.write() {
618 let schema_count = result.schemas.len();
619 for (name, descriptor) in result.schemas {
620 reg.restore(name, descriptor);
621 }
622 info!(
623 shard_id,
624 schemas = schema_count,
625 "restored schemas from AOF"
626 );
627 }
628 }
629 }
630 }
631
632 let mut aof_writer: Option<AofWriter> = match &persistence {
634 Some(pcfg) if pcfg.append_only => {
635 let path = ember_persistence::aof::aof_path(&pcfg.data_dir, shard_id);
636 #[cfg(feature = "encryption")]
637 let result = if let Some(ref key) = pcfg.encryption_key {
638 AofWriter::open_encrypted(path, key.clone())
639 } else {
640 AofWriter::open(path)
641 };
642 #[cfg(not(feature = "encryption"))]
643 let result = AofWriter::open(path);
644 match result {
645 Ok(w) => Some(w),
646 Err(e) => {
647 warn!(shard_id, "failed to open AOF writer: {e}");
648 None
649 }
650 }
651 }
652 _ => None,
653 };
654
655 let fsync_policy = persistence
656 .as_ref()
657 .map(|p| p.fsync_policy)
658 .unwrap_or(FsyncPolicy::No);
659
660 let mut expiry_tick = tokio::time::interval(EXPIRY_TICK);
662 expiry_tick.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
663
664 let mut fsync_tick = tokio::time::interval(FSYNC_INTERVAL);
665 fsync_tick.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
666
667 loop {
668 tokio::select! {
669 msg = rx.recv() => {
670 match msg {
671 Some(msg) => {
672 process_message(
673 msg,
674 &mut keyspace,
675 &mut aof_writer,
676 fsync_policy,
677 &persistence,
678 &drop_handle,
679 shard_id,
680 #[cfg(feature = "protobuf")]
681 &schema_registry,
682 );
683
684 while let Ok(msg) = rx.try_recv() {
689 process_message(
690 msg,
691 &mut keyspace,
692 &mut aof_writer,
693 fsync_policy,
694 &persistence,
695 &drop_handle,
696 shard_id,
697 #[cfg(feature = "protobuf")]
698 &schema_registry,
699 );
700 }
701 }
702 None => break, }
704 }
705 _ = expiry_tick.tick() => {
706 expiry::run_expiration_cycle(&mut keyspace);
707 }
708 _ = fsync_tick.tick(), if fsync_policy == FsyncPolicy::EverySec => {
709 if let Some(ref mut writer) = aof_writer {
710 if let Err(e) = writer.sync() {
711 warn!(shard_id, "periodic aof sync failed: {e}");
712 }
713 }
714 }
715 }
716 }
717
718 if let Some(ref mut writer) = aof_writer {
720 let _ = writer.sync();
721 }
722}
723
724#[allow(clippy::too_many_arguments)]
730fn process_message(
731 msg: ShardMessage,
732 keyspace: &mut Keyspace,
733 aof_writer: &mut Option<AofWriter>,
734 fsync_policy: FsyncPolicy,
735 persistence: &Option<ShardPersistenceConfig>,
736 drop_handle: &Option<DropHandle>,
737 shard_id: u16,
738 #[cfg(feature = "protobuf")] schema_registry: &Option<crate::schema::SharedSchemaRegistry>,
739) {
740 let request_kind = describe_request(&msg.request);
741 let response = dispatch(
742 keyspace,
743 &msg.request,
744 #[cfg(feature = "protobuf")]
745 schema_registry,
746 );
747
748 if let Some(ref mut writer) = aof_writer {
750 let records = to_aof_records(&msg.request, &response);
751 for record in &records {
752 if let Err(e) = writer.write_record(record) {
753 warn!(shard_id, "aof write failed: {e}");
754 }
755 }
756 if !records.is_empty() && fsync_policy == FsyncPolicy::Always {
757 if let Err(e) = writer.sync() {
758 warn!(shard_id, "aof sync failed: {e}");
759 }
760 }
761 }
762
763 match request_kind {
765 RequestKind::Snapshot => {
766 let resp = handle_snapshot(keyspace, persistence, shard_id);
767 let _ = msg.reply.send(resp);
768 return;
769 }
770 RequestKind::RewriteAof => {
771 let resp = handle_rewrite(
772 keyspace,
773 persistence,
774 aof_writer,
775 shard_id,
776 #[cfg(feature = "protobuf")]
777 schema_registry,
778 );
779 let _ = msg.reply.send(resp);
780 return;
781 }
782 RequestKind::FlushDbAsync => {
783 let old_entries = keyspace.flush_async();
784 if let Some(ref handle) = drop_handle {
785 handle.defer_entries(old_entries);
786 }
787 let _ = msg.reply.send(ShardResponse::Ok);
788 return;
789 }
790 RequestKind::Other => {}
791 }
792
793 let _ = msg.reply.send(response);
794}
795
796enum RequestKind {
799 Snapshot,
800 RewriteAof,
801 FlushDbAsync,
802 Other,
803}
804
805fn describe_request(req: &ShardRequest) -> RequestKind {
806 match req {
807 ShardRequest::Snapshot => RequestKind::Snapshot,
808 ShardRequest::RewriteAof => RequestKind::RewriteAof,
809 ShardRequest::FlushDbAsync => RequestKind::FlushDbAsync,
810 _ => RequestKind::Other,
811 }
812}
813
814fn incr_result(result: Result<i64, IncrError>) -> ShardResponse {
816 match result {
817 Ok(val) => ShardResponse::Integer(val),
818 Err(IncrError::WrongType) => ShardResponse::WrongType,
819 Err(IncrError::OutOfMemory) => ShardResponse::OutOfMemory,
820 Err(e) => ShardResponse::Err(e.to_string()),
821 }
822}
823
824fn write_result_len(result: Result<usize, WriteError>) -> ShardResponse {
826 match result {
827 Ok(len) => ShardResponse::Len(len),
828 Err(WriteError::WrongType) => ShardResponse::WrongType,
829 Err(WriteError::OutOfMemory) => ShardResponse::OutOfMemory,
830 }
831}
832
833fn dispatch(
835 ks: &mut Keyspace,
836 req: &ShardRequest,
837 #[cfg(feature = "protobuf")] schema_registry: &Option<crate::schema::SharedSchemaRegistry>,
838) -> ShardResponse {
839 match req {
840 ShardRequest::Get { key } => match ks.get_string(key) {
841 Ok(val) => ShardResponse::Value(val.map(Value::String)),
842 Err(_) => ShardResponse::WrongType,
843 },
844 ShardRequest::Set {
845 key,
846 value,
847 expire,
848 nx,
849 xx,
850 } => {
851 if *nx && ks.exists(key) {
853 return ShardResponse::Value(None);
854 }
855 if *xx && !ks.exists(key) {
857 return ShardResponse::Value(None);
858 }
859 match ks.set(key.clone(), value.clone(), *expire) {
860 SetResult::Ok => ShardResponse::Ok,
861 SetResult::OutOfMemory => ShardResponse::OutOfMemory,
862 }
863 }
864 ShardRequest::Incr { key } => incr_result(ks.incr(key)),
865 ShardRequest::Decr { key } => incr_result(ks.decr(key)),
866 ShardRequest::IncrBy { key, delta } => incr_result(ks.incr_by(key, *delta)),
867 ShardRequest::DecrBy { key, delta } => match delta.checked_neg() {
868 Some(neg) => incr_result(ks.incr_by(key, neg)),
869 None => ShardResponse::Err("ERR increment or decrement would overflow".into()),
870 },
871 ShardRequest::IncrByFloat { key, delta } => match ks.incr_by_float(key, *delta) {
872 Ok(val) => ShardResponse::BulkString(val),
873 Err(IncrFloatError::WrongType) => ShardResponse::WrongType,
874 Err(IncrFloatError::OutOfMemory) => ShardResponse::OutOfMemory,
875 Err(e) => ShardResponse::Err(e.to_string()),
876 },
877 ShardRequest::Append { key, value } => write_result_len(ks.append(key, value)),
878 ShardRequest::Strlen { key } => match ks.strlen(key) {
879 Ok(len) => ShardResponse::Len(len),
880 Err(_) => ShardResponse::WrongType,
881 },
882 ShardRequest::Keys { pattern } => {
883 let keys = ks.keys(pattern);
884 ShardResponse::StringArray(keys)
885 }
886 ShardRequest::Rename { key, newkey } => {
887 use crate::keyspace::RenameError;
888 match ks.rename(key, newkey) {
889 Ok(()) => ShardResponse::Ok,
890 Err(RenameError::NoSuchKey) => ShardResponse::Err("ERR no such key".into()),
891 }
892 }
893 ShardRequest::Del { key } => ShardResponse::Bool(ks.del(key)),
894 ShardRequest::Unlink { key } => ShardResponse::Bool(ks.unlink(key)),
895 ShardRequest::Exists { key } => ShardResponse::Bool(ks.exists(key)),
896 ShardRequest::Expire { key, seconds } => ShardResponse::Bool(ks.expire(key, *seconds)),
897 ShardRequest::Ttl { key } => ShardResponse::Ttl(ks.ttl(key)),
898 ShardRequest::Persist { key } => ShardResponse::Bool(ks.persist(key)),
899 ShardRequest::Pttl { key } => ShardResponse::Ttl(ks.pttl(key)),
900 ShardRequest::Pexpire { key, milliseconds } => {
901 ShardResponse::Bool(ks.pexpire(key, *milliseconds))
902 }
903 ShardRequest::LPush { key, values } => write_result_len(ks.lpush(key, values)),
904 ShardRequest::RPush { key, values } => write_result_len(ks.rpush(key, values)),
905 ShardRequest::LPop { key } => match ks.lpop(key) {
906 Ok(val) => ShardResponse::Value(val.map(Value::String)),
907 Err(_) => ShardResponse::WrongType,
908 },
909 ShardRequest::RPop { key } => match ks.rpop(key) {
910 Ok(val) => ShardResponse::Value(val.map(Value::String)),
911 Err(_) => ShardResponse::WrongType,
912 },
913 ShardRequest::LRange { key, start, stop } => match ks.lrange(key, *start, *stop) {
914 Ok(items) => ShardResponse::Array(items),
915 Err(_) => ShardResponse::WrongType,
916 },
917 ShardRequest::LLen { key } => match ks.llen(key) {
918 Ok(len) => ShardResponse::Len(len),
919 Err(_) => ShardResponse::WrongType,
920 },
921 ShardRequest::Type { key } => ShardResponse::TypeName(ks.value_type(key)),
922 ShardRequest::ZAdd {
923 key,
924 members,
925 nx,
926 xx,
927 gt,
928 lt,
929 ch,
930 } => {
931 let flags = ZAddFlags {
932 nx: *nx,
933 xx: *xx,
934 gt: *gt,
935 lt: *lt,
936 ch: *ch,
937 };
938 match ks.zadd(key, members, &flags) {
939 Ok(result) => ShardResponse::ZAddLen {
940 count: result.count,
941 applied: result.applied,
942 },
943 Err(WriteError::WrongType) => ShardResponse::WrongType,
944 Err(WriteError::OutOfMemory) => ShardResponse::OutOfMemory,
945 }
946 }
947 ShardRequest::ZRem { key, members } => match ks.zrem(key, members) {
948 Ok(removed) => ShardResponse::ZRemLen {
949 count: removed.len(),
950 removed,
951 },
952 Err(_) => ShardResponse::WrongType,
953 },
954 ShardRequest::ZScore { key, member } => match ks.zscore(key, member) {
955 Ok(score) => ShardResponse::Score(score),
956 Err(_) => ShardResponse::WrongType,
957 },
958 ShardRequest::ZRank { key, member } => match ks.zrank(key, member) {
959 Ok(rank) => ShardResponse::Rank(rank),
960 Err(_) => ShardResponse::WrongType,
961 },
962 ShardRequest::ZCard { key } => match ks.zcard(key) {
963 Ok(len) => ShardResponse::Len(len),
964 Err(_) => ShardResponse::WrongType,
965 },
966 ShardRequest::ZRange {
967 key, start, stop, ..
968 } => match ks.zrange(key, *start, *stop) {
969 Ok(items) => ShardResponse::ScoredArray(items),
970 Err(_) => ShardResponse::WrongType,
971 },
972 ShardRequest::DbSize => ShardResponse::KeyCount(ks.len()),
973 ShardRequest::Stats => ShardResponse::Stats(ks.stats()),
974 ShardRequest::FlushDb => {
975 ks.clear();
976 ShardResponse::Ok
977 }
978 ShardRequest::Scan {
979 cursor,
980 count,
981 pattern,
982 } => {
983 let (next_cursor, keys) = ks.scan_keys(*cursor, *count, pattern.as_deref());
984 ShardResponse::Scan {
985 cursor: next_cursor,
986 keys,
987 }
988 }
989 ShardRequest::HSet { key, fields } => write_result_len(ks.hset(key, fields)),
990 ShardRequest::HGet { key, field } => match ks.hget(key, field) {
991 Ok(val) => ShardResponse::Value(val.map(Value::String)),
992 Err(_) => ShardResponse::WrongType,
993 },
994 ShardRequest::HGetAll { key } => match ks.hgetall(key) {
995 Ok(fields) => ShardResponse::HashFields(fields),
996 Err(_) => ShardResponse::WrongType,
997 },
998 ShardRequest::HDel { key, fields } => match ks.hdel(key, fields) {
999 Ok(removed) => ShardResponse::HDelLen {
1000 count: removed.len(),
1001 removed,
1002 },
1003 Err(_) => ShardResponse::WrongType,
1004 },
1005 ShardRequest::HExists { key, field } => match ks.hexists(key, field) {
1006 Ok(exists) => ShardResponse::Bool(exists),
1007 Err(_) => ShardResponse::WrongType,
1008 },
1009 ShardRequest::HLen { key } => match ks.hlen(key) {
1010 Ok(len) => ShardResponse::Len(len),
1011 Err(_) => ShardResponse::WrongType,
1012 },
1013 ShardRequest::HIncrBy { key, field, delta } => incr_result(ks.hincrby(key, field, *delta)),
1014 ShardRequest::HKeys { key } => match ks.hkeys(key) {
1015 Ok(keys) => ShardResponse::StringArray(keys),
1016 Err(_) => ShardResponse::WrongType,
1017 },
1018 ShardRequest::HVals { key } => match ks.hvals(key) {
1019 Ok(vals) => ShardResponse::Array(vals),
1020 Err(_) => ShardResponse::WrongType,
1021 },
1022 ShardRequest::HMGet { key, fields } => match ks.hmget(key, fields) {
1023 Ok(vals) => ShardResponse::OptionalArray(vals),
1024 Err(_) => ShardResponse::WrongType,
1025 },
1026 ShardRequest::SAdd { key, members } => write_result_len(ks.sadd(key, members)),
1027 ShardRequest::SRem { key, members } => match ks.srem(key, members) {
1028 Ok(count) => ShardResponse::Len(count),
1029 Err(_) => ShardResponse::WrongType,
1030 },
1031 ShardRequest::SMembers { key } => match ks.smembers(key) {
1032 Ok(members) => ShardResponse::StringArray(members),
1033 Err(_) => ShardResponse::WrongType,
1034 },
1035 ShardRequest::SIsMember { key, member } => match ks.sismember(key, member) {
1036 Ok(exists) => ShardResponse::Bool(exists),
1037 Err(_) => ShardResponse::WrongType,
1038 },
1039 ShardRequest::SCard { key } => match ks.scard(key) {
1040 Ok(count) => ShardResponse::Len(count),
1041 Err(_) => ShardResponse::WrongType,
1042 },
1043 ShardRequest::CountKeysInSlot { slot } => {
1044 ShardResponse::KeyCount(ks.count_keys_in_slot(*slot))
1045 }
1046 ShardRequest::GetKeysInSlot { slot, count } => {
1047 ShardResponse::StringArray(ks.get_keys_in_slot(*slot, *count))
1048 }
1049 #[cfg(feature = "vector")]
1050 ShardRequest::VAdd {
1051 key,
1052 element,
1053 vector,
1054 metric,
1055 quantization,
1056 connectivity,
1057 expansion_add,
1058 } => {
1059 use crate::types::vector::{DistanceMetric, QuantizationType};
1060 match ks.vadd(
1061 key,
1062 element.clone(),
1063 vector.clone(),
1064 DistanceMetric::from_u8(*metric),
1065 QuantizationType::from_u8(*quantization),
1066 *connectivity as usize,
1067 *expansion_add as usize,
1068 ) {
1069 Ok(result) => ShardResponse::VAddResult {
1070 element: result.element,
1071 vector: result.vector,
1072 added: result.added,
1073 },
1074 Err(crate::keyspace::VectorWriteError::WrongType) => ShardResponse::WrongType,
1075 Err(crate::keyspace::VectorWriteError::OutOfMemory) => ShardResponse::OutOfMemory,
1076 Err(crate::keyspace::VectorWriteError::IndexError(e))
1077 | Err(crate::keyspace::VectorWriteError::PartialBatch { message: e, .. }) => {
1078 ShardResponse::Err(format!("ERR vector index: {e}"))
1079 }
1080 }
1081 }
1082 #[cfg(feature = "vector")]
1083 ShardRequest::VAddBatch {
1084 key,
1085 entries,
1086 metric,
1087 quantization,
1088 connectivity,
1089 expansion_add,
1090 ..
1091 } => {
1092 use crate::types::vector::{DistanceMetric, QuantizationType};
1093 match ks.vadd_batch(
1094 key,
1095 entries,
1096 DistanceMetric::from_u8(*metric),
1097 QuantizationType::from_u8(*quantization),
1098 *connectivity as usize,
1099 *expansion_add as usize,
1100 ) {
1101 Ok(result) => ShardResponse::VAddBatchResult {
1102 added_count: result.added_count,
1103 applied: result.applied,
1104 },
1105 Err(crate::keyspace::VectorWriteError::WrongType) => ShardResponse::WrongType,
1106 Err(crate::keyspace::VectorWriteError::OutOfMemory) => ShardResponse::OutOfMemory,
1107 Err(crate::keyspace::VectorWriteError::IndexError(e)) => {
1108 ShardResponse::Err(format!("ERR vector index: {e}"))
1109 }
1110 Err(crate::keyspace::VectorWriteError::PartialBatch { applied, .. }) => {
1111 ShardResponse::VAddBatchResult {
1113 added_count: applied.len(),
1114 applied,
1115 }
1116 }
1117 }
1118 }
1119 #[cfg(feature = "vector")]
1120 ShardRequest::VSim {
1121 key,
1122 query,
1123 count,
1124 ef_search,
1125 } => match ks.vsim(key, query, *count, *ef_search) {
1126 Ok(results) => ShardResponse::VSimResult(
1127 results
1128 .into_iter()
1129 .map(|r| (r.element, r.distance))
1130 .collect(),
1131 ),
1132 Err(_) => ShardResponse::WrongType,
1133 },
1134 #[cfg(feature = "vector")]
1135 ShardRequest::VRem { key, element } => match ks.vrem(key, element) {
1136 Ok(removed) => ShardResponse::Bool(removed),
1137 Err(_) => ShardResponse::WrongType,
1138 },
1139 #[cfg(feature = "vector")]
1140 ShardRequest::VGet { key, element } => match ks.vget(key, element) {
1141 Ok(data) => ShardResponse::VectorData(data),
1142 Err(_) => ShardResponse::WrongType,
1143 },
1144 #[cfg(feature = "vector")]
1145 ShardRequest::VCard { key } => match ks.vcard(key) {
1146 Ok(count) => ShardResponse::Integer(count as i64),
1147 Err(_) => ShardResponse::WrongType,
1148 },
1149 #[cfg(feature = "vector")]
1150 ShardRequest::VDim { key } => match ks.vdim(key) {
1151 Ok(dim) => ShardResponse::Integer(dim as i64),
1152 Err(_) => ShardResponse::WrongType,
1153 },
1154 #[cfg(feature = "vector")]
1155 ShardRequest::VInfo { key } => match ks.vinfo(key) {
1156 Ok(Some(info)) => {
1157 let fields = vec![
1158 ("dim".to_owned(), info.dim.to_string()),
1159 ("count".to_owned(), info.count.to_string()),
1160 ("metric".to_owned(), info.metric.to_string()),
1161 ("quantization".to_owned(), info.quantization.to_string()),
1162 ("connectivity".to_owned(), info.connectivity.to_string()),
1163 ("expansion_add".to_owned(), info.expansion_add.to_string()),
1164 ];
1165 ShardResponse::VectorInfo(Some(fields))
1166 }
1167 Ok(None) => ShardResponse::VectorInfo(None),
1168 Err(_) => ShardResponse::WrongType,
1169 },
1170 #[cfg(feature = "protobuf")]
1171 ShardRequest::ProtoSet {
1172 key,
1173 type_name,
1174 data,
1175 expire,
1176 nx,
1177 xx,
1178 } => {
1179 if *nx && ks.exists(key) {
1180 return ShardResponse::Value(None);
1181 }
1182 if *xx && !ks.exists(key) {
1183 return ShardResponse::Value(None);
1184 }
1185 match ks.proto_set(key.clone(), type_name.clone(), data.clone(), *expire) {
1186 SetResult::Ok => ShardResponse::Ok,
1187 SetResult::OutOfMemory => ShardResponse::OutOfMemory,
1188 }
1189 }
1190 #[cfg(feature = "protobuf")]
1191 ShardRequest::ProtoGet { key } => match ks.proto_get(key) {
1192 Ok(val) => ShardResponse::ProtoValue(val),
1193 Err(_) => ShardResponse::WrongType,
1194 },
1195 #[cfg(feature = "protobuf")]
1196 ShardRequest::ProtoType { key } => match ks.proto_type(key) {
1197 Ok(name) => ShardResponse::ProtoTypeName(name),
1198 Err(_) => ShardResponse::WrongType,
1199 },
1200 #[cfg(feature = "protobuf")]
1203 ShardRequest::ProtoRegisterAof { .. } => ShardResponse::Ok,
1204 #[cfg(feature = "protobuf")]
1205 ShardRequest::ProtoSetField {
1206 key,
1207 field_path,
1208 value,
1209 } => dispatch_proto_field_op(ks, schema_registry, key, |reg, type_name, data, ttl| {
1210 let new_data = reg.set_field(type_name, data, field_path, value)?;
1211 Ok(ShardResponse::ProtoFieldUpdated {
1212 type_name: type_name.to_owned(),
1213 data: new_data,
1214 expire: ttl,
1215 })
1216 }),
1217 #[cfg(feature = "protobuf")]
1218 ShardRequest::ProtoDelField { key, field_path } => {
1219 dispatch_proto_field_op(ks, schema_registry, key, |reg, type_name, data, ttl| {
1220 let new_data = reg.clear_field(type_name, data, field_path)?;
1221 Ok(ShardResponse::ProtoFieldUpdated {
1222 type_name: type_name.to_owned(),
1223 data: new_data,
1224 expire: ttl,
1225 })
1226 })
1227 }
1228 ShardRequest::Snapshot | ShardRequest::RewriteAof | ShardRequest::FlushDbAsync => {
1230 ShardResponse::Ok
1231 }
1232 }
1233}
1234
1235#[cfg(feature = "protobuf")]
1241fn dispatch_proto_field_op<F>(
1242 ks: &mut Keyspace,
1243 schema_registry: &Option<crate::schema::SharedSchemaRegistry>,
1244 key: &str,
1245 mutate: F,
1246) -> ShardResponse
1247where
1248 F: FnOnce(
1249 &crate::schema::SchemaRegistry,
1250 &str,
1251 &[u8],
1252 Option<Duration>,
1253 ) -> Result<ShardResponse, crate::schema::SchemaError>,
1254{
1255 let registry = match schema_registry {
1256 Some(r) => r,
1257 None => return ShardResponse::Err("protobuf support is not enabled".into()),
1258 };
1259
1260 let (type_name, data, remaining_ttl) = match ks.proto_get(key) {
1261 Ok(Some(tuple)) => tuple,
1262 Ok(None) => return ShardResponse::Value(None),
1263 Err(_) => return ShardResponse::WrongType,
1264 };
1265
1266 let reg = match registry.read() {
1267 Ok(r) => r,
1268 Err(_) => return ShardResponse::Err("schema registry lock poisoned".into()),
1269 };
1270
1271 let resp = match mutate(®, &type_name, &data, remaining_ttl) {
1272 Ok(r) => r,
1273 Err(e) => return ShardResponse::Err(e.to_string()),
1274 };
1275
1276 if let ShardResponse::ProtoFieldUpdated {
1278 ref type_name,
1279 ref data,
1280 expire,
1281 } = resp
1282 {
1283 ks.proto_set(key.to_owned(), type_name.clone(), data.clone(), expire);
1284 }
1285
1286 resp
1287}
1288
1289fn duration_to_expire_ms(d: Duration) -> i64 {
1295 let ms = d.as_millis();
1296 if ms > i64::MAX as u128 {
1297 i64::MAX
1298 } else {
1299 ms as i64
1300 }
1301}
1302
1303fn to_aof_records(req: &ShardRequest, resp: &ShardResponse) -> Vec<AofRecord> {
1306 match (req, resp) {
1307 (
1308 ShardRequest::Set {
1309 key, value, expire, ..
1310 },
1311 ShardResponse::Ok,
1312 ) => {
1313 let expire_ms = expire.map(duration_to_expire_ms).unwrap_or(-1);
1314 vec![AofRecord::Set {
1315 key: key.clone(),
1316 value: value.clone(),
1317 expire_ms,
1318 }]
1319 }
1320 (ShardRequest::Del { key }, ShardResponse::Bool(true))
1321 | (ShardRequest::Unlink { key }, ShardResponse::Bool(true)) => {
1322 vec![AofRecord::Del { key: key.clone() }]
1323 }
1324 (ShardRequest::Expire { key, seconds }, ShardResponse::Bool(true)) => {
1325 vec![AofRecord::Expire {
1326 key: key.clone(),
1327 seconds: *seconds,
1328 }]
1329 }
1330 (ShardRequest::LPush { key, values }, ShardResponse::Len(_)) => vec![AofRecord::LPush {
1331 key: key.clone(),
1332 values: values.clone(),
1333 }],
1334 (ShardRequest::RPush { key, values }, ShardResponse::Len(_)) => vec![AofRecord::RPush {
1335 key: key.clone(),
1336 values: values.clone(),
1337 }],
1338 (ShardRequest::LPop { key }, ShardResponse::Value(Some(_))) => {
1339 vec![AofRecord::LPop { key: key.clone() }]
1340 }
1341 (ShardRequest::RPop { key }, ShardResponse::Value(Some(_))) => {
1342 vec![AofRecord::RPop { key: key.clone() }]
1343 }
1344 (ShardRequest::ZAdd { key, .. }, ShardResponse::ZAddLen { applied, .. })
1345 if !applied.is_empty() =>
1346 {
1347 vec![AofRecord::ZAdd {
1348 key: key.clone(),
1349 members: applied.clone(),
1350 }]
1351 }
1352 (ShardRequest::ZRem { key, .. }, ShardResponse::ZRemLen { removed, .. })
1353 if !removed.is_empty() =>
1354 {
1355 vec![AofRecord::ZRem {
1356 key: key.clone(),
1357 members: removed.clone(),
1358 }]
1359 }
1360 (ShardRequest::Incr { key }, ShardResponse::Integer(_)) => {
1361 vec![AofRecord::Incr { key: key.clone() }]
1362 }
1363 (ShardRequest::Decr { key }, ShardResponse::Integer(_)) => {
1364 vec![AofRecord::Decr { key: key.clone() }]
1365 }
1366 (ShardRequest::IncrBy { key, delta }, ShardResponse::Integer(_)) => {
1367 vec![AofRecord::IncrBy {
1368 key: key.clone(),
1369 delta: *delta,
1370 }]
1371 }
1372 (ShardRequest::DecrBy { key, delta }, ShardResponse::Integer(_)) => {
1373 vec![AofRecord::DecrBy {
1374 key: key.clone(),
1375 delta: *delta,
1376 }]
1377 }
1378 (ShardRequest::IncrByFloat { key, .. }, ShardResponse::BulkString(val)) => {
1381 vec![AofRecord::Set {
1382 key: key.clone(),
1383 value: Bytes::from(val.clone()),
1384 expire_ms: -1,
1385 }]
1386 }
1387 (ShardRequest::Append { key, value }, ShardResponse::Len(_)) => vec![AofRecord::Append {
1389 key: key.clone(),
1390 value: value.clone(),
1391 }],
1392 (ShardRequest::Rename { key, newkey }, ShardResponse::Ok) => vec![AofRecord::Rename {
1393 key: key.clone(),
1394 newkey: newkey.clone(),
1395 }],
1396 (ShardRequest::Persist { key }, ShardResponse::Bool(true)) => {
1397 vec![AofRecord::Persist { key: key.clone() }]
1398 }
1399 (ShardRequest::Pexpire { key, milliseconds }, ShardResponse::Bool(true)) => {
1400 vec![AofRecord::Pexpire {
1401 key: key.clone(),
1402 milliseconds: *milliseconds,
1403 }]
1404 }
1405 (ShardRequest::HSet { key, fields }, ShardResponse::Len(_)) => vec![AofRecord::HSet {
1407 key: key.clone(),
1408 fields: fields.clone(),
1409 }],
1410 (ShardRequest::HDel { key, .. }, ShardResponse::HDelLen { removed, .. })
1411 if !removed.is_empty() =>
1412 {
1413 vec![AofRecord::HDel {
1414 key: key.clone(),
1415 fields: removed.clone(),
1416 }]
1417 }
1418 (ShardRequest::HIncrBy { key, field, delta }, ShardResponse::Integer(_)) => {
1419 vec![AofRecord::HIncrBy {
1420 key: key.clone(),
1421 field: field.clone(),
1422 delta: *delta,
1423 }]
1424 }
1425 (ShardRequest::SAdd { key, members }, ShardResponse::Len(count)) if *count > 0 => {
1427 vec![AofRecord::SAdd {
1428 key: key.clone(),
1429 members: members.clone(),
1430 }]
1431 }
1432 (ShardRequest::SRem { key, members }, ShardResponse::Len(count)) if *count > 0 => {
1433 vec![AofRecord::SRem {
1434 key: key.clone(),
1435 members: members.clone(),
1436 }]
1437 }
1438 #[cfg(feature = "protobuf")]
1440 (
1441 ShardRequest::ProtoSet {
1442 key,
1443 type_name,
1444 data,
1445 expire,
1446 ..
1447 },
1448 ShardResponse::Ok,
1449 ) => {
1450 let expire_ms = expire.map(duration_to_expire_ms).unwrap_or(-1);
1451 vec![AofRecord::ProtoSet {
1452 key: key.clone(),
1453 type_name: type_name.clone(),
1454 data: data.clone(),
1455 expire_ms,
1456 }]
1457 }
1458 #[cfg(feature = "protobuf")]
1459 (ShardRequest::ProtoRegisterAof { name, descriptor }, ShardResponse::Ok) => {
1460 vec![AofRecord::ProtoRegister {
1461 name: name.clone(),
1462 descriptor: descriptor.clone(),
1463 }]
1464 }
1465 #[cfg(feature = "protobuf")]
1467 (
1468 ShardRequest::ProtoSetField { key, .. } | ShardRequest::ProtoDelField { key, .. },
1469 ShardResponse::ProtoFieldUpdated {
1470 type_name,
1471 data,
1472 expire,
1473 },
1474 ) => {
1475 let expire_ms = expire.map(duration_to_expire_ms).unwrap_or(-1);
1476 vec![AofRecord::ProtoSet {
1477 key: key.clone(),
1478 type_name: type_name.clone(),
1479 data: data.clone(),
1480 expire_ms,
1481 }]
1482 }
1483 #[cfg(feature = "vector")]
1485 (
1486 ShardRequest::VAdd {
1487 key,
1488 metric,
1489 quantization,
1490 connectivity,
1491 expansion_add,
1492 ..
1493 },
1494 ShardResponse::VAddResult {
1495 element, vector, ..
1496 },
1497 ) => vec![AofRecord::VAdd {
1498 key: key.clone(),
1499 element: element.clone(),
1500 vector: vector.clone(),
1501 metric: *metric,
1502 quantization: *quantization,
1503 connectivity: *connectivity,
1504 expansion_add: *expansion_add,
1505 }],
1506 #[cfg(feature = "vector")]
1508 (
1509 ShardRequest::VAddBatch {
1510 key,
1511 metric,
1512 quantization,
1513 connectivity,
1514 expansion_add,
1515 ..
1516 },
1517 ShardResponse::VAddBatchResult { applied, .. },
1518 ) => applied
1519 .iter()
1520 .map(|(element, vector)| AofRecord::VAdd {
1521 key: key.clone(),
1522 element: element.clone(),
1523 vector: vector.clone(),
1524 metric: *metric,
1525 quantization: *quantization,
1526 connectivity: *connectivity,
1527 expansion_add: *expansion_add,
1528 })
1529 .collect(),
1530 #[cfg(feature = "vector")]
1531 (ShardRequest::VRem { key, element }, ShardResponse::Bool(true)) => vec![AofRecord::VRem {
1532 key: key.clone(),
1533 element: element.clone(),
1534 }],
1535 _ => vec![],
1536 }
1537}
1538
1539fn handle_snapshot(
1541 keyspace: &Keyspace,
1542 persistence: &Option<ShardPersistenceConfig>,
1543 shard_id: u16,
1544) -> ShardResponse {
1545 let pcfg = match persistence {
1546 Some(p) => p,
1547 None => return ShardResponse::Err("persistence not configured".into()),
1548 };
1549
1550 let path = snapshot::snapshot_path(&pcfg.data_dir, shard_id);
1551 let result = write_snapshot(
1552 keyspace,
1553 &path,
1554 shard_id,
1555 #[cfg(feature = "encryption")]
1556 pcfg.encryption_key.as_ref(),
1557 );
1558 match result {
1559 Ok(count) => {
1560 info!(shard_id, entries = count, "snapshot written");
1561 ShardResponse::Ok
1562 }
1563 Err(e) => {
1564 warn!(shard_id, "snapshot failed: {e}");
1565 ShardResponse::Err(format!("snapshot failed: {e}"))
1566 }
1567 }
1568}
1569
1570fn handle_rewrite(
1575 keyspace: &Keyspace,
1576 persistence: &Option<ShardPersistenceConfig>,
1577 aof_writer: &mut Option<AofWriter>,
1578 shard_id: u16,
1579 #[cfg(feature = "protobuf")] schema_registry: &Option<crate::schema::SharedSchemaRegistry>,
1580) -> ShardResponse {
1581 let pcfg = match persistence {
1582 Some(p) => p,
1583 None => return ShardResponse::Err("persistence not configured".into()),
1584 };
1585
1586 let path = snapshot::snapshot_path(&pcfg.data_dir, shard_id);
1587 let result = write_snapshot(
1588 keyspace,
1589 &path,
1590 shard_id,
1591 #[cfg(feature = "encryption")]
1592 pcfg.encryption_key.as_ref(),
1593 );
1594 match result {
1595 Ok(count) => {
1596 if let Some(ref mut writer) = aof_writer {
1598 if let Err(e) = writer.truncate() {
1599 warn!(shard_id, "aof truncate after rewrite failed: {e}");
1600 }
1601
1602 #[cfg(feature = "protobuf")]
1604 if let Some(ref registry) = schema_registry {
1605 if let Ok(reg) = registry.read() {
1606 for (name, descriptor) in reg.iter_schemas() {
1607 let record = AofRecord::ProtoRegister {
1608 name: name.to_owned(),
1609 descriptor: descriptor.clone(),
1610 };
1611 if let Err(e) = writer.write_record(&record) {
1612 warn!(shard_id, "failed to re-persist schema after rewrite: {e}");
1613 }
1614 }
1615 }
1616 }
1617
1618 if let Err(e) = writer.sync() {
1620 warn!(shard_id, "aof sync after rewrite failed: {e}");
1621 }
1622 }
1623 info!(shard_id, entries = count, "aof rewrite complete");
1624 ShardResponse::Ok
1625 }
1626 Err(e) => {
1627 warn!(shard_id, "aof rewrite failed: {e}");
1628 ShardResponse::Err(format!("rewrite failed: {e}"))
1629 }
1630 }
1631}
1632
1633fn write_snapshot(
1635 keyspace: &Keyspace,
1636 path: &std::path::Path,
1637 shard_id: u16,
1638 #[cfg(feature = "encryption")] encryption_key: Option<
1639 &ember_persistence::encryption::EncryptionKey,
1640 >,
1641) -> Result<u32, ember_persistence::format::FormatError> {
1642 #[cfg(feature = "encryption")]
1643 let mut writer = if let Some(key) = encryption_key {
1644 SnapshotWriter::create_encrypted(path, shard_id, key.clone())?
1645 } else {
1646 SnapshotWriter::create(path, shard_id)?
1647 };
1648 #[cfg(not(feature = "encryption"))]
1649 let mut writer = SnapshotWriter::create(path, shard_id)?;
1650 let mut count = 0u32;
1651
1652 for (key, value, ttl_ms) in keyspace.iter_entries() {
1653 let snap_value = match value {
1654 Value::String(data) => SnapValue::String(data.clone()),
1655 Value::List(deque) => SnapValue::List(deque.clone()),
1656 Value::SortedSet(ss) => {
1657 let members: Vec<(f64, String)> = ss
1658 .iter()
1659 .map(|(member, score)| (score, member.to_owned()))
1660 .collect();
1661 SnapValue::SortedSet(members)
1662 }
1663 Value::Hash(map) => SnapValue::Hash(map.clone()),
1664 Value::Set(set) => SnapValue::Set(set.clone()),
1665 #[cfg(feature = "vector")]
1666 Value::Vector(ref vs) => {
1667 let mut elements = Vec::with_capacity(vs.len());
1668 for name in vs.elements() {
1669 if let Some(vec) = vs.get(name) {
1670 elements.push((name.to_owned(), vec));
1671 }
1672 }
1673 SnapValue::Vector {
1674 metric: vs.metric().into(),
1675 quantization: vs.quantization().into(),
1676 connectivity: vs.connectivity() as u32,
1677 expansion_add: vs.expansion_add() as u32,
1678 dim: vs.dim() as u32,
1679 elements,
1680 }
1681 }
1682 #[cfg(feature = "protobuf")]
1683 Value::Proto { type_name, data } => SnapValue::Proto {
1684 type_name: type_name.clone(),
1685 data: data.clone(),
1686 },
1687 };
1688 writer.write_entry(&SnapEntry {
1689 key: key.to_owned(),
1690 value: snap_value,
1691 expire_ms: ttl_ms,
1692 })?;
1693 count += 1;
1694 }
1695
1696 writer.finish()?;
1697 Ok(count)
1698}
1699
1700#[cfg(test)]
1701mod tests {
1702 use super::*;
1703
1704 fn test_dispatch(ks: &mut Keyspace, req: &ShardRequest) -> ShardResponse {
1706 dispatch(
1707 ks,
1708 req,
1709 #[cfg(feature = "protobuf")]
1710 &None,
1711 )
1712 }
1713
1714 #[test]
1715 fn dispatch_set_and_get() {
1716 let mut ks = Keyspace::new();
1717
1718 let resp = test_dispatch(
1719 &mut ks,
1720 &ShardRequest::Set {
1721 key: "k".into(),
1722 value: Bytes::from("v"),
1723 expire: None,
1724 nx: false,
1725 xx: false,
1726 },
1727 );
1728 assert!(matches!(resp, ShardResponse::Ok));
1729
1730 let resp = test_dispatch(&mut ks, &ShardRequest::Get { key: "k".into() });
1731 match resp {
1732 ShardResponse::Value(Some(Value::String(data))) => {
1733 assert_eq!(data, Bytes::from("v"));
1734 }
1735 other => panic!("expected Value(Some(String)), got {other:?}"),
1736 }
1737 }
1738
1739 #[test]
1740 fn dispatch_get_missing() {
1741 let mut ks = Keyspace::new();
1742 let resp = test_dispatch(&mut ks, &ShardRequest::Get { key: "nope".into() });
1743 assert!(matches!(resp, ShardResponse::Value(None)));
1744 }
1745
1746 #[test]
1747 fn dispatch_del() {
1748 let mut ks = Keyspace::new();
1749 ks.set("key".into(), Bytes::from("val"), None);
1750
1751 let resp = test_dispatch(&mut ks, &ShardRequest::Del { key: "key".into() });
1752 assert!(matches!(resp, ShardResponse::Bool(true)));
1753
1754 let resp = test_dispatch(&mut ks, &ShardRequest::Del { key: "key".into() });
1755 assert!(matches!(resp, ShardResponse::Bool(false)));
1756 }
1757
1758 #[test]
1759 fn dispatch_exists() {
1760 let mut ks = Keyspace::new();
1761 ks.set("yes".into(), Bytes::from("here"), None);
1762
1763 let resp = test_dispatch(&mut ks, &ShardRequest::Exists { key: "yes".into() });
1764 assert!(matches!(resp, ShardResponse::Bool(true)));
1765
1766 let resp = test_dispatch(&mut ks, &ShardRequest::Exists { key: "no".into() });
1767 assert!(matches!(resp, ShardResponse::Bool(false)));
1768 }
1769
1770 #[test]
1771 fn dispatch_expire_and_ttl() {
1772 let mut ks = Keyspace::new();
1773 ks.set("key".into(), Bytes::from("val"), None);
1774
1775 let resp = test_dispatch(
1776 &mut ks,
1777 &ShardRequest::Expire {
1778 key: "key".into(),
1779 seconds: 60,
1780 },
1781 );
1782 assert!(matches!(resp, ShardResponse::Bool(true)));
1783
1784 let resp = test_dispatch(&mut ks, &ShardRequest::Ttl { key: "key".into() });
1785 match resp {
1786 ShardResponse::Ttl(TtlResult::Seconds(s)) => assert!((58..=60).contains(&s)),
1787 other => panic!("expected Ttl(Seconds), got {other:?}"),
1788 }
1789 }
1790
1791 #[test]
1792 fn dispatch_ttl_missing() {
1793 let mut ks = Keyspace::new();
1794 let resp = test_dispatch(&mut ks, &ShardRequest::Ttl { key: "gone".into() });
1795 assert!(matches!(resp, ShardResponse::Ttl(TtlResult::NotFound)));
1796 }
1797
1798 #[tokio::test]
1799 async fn shard_round_trip() {
1800 let handle = spawn_shard(
1801 16,
1802 ShardConfig::default(),
1803 None,
1804 None,
1805 #[cfg(feature = "protobuf")]
1806 None,
1807 );
1808
1809 let resp = handle
1810 .send(ShardRequest::Set {
1811 key: "hello".into(),
1812 value: Bytes::from("world"),
1813 expire: None,
1814 nx: false,
1815 xx: false,
1816 })
1817 .await
1818 .unwrap();
1819 assert!(matches!(resp, ShardResponse::Ok));
1820
1821 let resp = handle
1822 .send(ShardRequest::Get {
1823 key: "hello".into(),
1824 })
1825 .await
1826 .unwrap();
1827 match resp {
1828 ShardResponse::Value(Some(Value::String(data))) => {
1829 assert_eq!(data, Bytes::from("world"));
1830 }
1831 other => panic!("expected Value(Some(String)), got {other:?}"),
1832 }
1833 }
1834
1835 #[tokio::test]
1836 async fn expired_key_through_shard() {
1837 let handle = spawn_shard(
1838 16,
1839 ShardConfig::default(),
1840 None,
1841 None,
1842 #[cfg(feature = "protobuf")]
1843 None,
1844 );
1845
1846 handle
1847 .send(ShardRequest::Set {
1848 key: "temp".into(),
1849 value: Bytes::from("gone"),
1850 expire: Some(Duration::from_millis(10)),
1851 nx: false,
1852 xx: false,
1853 })
1854 .await
1855 .unwrap();
1856
1857 tokio::time::sleep(Duration::from_millis(30)).await;
1858
1859 let resp = handle
1860 .send(ShardRequest::Get { key: "temp".into() })
1861 .await
1862 .unwrap();
1863 assert!(matches!(resp, ShardResponse::Value(None)));
1864 }
1865
1866 #[tokio::test]
1867 async fn active_expiration_cleans_up_without_access() {
1868 let handle = spawn_shard(
1869 16,
1870 ShardConfig::default(),
1871 None,
1872 None,
1873 #[cfg(feature = "protobuf")]
1874 None,
1875 );
1876
1877 handle
1879 .send(ShardRequest::Set {
1880 key: "ephemeral".into(),
1881 value: Bytes::from("temp"),
1882 expire: Some(Duration::from_millis(10)),
1883 nx: false,
1884 xx: false,
1885 })
1886 .await
1887 .unwrap();
1888
1889 handle
1891 .send(ShardRequest::Set {
1892 key: "persistent".into(),
1893 value: Bytes::from("stays"),
1894 expire: None,
1895 nx: false,
1896 xx: false,
1897 })
1898 .await
1899 .unwrap();
1900
1901 tokio::time::sleep(Duration::from_millis(250)).await;
1904
1905 let resp = handle
1907 .send(ShardRequest::Exists {
1908 key: "ephemeral".into(),
1909 })
1910 .await
1911 .unwrap();
1912 assert!(matches!(resp, ShardResponse::Bool(false)));
1913
1914 let resp = handle
1916 .send(ShardRequest::Exists {
1917 key: "persistent".into(),
1918 })
1919 .await
1920 .unwrap();
1921 assert!(matches!(resp, ShardResponse::Bool(true)));
1922 }
1923
1924 #[tokio::test]
1925 async fn shard_with_persistence_snapshot_and_recovery() {
1926 let dir = tempfile::tempdir().unwrap();
1927 let pcfg = ShardPersistenceConfig {
1928 data_dir: dir.path().to_owned(),
1929 append_only: true,
1930 fsync_policy: FsyncPolicy::Always,
1931 #[cfg(feature = "encryption")]
1932 encryption_key: None,
1933 };
1934 let config = ShardConfig {
1935 shard_id: 0,
1936 ..ShardConfig::default()
1937 };
1938
1939 {
1941 let handle = spawn_shard(
1942 16,
1943 config.clone(),
1944 Some(pcfg.clone()),
1945 None,
1946 #[cfg(feature = "protobuf")]
1947 None,
1948 );
1949 handle
1950 .send(ShardRequest::Set {
1951 key: "a".into(),
1952 value: Bytes::from("1"),
1953 expire: None,
1954 nx: false,
1955 xx: false,
1956 })
1957 .await
1958 .unwrap();
1959 handle
1960 .send(ShardRequest::Set {
1961 key: "b".into(),
1962 value: Bytes::from("2"),
1963 expire: Some(Duration::from_secs(300)),
1964 nx: false,
1965 xx: false,
1966 })
1967 .await
1968 .unwrap();
1969 handle.send(ShardRequest::Snapshot).await.unwrap();
1970 handle
1972 .send(ShardRequest::Set {
1973 key: "c".into(),
1974 value: Bytes::from("3"),
1975 expire: None,
1976 nx: false,
1977 xx: false,
1978 })
1979 .await
1980 .unwrap();
1981 }
1983
1984 tokio::time::sleep(Duration::from_millis(50)).await;
1986
1987 {
1989 let handle = spawn_shard(
1990 16,
1991 config,
1992 Some(pcfg),
1993 None,
1994 #[cfg(feature = "protobuf")]
1995 None,
1996 );
1997 tokio::time::sleep(Duration::from_millis(50)).await;
1999
2000 let resp = handle
2001 .send(ShardRequest::Get { key: "a".into() })
2002 .await
2003 .unwrap();
2004 match resp {
2005 ShardResponse::Value(Some(Value::String(data))) => {
2006 assert_eq!(data, Bytes::from("1"));
2007 }
2008 other => panic!("expected a=1, got {other:?}"),
2009 }
2010
2011 let resp = handle
2012 .send(ShardRequest::Get { key: "b".into() })
2013 .await
2014 .unwrap();
2015 assert!(matches!(resp, ShardResponse::Value(Some(_))));
2016
2017 let resp = handle
2018 .send(ShardRequest::Get { key: "c".into() })
2019 .await
2020 .unwrap();
2021 match resp {
2022 ShardResponse::Value(Some(Value::String(data))) => {
2023 assert_eq!(data, Bytes::from("3"));
2024 }
2025 other => panic!("expected c=3, got {other:?}"),
2026 }
2027 }
2028 }
2029
2030 #[test]
2031 fn to_aof_records_for_set() {
2032 let req = ShardRequest::Set {
2033 key: "k".into(),
2034 value: Bytes::from("v"),
2035 expire: Some(Duration::from_secs(60)),
2036 nx: false,
2037 xx: false,
2038 };
2039 let resp = ShardResponse::Ok;
2040 let record = to_aof_records(&req, &resp).into_iter().next().unwrap();
2041 match record {
2042 AofRecord::Set { key, expire_ms, .. } => {
2043 assert_eq!(key, "k");
2044 assert_eq!(expire_ms, 60_000);
2045 }
2046 other => panic!("expected Set, got {other:?}"),
2047 }
2048 }
2049
2050 #[test]
2051 fn to_aof_records_skips_failed_set() {
2052 let req = ShardRequest::Set {
2053 key: "k".into(),
2054 value: Bytes::from("v"),
2055 expire: None,
2056 nx: false,
2057 xx: false,
2058 };
2059 let resp = ShardResponse::OutOfMemory;
2060 assert!(to_aof_records(&req, &resp).is_empty());
2061 }
2062
2063 #[test]
2064 fn to_aof_records_for_del() {
2065 let req = ShardRequest::Del { key: "k".into() };
2066 let resp = ShardResponse::Bool(true);
2067 let record = to_aof_records(&req, &resp).into_iter().next().unwrap();
2068 assert!(matches!(record, AofRecord::Del { .. }));
2069 }
2070
2071 #[test]
2072 fn to_aof_records_skips_failed_del() {
2073 let req = ShardRequest::Del { key: "k".into() };
2074 let resp = ShardResponse::Bool(false);
2075 assert!(to_aof_records(&req, &resp).is_empty());
2076 }
2077
2078 #[test]
2079 fn dispatch_incr_new_key() {
2080 let mut ks = Keyspace::new();
2081 let resp = test_dispatch(&mut ks, &ShardRequest::Incr { key: "c".into() });
2082 assert!(matches!(resp, ShardResponse::Integer(1)));
2083 }
2084
2085 #[test]
2086 fn dispatch_decr_existing() {
2087 let mut ks = Keyspace::new();
2088 ks.set("n".into(), Bytes::from("10"), None);
2089 let resp = test_dispatch(&mut ks, &ShardRequest::Decr { key: "n".into() });
2090 assert!(matches!(resp, ShardResponse::Integer(9)));
2091 }
2092
2093 #[test]
2094 fn dispatch_incr_non_integer() {
2095 let mut ks = Keyspace::new();
2096 ks.set("s".into(), Bytes::from("hello"), None);
2097 let resp = test_dispatch(&mut ks, &ShardRequest::Incr { key: "s".into() });
2098 assert!(matches!(resp, ShardResponse::Err(_)));
2099 }
2100
2101 #[test]
2102 fn dispatch_incrby() {
2103 let mut ks = Keyspace::new();
2104 ks.set("n".into(), Bytes::from("10"), None);
2105 let resp = test_dispatch(
2106 &mut ks,
2107 &ShardRequest::IncrBy {
2108 key: "n".into(),
2109 delta: 5,
2110 },
2111 );
2112 assert!(matches!(resp, ShardResponse::Integer(15)));
2113 }
2114
2115 #[test]
2116 fn dispatch_decrby() {
2117 let mut ks = Keyspace::new();
2118 ks.set("n".into(), Bytes::from("10"), None);
2119 let resp = test_dispatch(
2120 &mut ks,
2121 &ShardRequest::DecrBy {
2122 key: "n".into(),
2123 delta: 3,
2124 },
2125 );
2126 assert!(matches!(resp, ShardResponse::Integer(7)));
2127 }
2128
2129 #[test]
2130 fn dispatch_incrby_new_key() {
2131 let mut ks = Keyspace::new();
2132 let resp = test_dispatch(
2133 &mut ks,
2134 &ShardRequest::IncrBy {
2135 key: "new".into(),
2136 delta: 42,
2137 },
2138 );
2139 assert!(matches!(resp, ShardResponse::Integer(42)));
2140 }
2141
2142 #[test]
2143 fn dispatch_incrbyfloat() {
2144 let mut ks = Keyspace::new();
2145 ks.set("n".into(), Bytes::from("10.5"), None);
2146 let resp = test_dispatch(
2147 &mut ks,
2148 &ShardRequest::IncrByFloat {
2149 key: "n".into(),
2150 delta: 2.3,
2151 },
2152 );
2153 match resp {
2154 ShardResponse::BulkString(val) => {
2155 let f: f64 = val.parse().unwrap();
2156 assert!((f - 12.8).abs() < 0.001);
2157 }
2158 other => panic!("expected BulkString, got {other:?}"),
2159 }
2160 }
2161
2162 #[test]
2163 fn dispatch_append() {
2164 let mut ks = Keyspace::new();
2165 ks.set("k".into(), Bytes::from("hello"), None);
2166 let resp = test_dispatch(
2167 &mut ks,
2168 &ShardRequest::Append {
2169 key: "k".into(),
2170 value: Bytes::from(" world"),
2171 },
2172 );
2173 assert!(matches!(resp, ShardResponse::Len(11)));
2174 }
2175
2176 #[test]
2177 fn dispatch_strlen() {
2178 let mut ks = Keyspace::new();
2179 ks.set("k".into(), Bytes::from("hello"), None);
2180 let resp = test_dispatch(&mut ks, &ShardRequest::Strlen { key: "k".into() });
2181 assert!(matches!(resp, ShardResponse::Len(5)));
2182 }
2183
2184 #[test]
2185 fn dispatch_strlen_missing() {
2186 let mut ks = Keyspace::new();
2187 let resp = test_dispatch(&mut ks, &ShardRequest::Strlen { key: "nope".into() });
2188 assert!(matches!(resp, ShardResponse::Len(0)));
2189 }
2190
2191 #[test]
2192 fn to_aof_records_for_append() {
2193 let req = ShardRequest::Append {
2194 key: "k".into(),
2195 value: Bytes::from("data"),
2196 };
2197 let resp = ShardResponse::Len(10);
2198 let record = to_aof_records(&req, &resp).into_iter().next().unwrap();
2199 match record {
2200 AofRecord::Append { key, value } => {
2201 assert_eq!(key, "k");
2202 assert_eq!(value, Bytes::from("data"));
2203 }
2204 other => panic!("expected Append, got {other:?}"),
2205 }
2206 }
2207
2208 #[test]
2209 fn dispatch_incrbyfloat_new_key() {
2210 let mut ks = Keyspace::new();
2211 let resp = test_dispatch(
2212 &mut ks,
2213 &ShardRequest::IncrByFloat {
2214 key: "new".into(),
2215 delta: 2.72,
2216 },
2217 );
2218 match resp {
2219 ShardResponse::BulkString(val) => {
2220 let f: f64 = val.parse().unwrap();
2221 assert!((f - 2.72).abs() < 0.001);
2222 }
2223 other => panic!("expected BulkString, got {other:?}"),
2224 }
2225 }
2226
2227 #[test]
2228 fn to_aof_records_for_incr() {
2229 let req = ShardRequest::Incr { key: "c".into() };
2230 let resp = ShardResponse::Integer(1);
2231 let record = to_aof_records(&req, &resp).into_iter().next().unwrap();
2232 assert!(matches!(record, AofRecord::Incr { .. }));
2233 }
2234
2235 #[test]
2236 fn to_aof_records_for_decr() {
2237 let req = ShardRequest::Decr { key: "c".into() };
2238 let resp = ShardResponse::Integer(-1);
2239 let record = to_aof_records(&req, &resp).into_iter().next().unwrap();
2240 assert!(matches!(record, AofRecord::Decr { .. }));
2241 }
2242
2243 #[test]
2244 fn to_aof_records_for_incrby() {
2245 let req = ShardRequest::IncrBy {
2246 key: "c".into(),
2247 delta: 5,
2248 };
2249 let resp = ShardResponse::Integer(15);
2250 let record = to_aof_records(&req, &resp).into_iter().next().unwrap();
2251 match record {
2252 AofRecord::IncrBy { key, delta } => {
2253 assert_eq!(key, "c");
2254 assert_eq!(delta, 5);
2255 }
2256 other => panic!("expected IncrBy, got {other:?}"),
2257 }
2258 }
2259
2260 #[test]
2261 fn to_aof_records_for_decrby() {
2262 let req = ShardRequest::DecrBy {
2263 key: "c".into(),
2264 delta: 3,
2265 };
2266 let resp = ShardResponse::Integer(7);
2267 let record = to_aof_records(&req, &resp).into_iter().next().unwrap();
2268 match record {
2269 AofRecord::DecrBy { key, delta } => {
2270 assert_eq!(key, "c");
2271 assert_eq!(delta, 3);
2272 }
2273 other => panic!("expected DecrBy, got {other:?}"),
2274 }
2275 }
2276
2277 #[test]
2278 fn dispatch_persist_removes_ttl() {
2279 let mut ks = Keyspace::new();
2280 ks.set(
2281 "key".into(),
2282 Bytes::from("val"),
2283 Some(Duration::from_secs(60)),
2284 );
2285
2286 let resp = test_dispatch(&mut ks, &ShardRequest::Persist { key: "key".into() });
2287 assert!(matches!(resp, ShardResponse::Bool(true)));
2288
2289 let resp = test_dispatch(&mut ks, &ShardRequest::Ttl { key: "key".into() });
2290 assert!(matches!(resp, ShardResponse::Ttl(TtlResult::NoExpiry)));
2291 }
2292
2293 #[test]
2294 fn dispatch_persist_missing_key() {
2295 let mut ks = Keyspace::new();
2296 let resp = test_dispatch(&mut ks, &ShardRequest::Persist { key: "nope".into() });
2297 assert!(matches!(resp, ShardResponse::Bool(false)));
2298 }
2299
2300 #[test]
2301 fn dispatch_pttl() {
2302 let mut ks = Keyspace::new();
2303 ks.set(
2304 "key".into(),
2305 Bytes::from("val"),
2306 Some(Duration::from_secs(60)),
2307 );
2308
2309 let resp = test_dispatch(&mut ks, &ShardRequest::Pttl { key: "key".into() });
2310 match resp {
2311 ShardResponse::Ttl(TtlResult::Milliseconds(ms)) => {
2312 assert!(ms > 59_000 && ms <= 60_000);
2313 }
2314 other => panic!("expected Ttl(Milliseconds), got {other:?}"),
2315 }
2316 }
2317
2318 #[test]
2319 fn dispatch_pttl_missing() {
2320 let mut ks = Keyspace::new();
2321 let resp = test_dispatch(&mut ks, &ShardRequest::Pttl { key: "nope".into() });
2322 assert!(matches!(resp, ShardResponse::Ttl(TtlResult::NotFound)));
2323 }
2324
2325 #[test]
2326 fn dispatch_pexpire() {
2327 let mut ks = Keyspace::new();
2328 ks.set("key".into(), Bytes::from("val"), None);
2329
2330 let resp = test_dispatch(
2331 &mut ks,
2332 &ShardRequest::Pexpire {
2333 key: "key".into(),
2334 milliseconds: 5000,
2335 },
2336 );
2337 assert!(matches!(resp, ShardResponse::Bool(true)));
2338
2339 let resp = test_dispatch(&mut ks, &ShardRequest::Pttl { key: "key".into() });
2340 match resp {
2341 ShardResponse::Ttl(TtlResult::Milliseconds(ms)) => {
2342 assert!(ms > 4000 && ms <= 5000);
2343 }
2344 other => panic!("expected Ttl(Milliseconds), got {other:?}"),
2345 }
2346 }
2347
2348 #[test]
2349 fn to_aof_records_for_persist() {
2350 let req = ShardRequest::Persist { key: "k".into() };
2351 let resp = ShardResponse::Bool(true);
2352 let record = to_aof_records(&req, &resp).into_iter().next().unwrap();
2353 assert!(matches!(record, AofRecord::Persist { .. }));
2354 }
2355
2356 #[test]
2357 fn to_aof_records_skips_failed_persist() {
2358 let req = ShardRequest::Persist { key: "k".into() };
2359 let resp = ShardResponse::Bool(false);
2360 assert!(to_aof_records(&req, &resp).is_empty());
2361 }
2362
2363 #[test]
2364 fn to_aof_records_for_pexpire() {
2365 let req = ShardRequest::Pexpire {
2366 key: "k".into(),
2367 milliseconds: 5000,
2368 };
2369 let resp = ShardResponse::Bool(true);
2370 let record = to_aof_records(&req, &resp).into_iter().next().unwrap();
2371 match record {
2372 AofRecord::Pexpire { key, milliseconds } => {
2373 assert_eq!(key, "k");
2374 assert_eq!(milliseconds, 5000);
2375 }
2376 other => panic!("expected Pexpire, got {other:?}"),
2377 }
2378 }
2379
2380 #[test]
2381 fn to_aof_records_skips_failed_pexpire() {
2382 let req = ShardRequest::Pexpire {
2383 key: "k".into(),
2384 milliseconds: 5000,
2385 };
2386 let resp = ShardResponse::Bool(false);
2387 assert!(to_aof_records(&req, &resp).is_empty());
2388 }
2389
2390 #[test]
2391 fn dispatch_set_nx_when_key_missing() {
2392 let mut ks = Keyspace::new();
2393 let resp = test_dispatch(
2394 &mut ks,
2395 &ShardRequest::Set {
2396 key: "k".into(),
2397 value: Bytes::from("v"),
2398 expire: None,
2399 nx: true,
2400 xx: false,
2401 },
2402 );
2403 assert!(matches!(resp, ShardResponse::Ok));
2404 assert!(ks.exists("k"));
2405 }
2406
2407 #[test]
2408 fn dispatch_set_nx_when_key_exists() {
2409 let mut ks = Keyspace::new();
2410 ks.set("k".into(), Bytes::from("old"), None);
2411
2412 let resp = test_dispatch(
2413 &mut ks,
2414 &ShardRequest::Set {
2415 key: "k".into(),
2416 value: Bytes::from("new"),
2417 expire: None,
2418 nx: true,
2419 xx: false,
2420 },
2421 );
2422 assert!(matches!(resp, ShardResponse::Value(None)));
2424 match ks.get("k").unwrap() {
2426 Some(Value::String(data)) => assert_eq!(data, Bytes::from("old")),
2427 other => panic!("expected old value, got {other:?}"),
2428 }
2429 }
2430
2431 #[test]
2432 fn dispatch_set_xx_when_key_exists() {
2433 let mut ks = Keyspace::new();
2434 ks.set("k".into(), Bytes::from("old"), None);
2435
2436 let resp = test_dispatch(
2437 &mut ks,
2438 &ShardRequest::Set {
2439 key: "k".into(),
2440 value: Bytes::from("new"),
2441 expire: None,
2442 nx: false,
2443 xx: true,
2444 },
2445 );
2446 assert!(matches!(resp, ShardResponse::Ok));
2447 match ks.get("k").unwrap() {
2448 Some(Value::String(data)) => assert_eq!(data, Bytes::from("new")),
2449 other => panic!("expected new value, got {other:?}"),
2450 }
2451 }
2452
2453 #[test]
2454 fn dispatch_set_xx_when_key_missing() {
2455 let mut ks = Keyspace::new();
2456 let resp = test_dispatch(
2457 &mut ks,
2458 &ShardRequest::Set {
2459 key: "k".into(),
2460 value: Bytes::from("v"),
2461 expire: None,
2462 nx: false,
2463 xx: true,
2464 },
2465 );
2466 assert!(matches!(resp, ShardResponse::Value(None)));
2468 assert!(!ks.exists("k"));
2469 }
2470
2471 #[test]
2472 fn to_aof_records_skips_nx_blocked_set() {
2473 let req = ShardRequest::Set {
2474 key: "k".into(),
2475 value: Bytes::from("v"),
2476 expire: None,
2477 nx: true,
2478 xx: false,
2479 };
2480 let resp = ShardResponse::Value(None);
2482 assert!(to_aof_records(&req, &resp).is_empty());
2483 }
2484
2485 #[test]
2486 fn dispatch_flushdb_clears_all_keys() {
2487 let mut ks = Keyspace::new();
2488 ks.set("a".into(), Bytes::from("1"), None);
2489 ks.set("b".into(), Bytes::from("2"), None);
2490
2491 assert_eq!(ks.len(), 2);
2492
2493 let resp = test_dispatch(&mut ks, &ShardRequest::FlushDb);
2494 assert!(matches!(resp, ShardResponse::Ok));
2495 assert_eq!(ks.len(), 0);
2496 }
2497
2498 #[test]
2499 fn dispatch_scan_returns_keys() {
2500 let mut ks = Keyspace::new();
2501 ks.set("user:1".into(), Bytes::from("a"), None);
2502 ks.set("user:2".into(), Bytes::from("b"), None);
2503 ks.set("item:1".into(), Bytes::from("c"), None);
2504
2505 let resp = test_dispatch(
2506 &mut ks,
2507 &ShardRequest::Scan {
2508 cursor: 0,
2509 count: 10,
2510 pattern: None,
2511 },
2512 );
2513
2514 match resp {
2515 ShardResponse::Scan { cursor, keys } => {
2516 assert_eq!(cursor, 0); assert_eq!(keys.len(), 3);
2518 }
2519 _ => panic!("expected Scan response"),
2520 }
2521 }
2522
2523 #[test]
2524 fn dispatch_scan_with_pattern() {
2525 let mut ks = Keyspace::new();
2526 ks.set("user:1".into(), Bytes::from("a"), None);
2527 ks.set("user:2".into(), Bytes::from("b"), None);
2528 ks.set("item:1".into(), Bytes::from("c"), None);
2529
2530 let resp = test_dispatch(
2531 &mut ks,
2532 &ShardRequest::Scan {
2533 cursor: 0,
2534 count: 10,
2535 pattern: Some("user:*".into()),
2536 },
2537 );
2538
2539 match resp {
2540 ShardResponse::Scan { cursor, keys } => {
2541 assert_eq!(cursor, 0);
2542 assert_eq!(keys.len(), 2);
2543 for k in &keys {
2544 assert!(k.starts_with("user:"));
2545 }
2546 }
2547 _ => panic!("expected Scan response"),
2548 }
2549 }
2550
2551 #[test]
2552 fn to_aof_records_for_hset() {
2553 let req = ShardRequest::HSet {
2554 key: "h".into(),
2555 fields: vec![("f1".into(), Bytes::from("v1"))],
2556 };
2557 let resp = ShardResponse::Len(1);
2558 let record = to_aof_records(&req, &resp).into_iter().next().unwrap();
2559 match record {
2560 AofRecord::HSet { key, fields } => {
2561 assert_eq!(key, "h");
2562 assert_eq!(fields.len(), 1);
2563 }
2564 _ => panic!("expected HSet record"),
2565 }
2566 }
2567
2568 #[test]
2569 fn to_aof_records_for_hdel() {
2570 let req = ShardRequest::HDel {
2571 key: "h".into(),
2572 fields: vec!["f1".into(), "f2".into()],
2573 };
2574 let resp = ShardResponse::HDelLen {
2575 count: 2,
2576 removed: vec!["f1".into(), "f2".into()],
2577 };
2578 let record = to_aof_records(&req, &resp).into_iter().next().unwrap();
2579 match record {
2580 AofRecord::HDel { key, fields } => {
2581 assert_eq!(key, "h");
2582 assert_eq!(fields.len(), 2);
2583 }
2584 _ => panic!("expected HDel record"),
2585 }
2586 }
2587
2588 #[test]
2589 fn to_aof_records_skips_hdel_when_none_removed() {
2590 let req = ShardRequest::HDel {
2591 key: "h".into(),
2592 fields: vec!["f1".into()],
2593 };
2594 let resp = ShardResponse::HDelLen {
2595 count: 0,
2596 removed: vec![],
2597 };
2598 assert!(to_aof_records(&req, &resp).is_empty());
2599 }
2600
2601 #[test]
2602 fn to_aof_records_for_hincrby() {
2603 let req = ShardRequest::HIncrBy {
2604 key: "h".into(),
2605 field: "counter".into(),
2606 delta: 5,
2607 };
2608 let resp = ShardResponse::Integer(10);
2609 let record = to_aof_records(&req, &resp).into_iter().next().unwrap();
2610 match record {
2611 AofRecord::HIncrBy { key, field, delta } => {
2612 assert_eq!(key, "h");
2613 assert_eq!(field, "counter");
2614 assert_eq!(delta, 5);
2615 }
2616 _ => panic!("expected HIncrBy record"),
2617 }
2618 }
2619
2620 #[test]
2621 fn to_aof_records_for_sadd() {
2622 let req = ShardRequest::SAdd {
2623 key: "s".into(),
2624 members: vec!["m1".into(), "m2".into()],
2625 };
2626 let resp = ShardResponse::Len(2);
2627 let record = to_aof_records(&req, &resp).into_iter().next().unwrap();
2628 match record {
2629 AofRecord::SAdd { key, members } => {
2630 assert_eq!(key, "s");
2631 assert_eq!(members.len(), 2);
2632 }
2633 _ => panic!("expected SAdd record"),
2634 }
2635 }
2636
2637 #[test]
2638 fn to_aof_records_skips_sadd_when_none_added() {
2639 let req = ShardRequest::SAdd {
2640 key: "s".into(),
2641 members: vec!["m1".into()],
2642 };
2643 let resp = ShardResponse::Len(0);
2644 assert!(to_aof_records(&req, &resp).is_empty());
2645 }
2646
2647 #[test]
2648 fn to_aof_records_for_srem() {
2649 let req = ShardRequest::SRem {
2650 key: "s".into(),
2651 members: vec!["m1".into()],
2652 };
2653 let resp = ShardResponse::Len(1);
2654 let record = to_aof_records(&req, &resp).into_iter().next().unwrap();
2655 match record {
2656 AofRecord::SRem { key, members } => {
2657 assert_eq!(key, "s");
2658 assert_eq!(members.len(), 1);
2659 }
2660 _ => panic!("expected SRem record"),
2661 }
2662 }
2663
2664 #[test]
2665 fn to_aof_records_skips_srem_when_none_removed() {
2666 let req = ShardRequest::SRem {
2667 key: "s".into(),
2668 members: vec!["m1".into()],
2669 };
2670 let resp = ShardResponse::Len(0);
2671 assert!(to_aof_records(&req, &resp).is_empty());
2672 }
2673
2674 #[test]
2675 fn dispatch_keys() {
2676 let mut ks = Keyspace::new();
2677 ks.set("user:1".into(), Bytes::from("a"), None);
2678 ks.set("user:2".into(), Bytes::from("b"), None);
2679 ks.set("item:1".into(), Bytes::from("c"), None);
2680 let resp = test_dispatch(
2681 &mut ks,
2682 &ShardRequest::Keys {
2683 pattern: "user:*".into(),
2684 },
2685 );
2686 match resp {
2687 ShardResponse::StringArray(mut keys) => {
2688 keys.sort();
2689 assert_eq!(keys, vec!["user:1", "user:2"]);
2690 }
2691 other => panic!("expected StringArray, got {other:?}"),
2692 }
2693 }
2694
2695 #[test]
2696 fn dispatch_rename() {
2697 let mut ks = Keyspace::new();
2698 ks.set("old".into(), Bytes::from("value"), None);
2699 let resp = test_dispatch(
2700 &mut ks,
2701 &ShardRequest::Rename {
2702 key: "old".into(),
2703 newkey: "new".into(),
2704 },
2705 );
2706 assert!(matches!(resp, ShardResponse::Ok));
2707 assert!(!ks.exists("old"));
2708 assert!(ks.exists("new"));
2709 }
2710
2711 #[test]
2712 fn dispatch_rename_missing_key() {
2713 let mut ks = Keyspace::new();
2714 let resp = test_dispatch(
2715 &mut ks,
2716 &ShardRequest::Rename {
2717 key: "missing".into(),
2718 newkey: "new".into(),
2719 },
2720 );
2721 assert!(matches!(resp, ShardResponse::Err(_)));
2722 }
2723
2724 #[test]
2725 fn to_aof_records_for_rename() {
2726 let req = ShardRequest::Rename {
2727 key: "old".into(),
2728 newkey: "new".into(),
2729 };
2730 let resp = ShardResponse::Ok;
2731 let record = to_aof_records(&req, &resp).into_iter().next().unwrap();
2732 match record {
2733 AofRecord::Rename { key, newkey } => {
2734 assert_eq!(key, "old");
2735 assert_eq!(newkey, "new");
2736 }
2737 other => panic!("expected Rename, got {other:?}"),
2738 }
2739 }
2740
2741 #[test]
2742 #[cfg(feature = "vector")]
2743 fn to_aof_records_for_vadd_batch() {
2744 let req = ShardRequest::VAddBatch {
2745 key: "vecs".into(),
2746 entries: vec![
2747 ("a".into(), vec![1.0, 2.0]),
2748 ("b".into(), vec![3.0, 4.0]),
2749 ("c".into(), vec![5.0, 6.0]),
2750 ],
2751 dim: 2,
2752 metric: 0,
2753 quantization: 0,
2754 connectivity: 16,
2755 expansion_add: 64,
2756 };
2757 let resp = ShardResponse::VAddBatchResult {
2758 added_count: 3,
2759 applied: vec![
2760 ("a".into(), vec![1.0, 2.0]),
2761 ("b".into(), vec![3.0, 4.0]),
2762 ("c".into(), vec![5.0, 6.0]),
2763 ],
2764 };
2765 let records = to_aof_records(&req, &resp);
2766 assert_eq!(records.len(), 3);
2767 for (i, record) in records.iter().enumerate() {
2768 match record {
2769 AofRecord::VAdd {
2770 key,
2771 element,
2772 metric,
2773 quantization,
2774 connectivity,
2775 expansion_add,
2776 ..
2777 } => {
2778 assert_eq!(key, "vecs");
2779 assert_eq!(*metric, 0);
2780 assert_eq!(*quantization, 0);
2781 assert_eq!(*connectivity, 16);
2782 assert_eq!(*expansion_add, 64);
2783 match i {
2784 0 => assert_eq!(element, "a"),
2785 1 => assert_eq!(element, "b"),
2786 2 => assert_eq!(element, "c"),
2787 _ => unreachable!(),
2788 }
2789 }
2790 other => panic!("expected VAdd, got {other:?}"),
2791 }
2792 }
2793 }
2794}