use crate::bytes;
use crate::model::ByteArray;
use core::cmp::max;
use core::mem;
use core::ops::{BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Not};
use subtle::{Choice, ConditionallySelectable};
#[doc(hidden)]
pub(super) struct InternalOps;
impl InternalOps {
fn base_op(
lhs: ByteArray,
rhs: ByteArray,
byte_op: impl Fn(u8, u8) -> u8,
identity: u8,
) -> ByteArray {
let orig_lhs_len = lhs.len();
let orig_rhs_len = rhs.len();
let lhs_padded = if lhs.is_empty() {
bytes![identity; 1]
} else {
lhs
};
let rhs_padded = if rhs.is_empty() {
bytes![identity; 1]
} else {
rhs
};
let max_arr_size = max(lhs_padded.len(), rhs_padded.len());
let mut res = bytes![identity; max_arr_size];
let first_offset = max_arr_size - lhs_padded.len();
let second_offset = max_arr_size - rhs_padded.len();
for i in 0..max_arr_size {
let lhs_started = i >= first_offset;
let lhs_idx = i.saturating_sub(first_offset);
let lhs_in_bounds = lhs_started && (lhs_idx < lhs_padded.len());
let lhs_valid = Choice::from(lhs_in_bounds as u8);
let lhs_safe_idx = lhs_idx.min(lhs_padded.len() - 1);
let lhs_value = lhs_padded[lhs_safe_idx];
let lhs_byte = u8::conditional_select(&identity, &lhs_value, lhs_valid);
res[i] = byte_op(res[i], lhs_byte);
let rhs_started = i >= second_offset;
let rhs_idx = i.saturating_sub(second_offset);
let rhs_in_bounds = rhs_started && (rhs_idx < rhs_padded.len());
let rhs_valid = Choice::from(rhs_in_bounds as u8);
let rhs_safe_idx = rhs_idx.min(rhs_padded.len() - 1);
let rhs_value = rhs_padded[rhs_safe_idx];
let rhs_byte = u8::conditional_select(&identity, &rhs_value, rhs_valid);
res[i] = byte_op(res[i], rhs_byte);
}
let final_len = max(orig_lhs_len, orig_rhs_len);
res.truncate(final_len);
res
}
#[inline]
fn xor_op(lhs: ByteArray, rhs: ByteArray) -> ByteArray {
Self::base_op(lhs, rhs, |x, y| x ^ y, 0x00)
}
#[inline]
fn and_op(lhs: ByteArray, rhs: ByteArray) -> ByteArray {
Self::base_op(lhs, rhs, |x, y| x & y, 0xFF)
}
#[inline]
fn or_op(lhs: ByteArray, rhs: ByteArray) -> ByteArray {
Self::base_op(lhs, rhs, |x, y| x | y, 0x00)
}
}
impl BitXor for ByteArray {
type Output = Self;
fn bitxor(self, rhs: Self) -> Self::Output {
InternalOps::xor_op(self, rhs)
}
}
impl BitXorAssign for ByteArray {
fn bitxor_assign(&mut self, rhs: Self) {
*self = InternalOps::xor_op(mem::take(self), rhs);
}
}
impl BitAnd for ByteArray {
type Output = Self;
fn bitand(self, rhs: Self) -> Self::Output {
InternalOps::and_op(self, rhs)
}
}
impl BitAndAssign for ByteArray {
fn bitand_assign(&mut self, rhs: Self) {
*self = InternalOps::and_op(mem::take(self), rhs)
}
}
impl BitOr for ByteArray {
type Output = Self;
fn bitor(self, rhs: Self) -> Self::Output {
InternalOps::or_op(self, rhs)
}
}
impl BitOrAssign for ByteArray {
fn bitor_assign(&mut self, rhs: Self) {
*self = InternalOps::or_op(mem::take(self), rhs)
}
}
impl Not for ByteArray {
type Output = Self;
fn not(self) -> Self::Output {
ByteArray {
bytes: self.bytes.iter().map(|&b| !b).collect(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_xor_simple() {
let b1: ByteArray = [0xAA, 0xBB, 0xCC].into();
let b2 = ByteArray::from([0x55, 0x44, 0x33]);
let b3 = b2 ^ b1;
assert_eq!(b3.len(), 3);
assert_eq!(b3[0], 0xAA ^ 0x55);
assert_eq!(b3[1], 0xBB ^ 0x44);
assert_eq!(b3[2], 0xCC ^ 0x33);
}
#[test]
fn test_xor_unequal_length() {
let b1: ByteArray = [0xAA, 0xBB].into();
let b2 = ByteArray::from([0x11, 0x22, 0x33]);
let res = b1 ^ b2;
assert_eq!(res.len(), 3);
assert_eq!(res[0], 0x11);
assert_eq!(res[1], 0xAA ^ 0x22);
assert_eq!(res[2], 0xBB ^ 0x33);
}
#[test]
fn test_xor_single_byte_right_aligned() {
let b1 = ByteArray::from([0x12, 0x35, 0x56]);
let b2 = ByteArray::from(0xFF);
let res = b1 ^ b2;
assert_eq!(res.len(), 3);
assert_eq!(res[0], 0x12);
assert_eq!(res[1], 0x35);
assert_eq!(res[2], 0xFF ^ 0x56);
}
#[test]
fn test_xor_assign() {
let mut b1: ByteArray = [0xAA, 0xBB, 0xCC].into();
let b2 = ByteArray::from([0x55, 0x44, 0x33]);
b1 ^= b2;
assert_eq!(b1.len(), 3);
assert_eq!(b1[0], 0xAA ^ 0x55);
assert_eq!(b1[1], 0xBB ^ 0x44);
assert_eq!(b1[2], 0xCC ^ 0x33);
}
#[test]
fn test_and_simple() {
let b1: ByteArray = [0xFF, 0xAA, 0x55].into();
let b2 = ByteArray::from([0x0F, 0xF0, 0x33]);
let res = b1 & b2;
assert_eq!(res.len(), 3);
assert_eq!(res[0], 0xFF & 0x0F);
assert_eq!(res[1], 0xAA & 0xF0);
assert_eq!(res[2], 0x55 & 0x33);
}
#[test]
fn test_and_unequal_length() {
let b1: ByteArray = [0xAA, 0xBB].into();
let b2 = ByteArray::from([0x11, 0x22, 0x33]);
let res = b1 & b2;
assert_eq!(res.len(), 3);
assert_eq!(res[0], 0xFF & 0x11);
assert_eq!(res[1], 0xAA & 0x22);
assert_eq!(res[2], 0xBB & 0x33);
}
#[test]
fn test_and_single_byte() {
let b1 = ByteArray::from([0xFF, 0xAA, 0x55]);
let b2 = ByteArray::from(0x0F);
let res = b1 & b2;
assert_eq!(res.len(), 3);
assert_eq!(res[0], 0xFF);
assert_eq!(res[1], 0xFF & 0xAA);
assert_eq!(res[2], 0x55 & 0x0F);
}
#[test]
fn test_and_assign() {
let mut b1: ByteArray = [0xFF, 0xAA, 0x55].into();
let b2 = ByteArray::from([0x0F, 0xF0, 0x33]);
b1 &= b2;
assert_eq!(b1.len(), 3);
assert_eq!(b1[0], 0xFF & 0x0F);
assert_eq!(b1[1], 0xAA & 0xF0);
assert_eq!(b1[2], 0x55 & 0x33);
}
#[test]
fn test_or_simple() {
let b1: ByteArray = [0x0F, 0xAA, 0x55].into();
let b2 = ByteArray::from([0xF0, 0x55, 0xAA]);
let res = b1 | b2;
assert_eq!(res.len(), 3);
assert_eq!(res[0], 0x0F | 0xF0);
assert_eq!(res[1], 0xAA | 0x55);
assert_eq!(res[2], 0x55 | 0xAA);
}
#[test]
fn test_or_unequal_length() {
let b1: ByteArray = [0xAA, 0xBB].into();
let b2 = ByteArray::from([0x11, 0x22, 0x33]);
let res = b1 | b2;
assert_eq!(res.len(), 3);
assert_eq!(res[0], 0x00 | 0x11);
assert_eq!(res[1], 0xAA | 0x22);
assert_eq!(res[2], 0xBB | 0x33);
}
#[test]
fn test_or_single_byte() {
let b1 = ByteArray::from([0x10, 0x20, 0x30]);
let b2 = ByteArray::from(0x0F);
let res = b1 | b2;
assert_eq!(res.len(), 3);
assert_eq!(res[0], 0x10);
assert_eq!(res[1], 0x20);
assert_eq!(res[2], 0x30 | 0x0F);
}
#[test]
fn test_or_assign() {
let mut b1: ByteArray = [0x0F, 0xAA, 0x55].into();
let b2 = ByteArray::from([0xF0, 0x55, 0xAA]);
b1 |= b2;
assert_eq!(b1.len(), 3);
assert_eq!(b1[0], 0x0F | 0xF0);
assert_eq!(b1[1], 0xAA | 0x55);
assert_eq!(b1[2], 0x55 | 0xAA);
}
#[test]
fn test_not_simple() {
let b1: ByteArray = [0xFF, 0x00, 0xAA].into();
let res = !b1;
assert_eq!(res.len(), 3);
assert_eq!(res[0], !0xFF);
assert_eq!(res[1], !0x00);
assert_eq!(res[2], !0xAA);
}
#[test]
fn test_not_all_ones() {
let b1: ByteArray = [0xFF, 0xFF, 0xFF].into();
let res = !b1;
assert_eq!(res.len(), 3);
assert_eq!(res.as_bytes(), [0x00, 0x00, 0x00]);
}
#[test]
fn test_not_all_zeros() {
let b1: ByteArray = [0x00, 0x00, 0x00].into();
let res = !b1;
assert_eq!(res.len(), 3);
assert_eq!(res.as_bytes(), [0xFF, 0xFF, 0xFF]);
}
#[test]
fn test_not_single_byte() {
let b1 = ByteArray::from(0x55);
let res = !b1;
assert_eq!(res.len(), 1);
assert_eq!(res[0], 0xAA);
}
#[test]
fn test_combined_operations() {
let b1: ByteArray = [0xFF, 0x00].into();
let b2: ByteArray = [0xF0, 0x0F].into();
let b3: ByteArray = [0x55, 0xAA].into();
let res = (b1 ^ b2) & b3;
assert_eq!(res.len(), 2);
assert_eq!(res[0], (0xFF ^ 0xF0) & 0x55);
assert_eq!(res[1], (0x00 ^ 0x0F) & 0xAA);
}
#[test]
fn test_chained_xor_assign() {
let mut b1: ByteArray = [0xFF, 0xFF].into();
let b2: ByteArray = [0x0F, 0xF0].into();
let b3: ByteArray = [0x11, 0x22].into();
b1 ^= b2;
b1 ^= b3;
assert_eq!(b1[0], 0xFF ^ 0x0F ^ 0x11);
assert_eq!(b1[1], 0xFF ^ 0xF0 ^ 0x22);
}
}