use ferray_core::Array;
use ferray_core::dimension::Dimension;
use ferray_core::dtype::Element;
use ferray_core::error::FerrayResult;
use crate::helpers::{binary_float_op, unary_float_op};
pub trait BitwiseOps:
std::ops::BitAnd<Output = Self>
+ std::ops::BitOr<Output = Self>
+ std::ops::BitXor<Output = Self>
+ std::ops::Not<Output = Self>
+ Copy
{
}
pub trait ShiftOps:
BitwiseOps + std::ops::Shl<u32, Output = Self> + std::ops::Shr<u32, Output = Self>
{
}
macro_rules! impl_bitwise_ops {
($($ty:ty),*) => {
$(impl BitwiseOps for $ty {})*
};
}
macro_rules! impl_shift_ops {
($($ty:ty),*) => {
$(impl ShiftOps for $ty {})*
};
}
impl_bitwise_ops!(i8, i16, i32, i64, i128, u8, u16, u32, u64, u128, bool);
impl_shift_ops!(i8, i16, i32, i64, i128, u8, u16, u32, u64, u128);
pub fn bitwise_and<T, D>(a: &Array<T, D>, b: &Array<T, D>) -> FerrayResult<Array<T, D>>
where
T: Element + BitwiseOps,
D: Dimension,
{
binary_float_op(a, b, |x, y| x & y)
}
pub fn bitwise_or<T, D>(a: &Array<T, D>, b: &Array<T, D>) -> FerrayResult<Array<T, D>>
where
T: Element + BitwiseOps,
D: Dimension,
{
binary_float_op(a, b, |x, y| x | y)
}
pub fn bitwise_xor<T, D>(a: &Array<T, D>, b: &Array<T, D>) -> FerrayResult<Array<T, D>>
where
T: Element + BitwiseOps,
D: Dimension,
{
binary_float_op(a, b, |x, y| x ^ y)
}
pub fn bitwise_not<T, D>(input: &Array<T, D>) -> FerrayResult<Array<T, D>>
where
T: Element + BitwiseOps,
D: Dimension,
{
unary_float_op(input, |x| !x)
}
pub fn invert<T, D>(input: &Array<T, D>) -> FerrayResult<Array<T, D>>
where
T: Element + BitwiseOps,
D: Dimension,
{
bitwise_not(input)
}
pub fn left_shift<T, D>(a: &Array<T, D>, b: &Array<u32, D>) -> FerrayResult<Array<T, D>>
where
T: Element + ShiftOps,
D: Dimension,
{
if a.shape() != b.shape() {
return Err(ferray_core::error::FerrayError::shape_mismatch(format!(
"left_shift: shapes {:?} and {:?} do not match",
a.shape(),
b.shape()
)));
}
let data: Vec<T> = a.iter().zip(b.iter()).map(|(&x, &s)| x << s).collect();
Array::from_vec(a.dim().clone(), data)
}
pub fn right_shift<T, D>(a: &Array<T, D>, b: &Array<u32, D>) -> FerrayResult<Array<T, D>>
where
T: Element + ShiftOps,
D: Dimension,
{
if a.shape() != b.shape() {
return Err(ferray_core::error::FerrayError::shape_mismatch(format!(
"right_shift: shapes {:?} and {:?} do not match",
a.shape(),
b.shape()
)));
}
let data: Vec<T> = a.iter().zip(b.iter()).map(|(&x, &s)| x >> s).collect();
Array::from_vec(a.dim().clone(), data)
}
#[cfg(test)]
mod tests {
use super::*;
use ferray_core::dimension::Ix1;
fn arr1_i32(data: Vec<i32>) -> Array<i32, Ix1> {
let n = data.len();
Array::from_vec(Ix1::new([n]), data).unwrap()
}
fn arr1_u32(data: Vec<u32>) -> Array<u32, Ix1> {
let n = data.len();
Array::from_vec(Ix1::new([n]), data).unwrap()
}
fn arr1_u8(data: Vec<u8>) -> Array<u8, Ix1> {
let n = data.len();
Array::from_vec(Ix1::new([n]), data).unwrap()
}
#[test]
fn test_bitwise_and() {
let a = arr1_i32(vec![0b1100, 0b1010]);
let b = arr1_i32(vec![0b1010, 0b1010]);
let r = bitwise_and(&a, &b).unwrap();
assert_eq!(r.as_slice().unwrap(), &[0b1000, 0b1010]);
}
#[test]
fn test_bitwise_or() {
let a = arr1_i32(vec![0b1100, 0b1010]);
let b = arr1_i32(vec![0b1010, 0b0101]);
let r = bitwise_or(&a, &b).unwrap();
assert_eq!(r.as_slice().unwrap(), &[0b1110, 0b1111]);
}
#[test]
fn test_bitwise_xor() {
let a = arr1_i32(vec![0b1100, 0b1010]);
let b = arr1_i32(vec![0b1010, 0b1010]);
let r = bitwise_xor(&a, &b).unwrap();
assert_eq!(r.as_slice().unwrap(), &[0b0110, 0b0000]);
}
#[test]
fn test_bitwise_not() {
let a = arr1_u8(vec![0b0000_1111]);
let r = bitwise_not(&a).unwrap();
assert_eq!(r.as_slice().unwrap(), &[0b1111_0000]);
}
#[test]
fn test_invert() {
let a = arr1_u8(vec![0b0000_1111]);
let r = invert(&a).unwrap();
assert_eq!(r.as_slice().unwrap(), &[0b1111_0000]);
}
#[test]
fn test_left_shift() {
let a = arr1_i32(vec![1, 2, 4]);
let s = arr1_u32(vec![1, 2, 3]);
let r = left_shift(&a, &s).unwrap();
assert_eq!(r.as_slice().unwrap(), &[2, 8, 32]);
}
#[test]
fn test_right_shift() {
let a = arr1_i32(vec![8, 16, 32]);
let s = arr1_u32(vec![1, 2, 3]);
let r = right_shift(&a, &s).unwrap();
assert_eq!(r.as_slice().unwrap(), &[4, 4, 4]);
}
}