1use std::cell::UnsafeCell;
2use std::ops::{Deref, DerefMut};
3use std::sync::Arc;
4
5use ndarray::{Array, ArrayView, ArrayViewMut, Axis, Dimension, Ix, Ix1, Ix2, Ix3, RemoveAxis};
6
7#[derive(Clone)]
44pub struct HogwildArray<A, D>(Arc<UnsafeCell<Array<A, D>>>);
45
46impl<A, D> HogwildArray<A, D> {
47 #[inline]
48 fn as_mut(&mut self) -> &mut Array<A, D> {
49 let ptr = self.0.as_ref().get();
50 unsafe { &mut *ptr }
51 }
52
53 #[inline]
54 fn as_ref(&self) -> &Array<A, D> {
55 let ptr = self.0.as_ref().get();
56 unsafe { &*ptr }
57 }
58
59 pub fn into_inner(self) -> Arc<UnsafeCell<Array<A, D>>> {
60 self.0
61 }
62}
63
64impl<A, D> HogwildArray<A, D>
65where
66 D: Dimension + RemoveAxis,
67{
68 #[inline]
70 pub fn subview(&self, axis: Axis, index: Ix) -> ArrayView<A, D::Smaller> {
71 self.as_ref().index_axis(axis, index)
72 }
73
74 #[inline]
76 pub fn subview_mut(&mut self, axis: Axis, index: Ix) -> ArrayViewMut<A, D::Smaller> {
77 self.as_mut().index_axis_mut(axis, index)
78 }
79}
80
81impl<A, D> HogwildArray<A, D>
82where
83 D: Dimension,
84{
85 #[inline]
87 pub fn as_slice(&self) -> Option<&[A]> {
88 self.as_ref().as_slice()
89 }
90
91 #[inline]
93 pub fn view(&self) -> ArrayView<A, D> {
94 self.as_ref().view()
95 }
96
97 #[inline]
99 pub fn view_mut(&mut self) -> ArrayViewMut<A, D> {
100 self.as_mut().view_mut()
101 }
102}
103
104impl<A, D> From<Array<A, D>> for HogwildArray<A, D> {
105 fn from(a: Array<A, D>) -> Self {
106 HogwildArray(Arc::new(UnsafeCell::new(a)))
107 }
108}
109
110unsafe impl<A, D> Send for HogwildArray<A, D> {}
111
112unsafe impl<A, D> Sync for HogwildArray<A, D> {}
113
114pub type HogwildArray1<A> = HogwildArray<A, Ix1>;
116
117pub type HogwildArray2<A> = HogwildArray<A, Ix2>;
119
120pub type HogwildArray3<A> = HogwildArray<A, Ix3>;
122
123#[derive(Clone)]
130pub struct Hogwild<T>(Arc<UnsafeCell<T>>);
131
132impl<T> Default for Hogwild<T>
133where
134 T: Default,
135{
136 fn default() -> Self {
137 Hogwild(Arc::new(UnsafeCell::new(T::default())))
138 }
139}
140
141impl<T> Deref for Hogwild<T> {
142 type Target = T;
143
144 fn deref(&self) -> &Self::Target {
145 let ptr = self.0.as_ref().get();
146 unsafe { &*ptr }
147 }
148}
149
150impl<T> DerefMut for Hogwild<T> {
151 fn deref_mut(&mut self) -> &mut T {
152 let ptr = self.0.as_ref().get();
153 unsafe { &mut *ptr }
154 }
155}
156
157unsafe impl<T> Send for Hogwild<T> {}
158
159unsafe impl<T> Sync for Hogwild<T> {}
160
161#[cfg(test)]
162mod test {
163 use ndarray::Array2;
164
165 use super::{Hogwild, HogwildArray2};
166
167 #[test]
168 pub fn hogwild_test() {
169 let mut a1: Hogwild<usize> = Hogwild::default();
170 let mut a2 = a1.clone();
171
172 *a1 = 1;
173 assert_eq!(*a2, 1);
174 *a2 = 2;
175 assert_eq!(*a1, 2);
176 }
177
178 #[test]
179 pub fn hogwild_array_test() {
180 let mut a1: HogwildArray2<f32> = Array2::zeros((2, 2)).into();
181 let mut a2 = a1.clone();
182
183 let mut a1_view = a1.view_mut();
184
185 let c00 = &mut a1_view[(0, 0)];
186 *c00 = 1.0;
187
188 a2.view_mut()[(1, 1)] = *c00 * 2.0;
190
191 assert_eq!(&[1.0, 0.0, 0.0, 2.0], a2.as_slice().unwrap());
192 }
193}