1use std::sync::Arc;
4
5use crate::dimension::Dimension;
6use crate::dtype::Element;
7use crate::layout::MemoryLayout;
8
9use super::ArrayFlags;
10use super::owned::Array;
11use super::view::ArrayView;
12
13pub struct ArcArray<T: Element, D: Dimension> {
21 data: Arc<Vec<T>>,
23 dim: D,
25 strides: Vec<isize>,
27 offset: usize,
29}
30
31impl<T: Element, D: Dimension> ArcArray<T, D> {
32 pub fn from_owned(arr: Array<T, D>) -> Self {
34 let dim = arr.dim.clone();
35 let data = if arr.inner.is_standard_layout() {
37 arr.inner.into_raw_vec_and_offset().0
38 } else {
39 let contiguous = arr.inner.as_standard_layout().into_owned();
41 contiguous.into_raw_vec_and_offset().0
42 };
43 let strides = compute_c_strides(dim.as_slice());
45 Self {
46 data: Arc::new(data),
47 dim,
48 strides,
49 offset: 0,
50 }
51 }
52
53 #[inline]
55 pub fn shape(&self) -> &[usize] {
56 self.dim.as_slice()
57 }
58
59 #[inline]
61 pub fn ndim(&self) -> usize {
62 self.dim.ndim()
63 }
64
65 #[inline]
67 pub fn size(&self) -> usize {
68 self.dim.size()
69 }
70
71 #[inline]
73 pub fn is_empty(&self) -> bool {
74 self.size() == 0
75 }
76
77 #[inline]
79 pub fn strides(&self) -> &[isize] {
80 &self.strides
81 }
82
83 pub fn layout(&self) -> MemoryLayout {
85 crate::layout::detect_layout(self.dim.as_slice(), &self.strides)
86 }
87
88 #[inline]
90 pub fn dim(&self) -> &D {
91 &self.dim
92 }
93
94 pub fn ref_count(&self) -> usize {
96 Arc::strong_count(&self.data)
97 }
98
99 pub fn is_unique(&self) -> bool {
101 Arc::strong_count(&self.data) == 1
102 }
103
104 pub fn as_slice(&self) -> &[T] {
106 &self.data[self.offset..self.offset + self.size()]
107 }
108
109 #[inline]
111 pub fn as_ptr(&self) -> *const T {
112 self.as_slice().as_ptr()
113 }
114
115 pub fn view(&self) -> ArrayView<'_, T, D> {
121 let nd_dim = self.dim.to_ndarray_dim();
122 let slice = self.as_slice();
123 let nd_view = ndarray::ArrayView::from_shape(nd_dim, slice)
124 .expect("ArcArray data should be consistent with shape");
125 ArrayView::from_ndarray(nd_view)
126 }
127
128 fn make_unique(&mut self) {
133 if Arc::strong_count(&self.data) > 1 {
134 let slice = &self.data[self.offset..self.offset + self.size()];
135 self.data = Arc::new(slice.to_vec());
136 self.offset = 0;
137 }
138 }
139
140 pub fn as_slice_mut(&mut self) -> &mut [T] {
142 self.make_unique();
143 let size = self.size();
144 let offset = self.offset;
145 Arc::get_mut(&mut self.data)
146 .expect("make_unique should ensure refcount == 1")
147 .get_mut(offset..offset + size)
148 .expect("offset + size should be in bounds")
149 }
150
151 pub fn mapv_inplace(&mut self, f: impl Fn(T) -> T) {
153 self.make_unique();
154 let size = self.size();
155 let offset = self.offset;
156 let data = Arc::get_mut(&mut self.data).expect("unique after make_unique");
157 for elem in &mut data[offset..offset + size] {
158 *elem = f(elem.clone());
159 }
160 }
161
162 pub fn into_owned(self) -> Array<T, D> {
164 let data: Vec<T> = if self.offset == 0 && self.data.len() == self.size() {
165 match Arc::try_unwrap(self.data) {
166 Ok(v) => v,
167 Err(arc) => arc[..].to_vec(),
168 }
169 } else {
170 self.data[self.offset..self.offset + self.size()].to_vec()
171 };
172 Array::from_vec(self.dim, data).expect("data should match shape")
173 }
174
175 pub fn copy(&self) -> Self {
177 let data = self.as_slice().to_vec();
178 Self {
179 data: Arc::new(data),
180 dim: self.dim.clone(),
181 strides: self.strides.clone(),
182 offset: 0,
183 }
184 }
185
186 pub fn flags(&self) -> ArrayFlags {
188 let layout = self.layout();
189 ArrayFlags {
190 c_contiguous: layout.is_c_contiguous(),
191 f_contiguous: layout.is_f_contiguous(),
192 owndata: true, writeable: true,
194 }
195 }
196}
197
198impl<T: Element, D: Dimension> Clone for ArcArray<T, D> {
199 fn clone(&self) -> Self {
200 Self {
201 data: Arc::clone(&self.data),
202 dim: self.dim.clone(),
203 strides: self.strides.clone(),
204 offset: self.offset,
205 }
206 }
207}
208
209impl<T: Element, D: Dimension> From<Array<T, D>> for ArcArray<T, D> {
210 fn from(arr: Array<T, D>) -> Self {
211 Self::from_owned(arr)
212 }
213}
214
215fn compute_c_strides(shape: &[usize]) -> Vec<isize> {
217 let ndim = shape.len();
218 if ndim == 0 {
219 return vec![];
220 }
221 let mut strides = vec![0isize; ndim];
222 strides[ndim - 1] = 1;
223 for i in (0..ndim - 1).rev() {
224 strides[i] = strides[i + 1] * shape[i + 1] as isize;
225 }
226 strides
227}
228
229#[cfg(test)]
230mod tests {
231 use super::*;
232 use crate::dimension::{Ix1, Ix2};
233
234 #[test]
235 fn arc_from_owned() {
236 let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
237 let arc = ArcArray::from_owned(arr);
238 assert_eq!(arc.shape(), &[3]);
239 assert_eq!(arc.as_slice(), &[1.0, 2.0, 3.0]);
240 assert_eq!(arc.ref_count(), 1);
241 }
242
243 #[test]
244 fn arc_clone_shares() {
245 let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
246 let arc1 = ArcArray::from_owned(arr);
247 let arc2 = arc1.clone();
248 assert_eq!(arc1.ref_count(), 2);
249 assert_eq!(arc2.ref_count(), 2);
250 assert_eq!(arc1.as_ptr(), arc2.as_ptr());
251 }
252
253 #[test]
254 fn arc_cow_on_mutation() {
255 let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
256 let arc1 = ArcArray::from_owned(arr);
257 let mut arc2 = arc1.clone();
258
259 assert_eq!(arc1.as_ptr(), arc2.as_ptr());
261 assert_eq!(arc1.ref_count(), 2);
262
263 arc2.as_slice_mut()[0] = 99.0;
265
266 assert_ne!(arc1.as_ptr(), arc2.as_ptr());
268 assert_eq!(arc1.as_slice(), &[1.0, 2.0, 3.0]);
269 assert_eq!(arc2.as_slice(), &[99.0, 2.0, 3.0]);
270 assert_eq!(arc1.ref_count(), 1);
271 assert_eq!(arc2.ref_count(), 1);
272 }
273
274 #[test]
275 fn arc_view_sees_old_data_after_cow() {
276 let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
277 let mut arc = ArcArray::from_owned(arr);
278 let arc_clone = arc.clone();
279
280 let view = arc_clone.view();
282 assert_eq!(view.as_slice().unwrap(), &[1.0, 2.0, 3.0]);
283
284 arc.as_slice_mut()[0] = 99.0;
286
287 assert_eq!(view.as_slice().unwrap(), &[1.0, 2.0, 3.0]);
289 assert_eq!(arc.as_slice(), &[99.0, 2.0, 3.0]);
291 }
292
293 #[test]
294 fn arc_unique_no_clone() {
295 let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
296 let mut arc = ArcArray::from_owned(arr);
297 let ptr_before = arc.as_ptr();
298
299 arc.as_slice_mut()[0] = 99.0;
301 assert_eq!(arc.as_ptr(), ptr_before);
302 assert_eq!(arc.as_slice(), &[99.0, 2.0, 3.0]);
303 }
304
305 #[test]
306 fn arc_into_owned() {
307 let arr = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0; 6]).unwrap();
308 let arc = ArcArray::from_owned(arr);
309 let owned = arc.into_owned();
310 assert_eq!(owned.shape(), &[2, 3]);
311 }
312
313 #[test]
314 fn arc_mapv_inplace() {
315 let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
316 let mut arc = ArcArray::from_owned(arr);
317 arc.mapv_inplace(|x| x * 2.0);
318 assert_eq!(arc.as_slice(), &[2.0, 4.0, 6.0]);
319 }
320
321 #[test]
322 fn arc_copy_is_independent() {
323 let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
324 let arc = ArcArray::from_owned(arr);
325 let copy = arc.copy();
326 assert_ne!(arc.as_ptr(), copy.as_ptr());
327 assert_eq!(arc.ref_count(), 1); assert_eq!(copy.ref_count(), 1);
329 }
330}