1use ember_protocol::command::{BitOpKind, BitRange, BitRangeUnit};
8
9use super::*;
10
11impl Keyspace {
12 pub fn getbit(&mut self, key: &str, offset: u64) -> Result<u8, WrongType> {
18 if self.remove_if_expired(key) {
19 return Ok(0);
20 }
21 match self.entries.get(key) {
22 None => Ok(0),
23 Some(e) => match &e.value {
24 Value::String(data) => {
25 let byte_idx = (offset / 8) as usize;
26 if byte_idx >= data.len() {
27 return Ok(0);
28 }
29 let bit_pos = 7 - (offset % 8) as u32;
30 Ok((data[byte_idx] >> bit_pos) & 1)
31 }
32 _ => Err(WrongType),
33 },
34 }
35 }
36
37 pub fn setbit(&mut self, key: &str, offset: u64, value: u8) -> Result<u8, WriteError> {
44 self.remove_if_expired(key);
45
46 let byte_idx = (offset / 8) as usize;
47 let bit_pos = 7 - (offset % 8) as u32;
48 let mask = 1u8 << bit_pos;
49
50 let (existing, expire) = match self.entries.get(key) {
51 Some(entry) => match &entry.value {
52 Value::String(data) => {
53 let expire = time::remaining_ms(entry.expires_at_ms).map(Duration::from_millis);
54 (data.clone(), expire)
55 }
56 _ => return Err(WriteError::WrongType),
57 },
58 None => (Bytes::new(), None),
59 };
60
61 let old_bit = if byte_idx < existing.len() {
62 (existing[byte_idx] >> bit_pos) & 1
63 } else {
64 0
65 };
66
67 let new_len = existing.len().max(byte_idx + 1);
69 let mut buf = existing.to_vec();
70 buf.resize(new_len, 0);
71
72 if value == 1 {
73 buf[byte_idx] |= mask;
74 } else {
75 buf[byte_idx] &= !mask;
76 }
77
78 match self.set(key.to_owned(), Bytes::from(buf), expire, false, false) {
79 SetResult::Ok | SetResult::Blocked => Ok(old_bit),
80 SetResult::OutOfMemory => Err(WriteError::OutOfMemory),
81 }
82 }
83
84 pub fn bitcount(&mut self, key: &str, range: Option<BitRange>) -> Result<u64, WrongType> {
92 if self.remove_if_expired(key) {
93 return Ok(0);
94 }
95 let data = match self.entries.get(key) {
96 None => return Ok(0),
97 Some(e) => match &e.value {
98 Value::String(b) => b.clone(),
99 _ => return Err(WrongType),
100 },
101 };
102
103 match range {
104 None => Ok(data.iter().map(|b| b.count_ones() as u64).sum()),
105 Some(r) if r.unit == BitRangeUnit::Bit => {
106 let len_bits = data.len() as i64 * 8;
108 let start = normalize_bit_index(r.start, len_bits).min(len_bits);
109 let end = normalize_bit_index(r.end, len_bits).min(len_bits - 1);
110 if start > end {
111 return Ok(0);
112 }
113 let mut count = 0u64;
114 for bit_idx in start..=end {
115 let byte_idx = (bit_idx / 8) as usize;
116 let bit_pos = 7 - (bit_idx % 8) as u32;
117 count += ((data[byte_idx] >> bit_pos) & 1) as u64;
118 }
119 Ok(count)
120 }
121 Some(r) => {
122 let slice = bit_range_slice(&data, r);
123 Ok(slice.iter().map(|b| b.count_ones() as u64).sum())
124 }
125 }
126 }
127
128 pub fn bitpos(
137 &mut self,
138 key: &str,
139 bit: u8,
140 range: Option<BitRange>,
141 ) -> Result<i64, WrongType> {
142 if self.remove_if_expired(key) {
143 return Ok(if bit == 0 { 0 } else { -1 });
145 }
146 let data = match self.entries.get(key) {
147 None => {
148 return Ok(if bit == 0 { 0 } else { -1 });
149 }
150 Some(e) => match &e.value {
151 Value::String(b) => b.clone(),
152 _ => return Err(WrongType),
153 },
154 };
155
156 let has_explicit_end = range.map(|r| r.end != -1).unwrap_or(false);
158
159 let (slice, bit_offset) = match range {
160 None => (&data[..], 0i64),
161 Some(r) if r.unit == BitRangeUnit::Bit => {
162 let len_bits = data.len() as i64 * 8;
164 let start = normalize_bit_index(r.start, len_bits).min(len_bits);
165 let end = normalize_bit_index(r.end, len_bits).min(len_bits - 1);
166 if start > end {
167 return Ok(-1);
168 }
169 for bit_idx in start..=end {
171 let byte_idx = (bit_idx / 8) as usize;
172 let bit_pos = 7 - (bit_idx % 8) as u32;
173 let found = (data[byte_idx] >> bit_pos) & 1;
174 if found == bit {
175 return Ok(bit_idx);
176 }
177 }
178 return Ok(-1);
179 }
180 Some(r) => {
181 let (s, e) = resolve_byte_range(r.start, r.end, data.len());
183 if s >= data.len() {
184 return Ok(-1);
185 }
186 let end = e.min(data.len() - 1);
187 (&data[s..=end], (s as i64) * 8)
188 }
189 };
190
191 for (i, &byte) in slice.iter().enumerate() {
193 let b = if bit == 1 { byte } else { !byte };
194 if b != 0 {
195 let bit_in_byte = b.leading_zeros() as i64;
196 return Ok(bit_offset + (i as i64) * 8 + bit_in_byte);
197 }
198 }
199
200 if bit == 0 && !has_explicit_end {
203 Ok((data.len() as i64) * 8)
204 } else {
205 Ok(-1)
206 }
207 }
208
209 pub fn bitop(
215 &mut self,
216 op: BitOpKind,
217 dest: String,
218 keys: &[String],
219 ) -> Result<usize, WriteError> {
220 let mut sources: Vec<Bytes> = Vec::with_capacity(keys.len());
222 for key in keys {
223 self.remove_if_expired(key);
224 match self.entries.get(key.as_str()) {
225 None => sources.push(Bytes::new()),
226 Some(e) => match &e.value {
227 Value::String(b) => sources.push(b.clone()),
228 _ => return Err(WriteError::WrongType),
229 },
230 }
231 }
232
233 let result_len = sources.iter().map(|s| s.len()).max().unwrap_or(0);
234 let mut result = vec![0u8; result_len];
235
236 match op {
237 BitOpKind::Not => {
238 let src = sources.first().map(|b| b.as_ref()).unwrap_or(&[]);
240 for (i, b) in result.iter_mut().enumerate() {
241 *b = if i < src.len() { !src[i] } else { 0xFF };
242 }
243 }
244 BitOpKind::And => {
245 if let Some(first) = sources.first() {
247 for (i, b) in result.iter_mut().enumerate() {
248 *b = if i < first.len() { first[i] } else { 0 };
249 }
250 }
251 for src in sources.iter().skip(1) {
252 for (i, b) in result.iter_mut().enumerate() {
253 let s = if i < src.len() { src[i] } else { 0 };
254 *b &= s;
255 }
256 }
257 }
258 BitOpKind::Or => {
259 for src in &sources {
260 for (i, b) in result.iter_mut().enumerate() {
261 if i < src.len() {
262 *b |= src[i];
263 }
264 }
265 }
266 }
267 BitOpKind::Xor => {
268 for src in &sources {
269 for (i, b) in result.iter_mut().enumerate() {
270 if i < src.len() {
271 *b ^= src[i];
272 }
273 }
274 }
275 }
276 }
277
278 match self.set(dest, Bytes::from(result), None, false, false) {
281 SetResult::Ok | SetResult::Blocked => Ok(result_len),
282 SetResult::OutOfMemory => Err(WriteError::OutOfMemory),
283 }
284 }
285}
286
287fn resolve_byte_range(start: i64, end: i64, len: usize) -> (usize, usize) {
293 let len = len as i64;
294 let s = if start < 0 {
295 (len + start).max(0)
296 } else {
297 start
298 } as usize;
299 let e = if end < 0 { (len + end).max(0) } else { end } as usize;
300 (s, e)
301}
302
303fn normalize_bit_index(idx: i64, len_bits: i64) -> i64 {
305 if idx < 0 {
306 (len_bits + idx).max(0)
307 } else {
308 idx
309 }
310}
311
312fn bit_range_slice(data: &[u8], range: BitRange) -> &[u8] {
317 match range.unit {
318 BitRangeUnit::Byte => {
319 let (s, e) = resolve_byte_range(range.start, range.end, data.len());
320 if s >= data.len() {
321 return &[];
322 }
323 let end = e.min(data.len() - 1);
324 if s > end {
325 &[]
326 } else {
327 &data[s..=end]
328 }
329 }
330 BitRangeUnit::Bit => {
331 let len_bits = data.len() as i64 * 8;
333 let start_bit = normalize_bit_index(range.start, len_bits).min(len_bits);
334 let end_bit = normalize_bit_index(range.end, len_bits).min(len_bits - 1);
335 if start_bit > end_bit || data.is_empty() {
336 return &[];
337 }
338 let start_byte = (start_bit / 8) as usize;
341 let end_byte = (end_bit / 8) as usize;
342 &data[start_byte..=end_byte.min(data.len() - 1)]
343 }
344 }
345}
346
347#[cfg(test)]
348mod tests {
349 use super::*;
350
351 #[test]
354 fn getbit_missing_key_returns_zero() {
355 let mut ks = Keyspace::new();
356 assert_eq!(ks.getbit("nope", 0).unwrap(), 0);
357 assert_eq!(ks.getbit("nope", 100).unwrap(), 0);
358 }
359
360 #[test]
361 fn getbit_reads_msb_first() {
362 let mut ks = Keyspace::new();
363 ks.set("k".into(), Bytes::from(vec![0xFF]), None, false, false);
365 for offset in 0..8 {
366 assert_eq!(ks.getbit("k", offset).unwrap(), 1, "offset {offset}");
367 }
368 ks.set("k".into(), Bytes::from(vec![0x00]), None, false, false);
370 for offset in 0..8 {
371 assert_eq!(ks.getbit("k", offset).unwrap(), 0, "offset {offset}");
372 }
373 }
374
375 #[test]
376 fn getbit_big_endian_ordering() {
377 let mut ks = Keyspace::new();
378 ks.set("k".into(), Bytes::from(vec![0x80]), None, false, false);
380 assert_eq!(ks.getbit("k", 0).unwrap(), 1);
381 assert_eq!(ks.getbit("k", 1).unwrap(), 0);
382 ks.set("k".into(), Bytes::from(vec![0x01]), None, false, false);
384 assert_eq!(ks.getbit("k", 7).unwrap(), 1);
385 assert_eq!(ks.getbit("k", 0).unwrap(), 0);
386 }
387
388 #[test]
389 fn getbit_beyond_string_returns_zero() {
390 let mut ks = Keyspace::new();
391 ks.set("k".into(), Bytes::from(vec![0xFF]), None, false, false);
392 assert_eq!(ks.getbit("k", 8).unwrap(), 0);
394 }
395
396 #[test]
397 fn getbit_wrong_type() {
398 let mut ks = Keyspace::new();
399 ks.lpush("list", &[Bytes::from("a")]).unwrap();
400 assert!(ks.getbit("list", 0).is_err());
401 }
402
403 #[test]
406 fn setbit_returns_old_bit() {
407 let mut ks = Keyspace::new();
408 assert_eq!(ks.setbit("k", 7, 1).unwrap(), 0);
410 assert_eq!(ks.setbit("k", 7, 1).unwrap(), 1);
412 assert_eq!(ks.setbit("k", 7, 0).unwrap(), 1);
414 assert_eq!(ks.setbit("k", 7, 0).unwrap(), 0);
416 }
417
418 #[test]
419 fn setbit_roundtrip_with_getbit() {
420 let mut ks = Keyspace::new();
421 ks.setbit("k", 10, 1).unwrap();
422 assert_eq!(ks.getbit("k", 10).unwrap(), 1);
423 assert_eq!(ks.getbit("k", 0).unwrap(), 0);
424 }
425
426 #[test]
427 fn setbit_extends_string() {
428 let mut ks = Keyspace::new();
429 ks.setbit("k", 15, 1).unwrap();
431 let val = match ks.get("k").unwrap() {
432 Some(Value::String(b)) => b,
433 other => panic!("expected String, got {other:?}"),
434 };
435 assert_eq!(val.len(), 2);
436 assert_eq!(ks.getbit("k", 15).unwrap(), 1);
437 }
438
439 #[test]
440 fn setbit_preserves_ttl() {
441 let mut ks = Keyspace::new();
442 ks.set(
443 "k".into(),
444 Bytes::from(vec![0u8]),
445 Some(Duration::from_secs(60)),
446 false,
447 false,
448 );
449 ks.setbit("k", 0, 1).unwrap();
450 assert!(matches!(ks.ttl("k"), TtlResult::Seconds(_)));
451 }
452
453 #[test]
454 fn setbit_wrong_type() {
455 let mut ks = Keyspace::new();
456 ks.lpush("list", &[Bytes::from("a")]).unwrap();
457 assert!(ks.setbit("list", 0, 1).is_err());
458 }
459
460 #[test]
463 fn bitcount_missing_key_returns_zero() {
464 let mut ks = Keyspace::new();
465 assert_eq!(ks.bitcount("nope", None).unwrap(), 0);
466 }
467
468 #[test]
469 fn bitcount_full_string() {
470 let mut ks = Keyspace::new();
471 ks.set(
473 "k".into(),
474 Bytes::from(vec![0xFF, 0x0F]),
475 None,
476 false,
477 false,
478 );
479 assert_eq!(ks.bitcount("k", None).unwrap(), 12);
480 }
481
482 #[test]
483 fn bitcount_byte_range() {
484 let mut ks = Keyspace::new();
485 ks.set(
486 "k".into(),
487 Bytes::from(vec![0xFF, 0x00, 0xFF]),
488 None,
489 false,
490 false,
491 );
492 assert_eq!(
494 ks.bitcount(
495 "k",
496 Some(BitRange {
497 start: 0,
498 end: 0,
499 unit: BitRangeUnit::Byte
500 })
501 )
502 .unwrap(),
503 8
504 );
505 assert_eq!(
507 ks.bitcount(
508 "k",
509 Some(BitRange {
510 start: 0,
511 end: 1,
512 unit: BitRangeUnit::Byte
513 })
514 )
515 .unwrap(),
516 8
517 );
518 }
519
520 #[test]
521 fn bitcount_bit_range() {
522 let mut ks = Keyspace::new();
523 ks.set("k".into(), Bytes::from(vec![0xFF]), None, false, false);
525 assert_eq!(
526 ks.bitcount(
527 "k",
528 Some(BitRange {
529 start: 0,
530 end: 7,
531 unit: BitRangeUnit::Bit
532 })
533 )
534 .unwrap(),
535 8
536 );
537 assert_eq!(
538 ks.bitcount(
539 "k",
540 Some(BitRange {
541 start: 0,
542 end: 3,
543 unit: BitRangeUnit::Bit
544 })
545 )
546 .unwrap(),
547 4
548 );
549 }
550
551 #[test]
552 fn bitcount_wrong_type() {
553 let mut ks = Keyspace::new();
554 ks.lpush("list", &[Bytes::from("a")]).unwrap();
555 assert!(ks.bitcount("list", None).is_err());
556 }
557
558 #[test]
561 fn bitpos_missing_key_bit1_returns_minus_one() {
562 let mut ks = Keyspace::new();
563 assert_eq!(ks.bitpos("nope", 1, None).unwrap(), -1);
564 }
565
566 #[test]
567 fn bitpos_missing_key_bit0_returns_zero() {
568 let mut ks = Keyspace::new();
569 assert_eq!(ks.bitpos("nope", 0, None).unwrap(), 0);
570 }
571
572 #[test]
573 fn bitpos_find_first_set_bit() {
574 let mut ks = Keyspace::new();
575 ks.set(
577 "k".into(),
578 Bytes::from(vec![0x00, 0x01]),
579 None,
580 false,
581 false,
582 );
583 assert_eq!(ks.bitpos("k", 1, None).unwrap(), 15);
584 }
585
586 #[test]
587 fn bitpos_find_first_clear_bit_in_all_ones() {
588 let mut ks = Keyspace::new();
589 ks.set(
591 "k".into(),
592 Bytes::from(vec![0xFF, 0xFF]),
593 None,
594 false,
595 false,
596 );
597 assert_eq!(ks.bitpos("k", 0, None).unwrap(), 16);
598 }
599
600 #[test]
601 fn bitpos_wrong_type() {
602 let mut ks = Keyspace::new();
603 ks.lpush("list", &[Bytes::from("a")]).unwrap();
604 assert!(ks.bitpos("list", 1, None).is_err());
605 }
606
607 #[test]
610 fn bitop_and() {
611 let mut ks = Keyspace::new();
612 ks.set(
613 "a".into(),
614 Bytes::from(vec![0xFF, 0x0F]),
615 None,
616 false,
617 false,
618 );
619 ks.set(
620 "b".into(),
621 Bytes::from(vec![0x0F, 0xFF]),
622 None,
623 false,
624 false,
625 );
626 let len = ks
627 .bitop(BitOpKind::And, "dest".into(), &["a".into(), "b".into()])
628 .unwrap();
629 assert_eq!(len, 2);
630 let val = match ks.get("dest").unwrap() {
631 Some(Value::String(b)) => b,
632 other => panic!("expected String, got {other:?}"),
633 };
634 assert_eq!(val[0], 0x0F);
635 assert_eq!(val[1], 0x0F);
636 }
637
638 #[test]
639 fn bitop_or() {
640 let mut ks = Keyspace::new();
641 ks.set("a".into(), Bytes::from(vec![0xF0]), None, false, false);
642 ks.set("b".into(), Bytes::from(vec![0x0F]), None, false, false);
643 ks.bitop(BitOpKind::Or, "dest".into(), &["a".into(), "b".into()])
644 .unwrap();
645 let val = match ks.get("dest").unwrap() {
646 Some(Value::String(b)) => b,
647 other => panic!("expected String, got {other:?}"),
648 };
649 assert_eq!(val[0], 0xFF);
650 }
651
652 #[test]
653 fn bitop_xor() {
654 let mut ks = Keyspace::new();
655 ks.set("a".into(), Bytes::from(vec![0xFF]), None, false, false);
656 ks.set("b".into(), Bytes::from(vec![0xFF]), None, false, false);
657 ks.bitop(BitOpKind::Xor, "dest".into(), &["a".into(), "b".into()])
658 .unwrap();
659 let val = match ks.get("dest").unwrap() {
660 Some(Value::String(b)) => b,
661 other => panic!("expected String, got {other:?}"),
662 };
663 assert_eq!(val[0], 0x00);
664 }
665
666 #[test]
667 fn bitop_not() {
668 let mut ks = Keyspace::new();
669 ks.set(
670 "src".into(),
671 Bytes::from(vec![0xF0, 0x0F]),
672 None,
673 false,
674 false,
675 );
676 let len = ks
677 .bitop(BitOpKind::Not, "dest".into(), &["src".into()])
678 .unwrap();
679 assert_eq!(len, 2);
680 let val = match ks.get("dest").unwrap() {
681 Some(Value::String(b)) => b,
682 other => panic!("expected String, got {other:?}"),
683 };
684 assert_eq!(val[0], 0x0F);
685 assert_eq!(val[1], 0xF0);
686 }
687
688 #[test]
689 fn bitop_wrong_type() {
690 let mut ks = Keyspace::new();
691 ks.lpush("list", &[Bytes::from("a")]).unwrap();
692 assert!(ks
693 .bitop(BitOpKind::And, "dest".into(), &["list".into()])
694 .is_err());
695 }
696
697 #[test]
698 fn bitop_extends_to_longest_source() {
699 let mut ks = Keyspace::new();
700 ks.set(
701 "a".into(),
702 Bytes::from(vec![0xFF, 0xFF, 0xFF]),
703 None,
704 false,
705 false,
706 );
707 ks.set("b".into(), Bytes::from(vec![0xFF]), None, false, false);
708 let len = ks
709 .bitop(BitOpKind::Or, "dest".into(), &["a".into(), "b".into()])
710 .unwrap();
711 assert_eq!(len, 3);
712 }
713}