1use crate::bytes;
2use crate::model::ByteArray;
3use core::cmp::max;
4use core::mem;
5use core::ops::{BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Not};
6use subtle::{Choice, ConditionallySelectable};
7
8#[doc(hidden)]
9pub(super) struct InternalOps;
10
11impl InternalOps {
12 fn base_op(
42 lhs: ByteArray,
43 rhs: ByteArray,
44 byte_op: impl Fn(u8, u8) -> u8,
45 identity: u8,
46 ) -> ByteArray {
47 let orig_lhs_len = lhs.len();
49 let orig_rhs_len = rhs.len();
50
51 let lhs_padded = if lhs.is_empty() {
53 bytes![identity; 1]
54 } else {
55 lhs
56 };
57 let rhs_padded = if rhs.is_empty() {
58 bytes![identity; 1]
59 } else {
60 rhs
61 };
62
63 let max_arr_size = max(lhs_padded.len(), rhs_padded.len());
64 let mut res = bytes![identity; max_arr_size];
65
66 let first_offset = max_arr_size - lhs_padded.len();
68 let second_offset = max_arr_size - rhs_padded.len();
69
70 for i in 0..max_arr_size {
73 let lhs_started = i >= first_offset;
76 let lhs_idx = i.saturating_sub(first_offset);
77 let lhs_in_bounds = lhs_started && (lhs_idx < lhs_padded.len());
78 let lhs_valid = Choice::from(lhs_in_bounds as u8);
79
80 let lhs_safe_idx = lhs_idx.min(lhs_padded.len() - 1);
82 let lhs_value = lhs_padded[lhs_safe_idx];
83 let lhs_byte = u8::conditional_select(&identity, &lhs_value, lhs_valid);
84
85 res[i] = byte_op(res[i], lhs_byte);
87
88 let rhs_started = i >= second_offset;
90 let rhs_idx = i.saturating_sub(second_offset);
91 let rhs_in_bounds = rhs_started && (rhs_idx < rhs_padded.len());
92 let rhs_valid = Choice::from(rhs_in_bounds as u8);
93
94 let rhs_safe_idx = rhs_idx.min(rhs_padded.len() - 1);
95 let rhs_value = rhs_padded[rhs_safe_idx];
96 let rhs_byte = u8::conditional_select(&identity, &rhs_value, rhs_valid);
97
98 res[i] = byte_op(res[i], rhs_byte);
99 }
100
101 let final_len = max(orig_lhs_len, orig_rhs_len);
103 res.truncate(final_len);
104
105 res
106 }
107
108 #[inline]
109 fn xor_op(lhs: ByteArray, rhs: ByteArray) -> ByteArray {
110 Self::base_op(lhs, rhs, |x, y| x ^ y, 0x00)
111 }
112
113 #[inline]
114 fn and_op(lhs: ByteArray, rhs: ByteArray) -> ByteArray {
115 Self::base_op(lhs, rhs, |x, y| x & y, 0xFF)
116 }
117
118 #[inline]
119 fn or_op(lhs: ByteArray, rhs: ByteArray) -> ByteArray {
120 Self::base_op(lhs, rhs, |x, y| x | y, 0x00)
121 }
122}
123
124impl BitXor for ByteArray {
125 type Output = Self;
126
127 fn bitxor(self, rhs: Self) -> Self::Output {
128 InternalOps::xor_op(self, rhs)
129 }
130}
131
132impl BitXorAssign for ByteArray {
133 fn bitxor_assign(&mut self, rhs: Self) {
134 *self = InternalOps::xor_op(mem::take(self), rhs);
135 }
136}
137
138impl BitAnd for ByteArray {
139 type Output = Self;
140
141 fn bitand(self, rhs: Self) -> Self::Output {
142 InternalOps::and_op(self, rhs)
143 }
144}
145
146impl BitAndAssign for ByteArray {
147 fn bitand_assign(&mut self, rhs: Self) {
148 *self = InternalOps::and_op(mem::take(self), rhs)
149 }
150}
151
152impl BitOr for ByteArray {
153 type Output = Self;
154
155 fn bitor(self, rhs: Self) -> Self::Output {
156 InternalOps::or_op(self, rhs)
157 }
158}
159
160impl BitOrAssign for ByteArray {
161 fn bitor_assign(&mut self, rhs: Self) {
162 *self = InternalOps::or_op(mem::take(self), rhs)
163 }
164}
165
166impl Not for ByteArray {
167 type Output = Self;
168
169 fn not(self) -> Self::Output {
170 ByteArray {
171 bytes: self.bytes.iter().map(|&b| !b).collect(),
172 }
173 }
174}
175
176#[cfg(test)]
177mod tests {
178 use super::*;
179
180 #[test]
181 fn test_xor_simple() {
182 let b1: ByteArray = [0xAA, 0xBB, 0xCC].into();
183 let b2 = ByteArray::from([0x55, 0x44, 0x33]);
184 let b3 = b2 ^ b1;
185
186 assert_eq!(b3.len(), 3);
187 assert_eq!(b3[0], 0xAA ^ 0x55);
188 assert_eq!(b3[1], 0xBB ^ 0x44);
189 assert_eq!(b3[2], 0xCC ^ 0x33);
190 }
191
192 #[test]
193 fn test_xor_unequal_length() {
194 let b1: ByteArray = [0xAA, 0xBB].into();
195 let b2 = ByteArray::from([0x11, 0x22, 0x33]);
196 let res = b1 ^ b2;
197
198 assert_eq!(res.len(), 3);
199 assert_eq!(res[0], 0x11);
200 assert_eq!(res[1], 0xAA ^ 0x22);
201 assert_eq!(res[2], 0xBB ^ 0x33);
202 }
203
204 #[test]
205 fn test_xor_single_byte_right_aligned() {
206 let b1 = ByteArray::from([0x12, 0x35, 0x56]);
207 let b2 = ByteArray::from(0xFF);
208 let res = b1 ^ b2;
209
210 assert_eq!(res.len(), 3);
211 assert_eq!(res[0], 0x12);
212 assert_eq!(res[1], 0x35);
213 assert_eq!(res[2], 0xFF ^ 0x56);
214 }
215
216 #[test]
217 fn test_xor_assign() {
218 let mut b1: ByteArray = [0xAA, 0xBB, 0xCC].into();
219 let b2 = ByteArray::from([0x55, 0x44, 0x33]);
220 b1 ^= b2;
221
222 assert_eq!(b1.len(), 3);
223 assert_eq!(b1[0], 0xAA ^ 0x55);
224 assert_eq!(b1[1], 0xBB ^ 0x44);
225 assert_eq!(b1[2], 0xCC ^ 0x33);
226 }
227
228 #[test]
229 fn test_and_simple() {
230 let b1: ByteArray = [0xFF, 0xAA, 0x55].into();
231 let b2 = ByteArray::from([0x0F, 0xF0, 0x33]);
232 let res = b1 & b2;
233
234 assert_eq!(res.len(), 3);
235 assert_eq!(res[0], 0xFF & 0x0F);
236 assert_eq!(res[1], 0xAA & 0xF0);
237 assert_eq!(res[2], 0x55 & 0x33);
238 }
239
240 #[test]
241 fn test_and_unequal_length() {
242 let b1: ByteArray = [0xAA, 0xBB].into();
243 let b2 = ByteArray::from([0x11, 0x22, 0x33]);
244 let res = b1 & b2;
245
246 assert_eq!(res.len(), 3);
247 assert_eq!(res[0], 0xFF & 0x11);
248 assert_eq!(res[1], 0xAA & 0x22);
249 assert_eq!(res[2], 0xBB & 0x33);
250 }
251
252 #[test]
253 fn test_and_single_byte() {
254 let b1 = ByteArray::from([0xFF, 0xAA, 0x55]);
255 let b2 = ByteArray::from(0x0F);
256 let res = b1 & b2;
257
258 assert_eq!(res.len(), 3);
259 assert_eq!(res[0], 0xFF);
260 assert_eq!(res[1], 0xFF & 0xAA);
261 assert_eq!(res[2], 0x55 & 0x0F);
262 }
263
264 #[test]
265 fn test_and_assign() {
266 let mut b1: ByteArray = [0xFF, 0xAA, 0x55].into();
267 let b2 = ByteArray::from([0x0F, 0xF0, 0x33]);
268 b1 &= b2;
269
270 assert_eq!(b1.len(), 3);
271 assert_eq!(b1[0], 0xFF & 0x0F);
272 assert_eq!(b1[1], 0xAA & 0xF0);
273 assert_eq!(b1[2], 0x55 & 0x33);
274 }
275
276 #[test]
277 fn test_or_simple() {
278 let b1: ByteArray = [0x0F, 0xAA, 0x55].into();
279 let b2 = ByteArray::from([0xF0, 0x55, 0xAA]);
280 let res = b1 | b2;
281
282 assert_eq!(res.len(), 3);
283 assert_eq!(res[0], 0x0F | 0xF0);
284 assert_eq!(res[1], 0xAA | 0x55);
285 assert_eq!(res[2], 0x55 | 0xAA);
286 }
287
288 #[test]
289 fn test_or_unequal_length() {
290 let b1: ByteArray = [0xAA, 0xBB].into();
291 let b2 = ByteArray::from([0x11, 0x22, 0x33]);
292 let res = b1 | b2;
293
294 assert_eq!(res.len(), 3);
295 assert_eq!(res[0], 0x00 | 0x11);
296 assert_eq!(res[1], 0xAA | 0x22);
297 assert_eq!(res[2], 0xBB | 0x33);
298 }
299
300 #[test]
301 fn test_or_single_byte() {
302 let b1 = ByteArray::from([0x10, 0x20, 0x30]);
303 let b2 = ByteArray::from(0x0F);
304 let res = b1 | b2;
305
306 assert_eq!(res.len(), 3);
307 assert_eq!(res[0], 0x10);
308 assert_eq!(res[1], 0x20);
309 assert_eq!(res[2], 0x30 | 0x0F);
310 }
311
312 #[test]
313 fn test_or_assign() {
314 let mut b1: ByteArray = [0x0F, 0xAA, 0x55].into();
315 let b2 = ByteArray::from([0xF0, 0x55, 0xAA]);
316 b1 |= b2;
317
318 assert_eq!(b1.len(), 3);
319 assert_eq!(b1[0], 0x0F | 0xF0);
320 assert_eq!(b1[1], 0xAA | 0x55);
321 assert_eq!(b1[2], 0x55 | 0xAA);
322 }
323
324 #[test]
325 fn test_not_simple() {
326 let b1: ByteArray = [0xFF, 0x00, 0xAA].into();
327 let res = !b1;
328
329 assert_eq!(res.len(), 3);
330 assert_eq!(res[0], !0xFF);
331 assert_eq!(res[1], !0x00);
332 assert_eq!(res[2], !0xAA);
333 }
334
335 #[test]
336 fn test_not_all_ones() {
337 let b1: ByteArray = [0xFF, 0xFF, 0xFF].into();
338 let res = !b1;
339
340 assert_eq!(res.len(), 3);
341 assert_eq!(res.as_bytes(), [0x00, 0x00, 0x00]);
342 }
343
344 #[test]
345 fn test_not_all_zeros() {
346 let b1: ByteArray = [0x00, 0x00, 0x00].into();
347 let res = !b1;
348
349 assert_eq!(res.len(), 3);
350 assert_eq!(res.as_bytes(), [0xFF, 0xFF, 0xFF]);
351 }
352
353 #[test]
354 fn test_not_single_byte() {
355 let b1 = ByteArray::from(0x55);
356 let res = !b1;
357
358 assert_eq!(res.len(), 1);
359 assert_eq!(res[0], 0xAA);
360 }
361
362 #[test]
363 fn test_combined_operations() {
364 let b1: ByteArray = [0xFF, 0x00].into();
365 let b2: ByteArray = [0xF0, 0x0F].into();
366 let b3: ByteArray = [0x55, 0xAA].into();
367
368 let res = (b1 ^ b2) & b3;
369
370 assert_eq!(res.len(), 2);
371 assert_eq!(res[0], (0xFF ^ 0xF0) & 0x55);
372 assert_eq!(res[1], (0x00 ^ 0x0F) & 0xAA);
373 }
374
375 #[test]
376 fn test_chained_xor_assign() {
377 let mut b1: ByteArray = [0xFF, 0xFF].into();
378 let b2: ByteArray = [0x0F, 0xF0].into();
379 let b3: ByteArray = [0x11, 0x22].into();
380
381 b1 ^= b2;
382 b1 ^= b3;
383
384 assert_eq!(b1[0], 0xFF ^ 0x0F ^ 0x11);
385 assert_eq!(b1[1], 0xFF ^ 0xF0 ^ 0x22);
386 }
387}