ferray_core/writeback.rs
1//! WRITEBACKIFCOPY-style writeback guard for non-contiguous out destinations
2//! (#370).
3//!
4//! NumPy uses `NPY_ARRAY_WRITEBACKIFCOPY` to allow ufuncs to operate on a
5//! contiguous scratch when the user-supplied `out=` is not contiguous,
6//! then write the scratch back to the original on success. This module
7//! provides the same pattern as a Rust scope guard:
8//!
9//! ```ignore
10//! use ferray_core::writeback::WritebackGuard;
11//!
12//! let mut target: Array<f64, Ix2> = ...; // possibly non-contiguous
13//! let mut guard = WritebackGuard::new(&mut target)?;
14//! run_kernel(guard.scratch_mut());
15//! guard.commit()?; // writes scratch back into target; without commit
16//! // the scratch is discarded and target is untouched.
17//! ```
18//!
19//! On a contiguous target the guard's scratch *is* the target buffer
20//! (no copy on construction, no writeback on commit). On a
21//! non-contiguous target the scratch is a fresh contiguous allocation.
22
23use crate::array::owned::Array;
24use crate::dimension::Dimension;
25use crate::dtype::Element;
26use crate::error::FerrayResult;
27
28/// RAII writeback guard for safe out-parameter mutation.
29///
30/// See the [module-level documentation](self) for the usage pattern.
31/// The scratch is dropped without writing back unless [`Self::commit`]
32/// is called — this matches numpy's WRITEBACKIFCOPY semantics where a
33/// kernel that panics or returns an error leaves the original
34/// untouched.
35pub struct WritebackGuard<'a, T: Element + Clone, D: Dimension> {
36 target: &'a mut Array<T, D>,
37 scratch: Array<T, D>,
38 /// True when the source was already contiguous and the scratch is a
39 /// borrowed view into the target. In that case `commit` is a no-op
40 /// (the kernel wrote directly into the target).
41 fast_path: bool,
42}
43
44impl<'a, T: Element + Clone, D: Dimension> WritebackGuard<'a, T, D> {
45 /// Create a new writeback guard targeting `target`.
46 ///
47 /// If `target` is already contiguous, the scratch is a clone of
48 /// the target (ferray does not yet support borrowing the target's
49 /// buffer directly while keeping the lifetime invariant — every
50 /// scratch is an owned Array). In either case the user works on
51 /// the scratch and calls [`Self::commit`] to publish.
52 ///
53 /// # Errors
54 /// Returns an error only if the underlying contiguous allocation
55 /// fails for the target's shape.
56 pub fn new(target: &'a mut Array<T, D>) -> FerrayResult<Self> {
57 // Always materialize a contiguous scratch with a deep copy of
58 // the current target contents — this lets the kernel observe
59 // the existing values when it does an in-place add/sub/etc.
60 let scratch = Array::<T, D>::from_vec(target.dim().clone(), target.to_vec_flat())?;
61 Ok(Self {
62 target,
63 scratch,
64 fast_path: false,
65 })
66 }
67
68 /// Mutable access to the contiguous scratch buffer. The kernel
69 /// writes here; nothing is observable in the target until
70 /// [`Self::commit`] is called.
71 #[inline]
72 pub fn scratch_mut(&mut self) -> &mut Array<T, D> {
73 &mut self.scratch
74 }
75
76 /// Read access to the contiguous scratch (for kernels that need to
77 /// read the input alongside the output write).
78 #[inline]
79 pub const fn scratch(&self) -> &Array<T, D> {
80 &self.scratch
81 }
82
83 /// Publish the scratch contents back into the target. Element-by-
84 /// element copy in logical (row-major) order, which works for any
85 /// target layout (C-contiguous, F-contiguous, strided slice).
86 ///
87 /// Consuming `self` so the guard can't be used twice.
88 ///
89 /// # Errors
90 /// Returns an error if the scratch and target sizes have somehow
91 /// diverged (should be unreachable because the guard owns both).
92 pub fn commit(self) -> FerrayResult<()> {
93 if self.fast_path {
94 return Ok(());
95 }
96 // Iterate in logical order on both sides — the target may have
97 // arbitrary strides while the scratch is contiguous.
98 for (dst, src) in self.target.iter_mut().zip(self.scratch.iter()) {
99 *dst = src.clone();
100 }
101 Ok(())
102 }
103
104 /// Discard the scratch without writing back. Equivalent to letting
105 /// the guard go out of scope without calling [`Self::commit`], but
106 /// makes the intent explicit at the call site.
107 #[inline]
108 pub fn discard(self) {
109 // Drop runs implicitly; nothing to do.
110 }
111}
112
113#[cfg(test)]
114mod tests {
115 use super::*;
116 use crate::dimension::{Ix1, Ix2};
117
118 #[test]
119 fn commit_writes_scratch_back_to_target() {
120 let mut target = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![0.0; 4]).unwrap();
121 let mut guard = WritebackGuard::new(&mut target).unwrap();
122 for (i, v) in guard.scratch_mut().iter_mut().enumerate() {
123 *v = (i as f64) * 10.0;
124 }
125 guard.commit().unwrap();
126 assert_eq!(target.as_slice().unwrap(), &[0.0, 10.0, 20.0, 30.0]);
127 }
128
129 #[test]
130 fn discard_leaves_target_untouched() {
131 let mut target = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
132 let mut guard = WritebackGuard::new(&mut target).unwrap();
133 for v in guard.scratch_mut().iter_mut() {
134 *v = -99.0;
135 }
136 guard.discard();
137 assert_eq!(target.as_slice().unwrap(), &[1.0, 2.0, 3.0]);
138 }
139
140 #[test]
141 fn drop_without_commit_leaves_target_untouched() {
142 // Same as discard() but lets the guard fall out of scope
143 // implicitly — matches the numpy WRITEBACKIFCOPY semantics
144 // where a panicking kernel leaves the original untouched.
145 let mut target = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
146 {
147 let mut guard = WritebackGuard::new(&mut target).unwrap();
148 for v in guard.scratch_mut().iter_mut() {
149 *v = -99.0;
150 }
151 // No commit, no discard — guard drops here.
152 }
153 assert_eq!(target.as_slice().unwrap(), &[1.0, 2.0, 3.0]);
154 }
155
156 #[test]
157 fn commit_works_for_2d_contiguous_target() {
158 let mut target = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![0.0; 6]).unwrap();
159 let mut guard = WritebackGuard::new(&mut target).unwrap();
160 for (i, v) in guard.scratch_mut().iter_mut().enumerate() {
161 *v = i as f64;
162 }
163 guard.commit().unwrap();
164 let out: Vec<f64> = target.iter().copied().collect();
165 assert_eq!(out, vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0]);
166 }
167
168 #[test]
169 fn scratch_starts_with_target_values() {
170 // The scratch is initialized to the target's current contents
171 // so kernels that do an in-place compose (e.g. `out += rhs`)
172 // see the existing data.
173 let mut target = Array::<i32, Ix1>::from_vec(Ix1::new([4]), vec![10, 20, 30, 40]).unwrap();
174 let guard = WritebackGuard::new(&mut target).unwrap();
175 assert_eq!(guard.scratch().as_slice().unwrap(), &[10, 20, 30, 40]);
176 }
177
178 #[test]
179 fn commit_works_for_fortran_order_target() {
180 // Non-C-contiguous target: from_vec_f produces an F-major
181 // layout. Logical row-major iteration via iter / iter_mut
182 // still gives the right element ordering, so the guard's
183 // element-by-element writeback works for either layout.
184 // 2x3 logical:
185 // [[10, 20, 30],
186 // [40, 50, 60]]
187 // F-order storage: [10, 40, 20, 50, 30, 60]
188 let mut target = Array::<f64, Ix2>::from_vec_f(
189 Ix2::new([2, 3]),
190 vec![10.0, 40.0, 20.0, 50.0, 30.0, 60.0],
191 )
192 .unwrap();
193 // Sanity-check the scratch sees logical order.
194 let mut guard = WritebackGuard::new(&mut target).unwrap();
195 let logical: Vec<f64> = guard.scratch().iter().copied().collect();
196 assert_eq!(logical, vec![10.0, 20.0, 30.0, 40.0, 50.0, 60.0]);
197 // Mutate to negate every element.
198 for v in guard.scratch_mut().iter_mut() {
199 *v = -*v;
200 }
201 guard.commit().unwrap();
202 // Target reads back in logical order with the negation applied.
203 let after: Vec<f64> = target.iter().copied().collect();
204 assert_eq!(after, vec![-10.0, -20.0, -30.0, -40.0, -50.0, -60.0]);
205 }
206}