Skip to main content

aorist_extendr_api/wrapper/
matrix.rs

1//! Wrappers for matrices with deferred arithmetic.
2
3use super::*;
4use std::ops::{Index, IndexMut};
5
6/// Wrapper for creating and using matrices and arrays.
7///
8/// ```
9/// use extendr_api::prelude::*;
10/// test! {
11///     let matrix = RMatrix::new_matrix(3, 2, |r, c| [
12///         [1., 2., 3.],
13///          [4., 5., 6.]][c][r]);
14///     let robj = r!(matrix);
15///     assert_eq!(robj.is_matrix(), true);
16///     assert_eq!(robj.nrows(), 3);
17///     assert_eq!(robj.ncols(), 2);
18///
19///     let matrix2 : RMatrix<f64> = robj.as_matrix().ok_or("error")?;
20///     assert_eq!(matrix2.data().len(), 6);
21///     assert_eq!(matrix2.nrows(), 3);
22///     assert_eq!(matrix2.ncols(), 2);
23/// }
24/// ```
25#[derive(Debug, PartialEq)]
26pub struct RArray<T, D> {
27    /// Owning Robj (probably should be a Pin).
28    robj: Robj,
29
30    /// Slice of the data references the Robj.
31    data: *mut T,
32
33    /// Dimensions of the array.
34    dim: D,
35}
36
37pub type RColumn<T> = RArray<T, [usize; 1]>;
38pub type RMatrix<T> = RArray<T, [usize; 2]>;
39pub type RMatrix3D<T> = RArray<T, [usize; 3]>;
40
41const BASE: usize = 0;
42
43trait Offset<D> {
44    /// Get the offset into the array for a given index.
45    fn offset(&self, idx: D) -> usize;
46}
47
48impl<T> Offset<[usize; 1]> for RArray<T, [usize; 1]> {
49    /// Get the offset into the array for a given index.
50    fn offset(&self, index: [usize; 1]) -> usize {
51        if index[0] - BASE > self.dim[0] {
52            panic!("array index: row overflow");
53        }
54        index[0] - BASE
55    }
56}
57
58impl<T> Offset<[usize; 2]> for RArray<T, [usize; 2]> {
59    /// Get the offset into the array for a given index.
60    fn offset(&self, index: [usize; 2]) -> usize {
61        if index[0] - BASE > self.dim[0] {
62            panic!("matrix index: row overflow");
63        }
64        if index[1] - BASE > self.dim[1] {
65            panic!("matrix index: column overflow");
66        }
67        (index[0] - BASE) + self.dim[0] * (index[1] - BASE)
68    }
69}
70
71impl<T> Offset<[usize; 3]> for RArray<T, [usize; 3]> {
72    /// Get the offset into the array for a given index.
73    fn offset(&self, index: [usize; 3]) -> usize {
74        if index[0] - BASE > self.dim[0] {
75            panic!("RMatrix3D index: row overflow");
76        }
77        if index[1] - BASE > self.dim[1] {
78            panic!("RMatrix3D index: column overflow");
79        }
80        if index[2] - BASE > self.dim[2] {
81            panic!("RMatrix3D index: submatrix overflow");
82        }
83        (index[0] - BASE) + self.dim[0] * (index[1] - BASE + self.dim[1] * (index[2] - BASE))
84    }
85}
86
87impl<T, D> RArray<T, D> {
88    pub fn from_parts(robj: Robj, data: *mut T, dim: D) -> Self {
89        Self { robj, data, dim }
90    }
91
92    /// Get the underlying data fro this array.
93    pub fn data(&self) -> &[T] {
94        unsafe { std::slice::from_raw_parts(self.data, self.robj.len()) }
95    }
96
97    /// Get the dimensions for this array.
98    pub fn dim(&self) -> &D {
99        &self.dim
100    }
101}
102
103impl<'a, T: ToVectorValue + 'a> RColumn<T>
104where
105    Robj: AsTypedSlice<'a, T>,
106{
107    /// Make a new column type.
108    pub fn new_column<F: FnMut(usize) -> T>(nrows: usize, mut f: F) -> Self {
109        let robj = (0..nrows).map(|r| f(r)).collect_robj();
110        let dim = [nrows];
111        let mut robj = robj.set_attrib(wrapper::symbol::dim_symbol(), dim).unwrap();
112        let slice = robj.as_typed_slice_mut().unwrap();
113        let data = slice.as_mut_ptr();
114        RArray::from_parts(robj, data, dim)
115    }
116
117    /// Get the number of rows.
118    pub fn nrows(&self) -> usize {
119        self.dim[0]
120    }
121}
122
123impl<'a, T: ToVectorValue + 'a> RMatrix<T>
124where
125    Robj: AsTypedSlice<'a, T>,
126{
127    /// Create a new matrix wrapper.
128    /// Make a new column type.
129    pub fn new_matrix<F: Clone + FnMut(usize, usize) -> T>(
130        nrows: usize,
131        ncols: usize,
132        f: F,
133    ) -> Self {
134        let robj = (0..ncols)
135            .map(|c| {
136                let mut g = f.clone();
137                (0..nrows).map(move |r| g(r, c))
138            })
139            .flatten()
140            .collect_robj();
141        let dim = [nrows, ncols];
142        let mut robj = robj.set_attrib(wrapper::symbol::dim_symbol(), dim).unwrap();
143        let data = robj.as_typed_slice_mut().unwrap().as_mut_ptr();
144        RArray::from_parts(robj, data, dim)
145    }
146
147    /// Get the number of rows.
148    pub fn nrows(&self) -> usize {
149        self.dim[0]
150    }
151
152    /// Get the number of columns.
153    pub fn ncols(&self) -> usize {
154        self.dim[1]
155    }
156}
157
158impl<'a, T: ToVectorValue + 'a> RMatrix3D<T>
159where
160    Robj: AsTypedSlice<'a, T>,
161{
162    pub fn new_matrix3d<F: Clone + FnMut(usize, usize, usize) -> T>(
163        nrows: usize,
164        ncols: usize,
165        nmatrix: usize,
166        f: F,
167    ) -> Self {
168        let robj = (0..nmatrix)
169            .map(|m| {
170                let h = f.clone();
171                (0..ncols)
172                    .map(move |c| {
173                        let mut g = h.clone();
174                        (0..nrows).map(move |r| g(r, c, m))
175                    })
176                    .flatten()
177            })
178            .flatten()
179            .collect_robj();
180        let dim = [nrows, ncols, nmatrix];
181        let mut robj = robj.set_attrib(wrapper::symbol::dim_symbol(), dim).unwrap();
182        let data = robj.as_typed_slice_mut().unwrap().as_mut_ptr();
183        RArray::from_parts(robj, data, dim)
184    }
185
186    /// Get the number of rows.
187    pub fn nrows(&self) -> usize {
188        self.dim[0]
189    }
190
191    /// Get the number of columns.
192    pub fn ncols(&self) -> usize {
193        self.dim[1]
194    }
195
196    /// Get the number of submatrices.
197    pub fn nsub(&self) -> usize {
198        self.dim[2]
199    }
200}
201
202impl<'a, T: 'a> TryFrom<Robj> for RColumn<T>
203where
204    Robj: AsTypedSlice<'a, T>,
205{
206    type Error = Error;
207
208    fn try_from(mut robj: Robj) -> Result<Self> {
209        if let Some(slice) = robj.as_typed_slice_mut() {
210            Ok(RArray::from_parts(robj, slice.as_mut_ptr(), [slice.len()]))
211        } else {
212            Err(Error::ExpectedVector(robj))
213        }
214    }
215}
216
217impl<'a, T: 'a> TryFrom<Robj> for RMatrix<T>
218where
219    Robj: AsTypedSlice<'a, T>,
220{
221    type Error = Error;
222
223    fn try_from(mut robj: Robj) -> Result<Self> {
224        if !robj.is_matrix() {
225            Err(Error::ExpectedMatrix(robj))
226        } else if let Some(slice) = robj.as_typed_slice_mut() {
227            if let Some(dim) = robj.dim() {
228                let dim: Vec<_> = dim.map(|d| d as usize).collect();
229                if dim.len() != 2 {
230                    Err(Error::ExpectedMatrix(robj))
231                } else {
232                    Ok(RArray::from_parts(
233                        robj,
234                        slice.as_mut_ptr(),
235                        [dim[0], dim[1]],
236                    ))
237                }
238            } else {
239                Err(Error::ExpectedMatrix(robj))
240            }
241        } else {
242            Err(Error::TypeMismatch(robj))
243        }
244    }
245}
246
247impl<'a, T: 'a> TryFrom<Robj> for RMatrix3D<T>
248where
249    Robj: AsTypedSlice<'a, T>,
250{
251    type Error = Error;
252
253    fn try_from(mut robj: Robj) -> Result<Self> {
254        if let Some(slice) = robj.as_typed_slice_mut() {
255            if let Some(dim) = robj.dim() {
256                if dim.len() != 3 {
257                    Err(Error::ExpectedMatrix3D(robj))
258                } else {
259                    let dim: Vec<_> = dim.map(|d| d as usize).collect();
260                    Ok(RArray::from_parts(
261                        robj,
262                        slice.as_mut_ptr(),
263                        [dim[0], dim[1], dim[2]],
264                    ))
265                }
266            } else {
267                Err(Error::ExpectedMatrix3D(robj))
268            }
269        } else {
270            Err(Error::TypeMismatch(robj))
271        }
272    }
273}
274
275impl<T, D> From<RArray<T, D>> for Robj {
276    /// Convert a column, matrix or matrix3d to an Robj.
277    fn from(array: RArray<T, D>) -> Self {
278        array.robj
279    }
280}
281
282impl Robj {
283    pub fn as_column<'a, E: 'a>(&self) -> Option<RColumn<E>>
284    where
285        Self: AsTypedSlice<'a, E>,
286    {
287        <RColumn<E>>::try_from(self.clone()).ok()
288    }
289
290    pub fn as_matrix<'a, E: 'a>(&self) -> Option<RMatrix<E>>
291    where
292        Self: AsTypedSlice<'a, E>,
293    {
294        <RMatrix<E>>::try_from(self.clone()).ok()
295    }
296
297    pub fn as_matrix3d<'a, E: 'a>(&self) -> Option<RMatrix3D<E>>
298    where
299        Self: AsTypedSlice<'a, E>,
300    {
301        <RMatrix3D<E>>::try_from(self.clone()).ok()
302    }
303}
304
305impl<T> Index<[usize; 2]> for RArray<T, [usize; 2]> {
306    type Output = T;
307
308    /// Zero-based indexing in row, column order.
309    ///
310    /// Panics if out of bounds.
311    /// ```
312    /// use extendr_api::prelude::*;
313    /// test! {
314    ///    let matrix = RArray::new_matrix(3, 2, |r, c| [
315    ///        [1., 2., 3.],
316    ///        [4., 5., 6.]][c][r]);
317    ///     assert_eq!(matrix[[0, 0]], 1.);
318    ///     assert_eq!(matrix[[1, 0]], 2.);
319    ///     assert_eq!(matrix[[2, 1]], 6.);
320    /// }
321    /// ```
322    fn index(&self, index: [usize; 2]) -> &Self::Output {
323        unsafe { self.data.add(self.offset(index)).as_ref().unwrap() }
324    }
325}
326
327impl<T> IndexMut<[usize; 2]> for RArray<T, [usize; 2]> {
328    /// Zero-based mutable indexing in row, column order.
329    ///
330    /// Panics if out of bounds.
331    /// ```
332    /// use extendr_api::prelude::*;
333    /// test! {
334    ///     let mut matrix = RMatrix::new_matrix(3, 2, |_, _| 0.);
335    ///     matrix[[0, 0]] = 1.;
336    ///     matrix[[1, 0]] = 2.;
337    ///     matrix[[2, 0]] = 3.;
338    ///     matrix[[0, 1]] = 4.;
339    ///     assert_eq!(matrix.as_real_slice().unwrap(), &[1., 2., 3., 4., 0., 0.]);
340    /// }
341    /// ```
342    fn index_mut(&mut self, index: [usize; 2]) -> &mut Self::Output {
343        unsafe { self.data.add(self.offset(index)).as_mut().unwrap() }
344    }
345}
346
347impl<T, D> Deref for RArray<T, D> {
348    type Target = Robj;
349
350    fn deref(&self) -> &Self::Target {
351        &self.robj
352    }
353}
354
355#[test]
356fn matrix_ops() {
357    test! {
358        let vector = RColumn::new_column(3, |r| [1., 2., 3.][r]);
359        let robj = r!(vector);
360        assert_eq!(robj.is_vector(), true);
361        assert_eq!(robj.nrows(), 3);
362
363        let vector2 : RColumn<f64> = robj.as_column().ok_or("expected array")?;
364        assert_eq!(vector2.data().len(), 3);
365        assert_eq!(vector2.nrows(), 3);
366
367        let matrix = RMatrix::new_matrix(3, 2, |r, c| [
368            [1., 2., 3.],
369            [4., 5., 6.]][c][r]);
370        let robj = r!(matrix);
371        assert_eq!(robj.is_matrix(), true);
372        assert_eq!(robj.nrows(), 3);
373        assert_eq!(robj.ncols(), 2);
374        let matrix2 : RMatrix<f64> = robj.as_matrix().ok_or("expected matrix")?;
375        assert_eq!(matrix2.data().len(), 6);
376        assert_eq!(matrix2.nrows(), 3);
377        assert_eq!(matrix2.ncols(), 2);
378
379        let array = RMatrix3D::new_matrix3d(2, 2, 2, |r, c, m| [
380            [[1., 2.],  [3., 4.]],
381            [[5.,  6.], [7., 8.]]][m][c][r]);
382        let robj = r!(array);
383        assert_eq!(robj.is_array(), true);
384        assert_eq!(robj.nrows(), 2);
385        assert_eq!(robj.ncols(), 2);
386        let array2 : RMatrix3D<f64> = robj.as_matrix3d().ok_or("expected matrix3d")?;
387        assert_eq!(array2.data().len(), 8);
388        assert_eq!(array2.nrows(), 2);
389        assert_eq!(array2.ncols(), 2);
390        assert_eq!(array2.nsub(), 2);
391    }
392}