Skip to main content

ferray_core/array/
cow.rs

1// ferray-core: CowArray<'a, T, D> — owned-or-borrowed (REQ-3)
2
3use crate::dimension::Dimension;
4use crate::dtype::Element;
5use crate::layout::MemoryLayout;
6
7use super::ArrayFlags;
8use super::owned::Array;
9use super::view::ArrayView;
10
11/// A copy-on-write array that is either a borrowed view or an owned array.
12///
13/// This is useful when a function might or might not need to allocate:
14/// it can accept borrowed data and only clone if mutation is required.
15pub enum CowArray<'a, T: Element, D: Dimension> {
16    /// Borrowed — refers to data owned by another array.
17    Borrowed(ArrayView<'a, T, D>),
18    /// Owned — has its own data buffer.
19    Owned(Array<T, D>),
20}
21
22impl<'a, T: Element, D: Dimension> CowArray<'a, T, D> {
23    /// Create a `CowArray` from a borrowed view.
24    pub fn from_view(view: ArrayView<'a, T, D>) -> Self {
25        Self::Borrowed(view)
26    }
27
28    /// Create a `CowArray` from an owned array.
29    pub fn from_owned(arr: Array<T, D>) -> Self {
30        Self::Owned(arr)
31    }
32
33    /// Shape as a slice.
34    #[inline]
35    pub fn shape(&self) -> &[usize] {
36        match self {
37            Self::Borrowed(v) => v.shape(),
38            Self::Owned(a) => a.shape(),
39        }
40    }
41
42    /// Number of dimensions.
43    #[inline]
44    pub fn ndim(&self) -> usize {
45        match self {
46            Self::Borrowed(v) => v.ndim(),
47            Self::Owned(a) => a.ndim(),
48        }
49    }
50
51    /// Total number of elements.
52    #[inline]
53    pub fn size(&self) -> usize {
54        match self {
55            Self::Borrowed(v) => v.size(),
56            Self::Owned(a) => a.size(),
57        }
58    }
59
60    /// Whether the array has zero elements.
61    #[inline]
62    pub fn is_empty(&self) -> bool {
63        self.size() == 0
64    }
65
66    /// Memory layout.
67    pub fn layout(&self) -> MemoryLayout {
68        match self {
69            Self::Borrowed(v) => v.layout(),
70            Self::Owned(a) => a.layout(),
71        }
72    }
73
74    /// Whether this is a borrowed (view) variant.
75    pub fn is_borrowed(&self) -> bool {
76        matches!(self, Self::Borrowed(_))
77    }
78
79    /// Whether this is an owned variant.
80    pub fn is_owned(&self) -> bool {
81        matches!(self, Self::Owned(_))
82    }
83
84    /// Convert to an owned array, cloning if currently borrowed.
85    pub fn into_owned(self) -> Array<T, D> {
86        match self {
87            Self::Borrowed(v) => v.to_owned(),
88            Self::Owned(a) => a,
89        }
90    }
91
92    /// Ensure this is the owned variant, cloning if necessary,
93    /// and return a mutable reference to the owned array.
94    pub fn to_mut(&mut self) -> &mut Array<T, D> {
95        if let Self::Borrowed(v) = self {
96            *self = Self::Owned(v.to_owned());
97        }
98        match self {
99            Self::Owned(a) => a,
100            Self::Borrowed(_) => unreachable!(),
101        }
102    }
103
104    /// Get a read-only view of the data.
105    ///
106    /// If this is a borrowed variant, returns a view with the same lifetime
107    /// as `&self`. If owned, returns a view borrowing from `self`.
108    pub fn view(&self) -> ArrayView<'_, T, D> {
109        match self {
110            Self::Borrowed(v) => {
111                // Reborrow the inner ndarray view with &self lifetime
112                ArrayView::from_ndarray(v.inner.view())
113            }
114            Self::Owned(a) => a.view(),
115        }
116    }
117
118    /// Raw pointer to the first element.
119    #[inline]
120    pub fn as_ptr(&self) -> *const T {
121        match self {
122            Self::Borrowed(v) => v.as_ptr(),
123            Self::Owned(a) => a.as_ptr(),
124        }
125    }
126
127    /// Array flags.
128    pub fn flags(&self) -> ArrayFlags {
129        match self {
130            Self::Borrowed(v) => v.flags(),
131            Self::Owned(a) => {
132                let layout = a.layout();
133                ArrayFlags {
134                    c_contiguous: layout.is_c_contiguous(),
135                    f_contiguous: layout.is_f_contiguous(),
136                    owndata: true,
137                    writeable: true,
138                }
139            }
140        }
141    }
142}
143
144impl<T: Element, D: Dimension> Clone for CowArray<'_, T, D> {
145    fn clone(&self) -> Self {
146        match self {
147            Self::Borrowed(v) => Self::Borrowed(v.clone()),
148            Self::Owned(a) => Self::Owned(a.clone()),
149        }
150    }
151}
152
153impl<T: Element, D: Dimension> From<Array<T, D>> for CowArray<'_, T, D> {
154    fn from(arr: Array<T, D>) -> Self {
155        Self::Owned(arr)
156    }
157}
158
159impl<'a, T: Element, D: Dimension> From<ArrayView<'a, T, D>> for CowArray<'a, T, D> {
160    fn from(view: ArrayView<'a, T, D>) -> Self {
161        Self::Borrowed(view)
162    }
163}
164
165#[cfg(test)]
166mod tests {
167    use super::*;
168    use crate::dimension::Ix1;
169
170    #[test]
171    fn cow_from_view() {
172        let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
173        let view = arr.view();
174        let cow = CowArray::from_view(view);
175        assert!(cow.is_borrowed());
176        assert!(!cow.is_owned());
177        assert_eq!(cow.shape(), &[3]);
178    }
179
180    #[test]
181    fn cow_from_owned() {
182        let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
183        let cow = CowArray::from_owned(arr);
184        assert!(cow.is_owned());
185        assert!(!cow.is_borrowed());
186    }
187
188    #[test]
189    fn cow_to_mut_clones_when_borrowed() {
190        let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
191        let view = arr.view();
192        let mut cow = CowArray::from_view(view);
193
194        assert!(cow.is_borrowed());
195        let owned = cow.to_mut();
196        assert_eq!(owned.shape(), &[3]);
197        assert!(cow.is_owned());
198    }
199
200    #[test]
201    fn cow_into_owned() {
202        let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
203        let view = arr.view();
204        let cow = CowArray::from_view(view);
205        let owned = cow.into_owned();
206        assert_eq!(owned.as_slice().unwrap(), &[1.0, 2.0, 3.0]);
207    }
208}