hogwild/
lib.rs

1use std::cell::UnsafeCell;
2use std::ops::{Deref, DerefMut};
3use std::sync::Arc;
4
5use ndarray::{Array, ArrayView, ArrayViewMut, Axis, Dimension, Ix, Ix1, Ix2, Ix3, RemoveAxis};
6
7/// Array for Hogwild parallel optimization.
8///
9/// This array type can be used for the Hogwild (Niu, et al. 2011) method
10/// of parallel Stochastic Gradient descent. In Hogwild different threads
11/// share the same parameters without locking. If SGD is performed on a
12/// sparse optimization problem, where only a small subset of parameters
13/// is updated in each gradient descent, the impact of data races is
14/// negligible.
15///
16/// In order to use Hogwild in Rust, we have to subvert the ownership
17/// system. This is what the `HogwildArray` type does. It uses reference
18/// counting to share an *ndarray* `Array` type between multiple
19/// `HogwildArray` instances. Views of the underling `Array` can be borrowed
20/// mutably from each instance, without mutual exclusion between mutable
21/// borrows in different `HogwildArray` instances.
22///
23/// # Example
24///
25/// ```
26/// use hogwild::HogwildArray2;
27/// use ndarray::Array2;
28///
29/// let mut a1: HogwildArray2<f32> = Array2::zeros((2, 2)).into();
30/// let mut a2 = a1.clone();
31///
32/// let mut a1_view = a1.view_mut();
33///
34/// let c00 = &mut a1_view[(0, 0)];
35/// *c00 = 1.0;
36///
37/// // Two simultaneous mutable borrows of the underlying array.
38/// a2.view_mut()[(1, 1)] = *c00 * 2.0;
39///
40/// assert_eq!(&[1.0, 0.0, 0.0, 2.0], a2.as_slice().unwrap());
41/// ```
42
43#[derive(Clone)]
44pub struct HogwildArray<A, D>(Arc<UnsafeCell<Array<A, D>>>);
45
46impl<A, D> HogwildArray<A, D> {
47    #[inline]
48    fn as_mut(&mut self) -> &mut Array<A, D> {
49        let ptr = self.0.as_ref().get();
50        unsafe { &mut *ptr }
51    }
52
53    #[inline]
54    fn as_ref(&self) -> &Array<A, D> {
55        let ptr = self.0.as_ref().get();
56        unsafe { &*ptr }
57    }
58
59    pub fn into_inner(self) -> Arc<UnsafeCell<Array<A, D>>> {
60        self.0
61    }
62}
63
64impl<A, D> HogwildArray<A, D>
65where
66    D: Dimension + RemoveAxis,
67{
68    /// Get an immutable subview of the Hogwild array.
69    #[inline]
70    pub fn subview(&self, axis: Axis, index: Ix) -> ArrayView<A, D::Smaller> {
71        self.as_ref().index_axis(axis, index)
72    }
73
74    /// Get a mutable subview of the Hogwild array.
75    #[inline]
76    pub fn subview_mut(&mut self, axis: Axis, index: Ix) -> ArrayViewMut<A, D::Smaller> {
77        self.as_mut().index_axis_mut(axis, index)
78    }
79}
80
81impl<A, D> HogwildArray<A, D>
82where
83    D: Dimension,
84{
85    /// Get a slice reference to the underlying data array.
86    #[inline]
87    pub fn as_slice(&self) -> Option<&[A]> {
88        self.as_ref().as_slice()
89    }
90
91    /// Get an immutable view of the Hogwild array.
92    #[inline]
93    pub fn view(&self) -> ArrayView<A, D> {
94        self.as_ref().view()
95    }
96
97    /// Get an mutable view of the Hogwild array.
98    #[inline]
99    pub fn view_mut(&mut self) -> ArrayViewMut<A, D> {
100        self.as_mut().view_mut()
101    }
102}
103
104impl<A, D> From<Array<A, D>> for HogwildArray<A, D> {
105    fn from(a: Array<A, D>) -> Self {
106        HogwildArray(Arc::new(UnsafeCell::new(a)))
107    }
108}
109
110unsafe impl<A, D> Send for HogwildArray<A, D> {}
111
112unsafe impl<A, D> Sync for HogwildArray<A, D> {}
113
114/// One-dimensional Hogwild array.
115pub type HogwildArray1<A> = HogwildArray<A, Ix1>;
116
117/// Two-dimensional Hogwild array.
118pub type HogwildArray2<A> = HogwildArray<A, Ix2>;
119
120/// Three-dimensional Hogwild array.
121pub type HogwildArray3<A> = HogwildArray<A, Ix3>;
122
123/// Hogwild for arbitrary data types.
124///
125/// `Hogwild` subverts Rust's type system by allowing concurrent modification
126/// of values. This should only be used for data types that cannot end up in
127/// an inconsistent state due to data races. For arrays `HogwildArray` should
128/// be preferred.
129#[derive(Clone)]
130pub struct Hogwild<T>(Arc<UnsafeCell<T>>);
131
132impl<T> Default for Hogwild<T>
133where
134    T: Default,
135{
136    fn default() -> Self {
137        Hogwild(Arc::new(UnsafeCell::new(T::default())))
138    }
139}
140
141impl<T> Deref for Hogwild<T> {
142    type Target = T;
143
144    fn deref(&self) -> &Self::Target {
145        let ptr = self.0.as_ref().get();
146        unsafe { &*ptr }
147    }
148}
149
150impl<T> DerefMut for Hogwild<T> {
151    fn deref_mut(&mut self) -> &mut T {
152        let ptr = self.0.as_ref().get();
153        unsafe { &mut *ptr }
154    }
155}
156
157unsafe impl<T> Send for Hogwild<T> {}
158
159unsafe impl<T> Sync for Hogwild<T> {}
160
161#[cfg(test)]
162mod test {
163    use ndarray::Array2;
164
165    use super::{Hogwild, HogwildArray2};
166
167    #[test]
168    pub fn hogwild_test() {
169        let mut a1: Hogwild<usize> = Hogwild::default();
170        let mut a2 = a1.clone();
171
172        *a1 = 1;
173        assert_eq!(*a2, 1);
174        *a2 = 2;
175        assert_eq!(*a1, 2);
176    }
177
178    #[test]
179    pub fn hogwild_array_test() {
180        let mut a1: HogwildArray2<f32> = Array2::zeros((2, 2)).into();
181        let mut a2 = a1.clone();
182
183        let mut a1_view = a1.view_mut();
184
185        let c00 = &mut a1_view[(0, 0)];
186        *c00 = 1.0;
187
188        // Two simultaneous mutable borrows of the underlying array.
189        a2.view_mut()[(1, 1)] = *c00 * 2.0;
190
191        assert_eq!(&[1.0, 0.0, 0.0, 2.0], a2.as_slice().unwrap());
192    }
193}