rocketmq_filter/utils/
bits_array.rs

1/*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements.  See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License.  You may obtain a copy of the License at
8 *
9 *     http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17#[derive(Clone, Debug)]
18pub struct BitsArray {
19    bytes: Vec<u8>,
20    bit_length: usize,
21}
22
23impl BitsArray {
24    pub fn create(bit_length: usize) -> Self {
25        let bytes = vec![0u8; bit_length.div_ceil(8)];
26        BitsArray { bytes, bit_length }
27    }
28
29    pub fn from_bytes(bytes: &[u8], bit_length: Option<usize>) -> Self {
30        if bytes.is_empty() {
31            panic!("Bytes is empty!");
32        }
33        let bit_length = bit_length.unwrap_or(bytes.len() * 8);
34        if bit_length < 1 {
35            panic!("Bit is less than 1.");
36        }
37        if bit_length < bytes.len() * 8 {
38            panic!("BitLength is less than bytes.len() * 8");
39        }
40        BitsArray {
41            bytes: bytes.to_vec(),
42            bit_length,
43        }
44    }
45
46    pub fn bit_length(&self) -> usize {
47        self.bit_length
48    }
49
50    pub fn byte_length(&self) -> usize {
51        self.bytes.len()
52    }
53
54    pub fn bytes(&self) -> &[u8] {
55        &self.bytes
56    }
57
58    pub fn xor(&mut self, other: &BitsArray) {
59        self.check_initialized();
60        other.check_initialized();
61        let min_len = self.byte_length().min(other.byte_length());
62        for i in 0..min_len {
63            self.bytes[i] ^= other.get_byte(i);
64        }
65    }
66
67    pub fn xor_bit(&mut self, bit_pos: usize, set: bool) {
68        self.check_bit_position(bit_pos);
69        let value = self.get_bit(bit_pos);
70        self.set_bit(bit_pos, value ^ set);
71    }
72
73    pub fn or(&mut self, other: &BitsArray) {
74        self.check_initialized();
75        other.check_initialized();
76        let min_len = self.byte_length().min(other.byte_length());
77        for i in 0..min_len {
78            self.bytes[i] |= other.get_byte(i);
79        }
80    }
81
82    pub fn or_bit(&mut self, bit_pos: usize, set: bool) {
83        self.check_bit_position(bit_pos);
84        if set {
85            self.set_bit(bit_pos, true);
86        }
87    }
88
89    pub fn and(&mut self, other: &BitsArray) {
90        self.check_initialized();
91        other.check_initialized();
92        let min_len = self.byte_length().min(other.byte_length());
93        for i in 0..min_len {
94            self.bytes[i] &= other.get_byte(i);
95        }
96    }
97
98    pub fn and_bit(&mut self, bit_pos: usize, set: bool) {
99        self.check_bit_position(bit_pos);
100        if !set {
101            self.set_bit(bit_pos, false);
102        }
103    }
104
105    pub fn not(&mut self, bit_pos: usize) {
106        self.check_bit_position(bit_pos);
107        let value = self.get_bit(bit_pos);
108        self.set_bit(bit_pos, !value);
109    }
110
111    pub fn set_bit(&mut self, bit_pos: usize, set: bool) {
112        self.check_bit_position(bit_pos);
113        let sub = self.subscript(bit_pos);
114        let pos = self.position(bit_pos);
115        if set {
116            self.bytes[sub] |= pos;
117        } else {
118            self.bytes[sub] &= !pos;
119        }
120    }
121
122    pub fn set_byte(&mut self, byte_pos: usize, set: u8) {
123        self.check_byte_position(byte_pos);
124        self.bytes[byte_pos] = set;
125    }
126
127    pub fn get_bit(&self, bit_pos: usize) -> bool {
128        self.check_bit_position(bit_pos);
129        (self.bytes[self.subscript(bit_pos)] & self.position(bit_pos)) != 0
130    }
131
132    pub fn get_byte(&self, byte_pos: usize) -> u8 {
133        self.check_byte_position(byte_pos);
134        self.bytes[byte_pos]
135    }
136
137    fn subscript(&self, bit_pos: usize) -> usize {
138        bit_pos / 8
139    }
140
141    fn position(&self, bit_pos: usize) -> u8 {
142        1 << (bit_pos % 8)
143    }
144
145    fn check_byte_position(&self, byte_pos: usize) {
146        self.check_initialized();
147        if byte_pos >= self.byte_length() {
148            panic!("BytePos is greater than {}", self.bytes.len());
149        }
150    }
151
152    fn check_bit_position(&self, bit_pos: usize) {
153        self.check_initialized();
154        if bit_pos >= self.bit_length() {
155            panic!("BitPos is greater than {}", self.bit_length);
156        }
157    }
158
159    fn check_initialized(&self) {
160        if self.bytes.is_empty() {
161            panic!("Not initialized!");
162        }
163    }
164
165    pub fn clone_bits(&self) -> BitsArray {
166        BitsArray {
167            bytes: self.bytes.clone(),
168            bit_length: self.bit_length,
169        }
170    }
171}
172
173impl std::fmt::Display for BitsArray {
174    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
175        if self.bytes.is_empty() {
176            return write!(f, "null");
177        }
178        let mut s = String::with_capacity(self.bytes.len() * 8);
179        for i in (0..self.bytes.len()).rev() {
180            let mut j = 7;
181            if i == self.bytes.len() - 1 && self.bit_length % 8 > 0 {
182                j = self.bit_length % 8 - 1;
183            }
184            for k in (0..=j).rev() {
185                let mask = 1 << k;
186                if (self.bytes[i] & mask) == mask {
187                    s.push('1');
188                } else {
189                    s.push('0');
190                }
191            }
192            if i % 8 == 0 {
193                s.push('\n');
194            }
195        }
196        write!(f, "{s}")
197    }
198}