Skip to main content

ferray_core/array/
cow.rs

1// ferray-core: CowArray<'a, T, D> — owned-or-borrowed (REQ-3)
2
3use crate::dimension::Dimension;
4use crate::dtype::Element;
5use crate::layout::MemoryLayout;
6
7use super::ArrayFlags;
8use super::owned::Array;
9use super::view::ArrayView;
10
11/// A copy-on-write array that is either a borrowed view or an owned array.
12///
13/// This is useful when a function might or might not need to allocate:
14/// it can accept borrowed data and only clone if mutation is required.
15pub enum CowArray<'a, T: Element, D: Dimension> {
16    /// Borrowed — refers to data owned by another array.
17    Borrowed(ArrayView<'a, T, D>),
18    /// Owned — has its own data buffer.
19    Owned(Array<T, D>),
20}
21
22impl<'a, T: Element, D: Dimension> CowArray<'a, T, D> {
23    /// Create a `CowArray` from a borrowed view.
24    pub const fn from_view(view: ArrayView<'a, T, D>) -> Self {
25        Self::Borrowed(view)
26    }
27
28    /// Create a `CowArray` from an owned array.
29    pub const fn from_owned(arr: Array<T, D>) -> Self {
30        Self::Owned(arr)
31    }
32
33    /// Shape as a slice.
34    #[inline]
35    pub fn shape(&self) -> &[usize] {
36        match self {
37            Self::Borrowed(v) => v.shape(),
38            Self::Owned(a) => a.shape(),
39        }
40    }
41
42    /// Number of dimensions.
43    #[inline]
44    pub fn ndim(&self) -> usize {
45        match self {
46            Self::Borrowed(v) => v.ndim(),
47            Self::Owned(a) => a.ndim(),
48        }
49    }
50
51    /// Total number of elements.
52    #[inline]
53    pub fn size(&self) -> usize {
54        match self {
55            Self::Borrowed(v) => v.size(),
56            Self::Owned(a) => a.size(),
57        }
58    }
59
60    /// Whether the array has zero elements.
61    #[inline]
62    pub fn is_empty(&self) -> bool {
63        self.size() == 0
64    }
65
66    /// Memory layout.
67    pub fn layout(&self) -> MemoryLayout {
68        match self {
69            Self::Borrowed(v) => v.layout(),
70            Self::Owned(a) => a.layout(),
71        }
72    }
73
74    /// Whether this is a borrowed (view) variant.
75    pub const fn is_borrowed(&self) -> bool {
76        matches!(self, Self::Borrowed(_))
77    }
78
79    /// Whether this is an owned variant.
80    pub const fn is_owned(&self) -> bool {
81        matches!(self, Self::Owned(_))
82    }
83
84    /// Convert to an owned array, cloning if currently borrowed.
85    pub fn into_owned(self) -> Array<T, D> {
86        match self {
87            Self::Borrowed(v) => v.to_owned(),
88            Self::Owned(a) => a,
89        }
90    }
91
92    /// Ensure this is the owned variant, cloning if necessary,
93    /// and return a mutable reference to the owned array.
94    pub fn to_mut(&mut self) -> &mut Array<T, D> {
95        if let Self::Borrowed(v) = self {
96            *self = Self::Owned(v.to_owned());
97        }
98        match self {
99            Self::Owned(a) => a,
100            Self::Borrowed(_) => unreachable!(),
101        }
102    }
103
104    /// Get a read-only view of the data.
105    ///
106    /// If this is a borrowed variant, returns a view with the same lifetime
107    /// as `&self`. If owned, returns a view borrowing from `self`.
108    pub fn view(&self) -> ArrayView<'_, T, D> {
109        match self {
110            Self::Borrowed(v) => {
111                // Reborrow the inner ndarray view with &self lifetime
112                ArrayView::from_ndarray(v.inner.view())
113            }
114            Self::Owned(a) => a.view(),
115        }
116    }
117
118    /// Raw pointer to the first element.
119    #[inline]
120    pub fn as_ptr(&self) -> *const T {
121        match self {
122            Self::Borrowed(v) => v.as_ptr(),
123            Self::Owned(a) => a.as_ptr(),
124        }
125    }
126
127    /// Array flags.
128    pub fn flags(&self) -> ArrayFlags {
129        match self {
130            Self::Borrowed(v) => v.flags(),
131            Self::Owned(a) => {
132                let layout = a.layout();
133                ArrayFlags {
134                    c_contiguous: layout.is_c_contiguous(),
135                    f_contiguous: layout.is_f_contiguous(),
136                    owndata: true,
137                    writeable: true,
138                    // CowArray::Owned wraps an Array (Vec<T> path) — aligned (#345).
139                    aligned: true,
140                }
141            }
142        }
143    }
144}
145
146impl<T: Element, D: Dimension> Clone for CowArray<'_, T, D> {
147    fn clone(&self) -> Self {
148        match self {
149            Self::Borrowed(v) => Self::Borrowed(v.clone()),
150            Self::Owned(a) => Self::Owned(a.clone()),
151        }
152    }
153}
154
155impl<T: Element, D: Dimension> From<Array<T, D>> for CowArray<'_, T, D> {
156    fn from(arr: Array<T, D>) -> Self {
157        Self::Owned(arr)
158    }
159}
160
161impl<'a, T: Element, D: Dimension> From<ArrayView<'a, T, D>> for CowArray<'a, T, D> {
162    fn from(view: ArrayView<'a, T, D>) -> Self {
163        Self::Borrowed(view)
164    }
165}
166
167#[cfg(test)]
168mod tests {
169    use super::*;
170    use crate::dimension::Ix1;
171
172    #[test]
173    fn cow_from_view() {
174        let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
175        let view = arr.view();
176        let cow = CowArray::from_view(view);
177        assert!(cow.is_borrowed());
178        assert!(!cow.is_owned());
179        assert_eq!(cow.shape(), &[3]);
180    }
181
182    #[test]
183    fn cow_from_owned() {
184        let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
185        let cow = CowArray::from_owned(arr);
186        assert!(cow.is_owned());
187        assert!(!cow.is_borrowed());
188    }
189
190    #[test]
191    fn cow_to_mut_clones_when_borrowed() {
192        let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
193        let view = arr.view();
194        let mut cow = CowArray::from_view(view);
195
196        assert!(cow.is_borrowed());
197        let owned = cow.to_mut();
198        assert_eq!(owned.shape(), &[3]);
199        assert!(cow.is_owned());
200    }
201
202    #[test]
203    fn cow_into_owned() {
204        let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
205        let view = arr.view();
206        let cow = CowArray::from_view(view);
207        let owned = cow.into_owned();
208        assert_eq!(owned.as_slice().unwrap(), &[1.0, 2.0, 3.0]);
209    }
210}