use crate::array::Array;
use crate::error::{NumRs2Error, Result};
use std::ops::{BitAnd, BitOr, BitXor, Not, Shl, Shr};
pub fn bitwise_and<T>(x1: &Array<T>, x2: &Array<T>) -> Result<Array<T>>
where
T: Clone + BitAnd<Output = T>,
{
if x1.shape() != x2.shape() {
return Err(NumRs2Error::ShapeMismatch {
expected: x1.shape(),
actual: x2.shape(),
});
}
let x1_data = x1.to_vec();
let x2_data = x2.to_vec();
let result: Vec<T> = x1_data
.into_iter()
.zip(x2_data)
.map(|(a, b)| a & b)
.collect();
Ok(Array::from_vec(result).reshape(&x1.shape()))
}
pub fn bitwise_or<T>(x1: &Array<T>, x2: &Array<T>) -> Result<Array<T>>
where
T: Clone + BitOr<Output = T>,
{
if x1.shape() != x2.shape() {
return Err(NumRs2Error::ShapeMismatch {
expected: x1.shape(),
actual: x2.shape(),
});
}
let x1_data = x1.to_vec();
let x2_data = x2.to_vec();
let result: Vec<T> = x1_data
.into_iter()
.zip(x2_data)
.map(|(a, b)| a | b)
.collect();
Ok(Array::from_vec(result).reshape(&x1.shape()))
}
pub fn bitwise_xor<T>(x1: &Array<T>, x2: &Array<T>) -> Result<Array<T>>
where
T: Clone + BitXor<Output = T>,
{
if x1.shape() != x2.shape() {
return Err(NumRs2Error::ShapeMismatch {
expected: x1.shape(),
actual: x2.shape(),
});
}
let x1_data = x1.to_vec();
let x2_data = x2.to_vec();
let result: Vec<T> = x1_data
.into_iter()
.zip(x2_data)
.map(|(a, b)| a ^ b)
.collect();
Ok(Array::from_vec(result).reshape(&x1.shape()))
}
pub fn bitwise_not<T>(x: &Array<T>) -> Array<T>
where
T: Clone + Not<Output = T>,
{
x.map(|val| !val)
}
pub fn invert<T>(x: &Array<T>) -> Array<T>
where
T: Clone + Not<Output = T>,
{
bitwise_not(x)
}
pub fn left_shift<T, U>(x1: &Array<T>, x2: &Array<U>) -> Result<Array<T>>
where
T: Clone + Shl<U, Output = T>,
U: Clone,
{
if x1.shape() != x2.shape() {
return Err(NumRs2Error::ShapeMismatch {
expected: x1.shape(),
actual: x2.shape(),
});
}
let x1_data = x1.to_vec();
let x2_data = x2.to_vec();
let result: Vec<T> = x1_data
.into_iter()
.zip(x2_data)
.map(|(a, b)| a << b)
.collect();
Ok(Array::from_vec(result).reshape(&x1.shape()))
}
pub fn right_shift<T, U>(x1: &Array<T>, x2: &Array<U>) -> Result<Array<T>>
where
T: Clone + Shr<U, Output = T>,
U: Clone,
{
if x1.shape() != x2.shape() {
return Err(NumRs2Error::ShapeMismatch {
expected: x1.shape(),
actual: x2.shape(),
});
}
let x1_data = x1.to_vec();
let x2_data = x2.to_vec();
let result: Vec<T> = x1_data
.into_iter()
.zip(x2_data)
.map(|(a, b)| a >> b)
.collect();
Ok(Array::from_vec(result).reshape(&x1.shape()))
}
pub fn left_shift_scalar<T, U>(x: &Array<T>, shift: U) -> Array<T>
where
T: Clone + Shl<U, Output = T>,
U: Clone,
{
x.map(|val| val << shift.clone())
}
pub fn right_shift_scalar<T, U>(x: &Array<T>, shift: U) -> Array<T>
where
T: Clone + Shr<U, Output = T>,
U: Clone,
{
x.map(|val| val >> shift.clone())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_bitwise_and() {
let a = Array::from_vec(vec![13, 17, 21]);
let b = Array::from_vec(vec![9, 7, 15]);
let result = bitwise_and(&a, &b).expect("bitwise_and should succeed");
assert_eq!(result.to_vec(), vec![9, 1, 5]);
}
#[test]
fn test_bitwise_or() {
let a = Array::from_vec(vec![13, 17, 21]);
let b = Array::from_vec(vec![9, 7, 15]);
let result = bitwise_or(&a, &b).expect("bitwise_or should succeed");
assert_eq!(result.to_vec(), vec![13, 23, 31]);
}
#[test]
fn test_bitwise_xor() {
let a = Array::from_vec(vec![13, 17, 21]);
let b = Array::from_vec(vec![9, 7, 15]);
let result = bitwise_xor(&a, &b).expect("bitwise_xor should succeed");
assert_eq!(result.to_vec(), vec![4, 22, 26]);
}
#[test]
fn test_bitwise_not() {
let a = Array::from_vec(vec![13u8, 17u8, 21u8]);
let result = bitwise_not(&a);
assert_eq!(result.to_vec(), vec![242u8, 238u8, 234u8]);
}
#[test]
fn test_left_shift() {
let a = Array::from_vec(vec![5, 10, 15]);
let shift = Array::from_vec(vec![1, 2, 3]);
let result = left_shift(&a, &shift).expect("left_shift should succeed");
assert_eq!(result.to_vec(), vec![10, 40, 120]);
}
#[test]
fn test_right_shift() {
let a = Array::from_vec(vec![40, 80, 120]);
let shift = Array::from_vec(vec![1, 2, 3]);
let result = right_shift(&a, &shift).expect("right_shift should succeed");
assert_eq!(result.to_vec(), vec![20, 20, 15]);
}
#[test]
fn test_shift_scalar() {
let a = Array::from_vec(vec![5, 10, 15]);
let left_result = left_shift_scalar(&a, 2);
assert_eq!(left_result.to_vec(), vec![20, 40, 60]);
let right_result = right_shift_scalar(&left_result, 2);
assert_eq!(right_result.to_vec(), vec![5, 10, 15]);
}
#[test]
fn test_shape_mismatch() {
let a = Array::from_vec(vec![1, 2, 3]);
let b = Array::from_vec(vec![1, 2]);
assert!(bitwise_and(&a, &b).is_err());
assert!(bitwise_or(&a, &b).is_err());
assert!(bitwise_xor(&a, &b).is_err());
}
}