use crate::dimension::Dimension;
use crate::dtype::Element;
use crate::layout::MemoryLayout;
use super::ArrayFlags;
use super::owned::Array;
pub struct ArrayViewMut<'a, T: Element, D: Dimension> {
pub(crate) inner: ndarray::ArrayViewMut<'a, T, D::NdarrayDim>,
pub(crate) dim: D,
}
impl<'a, T: Element, D: Dimension> ArrayViewMut<'a, T, D> {
pub(crate) fn from_ndarray(inner: ndarray::ArrayViewMut<'a, T, D::NdarrayDim>) -> Self {
let dim = D::from_ndarray_dim(&inner.raw_dim());
Self { inner, dim }
}
#[inline]
pub fn shape(&self) -> &[usize] {
self.inner.shape()
}
#[inline]
pub fn ndim(&self) -> usize {
self.dim.ndim()
}
#[inline]
pub fn size(&self) -> usize {
self.inner.len()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.inner.is_empty()
}
#[inline]
pub fn strides(&self) -> &[isize] {
self.inner.strides()
}
#[inline]
pub fn as_ptr(&self) -> *const T {
self.inner.as_ptr()
}
#[inline]
pub fn as_mut_ptr(&mut self) -> *mut T {
self.inner.as_mut_ptr()
}
pub fn as_slice(&self) -> Option<&[T]> {
self.inner.as_slice()
}
pub fn as_slice_mut(&mut self) -> Option<&mut [T]> {
self.inner.as_slice_mut()
}
pub fn layout(&self) -> MemoryLayout {
if self.inner.is_standard_layout() {
MemoryLayout::C
} else {
let shape = self.dim.as_slice();
let strides: Vec<isize> = self.inner.strides().to_vec();
crate::layout::detect_layout(shape, &strides)
}
}
#[inline]
pub fn dim(&self) -> &D {
&self.dim
}
pub fn flags(&self) -> ArrayFlags {
let layout = self.layout();
ArrayFlags {
c_contiguous: layout.is_c_contiguous(),
f_contiguous: layout.is_f_contiguous(),
owndata: false,
writeable: true,
}
}
}
impl<T: Element, D: Dimension> Array<T, D> {
pub fn view_mut(&mut self) -> ArrayViewMut<'_, T, D> {
ArrayViewMut::from_ndarray(self.inner.view_mut())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dimension::Ix1;
#[test]
fn view_mut_from_owned() {
let mut arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
let v = arr.view_mut();
assert_eq!(v.shape(), &[3]);
assert!(v.flags().writeable);
assert!(!v.flags().owndata);
}
#[test]
fn view_mut_modify() {
let mut arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
{
let mut v = arr.view_mut();
if let Some(s) = v.as_slice_mut() {
s[0] = 99.0;
}
}
assert_eq!(arr.as_slice().unwrap()[0], 99.0);
}
}