Skip to main content

ferray_ufunc/ops/
bitwise.rs

1// ferray-ufunc: Bitwise functions
2//
3// bitwise_and, bitwise_or, bitwise_xor, bitwise_not, invert,
4// left_shift, right_shift
5
6use ferray_core::Array;
7use ferray_core::dimension::Dimension;
8use ferray_core::dtype::Element;
9use ferray_core::error::FerrayResult;
10
11use crate::helpers::{binary_elementwise_op, binary_mixed_op, unary_float_op};
12
13/// Trait for types that support bitwise operations.
14pub trait BitwiseOps:
15    std::ops::BitAnd<Output = Self>
16    + std::ops::BitOr<Output = Self>
17    + std::ops::BitXor<Output = Self>
18    + std::ops::Not<Output = Self>
19    + Copy
20{
21}
22
23/// Trait for types that support shift operations in addition to bitwise ops.
24pub trait ShiftOps:
25    BitwiseOps + std::ops::Shl<u32, Output = Self> + std::ops::Shr<u32, Output = Self>
26{
27}
28
29macro_rules! impl_bitwise_ops {
30    ($($ty:ty),*) => {
31        $(impl BitwiseOps for $ty {})*
32    };
33}
34
35macro_rules! impl_shift_ops {
36    ($($ty:ty),*) => {
37        $(impl ShiftOps for $ty {})*
38    };
39}
40
41impl_bitwise_ops!(i8, i16, i32, i64, i128, u8, u16, u32, u64, u128, bool);
42impl_shift_ops!(i8, i16, i32, i64, i128, u8, u16, u32, u64, u128);
43
44/// Elementwise bitwise AND.
45pub fn bitwise_and<T, D>(a: &Array<T, D>, b: &Array<T, D>) -> FerrayResult<Array<T, D>>
46where
47    T: Element + BitwiseOps,
48    D: Dimension,
49{
50    binary_elementwise_op(a, b, |x, y| x & y)
51}
52
53/// Elementwise bitwise OR.
54pub fn bitwise_or<T, D>(a: &Array<T, D>, b: &Array<T, D>) -> FerrayResult<Array<T, D>>
55where
56    T: Element + BitwiseOps,
57    D: Dimension,
58{
59    binary_elementwise_op(a, b, |x, y| x | y)
60}
61
62/// Elementwise bitwise XOR.
63pub fn bitwise_xor<T, D>(a: &Array<T, D>, b: &Array<T, D>) -> FerrayResult<Array<T, D>>
64where
65    T: Element + BitwiseOps,
66    D: Dimension,
67{
68    binary_elementwise_op(a, b, |x, y| x ^ y)
69}
70
71/// Elementwise bitwise NOT.
72pub fn bitwise_not<T, D>(input: &Array<T, D>) -> FerrayResult<Array<T, D>>
73where
74    T: Element + BitwiseOps,
75    D: Dimension,
76{
77    unary_float_op(input, |x| !x)
78}
79
80/// Alias for [`bitwise_not`].
81pub fn invert<T, D>(input: &Array<T, D>) -> FerrayResult<Array<T, D>>
82where
83    T: Element + BitwiseOps,
84    D: Dimension,
85{
86    bitwise_not(input)
87}
88
89/// Elementwise left shift with `NumPy` broadcasting.
90///
91/// Each element of `a` is shifted left by the corresponding element of `b`.
92pub fn left_shift<T, D>(a: &Array<T, D>, b: &Array<u32, D>) -> FerrayResult<Array<T, D>>
93where
94    T: Element + ShiftOps,
95    D: Dimension,
96{
97    binary_mixed_op(a, b, |x, s| x << s)
98}
99
100/// Elementwise right shift with `NumPy` broadcasting.
101///
102/// Each element of `a` is shifted right by the corresponding element of `b`.
103pub fn right_shift<T, D>(a: &Array<T, D>, b: &Array<u32, D>) -> FerrayResult<Array<T, D>>
104where
105    T: Element + ShiftOps,
106    D: Dimension,
107{
108    binary_mixed_op(a, b, |x, s| x >> s)
109}
110
111/// Trait for integer types that expose `count_ones` (population count).
112///
113/// Implemented for every signed and unsigned integer width that
114/// `BitwiseOps` already covers, with `bool` mapping `true` -> 1 and
115/// `false` -> 0 to match `NumPy`'s `np.bitwise_count(np.bool_(...))`
116/// behaviour. The output type is always `u32` since the largest
117/// possible popcount over a 128-bit input is 128.
118pub trait BitwiseCount {
119    /// Number of set bits in the value.
120    fn bitwise_count(self) -> u32;
121}
122
123macro_rules! impl_bitwise_count {
124    ($($ty:ty),*) => {
125        $(
126            impl BitwiseCount for $ty {
127                #[inline]
128                fn bitwise_count(self) -> u32 {
129                    self.count_ones()
130                }
131            }
132        )*
133    };
134}
135
136impl_bitwise_count!(i8, i16, i32, i64, i128, u8, u16, u32, u64, u128);
137
138impl BitwiseCount for bool {
139    #[inline]
140    fn bitwise_count(self) -> u32 {
141        u32::from(self)
142    }
143}
144
145/// Elementwise population count: number of set bits in each element.
146///
147/// Mirrors `NumPy` 2.0's `numpy.bitwise_count` (issue #396). Routes to
148/// the underlying integer's `count_ones` intrinsic, which compiles to
149/// `POPCNT` on x86 and the equivalent instruction on ARM/RISC-V where
150/// available.
151pub fn bitwise_count<T, D>(input: &Array<T, D>) -> FerrayResult<Array<u32, D>>
152where
153    T: Element + BitwiseCount + Copy,
154    D: Dimension,
155{
156    let data: Vec<u32> = input.iter().map(|&x| x.bitwise_count()).collect();
157    Array::from_vec(input.dim().clone(), data)
158}
159
160#[cfg(test)]
161mod tests {
162    use super::*;
163    use ferray_core::dimension::Ix1;
164
165    fn arr1_i32(data: Vec<i32>) -> Array<i32, Ix1> {
166        let n = data.len();
167        Array::from_vec(Ix1::new([n]), data).unwrap()
168    }
169
170    fn arr1_u32(data: Vec<u32>) -> Array<u32, Ix1> {
171        let n = data.len();
172        Array::from_vec(Ix1::new([n]), data).unwrap()
173    }
174
175    fn arr1_u8(data: Vec<u8>) -> Array<u8, Ix1> {
176        let n = data.len();
177        Array::from_vec(Ix1::new([n]), data).unwrap()
178    }
179
180    #[test]
181    fn test_bitwise_and() {
182        let a = arr1_i32(vec![0b1100, 0b1010]);
183        let b = arr1_i32(vec![0b1010, 0b1010]);
184        let r = bitwise_and(&a, &b).unwrap();
185        assert_eq!(r.as_slice().unwrap(), &[0b1000, 0b1010]);
186    }
187
188    #[test]
189    fn test_bitwise_or() {
190        let a = arr1_i32(vec![0b1100, 0b1010]);
191        let b = arr1_i32(vec![0b1010, 0b0101]);
192        let r = bitwise_or(&a, &b).unwrap();
193        assert_eq!(r.as_slice().unwrap(), &[0b1110, 0b1111]);
194    }
195
196    #[test]
197    fn test_bitwise_xor() {
198        let a = arr1_i32(vec![0b1100, 0b1010]);
199        let b = arr1_i32(vec![0b1010, 0b1010]);
200        let r = bitwise_xor(&a, &b).unwrap();
201        assert_eq!(r.as_slice().unwrap(), &[0b0110, 0b0000]);
202    }
203
204    #[test]
205    fn test_bitwise_not() {
206        let a = arr1_u8(vec![0b0000_1111]);
207        let r = bitwise_not(&a).unwrap();
208        assert_eq!(r.as_slice().unwrap(), &[0b1111_0000]);
209    }
210
211    #[test]
212    fn test_invert() {
213        let a = arr1_u8(vec![0b0000_1111]);
214        let r = invert(&a).unwrap();
215        assert_eq!(r.as_slice().unwrap(), &[0b1111_0000]);
216    }
217
218    // ----- bitwise_count (#396) -----
219
220    #[test]
221    fn test_bitwise_count_u8() {
222        let a = arr1_u8(vec![0u8, 1, 0b0000_1111, 0b1111_1111, 0b1010_0101]);
223        let r = bitwise_count(&a).unwrap();
224        assert_eq!(r.as_slice().unwrap(), &[0u32, 1, 4, 8, 4]);
225    }
226
227    #[test]
228    fn test_bitwise_count_i32_negative() {
229        // -1 as i32 is all-bits-set: 32 ones.
230        let a = arr1_i32(vec![-1, 0, 1, 7]);
231        let r = bitwise_count(&a).unwrap();
232        assert_eq!(r.as_slice().unwrap(), &[32u32, 0, 1, 3]);
233    }
234
235    #[test]
236    fn test_bitwise_count_u32() {
237        let a = arr1_u32(vec![0u32, 1, 0xFFFF_FFFF, 0xFF00_FF00]);
238        let r = bitwise_count(&a).unwrap();
239        assert_eq!(r.as_slice().unwrap(), &[0u32, 1, 32, 16]);
240    }
241
242    #[test]
243    fn test_bitwise_count_bool() {
244        let n = 4;
245        let a = Array::from_vec(Ix1::new([n]), vec![true, false, true, true]).unwrap();
246        let r = bitwise_count(&a).unwrap();
247        assert_eq!(r.as_slice().unwrap(), &[1u32, 0, 1, 1]);
248    }
249
250    #[test]
251    fn test_left_shift() {
252        let a = arr1_i32(vec![1, 2, 4]);
253        let s = arr1_u32(vec![1, 2, 3]);
254        let r = left_shift(&a, &s).unwrap();
255        assert_eq!(r.as_slice().unwrap(), &[2, 8, 32]);
256    }
257
258    #[test]
259    fn test_right_shift() {
260        let a = arr1_i32(vec![8, 16, 32]);
261        let s = arr1_u32(vec![1, 2, 3]);
262        let r = right_shift(&a, &s).unwrap();
263        assert_eq!(r.as_slice().unwrap(), &[4, 4, 4]);
264    }
265
266    // -----------------------------------------------------------------------
267    // Broadcasting tests for bitwise ops (issue #379)
268    // -----------------------------------------------------------------------
269
270    #[test]
271    fn test_bitwise_and_broadcasts() {
272        use ferray_core::dimension::Ix2;
273        let a = Array::<i32, Ix2>::from_vec(Ix2::new([2, 1]), vec![0xFF, 0x0F]).unwrap();
274        let b = Array::<i32, Ix2>::from_vec(Ix2::new([1, 2]), vec![0x0F, 0xF0]).unwrap();
275        let r = bitwise_and(&a, &b).unwrap();
276        assert_eq!(r.shape(), &[2, 2]);
277        // 0xFF & 0x0F = 0x0F, 0xFF & 0xF0 = 0xF0, 0x0F & 0x0F = 0x0F, 0x0F & 0xF0 = 0x00
278        assert_eq!(
279            r.iter().copied().collect::<Vec<_>>(),
280            vec![0x0F, 0xF0, 0x0F, 0x00]
281        );
282    }
283
284    #[test]
285    fn test_bitwise_or_broadcasts() {
286        use ferray_core::dimension::Ix2;
287        let a = Array::<u8, Ix2>::from_vec(Ix2::new([2, 1]), vec![0b1010, 0b0101]).unwrap();
288        let b = Array::<u8, Ix2>::from_vec(Ix2::new([1, 2]), vec![0b0011, 0b1100]).unwrap();
289        let r = bitwise_or(&a, &b).unwrap();
290        assert_eq!(r.shape(), &[2, 2]);
291        assert_eq!(
292            r.iter().copied().collect::<Vec<_>>(),
293            vec![0b1011, 0b1110, 0b0111, 0b1101]
294        );
295    }
296
297    #[test]
298    fn test_left_shift_broadcasts() {
299        use ferray_core::dimension::Ix2;
300        let a = Array::<i32, Ix2>::from_vec(Ix2::new([2, 1]), vec![1, 2]).unwrap();
301        let s = Array::<u32, Ix2>::from_vec(Ix2::new([1, 3]), vec![0, 1, 2]).unwrap();
302        let r = left_shift(&a, &s).unwrap();
303        assert_eq!(r.shape(), &[2, 3]);
304        // 1 << {0,1,2} = {1,2,4}, 2 << {0,1,2} = {2,4,8}
305        assert_eq!(
306            r.iter().copied().collect::<Vec<_>>(),
307            vec![1, 2, 4, 2, 4, 8]
308        );
309    }
310
311    #[test]
312    fn test_right_shift_broadcasts() {
313        use ferray_core::dimension::Ix2;
314        let a = Array::<i32, Ix2>::from_vec(Ix2::new([2, 1]), vec![16, 64]).unwrap();
315        let s = Array::<u32, Ix2>::from_vec(Ix2::new([1, 3]), vec![0, 1, 2]).unwrap();
316        let r = right_shift(&a, &s).unwrap();
317        assert_eq!(r.shape(), &[2, 3]);
318        assert_eq!(
319            r.iter().copied().collect::<Vec<_>>(),
320            vec![16, 8, 4, 64, 32, 16]
321        );
322    }
323}