use crate::array::owned::Array;
use crate::dimension::Dimension;
use crate::dtype::Element;
use crate::error::FerrayResult;
pub struct WritebackGuard<'a, T: Element + Clone, D: Dimension> {
target: &'a mut Array<T, D>,
scratch: Array<T, D>,
fast_path: bool,
}
impl<'a, T: Element + Clone, D: Dimension> WritebackGuard<'a, T, D> {
pub fn new(target: &'a mut Array<T, D>) -> FerrayResult<Self> {
let scratch = Array::<T, D>::from_vec(target.dim().clone(), target.to_vec_flat())?;
Ok(Self {
target,
scratch,
fast_path: false,
})
}
#[inline]
pub fn scratch_mut(&mut self) -> &mut Array<T, D> {
&mut self.scratch
}
#[inline]
pub const fn scratch(&self) -> &Array<T, D> {
&self.scratch
}
pub fn commit(self) -> FerrayResult<()> {
if self.fast_path {
return Ok(());
}
for (dst, src) in self.target.iter_mut().zip(self.scratch.iter()) {
*dst = src.clone();
}
Ok(())
}
#[inline]
pub fn discard(self) {
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dimension::{Ix1, Ix2};
#[test]
fn commit_writes_scratch_back_to_target() {
let mut target = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![0.0; 4]).unwrap();
let mut guard = WritebackGuard::new(&mut target).unwrap();
for (i, v) in guard.scratch_mut().iter_mut().enumerate() {
*v = (i as f64) * 10.0;
}
guard.commit().unwrap();
assert_eq!(target.as_slice().unwrap(), &[0.0, 10.0, 20.0, 30.0]);
}
#[test]
fn discard_leaves_target_untouched() {
let mut target = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
let mut guard = WritebackGuard::new(&mut target).unwrap();
for v in guard.scratch_mut().iter_mut() {
*v = -99.0;
}
guard.discard();
assert_eq!(target.as_slice().unwrap(), &[1.0, 2.0, 3.0]);
}
#[test]
fn drop_without_commit_leaves_target_untouched() {
let mut target = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
{
let mut guard = WritebackGuard::new(&mut target).unwrap();
for v in guard.scratch_mut().iter_mut() {
*v = -99.0;
}
}
assert_eq!(target.as_slice().unwrap(), &[1.0, 2.0, 3.0]);
}
#[test]
fn commit_works_for_2d_contiguous_target() {
let mut target = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![0.0; 6]).unwrap();
let mut guard = WritebackGuard::new(&mut target).unwrap();
for (i, v) in guard.scratch_mut().iter_mut().enumerate() {
*v = i as f64;
}
guard.commit().unwrap();
let out: Vec<f64> = target.iter().copied().collect();
assert_eq!(out, vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0]);
}
#[test]
fn scratch_starts_with_target_values() {
let mut target = Array::<i32, Ix1>::from_vec(Ix1::new([4]), vec![10, 20, 30, 40]).unwrap();
let guard = WritebackGuard::new(&mut target).unwrap();
assert_eq!(guard.scratch().as_slice().unwrap(), &[10, 20, 30, 40]);
}
#[test]
fn commit_works_for_fortran_order_target() {
let mut target = Array::<f64, Ix2>::from_vec_f(
Ix2::new([2, 3]),
vec![10.0, 40.0, 20.0, 50.0, 30.0, 60.0],
)
.unwrap();
let mut guard = WritebackGuard::new(&mut target).unwrap();
let logical: Vec<f64> = guard.scratch().iter().copied().collect();
assert_eq!(logical, vec![10.0, 20.0, 30.0, 40.0, 50.0, 60.0]);
for v in guard.scratch_mut().iter_mut() {
*v = -*v;
}
guard.commit().unwrap();
let after: Vec<f64> = target.iter().copied().collect();
assert_eq!(after, vec![-10.0, -20.0, -30.0, -40.0, -50.0, -60.0]);
}
}