Skip to main content

ember_core/keyspace/
bitmap.rs

1//! Bitmap operations on string-typed keys.
2//!
3//! Bitmaps are not a separate data type — they use `Value::String(Bytes)`
4//! with big-endian bit ordering (Redis compatible): bit 0 is the most
5//! significant bit of byte 0.
6
7use ember_protocol::command::{BitOpKind, BitRange, BitRangeUnit};
8
9use super::*;
10
11impl Keyspace {
12    /// Returns the bit at `offset` in the string stored at `key`.
13    ///
14    /// Bit ordering is big-endian: byte `offset / 8`, bit position
15    /// `7 - (offset % 8)`. Returns 0 for missing keys or offsets
16    /// beyond the string's length.
17    pub fn getbit(&mut self, key: &str, offset: u64) -> Result<u8, WrongType> {
18        if self.remove_if_expired(key) {
19            return Ok(0);
20        }
21        match self.entries.get(key) {
22            None => Ok(0),
23            Some(e) => match &e.value {
24                Value::String(data) => {
25                    let byte_idx = (offset / 8) as usize;
26                    if byte_idx >= data.len() {
27                        return Ok(0);
28                    }
29                    let bit_pos = 7 - (offset % 8) as u32;
30                    Ok((data[byte_idx] >> bit_pos) & 1)
31                }
32                _ => Err(WrongType),
33            },
34        }
35    }
36
37    /// Sets the bit at `offset` to `value` (0 or 1).
38    ///
39    /// Extends the string with zero bytes when `offset` reaches beyond the
40    /// current string length. Returns the old bit value.
41    ///
42    /// Follows the same memory-tracking pattern as `setrange`.
43    pub fn setbit(&mut self, key: &str, offset: u64, value: u8) -> Result<u8, WriteError> {
44        self.remove_if_expired(key);
45
46        let byte_idx = (offset / 8) as usize;
47        let bit_pos = 7 - (offset % 8) as u32;
48        let mask = 1u8 << bit_pos;
49
50        let (existing, expire) = match self.entries.get(key) {
51            Some(entry) => match &entry.value {
52                Value::String(data) => {
53                    let expire = time::remaining_ms(entry.expires_at_ms).map(Duration::from_millis);
54                    (data.clone(), expire)
55                }
56                _ => return Err(WriteError::WrongType),
57            },
58            None => (Bytes::new(), None),
59        };
60
61        let old_bit = if byte_idx < existing.len() {
62            (existing[byte_idx] >> bit_pos) & 1
63        } else {
64            0
65        };
66
67        // build the new buffer: existing data + zero-padding if needed
68        let new_len = existing.len().max(byte_idx + 1);
69        let mut buf = existing.to_vec();
70        buf.resize(new_len, 0);
71
72        if value == 1 {
73            buf[byte_idx] |= mask;
74        } else {
75            buf[byte_idx] &= !mask;
76        }
77
78        match self.set(key.to_owned(), Bytes::from(buf), expire, false, false) {
79            SetResult::Ok | SetResult::Blocked => Ok(old_bit),
80            SetResult::OutOfMemory => Err(WriteError::OutOfMemory),
81        }
82    }
83
84    /// Counts set bits in the string at `key`.
85    ///
86    /// When `range` is `None`, counts bits across the entire string.
87    /// When `range` is `Some(r)`, restricts to the given byte or bit range
88    /// (negative indices count from the end of the string).
89    ///
90    /// Returns 0 for missing keys.
91    pub fn bitcount(&mut self, key: &str, range: Option<BitRange>) -> Result<u64, WrongType> {
92        if self.remove_if_expired(key) {
93            return Ok(0);
94        }
95        let data = match self.entries.get(key) {
96            None => return Ok(0),
97            Some(e) => match &e.value {
98                Value::String(b) => b.clone(),
99                _ => return Err(WrongType),
100            },
101        };
102
103        match range {
104            None => Ok(data.iter().map(|b| b.count_ones() as u64).sum()),
105            Some(r) if r.unit == BitRangeUnit::Bit => {
106                // bit-granularity: count each individual bit in [start_bit, end_bit]
107                let len_bits = data.len() as i64 * 8;
108                let start = normalize_bit_index(r.start, len_bits).min(len_bits);
109                let end = normalize_bit_index(r.end, len_bits).min(len_bits - 1);
110                if start > end {
111                    return Ok(0);
112                }
113                let mut count = 0u64;
114                for bit_idx in start..=end {
115                    let byte_idx = (bit_idx / 8) as usize;
116                    let bit_pos = 7 - (bit_idx % 8) as u32;
117                    count += ((data[byte_idx] >> bit_pos) & 1) as u64;
118                }
119                Ok(count)
120            }
121            Some(r) => {
122                let slice = bit_range_slice(&data, r);
123                Ok(slice.iter().map(|b| b.count_ones() as u64).sum())
124            }
125        }
126    }
127
128    /// Returns the position of the first bit equal to `bit` (0 or 1).
129    ///
130    /// `range` works the same as for `bitcount`. When no `end` is given for
131    /// `BITPOS 1`, the search covers the whole string; for `BITPOS 0`, it
132    /// also covers the virtual zero bits beyond the string's end.
133    ///
134    /// Returns -1 if the bit is not found (except for `BITPOS 0` on a missing
135    /// key, which returns 0).
136    pub fn bitpos(
137        &mut self,
138        key: &str,
139        bit: u8,
140        range: Option<BitRange>,
141    ) -> Result<i64, WrongType> {
142        if self.remove_if_expired(key) {
143            // missing key: BITPOS 0 → 0, BITPOS 1 → -1
144            return Ok(if bit == 0 { 0 } else { -1 });
145        }
146        let data = match self.entries.get(key) {
147            None => {
148                return Ok(if bit == 0 { 0 } else { -1 });
149            }
150            Some(e) => match &e.value {
151                Value::String(b) => b.clone(),
152                _ => return Err(WrongType),
153            },
154        };
155
156        // determine whether the caller constrained the end boundary
157        let has_explicit_end = range.map(|r| r.end != -1).unwrap_or(false);
158
159        let (slice, bit_offset) = match range {
160            None => (&data[..], 0i64),
161            Some(r) if r.unit == BitRangeUnit::Bit => {
162                // bit-granularity range: resolve to an inclusive bit range
163                let len_bits = data.len() as i64 * 8;
164                let start = normalize_bit_index(r.start, len_bits).min(len_bits);
165                let end = normalize_bit_index(r.end, len_bits).min(len_bits - 1);
166                if start > end {
167                    return Ok(-1);
168                }
169                // search bit-by-bit within the resolved range
170                for bit_idx in start..=end {
171                    let byte_idx = (bit_idx / 8) as usize;
172                    let bit_pos = 7 - (bit_idx % 8) as u32;
173                    let found = (data[byte_idx] >> bit_pos) & 1;
174                    if found == bit {
175                        return Ok(bit_idx);
176                    }
177                }
178                return Ok(-1);
179            }
180            Some(r) => {
181                // byte-granularity range
182                let (s, e) = resolve_byte_range(r.start, r.end, data.len());
183                if s >= data.len() {
184                    return Ok(-1);
185                }
186                let end = e.min(data.len() - 1);
187                (&data[s..=end], (s as i64) * 8)
188            }
189        };
190
191        // scan bytes for the first matching bit
192        for (i, &byte) in slice.iter().enumerate() {
193            let b = if bit == 1 { byte } else { !byte };
194            if b != 0 {
195                let bit_in_byte = b.leading_zeros() as i64;
196                return Ok(bit_offset + (i as i64) * 8 + bit_in_byte);
197            }
198        }
199
200        // not found — for BITPOS 0 without an explicit end, the answer is the
201        // first virtual bit past the end of the string.
202        if bit == 0 && !has_explicit_end {
203            Ok((data.len() as i64) * 8)
204        } else {
205            Ok(-1)
206        }
207    }
208
209    /// Performs a bitwise operation across `keys` and stores the result in `dest`.
210    ///
211    /// Returns the length of the result string (equal to the longest source).
212    /// Missing keys are treated as zero-filled strings of the same length.
213    /// `NOT` requires exactly one source key (enforced at parse time).
214    pub fn bitop(
215        &mut self,
216        op: BitOpKind,
217        dest: String,
218        keys: &[String],
219    ) -> Result<usize, WriteError> {
220        // collect source bytes — type-check each before mutating anything
221        let mut sources: Vec<Bytes> = Vec::with_capacity(keys.len());
222        for key in keys {
223            self.remove_if_expired(key);
224            match self.entries.get(key.as_str()) {
225                None => sources.push(Bytes::new()),
226                Some(e) => match &e.value {
227                    Value::String(b) => sources.push(b.clone()),
228                    _ => return Err(WriteError::WrongType),
229                },
230            }
231        }
232
233        let result_len = sources.iter().map(|s| s.len()).max().unwrap_or(0);
234        let mut result = vec![0u8; result_len];
235
236        match op {
237            BitOpKind::Not => {
238                // NOT of the single source; bytes beyond the source length → 0xFF
239                let src = sources.first().map(|b| b.as_ref()).unwrap_or(&[]);
240                for (i, b) in result.iter_mut().enumerate() {
241                    *b = if i < src.len() { !src[i] } else { 0xFF };
242                }
243            }
244            BitOpKind::And => {
245                // initialize from first source; bytes beyond any source → AND with 0
246                if let Some(first) = sources.first() {
247                    for (i, b) in result.iter_mut().enumerate() {
248                        *b = if i < first.len() { first[i] } else { 0 };
249                    }
250                }
251                for src in sources.iter().skip(1) {
252                    for (i, b) in result.iter_mut().enumerate() {
253                        let s = if i < src.len() { src[i] } else { 0 };
254                        *b &= s;
255                    }
256                }
257            }
258            BitOpKind::Or => {
259                for src in &sources {
260                    for (i, b) in result.iter_mut().enumerate() {
261                        if i < src.len() {
262                            *b |= src[i];
263                        }
264                    }
265                }
266            }
267            BitOpKind::Xor => {
268                for src in &sources {
269                    for (i, b) in result.iter_mut().enumerate() {
270                        if i < src.len() {
271                            *b ^= src[i];
272                        }
273                    }
274                }
275            }
276        }
277
278        // store the result — set() handles memory accounting and version bumping.
279        // it also implicitly removes any prior value at dest (including wrong-type keys).
280        match self.set(dest, Bytes::from(result), None, false, false) {
281            SetResult::Ok | SetResult::Blocked => Ok(result_len),
282            SetResult::OutOfMemory => Err(WriteError::OutOfMemory),
283        }
284    }
285}
286
287/// Resolves a byte-granularity range `[start, end]` (Redis semantics) into
288/// a concrete `(start_byte, end_byte)` pair against a string of `len` bytes.
289///
290/// Negative indices count from the end (-1 = last byte). The returned range
291/// is NOT yet clamped to `[0, len)` — callers must do that.
292fn resolve_byte_range(start: i64, end: i64, len: usize) -> (usize, usize) {
293    let len = len as i64;
294    let s = if start < 0 {
295        (len + start).max(0)
296    } else {
297        start
298    } as usize;
299    let e = if end < 0 { (len + end).max(0) } else { end } as usize;
300    (s, e)
301}
302
303/// Resolves a signed bit index against `len_bits` (negative = from end).
304fn normalize_bit_index(idx: i64, len_bits: i64) -> i64 {
305    if idx < 0 {
306        (len_bits + idx).max(0)
307    } else {
308        idx
309    }
310}
311
312/// Returns the sub-slice of `data` described by `range`.
313///
314/// Handles both byte and bit-unit ranges. For bit-unit ranges, the slice is
315/// rounded to the containing bytes (bit searches happen inside the caller).
316fn bit_range_slice(data: &[u8], range: BitRange) -> &[u8] {
317    match range.unit {
318        BitRangeUnit::Byte => {
319            let (s, e) = resolve_byte_range(range.start, range.end, data.len());
320            if s >= data.len() {
321                return &[];
322            }
323            let end = e.min(data.len() - 1);
324            if s > end {
325                &[]
326            } else {
327                &data[s..=end]
328            }
329        }
330        BitRangeUnit::Bit => {
331            // for BITCOUNT with BIT range, convert to byte boundaries (inclusive)
332            let len_bits = data.len() as i64 * 8;
333            let start_bit = normalize_bit_index(range.start, len_bits).min(len_bits);
334            let end_bit = normalize_bit_index(range.end, len_bits).min(len_bits - 1);
335            if start_bit > end_bit || data.is_empty() {
336                return &[];
337            }
338            // return the byte slice containing all bits in [start_bit, end_bit]
339            // (individual bit masking happens in the caller for bitcount)
340            let start_byte = (start_bit / 8) as usize;
341            let end_byte = (end_bit / 8) as usize;
342            &data[start_byte..=end_byte.min(data.len() - 1)]
343        }
344    }
345}
346
347#[cfg(test)]
348mod tests {
349    use super::*;
350
351    // --- getbit ---
352
353    #[test]
354    fn getbit_missing_key_returns_zero() {
355        let mut ks = Keyspace::new();
356        assert_eq!(ks.getbit("nope", 0).unwrap(), 0);
357        assert_eq!(ks.getbit("nope", 100).unwrap(), 0);
358    }
359
360    #[test]
361    fn getbit_reads_msb_first() {
362        let mut ks = Keyspace::new();
363        // byte 0xFF: all bits set
364        ks.set("k".into(), Bytes::from(vec![0xFF]), None, false, false);
365        for offset in 0..8 {
366            assert_eq!(ks.getbit("k", offset).unwrap(), 1, "offset {offset}");
367        }
368        // byte 0x00: no bits set
369        ks.set("k".into(), Bytes::from(vec![0x00]), None, false, false);
370        for offset in 0..8 {
371            assert_eq!(ks.getbit("k", offset).unwrap(), 0, "offset {offset}");
372        }
373    }
374
375    #[test]
376    fn getbit_big_endian_ordering() {
377        let mut ks = Keyspace::new();
378        // 0x80 = 0b10000000: only bit 0 (MSB) is set
379        ks.set("k".into(), Bytes::from(vec![0x80]), None, false, false);
380        assert_eq!(ks.getbit("k", 0).unwrap(), 1);
381        assert_eq!(ks.getbit("k", 1).unwrap(), 0);
382        // 0x01 = 0b00000001: only bit 7 (LSB) is set
383        ks.set("k".into(), Bytes::from(vec![0x01]), None, false, false);
384        assert_eq!(ks.getbit("k", 7).unwrap(), 1);
385        assert_eq!(ks.getbit("k", 0).unwrap(), 0);
386    }
387
388    #[test]
389    fn getbit_beyond_string_returns_zero() {
390        let mut ks = Keyspace::new();
391        ks.set("k".into(), Bytes::from(vec![0xFF]), None, false, false);
392        // bit 8 is byte 1, which doesn't exist
393        assert_eq!(ks.getbit("k", 8).unwrap(), 0);
394    }
395
396    #[test]
397    fn getbit_wrong_type() {
398        let mut ks = Keyspace::new();
399        ks.lpush("list", &[Bytes::from("a")]).unwrap();
400        assert!(ks.getbit("list", 0).is_err());
401    }
402
403    // --- setbit ---
404
405    #[test]
406    fn setbit_returns_old_bit() {
407        let mut ks = Keyspace::new();
408        // new key: old bit is 0
409        assert_eq!(ks.setbit("k", 7, 1).unwrap(), 0);
410        // now bit 7 is set, returns 1
411        assert_eq!(ks.setbit("k", 7, 1).unwrap(), 1);
412        // clear it: returns 1
413        assert_eq!(ks.setbit("k", 7, 0).unwrap(), 1);
414        // now it's 0 again
415        assert_eq!(ks.setbit("k", 7, 0).unwrap(), 0);
416    }
417
418    #[test]
419    fn setbit_roundtrip_with_getbit() {
420        let mut ks = Keyspace::new();
421        ks.setbit("k", 10, 1).unwrap();
422        assert_eq!(ks.getbit("k", 10).unwrap(), 1);
423        assert_eq!(ks.getbit("k", 0).unwrap(), 0);
424    }
425
426    #[test]
427    fn setbit_extends_string() {
428        let mut ks = Keyspace::new();
429        // setting bit 15 requires 2 bytes
430        ks.setbit("k", 15, 1).unwrap();
431        let val = match ks.get("k").unwrap() {
432            Some(Value::String(b)) => b,
433            other => panic!("expected String, got {other:?}"),
434        };
435        assert_eq!(val.len(), 2);
436        assert_eq!(ks.getbit("k", 15).unwrap(), 1);
437    }
438
439    #[test]
440    fn setbit_preserves_ttl() {
441        let mut ks = Keyspace::new();
442        ks.set(
443            "k".into(),
444            Bytes::from(vec![0u8]),
445            Some(Duration::from_secs(60)),
446            false,
447            false,
448        );
449        ks.setbit("k", 0, 1).unwrap();
450        assert!(matches!(ks.ttl("k"), TtlResult::Seconds(_)));
451    }
452
453    #[test]
454    fn setbit_wrong_type() {
455        let mut ks = Keyspace::new();
456        ks.lpush("list", &[Bytes::from("a")]).unwrap();
457        assert!(ks.setbit("list", 0, 1).is_err());
458    }
459
460    // --- bitcount ---
461
462    #[test]
463    fn bitcount_missing_key_returns_zero() {
464        let mut ks = Keyspace::new();
465        assert_eq!(ks.bitcount("nope", None).unwrap(), 0);
466    }
467
468    #[test]
469    fn bitcount_full_string() {
470        let mut ks = Keyspace::new();
471        // 0xFF has 8 set bits, 0x0F has 4
472        ks.set(
473            "k".into(),
474            Bytes::from(vec![0xFF, 0x0F]),
475            None,
476            false,
477            false,
478        );
479        assert_eq!(ks.bitcount("k", None).unwrap(), 12);
480    }
481
482    #[test]
483    fn bitcount_byte_range() {
484        let mut ks = Keyspace::new();
485        ks.set(
486            "k".into(),
487            Bytes::from(vec![0xFF, 0x00, 0xFF]),
488            None,
489            false,
490            false,
491        );
492        // only byte 0
493        assert_eq!(
494            ks.bitcount(
495                "k",
496                Some(BitRange {
497                    start: 0,
498                    end: 0,
499                    unit: BitRangeUnit::Byte
500                })
501            )
502            .unwrap(),
503            8
504        );
505        // bytes 0 and 1
506        assert_eq!(
507            ks.bitcount(
508                "k",
509                Some(BitRange {
510                    start: 0,
511                    end: 1,
512                    unit: BitRangeUnit::Byte
513                })
514            )
515            .unwrap(),
516            8
517        );
518    }
519
520    #[test]
521    fn bitcount_bit_range() {
522        let mut ks = Keyspace::new();
523        // 0xFF: bits 0-7 all set
524        ks.set("k".into(), Bytes::from(vec![0xFF]), None, false, false);
525        assert_eq!(
526            ks.bitcount(
527                "k",
528                Some(BitRange {
529                    start: 0,
530                    end: 7,
531                    unit: BitRangeUnit::Bit
532                })
533            )
534            .unwrap(),
535            8
536        );
537        assert_eq!(
538            ks.bitcount(
539                "k",
540                Some(BitRange {
541                    start: 0,
542                    end: 3,
543                    unit: BitRangeUnit::Bit
544                })
545            )
546            .unwrap(),
547            4
548        );
549    }
550
551    #[test]
552    fn bitcount_wrong_type() {
553        let mut ks = Keyspace::new();
554        ks.lpush("list", &[Bytes::from("a")]).unwrap();
555        assert!(ks.bitcount("list", None).is_err());
556    }
557
558    // --- bitpos ---
559
560    #[test]
561    fn bitpos_missing_key_bit1_returns_minus_one() {
562        let mut ks = Keyspace::new();
563        assert_eq!(ks.bitpos("nope", 1, None).unwrap(), -1);
564    }
565
566    #[test]
567    fn bitpos_missing_key_bit0_returns_zero() {
568        let mut ks = Keyspace::new();
569        assert_eq!(ks.bitpos("nope", 0, None).unwrap(), 0);
570    }
571
572    #[test]
573    fn bitpos_find_first_set_bit() {
574        let mut ks = Keyspace::new();
575        // 0x00 0x01: first set bit is at position 15 (LSB of byte 1)
576        ks.set(
577            "k".into(),
578            Bytes::from(vec![0x00, 0x01]),
579            None,
580            false,
581            false,
582        );
583        assert_eq!(ks.bitpos("k", 1, None).unwrap(), 15);
584    }
585
586    #[test]
587    fn bitpos_find_first_clear_bit_in_all_ones() {
588        let mut ks = Keyspace::new();
589        // all bytes 0xFF: first clear bit is at position 16 (past end)
590        ks.set(
591            "k".into(),
592            Bytes::from(vec![0xFF, 0xFF]),
593            None,
594            false,
595            false,
596        );
597        assert_eq!(ks.bitpos("k", 0, None).unwrap(), 16);
598    }
599
600    #[test]
601    fn bitpos_wrong_type() {
602        let mut ks = Keyspace::new();
603        ks.lpush("list", &[Bytes::from("a")]).unwrap();
604        assert!(ks.bitpos("list", 1, None).is_err());
605    }
606
607    // --- bitop ---
608
609    #[test]
610    fn bitop_and() {
611        let mut ks = Keyspace::new();
612        ks.set(
613            "a".into(),
614            Bytes::from(vec![0xFF, 0x0F]),
615            None,
616            false,
617            false,
618        );
619        ks.set(
620            "b".into(),
621            Bytes::from(vec![0x0F, 0xFF]),
622            None,
623            false,
624            false,
625        );
626        let len = ks
627            .bitop(BitOpKind::And, "dest".into(), &["a".into(), "b".into()])
628            .unwrap();
629        assert_eq!(len, 2);
630        let val = match ks.get("dest").unwrap() {
631            Some(Value::String(b)) => b,
632            other => panic!("expected String, got {other:?}"),
633        };
634        assert_eq!(val[0], 0x0F);
635        assert_eq!(val[1], 0x0F);
636    }
637
638    #[test]
639    fn bitop_or() {
640        let mut ks = Keyspace::new();
641        ks.set("a".into(), Bytes::from(vec![0xF0]), None, false, false);
642        ks.set("b".into(), Bytes::from(vec![0x0F]), None, false, false);
643        ks.bitop(BitOpKind::Or, "dest".into(), &["a".into(), "b".into()])
644            .unwrap();
645        let val = match ks.get("dest").unwrap() {
646            Some(Value::String(b)) => b,
647            other => panic!("expected String, got {other:?}"),
648        };
649        assert_eq!(val[0], 0xFF);
650    }
651
652    #[test]
653    fn bitop_xor() {
654        let mut ks = Keyspace::new();
655        ks.set("a".into(), Bytes::from(vec![0xFF]), None, false, false);
656        ks.set("b".into(), Bytes::from(vec![0xFF]), None, false, false);
657        ks.bitop(BitOpKind::Xor, "dest".into(), &["a".into(), "b".into()])
658            .unwrap();
659        let val = match ks.get("dest").unwrap() {
660            Some(Value::String(b)) => b,
661            other => panic!("expected String, got {other:?}"),
662        };
663        assert_eq!(val[0], 0x00);
664    }
665
666    #[test]
667    fn bitop_not() {
668        let mut ks = Keyspace::new();
669        ks.set(
670            "src".into(),
671            Bytes::from(vec![0xF0, 0x0F]),
672            None,
673            false,
674            false,
675        );
676        let len = ks
677            .bitop(BitOpKind::Not, "dest".into(), &["src".into()])
678            .unwrap();
679        assert_eq!(len, 2);
680        let val = match ks.get("dest").unwrap() {
681            Some(Value::String(b)) => b,
682            other => panic!("expected String, got {other:?}"),
683        };
684        assert_eq!(val[0], 0x0F);
685        assert_eq!(val[1], 0xF0);
686    }
687
688    #[test]
689    fn bitop_wrong_type() {
690        let mut ks = Keyspace::new();
691        ks.lpush("list", &[Bytes::from("a")]).unwrap();
692        assert!(ks
693            .bitop(BitOpKind::And, "dest".into(), &["list".into()])
694            .is_err());
695    }
696
697    #[test]
698    fn bitop_extends_to_longest_source() {
699        let mut ks = Keyspace::new();
700        ks.set(
701            "a".into(),
702            Bytes::from(vec![0xFF, 0xFF, 0xFF]),
703            None,
704            false,
705            false,
706        );
707        ks.set("b".into(), Bytes::from(vec![0xFF]), None, false, false);
708        let len = ks
709            .bitop(BitOpKind::Or, "dest".into(), &["a".into(), "b".into()])
710            .unwrap();
711        assert_eq!(len, 3);
712    }
713}