1use crate::dimension::Dimension;
4use crate::dtype::Element;
5use crate::layout::MemoryLayout;
6
7use super::ArrayFlags;
8use super::owned::Array;
9use super::view::ArrayView;
10
11pub enum CowArray<'a, T: Element, D: Dimension> {
16 Borrowed(ArrayView<'a, T, D>),
18 Owned(Array<T, D>),
20}
21
22impl<'a, T: Element, D: Dimension> CowArray<'a, T, D> {
23 pub fn from_view(view: ArrayView<'a, T, D>) -> Self {
25 Self::Borrowed(view)
26 }
27
28 pub fn from_owned(arr: Array<T, D>) -> Self {
30 Self::Owned(arr)
31 }
32
33 #[inline]
35 pub fn shape(&self) -> &[usize] {
36 match self {
37 Self::Borrowed(v) => v.shape(),
38 Self::Owned(a) => a.shape(),
39 }
40 }
41
42 #[inline]
44 pub fn ndim(&self) -> usize {
45 match self {
46 Self::Borrowed(v) => v.ndim(),
47 Self::Owned(a) => a.ndim(),
48 }
49 }
50
51 #[inline]
53 pub fn size(&self) -> usize {
54 match self {
55 Self::Borrowed(v) => v.size(),
56 Self::Owned(a) => a.size(),
57 }
58 }
59
60 #[inline]
62 pub fn is_empty(&self) -> bool {
63 self.size() == 0
64 }
65
66 pub fn layout(&self) -> MemoryLayout {
68 match self {
69 Self::Borrowed(v) => v.layout(),
70 Self::Owned(a) => a.layout(),
71 }
72 }
73
74 pub fn is_borrowed(&self) -> bool {
76 matches!(self, Self::Borrowed(_))
77 }
78
79 pub fn is_owned(&self) -> bool {
81 matches!(self, Self::Owned(_))
82 }
83
84 pub fn into_owned(self) -> Array<T, D> {
86 match self {
87 Self::Borrowed(v) => v.to_owned(),
88 Self::Owned(a) => a,
89 }
90 }
91
92 pub fn to_mut(&mut self) -> &mut Array<T, D> {
95 if let Self::Borrowed(v) = self {
96 *self = Self::Owned(v.to_owned());
97 }
98 match self {
99 Self::Owned(a) => a,
100 Self::Borrowed(_) => unreachable!(),
101 }
102 }
103
104 pub fn view(&self) -> ArrayView<'_, T, D> {
109 match self {
110 Self::Borrowed(v) => {
111 ArrayView::from_ndarray(v.inner.view())
113 }
114 Self::Owned(a) => a.view(),
115 }
116 }
117
118 #[inline]
120 pub fn as_ptr(&self) -> *const T {
121 match self {
122 Self::Borrowed(v) => v.as_ptr(),
123 Self::Owned(a) => a.as_ptr(),
124 }
125 }
126
127 pub fn flags(&self) -> ArrayFlags {
129 match self {
130 Self::Borrowed(v) => v.flags(),
131 Self::Owned(a) => {
132 let layout = a.layout();
133 ArrayFlags {
134 c_contiguous: layout.is_c_contiguous(),
135 f_contiguous: layout.is_f_contiguous(),
136 owndata: true,
137 writeable: true,
138 }
139 }
140 }
141 }
142}
143
144impl<T: Element, D: Dimension> Clone for CowArray<'_, T, D> {
145 fn clone(&self) -> Self {
146 match self {
147 Self::Borrowed(v) => Self::Borrowed(v.clone()),
148 Self::Owned(a) => Self::Owned(a.clone()),
149 }
150 }
151}
152
153impl<T: Element, D: Dimension> From<Array<T, D>> for CowArray<'_, T, D> {
154 fn from(arr: Array<T, D>) -> Self {
155 Self::Owned(arr)
156 }
157}
158
159impl<'a, T: Element, D: Dimension> From<ArrayView<'a, T, D>> for CowArray<'a, T, D> {
160 fn from(view: ArrayView<'a, T, D>) -> Self {
161 Self::Borrowed(view)
162 }
163}
164
165#[cfg(test)]
166mod tests {
167 use super::*;
168 use crate::dimension::Ix1;
169
170 #[test]
171 fn cow_from_view() {
172 let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
173 let view = arr.view();
174 let cow = CowArray::from_view(view);
175 assert!(cow.is_borrowed());
176 assert!(!cow.is_owned());
177 assert_eq!(cow.shape(), &[3]);
178 }
179
180 #[test]
181 fn cow_from_owned() {
182 let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
183 let cow = CowArray::from_owned(arr);
184 assert!(cow.is_owned());
185 assert!(!cow.is_borrowed());
186 }
187
188 #[test]
189 fn cow_to_mut_clones_when_borrowed() {
190 let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
191 let view = arr.view();
192 let mut cow = CowArray::from_view(view);
193
194 assert!(cow.is_borrowed());
195 let owned = cow.to_mut();
196 assert_eq!(owned.shape(), &[3]);
197 assert!(cow.is_owned());
198 }
199
200 #[test]
201 fn cow_into_owned() {
202 let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
203 let view = arr.view();
204 let cow = CowArray::from_view(view);
205 let owned = cow.into_owned();
206 assert_eq!(owned.as_slice().unwrap(), &[1.0, 2.0, 3.0]);
207 }
208}