pub struct CowSlice<'a> {
data: &'a [i32],
data_mut: &'a mut [i32],
use_mut: bool,
}
impl<'a> CowSlice<'a> {
pub fn new(
data: &'a [i32],
data_mut: &'a mut [i32],
) -> Result<Self, CowSliceSizeMismatchError> {
if data.len() != data_mut.len() {
return Err(CowSliceSizeMismatchError(data.len(), data_mut.len()));
}
Ok(Self {
data,
data_mut,
use_mut: false,
})
}
pub fn new_mut(data_mut: &'a mut [i32]) -> Self {
Self {
use_mut: true,
data: &[],
data_mut,
}
}
pub fn get(&self, index: usize) -> Option<i32> {
if self.use_mut {
self.data_mut.get(index).copied()
} else {
self.data.get(index).copied()
}
}
pub fn set(&mut self, index: usize, value: i32) -> Option<()> {
if !self.use_mut {
self.data_mut.copy_from_slice(self.data);
self.use_mut = true;
}
*self.data_mut.get_mut(index)? = value;
Some(())
}
pub fn len(&self) -> usize {
if self.use_mut {
self.data_mut.len()
} else {
self.data.len()
}
}
}
#[derive(Clone, Debug)]
pub struct CowSliceSizeMismatchError(usize, usize);
impl std::fmt::Display for CowSliceSizeMismatchError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"size mismatch for immutable and mutable buffers: data.len() = {}, data_mut.len() = {}",
self.0, self.1
)
}
}
#[cfg(test)]
mod tests {
use super::{CowSlice, CowSliceSizeMismatchError};
#[test]
fn size_mismatch_error() {
let data_mut = &mut [0, 0];
let result = CowSlice::new(&[1, 2, 3], data_mut);
assert!(matches!(result, Err(CowSliceSizeMismatchError(3, 2))))
}
#[test]
fn copy_on_write() {
let data = std::array::from_fn::<_, 16, _>(|i| i as i32);
let mut data_mut = [0i32; 16];
let mut slice = CowSlice::new(&data, &mut data_mut).unwrap();
assert!(!slice.use_mut);
for i in 0..data.len() {
assert_eq!(slice.get(i).unwrap(), i as i32);
}
for i in 0..data.len() {
let value = slice.get(i).unwrap();
slice.set(i, value * 2).unwrap();
}
assert!(slice.use_mut);
for i in 0..data.len() {
assert_eq!(slice.get(i).unwrap(), i as i32 * 2);
}
}
#[test]
fn out_of_bounds() {
let data_mut = &mut [1, 2];
let slice = CowSlice::new_mut(data_mut);
assert_eq!(slice.get(0), Some(1));
assert_eq!(slice.get(1), Some(2));
assert_eq!(slice.get(2), None);
}
}