Skip to main content

ferray_core/
writeback.rs

1//! WRITEBACKIFCOPY-style writeback guard for non-contiguous out destinations
2//! (#370).
3//!
4//! NumPy uses `NPY_ARRAY_WRITEBACKIFCOPY` to allow ufuncs to operate on a
5//! contiguous scratch when the user-supplied `out=` is not contiguous,
6//! then write the scratch back to the original on success. This module
7//! provides the same pattern as a Rust scope guard:
8//!
9//! ```ignore
10//! use ferray_core::writeback::WritebackGuard;
11//!
12//! let mut target: Array<f64, Ix2> = ...;  // possibly non-contiguous
13//! let mut guard = WritebackGuard::new(&mut target)?;
14//! run_kernel(guard.scratch_mut());
15//! guard.commit()?;  // writes scratch back into target; without commit
16//!                    // the scratch is discarded and target is untouched.
17//! ```
18//!
19//! On a contiguous target the guard's scratch *is* the target buffer
20//! (no copy on construction, no writeback on commit). On a
21//! non-contiguous target the scratch is a fresh contiguous allocation.
22
23use crate::array::owned::Array;
24use crate::dimension::Dimension;
25use crate::dtype::Element;
26use crate::error::FerrayResult;
27
28/// RAII writeback guard for safe out-parameter mutation.
29///
30/// See the [module-level documentation](self) for the usage pattern.
31/// The scratch is dropped without writing back unless [`Self::commit`]
32/// is called — this matches numpy's WRITEBACKIFCOPY semantics where a
33/// kernel that panics or returns an error leaves the original
34/// untouched.
35pub struct WritebackGuard<'a, T: Element + Clone, D: Dimension> {
36    target: &'a mut Array<T, D>,
37    scratch: Array<T, D>,
38    /// True when the source was already contiguous and the scratch is a
39    /// borrowed view into the target. In that case `commit` is a no-op
40    /// (the kernel wrote directly into the target).
41    fast_path: bool,
42}
43
44impl<'a, T: Element + Clone, D: Dimension> WritebackGuard<'a, T, D> {
45    /// Create a new writeback guard targeting `target`.
46    ///
47    /// If `target` is already contiguous, the scratch is a clone of
48    /// the target (ferray does not yet support borrowing the target's
49    /// buffer directly while keeping the lifetime invariant — every
50    /// scratch is an owned Array). In either case the user works on
51    /// the scratch and calls [`Self::commit`] to publish.
52    ///
53    /// # Errors
54    /// Returns an error only if the underlying contiguous allocation
55    /// fails for the target's shape.
56    pub fn new(target: &'a mut Array<T, D>) -> FerrayResult<Self> {
57        // Always materialize a contiguous scratch with a deep copy of
58        // the current target contents — this lets the kernel observe
59        // the existing values when it does an in-place add/sub/etc.
60        let scratch = Array::<T, D>::from_vec(target.dim().clone(), target.to_vec_flat())?;
61        Ok(Self {
62            target,
63            scratch,
64            fast_path: false,
65        })
66    }
67
68    /// Mutable access to the contiguous scratch buffer. The kernel
69    /// writes here; nothing is observable in the target until
70    /// [`Self::commit`] is called.
71    #[inline]
72    pub fn scratch_mut(&mut self) -> &mut Array<T, D> {
73        &mut self.scratch
74    }
75
76    /// Read access to the contiguous scratch (for kernels that need to
77    /// read the input alongside the output write).
78    #[inline]
79    pub const fn scratch(&self) -> &Array<T, D> {
80        &self.scratch
81    }
82
83    /// Publish the scratch contents back into the target. Element-by-
84    /// element copy in logical (row-major) order, which works for any
85    /// target layout (C-contiguous, F-contiguous, strided slice).
86    ///
87    /// Consuming `self` so the guard can't be used twice.
88    ///
89    /// # Errors
90    /// Returns an error if the scratch and target sizes have somehow
91    /// diverged (should be unreachable because the guard owns both).
92    pub fn commit(self) -> FerrayResult<()> {
93        if self.fast_path {
94            return Ok(());
95        }
96        // Iterate in logical order on both sides — the target may have
97        // arbitrary strides while the scratch is contiguous.
98        for (dst, src) in self.target.iter_mut().zip(self.scratch.iter()) {
99            *dst = src.clone();
100        }
101        Ok(())
102    }
103
104    /// Discard the scratch without writing back. Equivalent to letting
105    /// the guard go out of scope without calling [`Self::commit`], but
106    /// makes the intent explicit at the call site.
107    #[inline]
108    pub fn discard(self) {
109        // Drop runs implicitly; nothing to do.
110    }
111}
112
113#[cfg(test)]
114mod tests {
115    use super::*;
116    use crate::dimension::{Ix1, Ix2};
117
118    #[test]
119    fn commit_writes_scratch_back_to_target() {
120        let mut target = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![0.0; 4]).unwrap();
121        let mut guard = WritebackGuard::new(&mut target).unwrap();
122        for (i, v) in guard.scratch_mut().iter_mut().enumerate() {
123            *v = (i as f64) * 10.0;
124        }
125        guard.commit().unwrap();
126        assert_eq!(target.as_slice().unwrap(), &[0.0, 10.0, 20.0, 30.0]);
127    }
128
129    #[test]
130    fn discard_leaves_target_untouched() {
131        let mut target = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
132        let mut guard = WritebackGuard::new(&mut target).unwrap();
133        for v in guard.scratch_mut().iter_mut() {
134            *v = -99.0;
135        }
136        guard.discard();
137        assert_eq!(target.as_slice().unwrap(), &[1.0, 2.0, 3.0]);
138    }
139
140    #[test]
141    fn drop_without_commit_leaves_target_untouched() {
142        // Same as discard() but lets the guard fall out of scope
143        // implicitly — matches the numpy WRITEBACKIFCOPY semantics
144        // where a panicking kernel leaves the original untouched.
145        let mut target = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
146        {
147            let mut guard = WritebackGuard::new(&mut target).unwrap();
148            for v in guard.scratch_mut().iter_mut() {
149                *v = -99.0;
150            }
151            // No commit, no discard — guard drops here.
152        }
153        assert_eq!(target.as_slice().unwrap(), &[1.0, 2.0, 3.0]);
154    }
155
156    #[test]
157    fn commit_works_for_2d_contiguous_target() {
158        let mut target = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![0.0; 6]).unwrap();
159        let mut guard = WritebackGuard::new(&mut target).unwrap();
160        for (i, v) in guard.scratch_mut().iter_mut().enumerate() {
161            *v = i as f64;
162        }
163        guard.commit().unwrap();
164        let out: Vec<f64> = target.iter().copied().collect();
165        assert_eq!(out, vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0]);
166    }
167
168    #[test]
169    fn scratch_starts_with_target_values() {
170        // The scratch is initialized to the target's current contents
171        // so kernels that do an in-place compose (e.g. `out += rhs`)
172        // see the existing data.
173        let mut target = Array::<i32, Ix1>::from_vec(Ix1::new([4]), vec![10, 20, 30, 40]).unwrap();
174        let guard = WritebackGuard::new(&mut target).unwrap();
175        assert_eq!(guard.scratch().as_slice().unwrap(), &[10, 20, 30, 40]);
176    }
177
178    #[test]
179    fn commit_works_for_fortran_order_target() {
180        // Non-C-contiguous target: from_vec_f produces an F-major
181        // layout. Logical row-major iteration via iter / iter_mut
182        // still gives the right element ordering, so the guard's
183        // element-by-element writeback works for either layout.
184        // 2x3 logical:
185        //   [[10, 20, 30],
186        //    [40, 50, 60]]
187        // F-order storage: [10, 40, 20, 50, 30, 60]
188        let mut target = Array::<f64, Ix2>::from_vec_f(
189            Ix2::new([2, 3]),
190            vec![10.0, 40.0, 20.0, 50.0, 30.0, 60.0],
191        )
192        .unwrap();
193        // Sanity-check the scratch sees logical order.
194        let mut guard = WritebackGuard::new(&mut target).unwrap();
195        let logical: Vec<f64> = guard.scratch().iter().copied().collect();
196        assert_eq!(logical, vec![10.0, 20.0, 30.0, 40.0, 50.0, 60.0]);
197        // Mutate to negate every element.
198        for v in guard.scratch_mut().iter_mut() {
199            *v = -*v;
200        }
201        guard.commit().unwrap();
202        // Target reads back in logical order with the negation applied.
203        let after: Vec<f64> = target.iter().copied().collect();
204        assert_eq!(after, vec![-10.0, -20.0, -30.0, -40.0, -50.0, -60.0]);
205    }
206}