Skip to main content

ferray_stride_tricks/
lib.rs

1// ferray-stride-tricks: Low-level view construction via custom strides.
2//
3// Implements `numpy.lib.stride_tricks`: sliding window views, broadcast
4// views, and the safe/unsafe `as_strided` family.
5//
6// # Overview
7//
8// | Function                | Safety | Copies data? |
9// |------------------------|--------|--------------|
10// | `sliding_window_view`  | safe   | no           |
11// | `broadcast_to`         | safe   | no           |
12// | `broadcast_arrays`     | safe   | no           |
13// | `broadcast_shapes`     | safe   | n/a          |
14// | `as_strided`           | safe   | no           |
15// | `as_strided_unchecked` | unsafe | no           |
16
17pub mod as_strided;
18pub mod broadcast;
19pub mod overlap_check;
20pub mod sliding_window;
21
22// Re-export primary public functions at crate root for ergonomics.
23
24pub use as_strided::{as_strided, as_strided_unchecked};
25pub use broadcast::{broadcast_arrays, broadcast_shapes, broadcast_to};
26pub use sliding_window::sliding_window_view;
27
28#[cfg(test)]
29mod integration_tests {
30    //! Integration tests covering the acceptance criteria (AC-1 through AC-6).
31
32    use ferray_core::Array;
33    use ferray_core::dimension::{Ix1, Ix2};
34
35    use crate::{
36        as_strided, as_strided_unchecked, broadcast_arrays, broadcast_shapes, broadcast_to,
37        sliding_window_view,
38    };
39
40    // AC-1: sliding_window_view(&[1,2,3,4,5], (3,)) returns views
41    //        [[1,2,3], [2,3,4], [3,4,5]]
42    #[test]
43    fn ac1_sliding_window_view() {
44        let a = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![1, 2, 3, 4, 5]).unwrap();
45        let v = sliding_window_view(&a, &[3]).unwrap();
46        assert_eq!(v.shape(), &[3, 3]);
47        let data: Vec<i32> = v.iter().copied().collect();
48        assert_eq!(data, vec![1, 2, 3, 2, 3, 4, 3, 4, 5]);
49    }
50
51    // AC-2: broadcast_to(&ones(3), (4, 3)) returns a (4,3) view without
52    //        allocation
53    #[test]
54    fn ac2_broadcast_to_no_alloc() {
55        let a = Array::<f64, Ix1>::ones(Ix1::new([3])).unwrap();
56        let v = broadcast_to(&a, &[4, 3]).unwrap();
57        assert_eq!(v.shape(), &[4, 3]);
58        assert_eq!(v.size(), 12);
59        // No allocation: view shares the same base pointer
60        assert_eq!(v.as_ptr(), a.as_ptr());
61        // All elements are 1.0
62        let data: Vec<f64> = v.iter().copied().collect();
63        assert_eq!(data, vec![1.0; 12]);
64    }
65
66    // AC-3: broadcast_shapes(&[(3,1), (1,4)]) returns (3,4)
67    #[test]
68    fn ac3_broadcast_shapes() {
69        let result = broadcast_shapes(&[&[3, 1][..], &[1, 4][..]]).unwrap();
70        assert_eq!(result, vec![3, 4]);
71    }
72
73    // AC-4: as_strided with non-overlapping strides succeeds; overlapping
74    //        strides returns Err
75    #[test]
76    fn ac4_as_strided_safe_vs_overlapping() {
77        let a =
78            Array::<f64, Ix1>::from_vec(Ix1::new([6]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
79
80        // Non-overlapping: 2x3 with standard row-major strides
81        let v = as_strided(&a, &[2, 3], &[3, 1]).unwrap();
82        assert_eq!(v.shape(), &[2, 3]);
83        let data: Vec<f64> = v.iter().copied().collect();
84        assert_eq!(data, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
85
86        // Overlapping: sliding window pattern (strides [1,1])
87        let a5 = Array::<f64, Ix1>::from_vec(Ix1::new([5]), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
88        let err = as_strided(&a5, &[3, 3], &[1, 1]);
89        assert!(err.is_err());
90    }
91
92    // AC-5: as_strided_unchecked compiles only in unsafe block
93    // (This is enforced by the `unsafe fn` signature; attempting to call
94    // it outside an unsafe block will produce a compile error.)
95    #[test]
96    fn ac5_as_strided_unchecked_requires_unsafe() {
97        let a = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![1, 2, 3, 4, 5]).unwrap();
98        // This call requires `unsafe` — removing the unsafe block would
99        // fail to compile.
100        let v = unsafe { as_strided_unchecked(&a, &[3, 3], &[1, 1]).unwrap() };
101        assert_eq!(v.shape(), &[3, 3]);
102        let data: Vec<i32> = v.iter().copied().collect();
103        assert_eq!(data, vec![1, 2, 3, 2, 3, 4, 3, 4, 5]);
104    }
105
106    // AC-6: cargo test -p ferray-stride-tricks passes (this file!),
107    //        cargo clippy clean (checked by CI / manual run).
108
109    // Additional integration: broadcast_arrays
110    #[test]
111    fn broadcast_arrays_integration() {
112        let a = Array::<f64, Ix2>::ones(Ix2::new([4, 1])).unwrap();
113        let b = Array::<f64, Ix2>::ones(Ix2::new([1, 3])).unwrap();
114        let arrays = [a, b];
115        let views = broadcast_arrays(&arrays).unwrap();
116        assert_eq!(views.len(), 2);
117        assert_eq!(views[0].shape(), &[4, 3]);
118        assert_eq!(views[1].shape(), &[4, 3]);
119    }
120
121    // Additional: sliding_window_view is truly zero-copy
122    #[test]
123    fn sliding_window_is_zero_copy() {
124        let a = Array::<f64, Ix1>::from_vec(Ix1::new([10]), (0..10).map(|i| i as f64).collect())
125            .unwrap();
126        let v = sliding_window_view(&a, &[4]).unwrap();
127        // Shape: (7, 4)
128        assert_eq!(v.shape(), &[7, 4]);
129        // Same base pointer
130        assert_eq!(v.as_ptr(), a.as_ptr());
131    }
132
133    // Additional: as_strided with non-contiguous but non-overlapping strides
134    #[test]
135    fn as_strided_skip_elements() {
136        let a = Array::<i32, Ix1>::from_vec(Ix1::new([10]), vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
137            .unwrap();
138        // Take every 3rd element: shape (3,), stride (3,)
139        // offsets: 0, 3, 6 — within buffer of 10, non-overlapping
140        let v = as_strided(&a, &[3], &[3]).unwrap();
141        let data: Vec<i32> = v.iter().copied().collect();
142        assert_eq!(data, vec![0, 3, 6]);
143    }
144}