1#[cfg(not(feature = "std"))]
6use alloc::{vec, vec::Vec};
7
8#[cfg(feature = "serde")]
9use serde::{Deserialize, Serialize};
10
11use crate::Automaton;
12
13#[derive(Debug, Clone)]
19#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
20#[repr(align(64))]
21pub struct BitwiseClause {
22 automata: Vec<Automaton>,
23 include: Vec<u64>,
24 negated: Vec<u64>,
25 polarity: i8,
26 n_features: usize,
27 dirty: bool
28}
29
30impl BitwiseClause {
31 #[must_use]
35 pub fn new(n_features: usize, n_states: i16, polarity: i8) -> Self {
36 debug_assert!(polarity == 1 || polarity == -1);
37 let n_words = n_features.div_ceil(64);
38 let automata = (0..2 * n_features)
39 .map(|_| Automaton::new(n_states))
40 .collect();
41
42 Self {
43 automata,
44 include: vec![0; n_words],
45 negated: vec![0; n_words],
46 polarity,
47 n_features,
48 dirty: true
49 }
50 }
51
52 #[inline(always)]
53 #[must_use]
54 pub const fn polarity(&self) -> i8 {
55 self.polarity
56 }
57
58 #[inline(always)]
59 #[must_use]
60 pub const fn n_features(&self) -> usize {
61 self.n_features
62 }
63
64 #[inline(always)]
65 #[must_use]
66 pub fn automata(&self) -> &[Automaton] {
67 &self.automata
68 }
69
70 #[inline(always)]
71 pub fn automata_mut(&mut self) -> &mut [Automaton] {
72 self.dirty = true;
73 &mut self.automata
74 }
75
76 pub fn rebuild_masks(&mut self) {
80 if !self.dirty {
81 return;
82 }
83
84 for word in &mut self.include {
85 *word = 0;
86 }
87 for word in &mut self.negated {
88 *word = 0;
89 }
90
91 for k in 0..self.n_features {
92 let word_idx = k / 64;
93 let bit_idx = k % 64;
94
95 if self.automata[2 * k].action() {
96 self.include[word_idx] |= 1u64 << bit_idx;
97 }
98 if self.automata[2 * k + 1].action() {
99 self.negated[word_idx] |= 1u64 << bit_idx;
100 }
101 }
102
103 self.dirty = false;
104 }
105
106 #[inline]
118 #[must_use]
119 pub fn evaluate_packed(&self, x_packed: &[u64]) -> bool {
120 debug_assert!(!self.dirty, "call rebuild_masks() first");
121
122 let n_words = self.include.len().min(x_packed.len());
123
124 for i in 0..n_words {
125 let x = unsafe { *x_packed.get_unchecked(i) };
128 let inc = unsafe { *self.include.get_unchecked(i) };
129 let neg = unsafe { *self.negated.get_unchecked(i) };
130
131 if (inc & !x) | (neg & x) != 0 {
134 return false;
135 }
136 }
137 true
138 }
139
140 #[inline(always)]
142 #[must_use]
143 pub fn vote_packed(&self, x_packed: &[u64]) -> i32 {
144 if self.evaluate_packed(x_packed) {
145 self.polarity as i32
146 } else {
147 0
148 }
149 }
150
151 #[inline]
155 #[must_use]
156 pub fn evaluate(&self, x: &[u8]) -> bool {
157 let n = self.n_features.min(x.len());
158
159 for k in 0..n {
160 let include = unsafe { self.automata.get_unchecked(2 * k).action() };
163 let negated = unsafe { self.automata.get_unchecked(2 * k + 1).action() };
164
165 let xk = unsafe { *x.get_unchecked(k) };
167
168 if include && xk == 0 {
169 return false;
170 }
171 if negated && xk == 1 {
172 return false;
173 }
174 }
175 true
176 }
177}
178
179#[inline]
183#[must_use]
184pub fn pack_input(x: &[u8]) -> Vec<u64> {
185 let n_words = x.len().div_ceil(64);
186 let mut packed = vec![0u64; n_words];
187
188 for (k, &xk) in x.iter().enumerate() {
189 if xk != 0 {
190 packed[k / 64] |= 1u64 << (k % 64);
191 }
192 }
193
194 packed
195}
196
197#[inline]
201#[must_use]
202pub fn pack_batch(xs: &[Vec<u8>]) -> Vec<Vec<u64>> {
203 xs.iter().map(|x| pack_input(x)).collect()
204}
205
206#[cfg(test)]
207mod tests {
208 use super::*;
209
210 #[test]
211 fn pack_input_basic() {
212 let x = vec![1, 0, 1, 1, 0, 0, 0, 1];
213 let packed = pack_input(&x);
214
215 assert_eq!(packed.len(), 1);
216 assert_eq!(packed[0], 0b10001101); }
218
219 #[test]
220 fn bitwise_evaluate_empty() {
221 let mut c = BitwiseClause::new(64, 100, 1);
222 c.rebuild_masks();
223
224 let x_packed = vec![0xFFFF_FFFF_FFFF_FFFFu64];
225 assert!(c.evaluate_packed(&x_packed));
226 }
227
228 #[test]
229 fn bitwise_evaluate_violation() {
230 let mut c = BitwiseClause::new(64, 100, 1);
231
232 for _ in 0..200 {
234 c.automata_mut()[0].increment();
235 }
236 c.rebuild_masks();
237
238 let x_packed = vec![0u64];
240 assert!(!c.evaluate_packed(&x_packed));
241
242 let x_packed = vec![1u64];
244 assert!(c.evaluate_packed(&x_packed));
245 }
246
247 #[test]
248 fn bitwise_clause_accessors() {
249 let c = BitwiseClause::new(128, 100, -1);
250
251 assert_eq!(c.polarity(), -1);
252 assert_eq!(c.n_features(), 128);
253 assert_eq!(c.automata().len(), 256); }
255
256 #[test]
257 fn bitwise_automata_mut_sets_dirty() {
258 let mut c = BitwiseClause::new(64, 100, 1);
259 c.rebuild_masks();
260
261 let _ = c.automata_mut();
263
264 c.rebuild_masks();
266 }
268
269 #[test]
270 fn bitwise_vote_packed() {
271 let mut c = BitwiseClause::new(64, 100, 1);
272 c.rebuild_masks(); assert_eq!(c.vote_packed(&[0u64]), 1);
276
277 for _ in 0..200 {
279 c.automata_mut()[0].increment();
280 }
281 c.rebuild_masks();
282
283 assert_eq!(c.vote_packed(&[0u64]), 0);
285
286 assert_eq!(c.vote_packed(&[1u64]), 1);
288 }
289
290 #[test]
291 fn bitwise_vote_packed_negative_polarity() {
292 let mut c = BitwiseClause::new(64, 100, -1);
293 c.rebuild_masks();
294
295 assert_eq!(c.vote_packed(&[0u64]), -1);
297 }
298
299 #[test]
300 fn bitwise_evaluate_scalar() {
301 let mut c = BitwiseClause::new(4, 100, 1);
302
303 for _ in 0..200 {
305 c.automata_mut()[0].increment(); c.automata_mut()[5].increment(); }
308
309 assert!(c.evaluate(&[1, 0, 0, 0]));
311
312 assert!(!c.evaluate(&[0, 0, 0, 0]));
314
315 assert!(!c.evaluate(&[1, 0, 1, 0]));
317 }
318
319 #[test]
320 fn bitwise_evaluate_scalar_empty() {
321 let c = BitwiseClause::new(4, 100, 1);
322 assert!(c.evaluate(&[0, 0, 0, 0]));
324 assert!(c.evaluate(&[1, 1, 1, 1]));
325 }
326
327 #[test]
328 fn pack_batch_multiple() {
329 let xs = vec![vec![1, 0, 0, 0], vec![0, 1, 0, 0], vec![1, 1, 0, 0]];
330 let packed = pack_batch(&xs);
331
332 assert_eq!(packed.len(), 3);
333 assert_eq!(packed[0][0], 0b0001); assert_eq!(packed[1][0], 0b0010); assert_eq!(packed[2][0], 0b0011); }
337
338 #[test]
339 fn pack_input_large() {
340 let mut x = vec![0u8; 128];
342 x[0] = 1;
343 x[63] = 1;
344 x[64] = 1;
345 x[127] = 1;
346
347 let packed = pack_input(&x);
348 assert_eq!(packed.len(), 2);
349 assert_eq!(packed[0], 1u64 | (1u64 << 63)); assert_eq!(packed[1], 1u64 | (1u64 << 63)); }
352
353 #[test]
354 fn bitwise_negated_violation() {
355 let mut c = BitwiseClause::new(64, 100, 1);
356
357 for _ in 0..200 {
359 c.automata_mut()[1].increment();
360 }
361 c.rebuild_masks();
362
363 assert!(!c.evaluate_packed(&[1u64]));
365
366 assert!(c.evaluate_packed(&[0u64]));
368 }
369
370 #[test]
371 fn bitwise_multi_word() {
372 let mut c = BitwiseClause::new(128, 100, 1);
373
374 for _ in 0..200 {
376 c.automata_mut()[128].increment(); }
378 c.rebuild_masks();
379
380 let x_packed = vec![0u64, 1u64];
382 assert!(c.evaluate_packed(&x_packed));
383
384 let x_packed = vec![0u64, 0u64];
386 assert!(!c.evaluate_packed(&x_packed));
387 }
388}