1use crate::dimension::Dimension;
4use crate::dtype::Element;
5use crate::layout::MemoryLayout;
6
7use super::ArrayFlags;
8use super::owned::Array;
9use super::view::ArrayView;
10
11pub enum CowArray<'a, T: Element, D: Dimension> {
16 Borrowed(ArrayView<'a, T, D>),
18 Owned(Array<T, D>),
20}
21
22impl<'a, T: Element, D: Dimension> CowArray<'a, T, D> {
23 pub const fn from_view(view: ArrayView<'a, T, D>) -> Self {
25 Self::Borrowed(view)
26 }
27
28 pub const fn from_owned(arr: Array<T, D>) -> Self {
30 Self::Owned(arr)
31 }
32
33 #[inline]
35 pub fn shape(&self) -> &[usize] {
36 match self {
37 Self::Borrowed(v) => v.shape(),
38 Self::Owned(a) => a.shape(),
39 }
40 }
41
42 #[inline]
44 pub fn ndim(&self) -> usize {
45 match self {
46 Self::Borrowed(v) => v.ndim(),
47 Self::Owned(a) => a.ndim(),
48 }
49 }
50
51 #[inline]
53 pub fn size(&self) -> usize {
54 match self {
55 Self::Borrowed(v) => v.size(),
56 Self::Owned(a) => a.size(),
57 }
58 }
59
60 #[inline]
62 pub fn is_empty(&self) -> bool {
63 self.size() == 0
64 }
65
66 pub fn layout(&self) -> MemoryLayout {
68 match self {
69 Self::Borrowed(v) => v.layout(),
70 Self::Owned(a) => a.layout(),
71 }
72 }
73
74 pub const fn is_borrowed(&self) -> bool {
76 matches!(self, Self::Borrowed(_))
77 }
78
79 pub const fn is_owned(&self) -> bool {
81 matches!(self, Self::Owned(_))
82 }
83
84 pub fn into_owned(self) -> Array<T, D> {
86 match self {
87 Self::Borrowed(v) => v.to_owned(),
88 Self::Owned(a) => a,
89 }
90 }
91
92 pub fn to_mut(&mut self) -> &mut Array<T, D> {
95 if let Self::Borrowed(v) = self {
96 *self = Self::Owned(v.to_owned());
97 }
98 match self {
99 Self::Owned(a) => a,
100 Self::Borrowed(_) => unreachable!(),
101 }
102 }
103
104 pub fn view(&self) -> ArrayView<'_, T, D> {
109 match self {
110 Self::Borrowed(v) => {
111 ArrayView::from_ndarray(v.inner.view())
113 }
114 Self::Owned(a) => a.view(),
115 }
116 }
117
118 #[inline]
120 pub fn as_ptr(&self) -> *const T {
121 match self {
122 Self::Borrowed(v) => v.as_ptr(),
123 Self::Owned(a) => a.as_ptr(),
124 }
125 }
126
127 pub fn flags(&self) -> ArrayFlags {
129 match self {
130 Self::Borrowed(v) => v.flags(),
131 Self::Owned(a) => {
132 let layout = a.layout();
133 ArrayFlags {
134 c_contiguous: layout.is_c_contiguous(),
135 f_contiguous: layout.is_f_contiguous(),
136 owndata: true,
137 writeable: true,
138 aligned: true,
140 }
141 }
142 }
143 }
144}
145
146impl<T: Element, D: Dimension> Clone for CowArray<'_, T, D> {
147 fn clone(&self) -> Self {
148 match self {
149 Self::Borrowed(v) => Self::Borrowed(v.clone()),
150 Self::Owned(a) => Self::Owned(a.clone()),
151 }
152 }
153}
154
155impl<T: Element, D: Dimension> From<Array<T, D>> for CowArray<'_, T, D> {
156 fn from(arr: Array<T, D>) -> Self {
157 Self::Owned(arr)
158 }
159}
160
161impl<'a, T: Element, D: Dimension> From<ArrayView<'a, T, D>> for CowArray<'a, T, D> {
162 fn from(view: ArrayView<'a, T, D>) -> Self {
163 Self::Borrowed(view)
164 }
165}
166
167#[cfg(test)]
168mod tests {
169 use super::*;
170 use crate::dimension::Ix1;
171
172 #[test]
173 fn cow_from_view() {
174 let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
175 let view = arr.view();
176 let cow = CowArray::from_view(view);
177 assert!(cow.is_borrowed());
178 assert!(!cow.is_owned());
179 assert_eq!(cow.shape(), &[3]);
180 }
181
182 #[test]
183 fn cow_from_owned() {
184 let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
185 let cow = CowArray::from_owned(arr);
186 assert!(cow.is_owned());
187 assert!(!cow.is_borrowed());
188 }
189
190 #[test]
191 fn cow_to_mut_clones_when_borrowed() {
192 let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
193 let view = arr.view();
194 let mut cow = CowArray::from_view(view);
195
196 assert!(cow.is_borrowed());
197 let owned = cow.to_mut();
198 assert_eq!(owned.shape(), &[3]);
199 assert!(cow.is_owned());
200 }
201
202 #[test]
203 fn cow_into_owned() {
204 let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
205 let view = arr.view();
206 let cow = CowArray::from_view(view);
207 let owned = cow.into_owned();
208 assert_eq!(owned.as_slice().unwrap(), &[1.0, 2.0, 3.0]);
209 }
210}