feanor_math/matrix/
submatrix.rs

1use std::ops::{Deref, Range};
2use std::marker::PhantomData;
3use std::ptr::{addr_of_mut, NonNull};
4
5#[cfg(feature = "ndarray")]
6use ndarray::{ArrayBase, DataMut, Ix2};
7
8use crate::seq::SwappableVectorViewMut;
9use crate::seq::{VectorView, VectorViewMut};
10
11///
12/// Trait for objects that can be considered a contiguous part of memory. In particular,
13/// the pointer returned by `get_pointer()` should be interpreted as the pointer to the first
14/// element of a range of elements of type `T` (basically a C-style array). In some
15/// sense, this is thus the unsafe equivalent of `Deref<Target = [T]>`.
16/// 
17/// # Safety
18/// 
19/// Since we use this to provide iterators that do not follow the natural layout of
20/// the data, the following restrictions are necessary:
21///  - Calling multiple times `get_pointer()` on the same reference is valid, and
22///    all resulting pointers are valid to be dereferenced.
23///  - In the above situation, we may also keep multiple mutable references that were
24///    obtained by dereferencing the pointers, *as long as they don't alias, i.e. refer to 
25///    different elements*. 
26///  - If `Self: Sync` then `T: Send` and the above situation should be valid even if
27///    the pointers returned by `get_pointer()` are produced and used from different threads. 
28/// 
29pub unsafe trait AsPointerToSlice<T> {
30
31    ///
32    /// Returns a pointer to the first element of multiple, contiguous `T`s.
33    /// 
34    /// # Safety
35    /// 
36    /// Requires that `self_` is a pointer to a valid object of this type. Note that
37    /// it is legal to call this function while there exist mutable references to `T`s
38    /// that were obtained by dereferencing an earlier result of `get_pointer()`. This
39    /// means that in some situations, the passed `self_` may not be dereferenced without
40    /// violating the aliasing rules.
41    /// 
42    /// However, it must be guaranteed that no mutable reference to an part of `self_` that
43    /// is not pointed to by a result of `get_pointer()` exists. Immutable references may exist.
44    /// 
45    /// For additional detail, see the trait-level doc [`AsPointerToSlice`].
46    /// 
47    unsafe fn get_pointer(self_: NonNull<Self>) -> NonNull<T>;
48}
49
50unsafe impl<T> AsPointerToSlice<T> for Vec<T> {
51
52    unsafe fn get_pointer(self_: NonNull<Self>) -> NonNull<T> {
53        // Safe, because "This method guarantees that for the purpose of the aliasing model, this 
54        // method does not materialize a reference to the underlying slice" (quote from the doc of 
55        // [`Vec::as_mut_ptr()`])
56        unsafe {
57            NonNull::new((*self_.as_ptr()).as_mut_ptr()).unwrap()
58        }
59    }
60}
61
62///
63/// Newtype for `[T; SIZE]` that implements `Deref<Target = [T]>` so that it can be used
64/// to store columns and access them through [`Submatrix`].
65/// 
66/// This is necessary, since [`Submatrix::from_2d`] requires that `V: Deref<Target = [T]>`.
67/// 
68#[repr(transparent)]
69#[derive(Clone, Copy, PartialEq, Eq)]
70pub struct DerefArray<T, const SIZE: usize> {
71    pub data: [T; SIZE]
72}
73
74impl<T: std::fmt::Debug, const SIZE: usize> std::fmt::Debug for DerefArray<T, SIZE> {
75
76    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
77        self.data.fmt(f)
78    }
79}
80
81impl<T, const SIZE: usize> From<[T; SIZE]> for DerefArray<T, SIZE> {
82
83    fn from(value: [T; SIZE]) -> Self {
84        Self { data: value }
85    }
86}
87
88impl<'a, T, const SIZE: usize> From<&'a [T; SIZE]> for &'a DerefArray<T, SIZE> {
89
90    fn from(value: &'a [T; SIZE]) -> Self {
91        unsafe { std::mem::transmute(value) }
92    }
93}
94
95impl<'a, T, const SIZE: usize> From<&'a mut [T; SIZE]> for &'a mut DerefArray<T, SIZE> {
96
97    fn from(value: &'a mut [T; SIZE]) -> Self {
98        unsafe { std::mem::transmute(value) }
99    }
100}
101
102impl<T, const SIZE: usize> Deref for DerefArray<T, SIZE> {
103
104    type Target = [T];
105
106    fn deref(&self) -> &Self::Target {
107        &self.data[..]
108    }
109}
110
111unsafe impl<T, const SIZE: usize> AsPointerToSlice<T> for DerefArray<T, SIZE> {
112
113    unsafe fn get_pointer(self_: NonNull<Self>) -> NonNull<T> {
114        unsafe { 
115            let self_ptr = self_.as_ptr();
116            let data_ptr = addr_of_mut!((*self_ptr).data);
117            NonNull::new((*data_ptr).as_mut_ptr()).unwrap()
118        }
119    }
120}
121
122///
123/// Represents a contiguous batch of `T`s by their first element.
124/// In other words, a pointer to the batch is equal to a pointer to 
125/// the first value.
126/// 
127#[repr(transparent)]
128pub struct AsFirstElement<T>(T);
129
130unsafe impl<'a, T> AsPointerToSlice<T> for AsFirstElement<T> {
131
132    unsafe fn get_pointer(self_: NonNull<Self>) -> NonNull<T> {
133        std::mem::transmute(self_)
134    }
135}
136
137///
138/// A submatrix that works on raw pointers, thus does not care about mutability
139/// and borrowing. It already takes care about bounds checking and indexing.
140/// Nevertheless, it is quite difficult to use this correctly, best not use it at
141/// all. I mainly made it public to allow doctests.
142/// 
143/// More concretely, when having a 2d-structure, given by a sequence of `V`s, we
144/// can consider a rectangular sub-block. This is encapsulated by SubmatrixRaw.
145/// 
146/// # Safety
147/// 
148/// The individual safety contracts are described at the corresponding functions.
149/// However, in total be careful when actuall transforming the entry pointers
150/// (given by [`SubmatrixRaw::entry_at`] or [`SubmatrixRaw::row_at`]) into references.
151/// Since `SubmatrixRaw` does not borrow-check (and is `Copy`!), it is easy to create
152/// aliasing pointers, that must not be converted into references.
153/// 
154/// ## Example of illegal use
155/// ```
156/// # use feanor_math::matrix::*;
157/// # use core::ptr::NonNull;
158/// let mut data = [1, 2, 3];
159/// // this is actuall safe and intended use
160/// let mut matrix = unsafe { SubmatrixRaw::<AsFirstElement<i64>, i64>::new(std::mem::transmute(NonNull::new(data.as_mut_ptr()).unwrap()), 1, 3, 0, 3) };
161/// // this is safe, but note that ptr1 and ptr2 alias...
162/// let mut ptr1: NonNull<i64> = matrix.entry_at(0, 0);
163/// let mut ptr2: NonNull<i64> = matrix.entry_at(0, 0);
164/// // this is UB now!
165/// let (ref1, ref2) = unsafe { (ptr1.as_mut(), ptr2.as_mut()) };
166/// ```
167///
168#[stability::unstable(feature = "enable")]
169pub struct SubmatrixRaw<V, T>
170    where V: AsPointerToSlice<T>
171{
172    entry: PhantomData<*mut T>,
173    rows: NonNull<V>,
174    row_count: usize,
175    row_step: isize,
176    col_start: usize,
177    col_count: usize
178}
179
180///
181/// Requiring `T: Sync` is the more conservative choice. If `SubmatrixRaw`
182/// acts as a mutable reference, we would only require `T: Send`, but we also
183/// want `SubmatrixRaw` to be usable as an immutable reference, thus it can be
184/// shared between threads, which requires `T: Sync`.
185/// 
186unsafe impl<V, T> Send for SubmatrixRaw<V, T> 
187    where V: AsPointerToSlice<T> + Sync, T: Sync
188{}
189
190unsafe impl<V, T> Sync for SubmatrixRaw<V, T> 
191    where V: AsPointerToSlice<T> + Sync, T: Sync
192{}
193
194impl<V, T> Clone for SubmatrixRaw<V, T> 
195    where V: AsPointerToSlice<T>
196{
197    fn clone(&self) -> Self {
198        *self
199    }
200}
201
202impl<V, T> Copy for SubmatrixRaw<V, T> 
203    where V: AsPointerToSlice<T>
204{}
205
206impl<V, T> SubmatrixRaw<V, T> 
207    where V: AsPointerToSlice<T>
208{
209    ///
210    /// Create a new SubmatrixRaw object.
211    /// 
212    /// # Safety
213    /// 
214    /// We require that each pointer `rows.offset(row_step * i)` for `0 <= i < row_count` points to a
215    /// valid object and can be dereferenced. Furthermore, if `ptr` is the pointer returned by `[`AsPointerToSlice::get_pointer()`]`,
216    /// then `ptr.offset(i + cols_start)` must point to a valid `T` for `0 <= i < col_count`.
217    /// 
218    /// Furthermore, we require any two of these (for different i) to represent disjunct "slices", i.e. if they
219    /// give pointers `ptr1` and `ptr2` (via [`AsPointerToSlice::get_pointer()`]), then `ptr1.offset(cols_start + k)` and
220    /// `ptr2.offset(cols_start + l)` for `0 <= k, l < col_count` never alias.
221    /// 
222    #[stability::unstable(feature = "enable")]
223    pub unsafe fn new(rows: NonNull<V>, row_count: usize, row_step: isize, cols_start: usize, col_count: usize) -> Self {
224        Self {
225            entry: PhantomData,
226            row_count: row_count,
227            row_step: row_step,
228            rows: rows,
229            col_start: cols_start,
230            col_count
231        }
232    }
233
234    #[stability::unstable(feature = "enable")]
235    pub fn restrict_rows(mut self, rows: Range<usize>) -> Self {
236        assert!(rows.start <= rows.end);
237        assert!(rows.end <= self.row_count);
238        // this is safe since we require (during the constructor) that all pointers `rows.offset(i * row_step)`
239        // are valid for `0 <= i < row_count`. Technically, this is not completely legal, as in the case 
240        // `rows.start == rows.end == row_count`, the resulting pointer might point outside of the allocated area
241        // - this is legal only when we are exactly one byte after it, but if `row_step` has a weird value, this does
242        // not work. However, `row_step` has suitable values in all safe interfaces.
243        unsafe {
244            self.row_count = rows.end - rows.start;
245            self.rows = self.rows.offset(rows.start as isize * self.row_step);
246        }
247        self
248    }
249
250    #[stability::unstable(feature = "enable")]
251    pub fn restrict_cols(mut self, cols: Range<usize>) -> Self {
252        assert!(cols.end <= self.col_count);
253        self.col_count = cols.end - cols.start;
254        self.col_start += cols.start;
255        self
256    }
257
258    ///
259    /// Returns a pointer to the `row`-th row of the matrix.
260    /// Be carefull about aliasing when making this into a reference!
261    /// 
262    #[stability::unstable(feature = "enable")]
263    pub fn row_at(&self, row: usize) -> NonNull<[T]> {
264        assert!(row < self.row_count);
265        // this is safe since `row < row_count` and we require `rows.offset(row * row_step)` to point
266        // to a valid element of `V`
267        let row_ref = unsafe {
268            V::get_pointer(self.rows.offset(row as isize * self.row_step))
269        };
270        // similarly safe by constructor requirements
271        unsafe {
272            NonNull::slice_from_raw_parts(row_ref.offset(self.col_start as isize), self.col_count)
273        }
274    }
275
276    ///
277    /// Returns a pointer to the `(row, col)`-th entry of the matrix.
278    /// Be carefull about aliasing when making this into a reference!
279    /// 
280    #[stability::unstable(feature = "enable")]
281    pub fn entry_at(&self, row: usize, col: usize) -> NonNull<T> {
282        assert!(row < self.row_count, "Row index {} out of range 0..{}", row, self.row_count);
283        assert!(col < self.col_count, "Col index {} out of range 0..{}", col, self.col_count);
284        // this is safe since `row < row_count` and we require `rows.offset(row * row_step)` to point
285        // to a valid element of `V`
286        let row_ref = unsafe {
287            V::get_pointer(self.rows.offset(row as isize * self.row_step))
288        };
289        // similarly safe by constructor requirements
290        unsafe {
291            row_ref.offset(self.col_start as isize + col as isize)
292        }
293    }
294}
295
296///
297/// An immutable view on a column of a matrix [`Submatrix`]. 
298/// 
299pub struct Column<'a, V, T>
300    where V: AsPointerToSlice<T>
301{
302    entry: PhantomData<&'a T>,
303    raw_data: SubmatrixRaw<V, T>
304}
305
306impl<'a, V, T> Column<'a, V, T>
307    where V: AsPointerToSlice<T>
308{
309    ///
310    /// Creates a new column object representing the given submatrix. Thus,
311    /// the submatrix must only have one column.
312    /// 
313    /// # Safety
314    /// 
315    /// Since `Column` represents immutable borrowing, callers of this method
316    /// must ensure that for the lifetime `'a`, there are no mutable references
317    /// to any object pointed to by `raw_data` (this includes both the "row descriptors"
318    /// `V` and the actual elements `T`).
319    /// 
320    unsafe fn new(raw_data: SubmatrixRaw<V, T>) -> Self {
321        assert!(raw_data.col_count == 1);
322        Self {
323            entry: PhantomData,
324            raw_data: raw_data
325        }
326    }
327}
328
329impl<'a, V, T> Clone for Column<'a, V, T>
330    where V: AsPointerToSlice<T>
331{
332    fn clone(&self) -> Self {
333        *self
334    }
335}
336
337impl<'a, V, T> Copy for Column<'a, V, T>
338    where V: AsPointerToSlice<T>
339{}
340
341impl<'a, V, T> VectorView<T> for Column<'a, V, T>
342    where V: AsPointerToSlice<T>
343{
344    fn len(&self) -> usize {
345        self.raw_data.row_count
346    }
347
348    fn at(&self, i: usize) -> &T {
349        // safe since we assume that there are no mutable references to `raw_data` 
350        unsafe {
351            self.raw_data.entry_at(i, 0).as_ref()
352        }
353    }
354}
355
356///
357///
358/// A mutable view on a column of a matrix [`SubmatrixMut`]. 
359/// 
360/// Clearly must not be Copy/Clone.
361/// 
362pub struct ColumnMut<'a, V, T>
363    where V: AsPointerToSlice<T>
364{
365    entry: PhantomData<&'a mut T>,
366    raw_data: SubmatrixRaw<V, T>
367}
368
369impl<'a, V, T> ColumnMut<'a, V, T>
370    where V: AsPointerToSlice<T>
371{
372    ///
373    /// Creates a new column object representing the given submatrix. Thus,
374    /// the submatrix must only have one column.
375    /// 
376    /// # Safety
377    /// 
378    /// Since `ColumnMut` represents mutable borrowing, callers of this method
379    /// must ensure that for the lifetime `'a`, there are no other references
380    /// to any matrix entry pointed to by `raw_data` (meaning the "content" elements
381    /// `T`). It is allowed to have immutable references to the "row descriptors" `V`
382    /// (assuming they are strictly different from the content `T`).
383    /// 
384    unsafe fn new(raw_data: SubmatrixRaw<V, T>) -> Self {
385        assert!(raw_data.col_count == 1);
386        Self {
387            entry: PhantomData,
388            raw_data: raw_data
389        }
390    }
391    
392    pub fn reborrow<'b>(&'b mut self) -> ColumnMut<'b, V, T> {
393        ColumnMut {
394            entry: PhantomData,
395            raw_data: self.raw_data
396        }
397    }
398
399    pub fn two_entries<'b>(&'b mut self, i: usize, j: usize) -> (&'b mut T, &'b mut T) {
400        assert!(i != j);
401        // safe since i != j
402        unsafe {
403            (self.raw_data.entry_at(i, 0).as_mut(), self.raw_data.entry_at(j, 0).as_mut())
404        }
405    }
406    
407    pub fn as_const<'b>(&'b self) -> Column<'b, V, T> {
408        Column {
409            entry: PhantomData,
410            raw_data: self.raw_data
411        }
412    }
413}
414
415unsafe impl<'a, V, T> Send for ColumnMut<'a, V, T>
416    where V: AsPointerToSlice<T> + Sync, T: Send
417{}
418
419impl<'a, V, T> VectorView<T> for ColumnMut<'a, V, T>
420    where V: AsPointerToSlice<T>
421{
422    fn len(&self) -> usize {
423        self.raw_data.row_count
424    }
425
426    fn at(&self, i: usize) -> &T {
427        // safe since we assume that there are no other references to `raw_data` 
428        unsafe {
429            self.raw_data.entry_at(i, 0).as_ref()
430        }
431    }
432}
433
434///
435/// Iterator over mutable references to the entries of a column
436/// of a matrix [`SubmatrixMut`].
437/// 
438pub struct ColumnMutIter<'a, V, T> 
439    where V: AsPointerToSlice<T>
440{
441    column_mut: ColumnMut<'a, V, T>
442}
443
444impl<'a, V, T> Iterator for ColumnMutIter<'a, V, T>
445    where V: AsPointerToSlice<T>
446{
447    type Item = &'a mut T;
448
449    fn next(&mut self) -> Option<Self::Item> {
450        if self.column_mut.raw_data.row_count > 0 {
451            let mut result = self.column_mut.raw_data.entry_at(0, 0);
452            self.column_mut.raw_data = self.column_mut.raw_data.restrict_rows(1..self.column_mut.raw_data.row_count);
453            // safe since for the result lifetime, one cannot legally access this value using only the new value of `self.column_mut.raw_data`
454            unsafe {
455                Some(result.as_mut())
456            }
457        } else {
458            None
459        }
460    }
461}
462
463impl<'a, V, T> IntoIterator for ColumnMut<'a, V, T>
464    where V: AsPointerToSlice<T>
465{
466    type Item = &'a mut T;
467    type IntoIter = ColumnMutIter<'a, V, T>;
468
469    fn into_iter(self) -> ColumnMutIter<'a, V, T> {
470        ColumnMutIter { column_mut: self }
471    }
472}
473
474impl<'a, V, T> VectorViewMut<T> for ColumnMut<'a, V, T>
475    where V: AsPointerToSlice<T>
476{
477    fn at_mut<'b>(&'b mut self, i: usize) -> &'b mut T {
478        // safe since self is borrow mutably
479        unsafe {
480            self.raw_data.entry_at(i, 0).as_mut()
481        }
482    }
483}
484
485impl<'a, V, T> SwappableVectorViewMut<T> for ColumnMut<'a, V, T>
486    where V: AsPointerToSlice<T>
487{
488    fn swap(&mut self, i: usize, j: usize) {
489        if i != j {
490            // safe since i != j, so these pointers do not alias; I think it is also safe for
491            // zero sized type, even though it is slightly weird since the pointer might point
492            // to the same location even if i != j
493            unsafe {
494                std::mem::swap(self.raw_data.entry_at(i, 0).as_mut(), self.raw_data.entry_at(j, 0).as_mut());
495            }
496        }
497    }
498}
499
500///
501/// Immutable view on a matrix that stores elements of type `T`.
502/// 
503/// This view is designed to work with various underlying representations
504/// of the matrix, as described by [`AsPointerToSlice`].
505/// 
506pub struct Submatrix<'a, V, T>
507    where V: 'a + AsPointerToSlice<T>
508{
509    entry: PhantomData<&'a T>,
510    raw_data: SubmatrixRaw<V, T>
511}
512
513impl<'a, V, T> Submatrix<'a, V, T>
514    where V: 'a + AsPointerToSlice<T>
515{
516    pub fn submatrix(self, rows: Range<usize>, cols: Range<usize>) -> Self {
517        self.restrict_rows(rows).restrict_cols(cols)
518    }
519
520    pub fn restrict_rows(self, rows: Range<usize>) -> Self {
521        Self {
522            entry: PhantomData,
523            raw_data: self.raw_data.restrict_rows(rows)
524        }
525    }
526
527    pub fn into_at(self, i: usize, j: usize) -> &'a T {
528        &self.into_row_at(i)[j]
529    }
530    
531    pub fn at<'b>(&'b self, i: usize, j: usize) -> &'b T {
532        &self.row_at(i)[j]
533    }
534
535    pub fn restrict_cols(self, cols: Range<usize>) -> Self {
536        Self {
537            entry: PhantomData,
538            raw_data: self.raw_data.restrict_cols(cols)
539        }
540    }
541
542    pub fn row_iter(self) -> impl 'a + Clone + ExactSizeIterator<Item = &'a [T]> {
543        (0..self.raw_data.row_count).map(move |i| 
544        // safe since there are no immutable references to self.raw_data    
545        unsafe {
546            self.raw_data.row_at(i).as_ref()
547        })
548    }
549
550    pub fn col_iter(self) -> impl 'a + Clone + ExactSizeIterator<Item = Column<'a, V, T>> {
551        (0..self.raw_data.col_count).map(move |j| {
552            debug_assert!(j < self.raw_data.col_count);
553            let mut result_raw = self.raw_data;
554            result_raw.col_start += j;
555            result_raw.col_count = 1;
556            // safe since there are no immutable references to self.raw_data
557            unsafe {
558                return Column::new(result_raw);
559            }
560        })
561    }
562
563    pub fn into_row_at(self, i: usize) -> &'a [T] {
564        // safe since there are no mutable references to self.raw_data
565        unsafe {
566            self.raw_data.row_at(i).as_ref()
567        }
568    }
569
570    pub fn row_at<'b>(&'b self, i: usize) -> &'b [T] {
571        // safe since there are no immutable references to self.raw_data
572        unsafe {
573            self.raw_data.row_at(i).as_ref()
574        }
575    }
576
577    pub fn into_col_at(self, j: usize) -> Column<'a, V, T> {
578        assert!(j < self.raw_data.col_count);
579        let mut result_raw = self.raw_data;
580        result_raw.col_start += j;
581        result_raw.col_count = 1;
582        // safe since there are no immutable references to self.raw_data
583        unsafe {
584            return Column::new(result_raw);
585        }
586    }
587
588    pub fn col_at<'b>(&'b self, j: usize) -> Column<'b, V, T> {
589        assert!(j < self.raw_data.col_count);
590        let mut result_raw = self.raw_data;
591        result_raw.col_start += j;
592        result_raw.col_count = 1;
593        // safe since there are no immutable references to self.raw_data
594        unsafe {
595            return Column::new(result_raw);
596        }
597    }
598
599    pub fn col_count(&self) -> usize {
600        self.raw_data.col_count
601    }
602
603    pub fn row_count(&self) -> usize {
604        self.raw_data.row_count
605    }
606}
607
608impl<'a, V, T> Clone for Submatrix<'a, V, T>
609    where V: 'a + AsPointerToSlice<T>
610{
611    fn clone(&self) -> Self {
612        *self
613    }
614}
615
616impl<'a, V, T> Copy for Submatrix<'a, V, T>
617    where V: 'a + AsPointerToSlice<T>
618{}
619
620///
621/// Mutable view on a matrix that stores elements of type `T`.
622/// 
623/// This view is designed to work with various underlying representations
624/// of the matrix, as described by [`AsPointerToSlice`].
625/// 
626pub struct SubmatrixMut<'a, V, T>
627    where V: 'a + AsPointerToSlice<T>
628{
629    entry: PhantomData<&'a mut T>,
630    raw_data: SubmatrixRaw<V, T>
631}
632
633impl<'a, V, T> SubmatrixMut<'a, V, T>
634    where V: 'a + AsPointerToSlice<T>
635{
636    pub fn submatrix(self, rows: Range<usize>, cols: Range<usize>) -> Self {
637        self.restrict_rows(rows).restrict_cols(cols)
638    }
639
640    pub fn restrict_rows(self, rows: Range<usize>) -> Self {
641        Self {
642            entry: PhantomData,
643            raw_data: self.raw_data.restrict_rows(rows)
644        }
645    }
646
647    pub fn restrict_cols(self, cols: Range<usize>) -> Self {
648        Self {
649            entry: PhantomData,
650            raw_data: self.raw_data.restrict_cols(cols)
651        }
652    }
653
654    pub fn split_rows(self, fst_rows: Range<usize>, snd_rows: Range<usize>) -> (Self, Self) {
655        assert!(fst_rows.end <= snd_rows.start || snd_rows.end <= fst_rows.start);
656        (
657            Self {
658                entry: PhantomData,
659                raw_data: self.raw_data.restrict_rows(fst_rows)
660            },
661            Self {
662                entry: PhantomData,
663                raw_data: self.raw_data.restrict_rows(snd_rows)
664            },
665        )
666    }
667
668    pub fn split_cols(self, fst_cols: Range<usize>, snd_cols: Range<usize>) -> (Self, Self) {
669        assert!(fst_cols.end <= snd_cols.start || snd_cols.end <= fst_cols.start);
670        (
671            Self {
672                entry: PhantomData,
673                raw_data: self.raw_data.restrict_cols(fst_cols)
674            },
675            Self {
676                entry: PhantomData,
677                raw_data: self.raw_data.restrict_cols(snd_cols)
678            },
679        )
680    }
681
682    pub fn row_iter(self) -> impl 'a + ExactSizeIterator<Item = &'a mut [T]> {
683        (0..self.raw_data.row_count).map(move |i| 
684        // safe since each access goes to a different location
685        unsafe {
686            self.raw_data.row_at(i).as_mut()
687        })
688    }
689
690    pub fn col_iter(self) -> impl 'a + ExactSizeIterator<Item = ColumnMut<'a, V, T>> {
691        (0..self.raw_data.col_count).map(move |j| {
692            let mut result_raw = self.raw_data;
693            result_raw.col_start += j;
694            result_raw.col_count = 1;
695            // safe since each time, the `result_raw` don't overlap
696            unsafe {
697                return ColumnMut::new(result_raw);
698            }
699        })
700    }
701
702    pub fn into_at_mut(self, i: usize, j: usize) -> &'a mut T {
703        &mut self.into_row_mut_at(i)[j]
704    }
705
706    pub fn at_mut<'b>(&'b mut self, i: usize, j: usize) -> &'b mut T {
707        &mut self.row_mut_at(i)[j]
708    }
709
710    pub fn at<'b>(&'b self, i: usize, j: usize) -> &'b T {
711        self.as_const().into_at(i, j)
712    }
713
714    pub fn row_at<'b>(&'b self, i: usize) -> &'b [T] {
715        self.as_const().into_row_at(i)
716    }
717
718    pub fn into_row_mut_at(self, i: usize) -> &'a mut [T] {
719        // safe since self is exists borrowed for 'a
720        unsafe {
721            self.raw_data.row_at(i).as_mut()
722        }
723    }
724
725    pub fn row_mut_at<'b>(&'b mut self, i: usize) -> &'b mut [T] {
726        self.reborrow().into_row_mut_at(i)
727    }
728
729    pub fn col_at<'b>(&'b self, j: usize) -> Column<'b, V, T> {
730        self.as_const().into_col_at(j)
731    }
732
733    pub fn col_mut_at<'b>(&'b mut self, j: usize) -> ColumnMut<'b, V, T> {
734        assert!(j < self.raw_data.col_count);
735        let mut result_raw = self.raw_data;
736        result_raw.col_start += j;
737        result_raw.col_count = 1;
738        // safe since self is mutably borrowed for 'b
739        unsafe {
740            return ColumnMut::new(result_raw);
741        }
742    }
743
744    pub fn reborrow<'b>(&'b mut self) -> SubmatrixMut<'b, V, T> {
745        SubmatrixMut {
746            entry: PhantomData,
747            raw_data: self.raw_data
748        }
749    }
750
751    pub fn as_const<'b>(&'b self) -> Submatrix<'b, V, T> {
752        Submatrix {
753            entry: PhantomData,
754            raw_data: self.raw_data
755        }
756    }
757
758    pub fn col_count(&self) -> usize {
759        self.raw_data.col_count
760    }
761
762    pub fn row_count(&self) -> usize {
763        self.raw_data.row_count
764    }
765}
766
767impl<'a, T> SubmatrixMut<'a, AsFirstElement<T>, T> {
768
769    ///
770    /// Creates a view on the given data slice, interpreting it as a matrix of given shape.
771    /// Assumes row-major order, i.e. contigous subslices of `data` will be the rows of the
772    /// resulting matrix.
773    /// 
774    pub fn from_1d(data: &'a mut [T], row_count: usize, col_count: usize) -> Self {
775        assert_eq!(row_count * col_count, data.len());
776        unsafe {
777            Self {
778                entry: PhantomData,
779                raw_data: SubmatrixRaw::new(std::mem::transmute(NonNull::new(data.as_mut_ptr()).unwrap_unchecked()), row_count, col_count as isize, 0, col_count)
780            }
781        }
782    }
783
784    #[doc(cfg(feature = "ndarray"))]
785    #[cfg(feature = "ndarray")]
786    pub fn from_ndarray<S>(data: &'a mut ArrayBase<S, Ix2>) -> Self
787        where S: DataMut<Elem = T>
788    {
789        assert!(data.is_standard_layout());
790        let (nrows, ncols) = (data.nrows(), data.ncols());
791        return Self::new(data.as_slice_mut().unwrap(), nrows, ncols);
792    }
793}
794
795impl<'a, V: AsPointerToSlice<T> + Deref<Target = [T]>, T> SubmatrixMut<'a, V, T> {
796
797    ///
798    /// Interprets the given slice of slices as a matrix, by using the elements
799    /// of the outer slice as the rows of the matrix.
800    /// 
801    pub fn from_2d(data: &'a mut [V]) -> Self {
802        assert!(data.len() > 0);
803        let row_count = data.len();
804        let col_count = data[0].len();
805        for row in data.iter() {
806            assert_eq!(col_count, row.len());
807        }
808        unsafe {
809            Self {
810                entry: PhantomData,
811                raw_data: SubmatrixRaw::new(NonNull::new(data.as_mut_ptr() as *mut _).unwrap_unchecked(), row_count, 1, 0, col_count)
812            }
813        }
814    }
815}
816
817impl<'a, T> Submatrix<'a, AsFirstElement<T>, T> {
818
819    ///
820    /// Creates a view on the given data slice, interpreting it as a matrix of given shape.
821    /// Assumes row-major order, i.e. contigous subslices of `data` will be the rows of the
822    /// resulting matrix.
823    /// 
824    pub fn from_1d(data: &'a [T], row_count: usize, col_count: usize) -> Self {
825        assert_eq!(row_count * col_count, data.len());
826        unsafe {
827            Self {
828                entry: PhantomData,
829                raw_data: SubmatrixRaw::new(std::mem::transmute(NonNull::new(data.as_ptr() as *mut T).unwrap_unchecked()), row_count, col_count as isize, 0, col_count)
830            }
831        }
832    }
833
834    #[doc(cfg(feature = "ndarray"))]
835    #[cfg(feature = "ndarray")]
836    pub fn from_ndarray<S>(data: &'a ArrayBase<S, Ix2>) -> Self
837        where S: DataMut<Elem = T>
838    {
839        let (nrows, ncols) = (data.nrows(), data.ncols());
840        return Self::new(data.as_slice().unwrap(), nrows, ncols);
841    }
842}
843
844impl<'a, V: AsPointerToSlice<T> + Deref<Target = [T]>, T> Submatrix<'a, V, T> {
845
846    ///
847    /// Interprets the given slice of slices as a matrix, by using the elements
848    /// of the outer slice as the rows of the matrix.
849    /// 
850    pub fn from_2d(data: &'a [V]) -> Self {
851        assert!(data.len() > 0);
852        let row_count = data.len();
853        let col_count = data[0].len();
854        for row in data.iter() {
855            assert_eq!(col_count, row.len());
856        }
857        unsafe {
858            Self {
859                entry: PhantomData,
860                raw_data: SubmatrixRaw::new(NonNull::new(data.as_ptr() as *mut _).unwrap_unchecked(), row_count, 1, 0, col_count)
861            }
862        }
863    }
864}
865
866#[cfg(test)]
867use std::fmt::Debug;
868
869#[cfg(test)]
870fn assert_submatrix_eq<V: AsPointerToSlice<T>, T: PartialEq + Debug, const N: usize, const M: usize>(expected: [[T; M]; N], actual: &mut SubmatrixMut<V, T>) {
871    assert_eq!(N, actual.row_count());
872    assert_eq!(M, actual.col_count());
873    for i in 0..N {
874        for j in 0..M {
875            assert_eq!(&expected[i][j], actual.at(i, j));
876            assert_eq!(&expected[i][j], actual.as_const().at(i, j));
877        }
878    }
879}
880
881#[cfg(test)]
882fn with_testmatrix_vec<F>(f: F)
883    where F: FnOnce(SubmatrixMut<Vec<i64>, i64>)
884{
885    let mut data = vec![
886        vec![1, 2, 3, 4, 5],
887        vec![6, 7, 8, 9, 10],
888        vec![11, 12, 13, 14, 15]
889    ];
890    let matrix = SubmatrixMut::<Vec<_>, _>::from_2d(&mut data[..]);
891    f(matrix)
892}
893
894#[cfg(test)]
895fn with_testmatrix_array<F>(f: F)
896    where F: FnOnce(SubmatrixMut<DerefArray<i64, 5>, i64>)
897{
898    let mut data = vec![
899        DerefArray::from([1, 2, 3, 4, 5]),
900        DerefArray::from([6, 7, 8, 9, 10]),
901        DerefArray::from([11, 12, 13, 14, 15])
902    ];
903    let matrix = SubmatrixMut::<DerefArray<_, 5>, _>::from_2d(&mut data[..]);
904    f(matrix)
905}
906
907#[cfg(test)]
908fn with_testmatrix_linmem<F>(f: F)
909    where F: FnOnce(SubmatrixMut<AsFirstElement<i64>, i64>)
910{
911    let mut data = vec![
912        1, 2, 3, 4, 5,
913        6, 7, 8, 9, 10,
914        11, 12, 13, 14, 15
915    ];
916    let matrix = SubmatrixMut::<AsFirstElement<_>, _>::from_1d(&mut data[..], 3, 5);
917    f(matrix)
918}
919
920#[cfg(feature = "ndarray")]
921#[cfg(test)]
922fn with_testmatrix_ndarray<F>(f: F)
923    where F: FnOnce(SubmatrixMut<AsFirstElement<i64>, i64>)
924{
925    use ndarray::array;
926
927    let mut data = array![
928        [1, 2, 3, 4, 5],
929        [6, 7, 8, 9, 10],
930        [11, 12, 13, 14, 15]
931    ];
932    let matrix = SubmatrixMut::<AsFirstElement<_>, _>::from_ndarray(&mut data);
933    f(matrix)
934}
935
936#[cfg(not(feature = "ndarray"))]
937#[cfg(test)]
938fn with_testmatrix_ndarray<F>(_: F)
939    where F: FnOnce(SubmatrixMut<AsFirstElement<i64>, i64>)
940{
941    // do nothing
942}
943
944#[cfg(test)]
945fn test_submatrix<V: AsPointerToSlice<i64>>(mut matrix: SubmatrixMut<V, i64>) {
946    assert_submatrix_eq([
947        [1, 2, 3, 4, 5],
948        [6, 7, 8, 9, 10],
949        [11, 12, 13, 14, 15]
950    ], &mut matrix);
951    assert_submatrix_eq([[2, 3], [7, 8]], &mut matrix.reborrow().submatrix(0..2, 1..3));
952    assert_submatrix_eq([[8, 9, 10]], &mut matrix.reborrow().submatrix(1..2, 2..5));
953    assert_submatrix_eq([[8, 9, 10], [13, 14, 15]], &mut matrix.reborrow().submatrix(1..3, 2..5));
954
955    let (mut left, mut right) = matrix.split_cols(0..3, 3..5);
956    assert_submatrix_eq([[1, 2, 3], [6, 7, 8], [11, 12, 13]], &mut left);
957    assert_submatrix_eq([[4, 5], [9, 10], [14, 15]], &mut right);
958}
959
960#[test]
961fn test_submatrix_wrapper() {
962    with_testmatrix_vec(test_submatrix);
963    with_testmatrix_array(test_submatrix);
964    with_testmatrix_linmem(test_submatrix);
965    with_testmatrix_ndarray(test_submatrix);
966}
967
968#[cfg(test)]
969fn test_submatrix_mutate<V: AsPointerToSlice<i64>>(mut matrix: SubmatrixMut<V, i64>) {
970    assert_submatrix_eq([
971        [1, 2, 3, 4, 5],
972        [6, 7, 8, 9, 10],
973        [11, 12, 13, 14, 15]
974    ], &mut matrix);
975    let (mut left, mut right) = matrix.split_cols(0..3, 3..5);
976    assert_submatrix_eq([[1, 2, 3], [6, 7, 8], [11, 12, 13]], &mut left);
977    assert_submatrix_eq([[4, 5], [9, 10], [14, 15]], &mut right);
978    *left.at_mut(1, 1) += 1;
979    *right.at_mut(0, 0) += 1;
980    *right.at_mut(2, 1) += 1;
981    assert_submatrix_eq([[1, 2, 3], [6, 8, 8], [11, 12, 13]], &mut left);
982    assert_submatrix_eq([[5, 5], [9, 10], [14, 16]], &mut right);
983
984    let (mut top, mut bottom) = left.split_rows(0..1, 1..3);
985    assert_submatrix_eq([[1, 2, 3]], &mut top);
986    assert_submatrix_eq([[6, 8, 8], [11, 12, 13]], &mut bottom);
987    *top.at_mut(0, 0) -= 1;
988    *top.at_mut(0, 2) += 3;
989    *bottom.at_mut(0, 2) -= 1;
990    *bottom.at_mut(1, 0) += 3;
991    assert_submatrix_eq([[0, 2, 6]], &mut top);
992    assert_submatrix_eq([[6, 8, 7], [14, 12, 13]], &mut bottom);
993}
994
995#[test]
996fn test_submatrix_mutate_wrapper() {
997    with_testmatrix_vec(test_submatrix_mutate);
998    with_testmatrix_array(test_submatrix_mutate);
999    with_testmatrix_linmem(test_submatrix_mutate);
1000    with_testmatrix_ndarray(test_submatrix_mutate);
1001}
1002
1003#[cfg(test)]
1004fn test_submatrix_col_iter<V: AsPointerToSlice<i64>>(mut matrix: SubmatrixMut<V, i64>) {
1005    assert_submatrix_eq([
1006        [1, 2, 3, 4, 5],
1007        [6, 7, 8, 9, 10],
1008        [11, 12, 13, 14, 15]
1009    ], &mut matrix);
1010    {
1011        let mut it = matrix.reborrow().col_iter();
1012        assert_eq!(vec![2, 7, 12], it.by_ref().skip(1).next().unwrap().into_iter().map(|x| *x).collect::<Vec<_>>());
1013        assert_eq!(vec![4, 9, 14], it.by_ref().skip(1).next().unwrap().into_iter().map(|x| *x).collect::<Vec<_>>());
1014        let mut last_col = it.next().unwrap();
1015        for x in last_col.reborrow() {
1016            *x *= 2;
1017        }
1018        assert_eq!(vec![10, 20, 30], last_col.into_iter().map(|x| *x).collect::<Vec<_>>());
1019    }
1020    assert_submatrix_eq([
1021        [1, 2, 3, 4, 10],
1022        [6, 7, 8, 9, 20],
1023        [11, 12, 13, 14, 30]], 
1024        &mut matrix
1025    );
1026    
1027    let (left, _right) = matrix.reborrow().split_cols(0..2, 3..4);
1028    {
1029        let mut it = left.col_iter();
1030        let mut col1 = it.next().unwrap();
1031        let mut col2 = it.next().unwrap();
1032        assert!(it.next().is_none());
1033        assert_eq!(vec![1, 6, 11], col1.as_iter().map(|x| *x).collect::<Vec<_>>());
1034        assert_eq!(vec![2, 7, 12], col2.as_iter().map(|x| *x).collect::<Vec<_>>());
1035        assert_eq!(vec![1, 6, 11], col1.reborrow().into_iter().map(|x| *x).collect::<Vec<_>>());
1036        assert_eq!(vec![2, 7, 12], col2.reborrow().into_iter().map(|x| *x).collect::<Vec<_>>());
1037        *col1.into_iter().skip(1).next().unwrap() += 5;
1038    }
1039    assert_submatrix_eq([
1040        [1, 2, 3, 4, 10],
1041        [11, 7, 8, 9, 20],
1042        [11, 12, 13, 14, 30]], 
1043        &mut matrix
1044    );
1045
1046    let (_left, right) = matrix.reborrow().split_cols(0..2, 3..4);
1047    {
1048        let mut it = right.col_iter();
1049        let mut col = it.next().unwrap();
1050        assert!(it.next().is_none());
1051        assert_eq!(vec![4, 9, 14], col.reborrow().as_iter().map(|x| *x).collect::<Vec<_>>());
1052        *col.into_iter().next().unwrap() += 3;
1053    }
1054    assert_submatrix_eq([
1055        [1, 2, 3, 7, 10],
1056        [11, 7, 8, 9, 20],
1057        [11, 12, 13, 14, 30]], 
1058        &mut matrix
1059    );
1060}
1061
1062#[test]
1063fn test_submatrix_col_iter_wrapper() {
1064    with_testmatrix_vec(test_submatrix_col_iter);
1065    with_testmatrix_array(test_submatrix_col_iter);
1066    with_testmatrix_linmem(test_submatrix_col_iter);
1067    with_testmatrix_ndarray(test_submatrix_col_iter);
1068}
1069
1070#[cfg(test)]
1071fn test_submatrix_row_iter<V: AsPointerToSlice<i64>>(mut matrix: SubmatrixMut<V, i64>) {
1072    assert_submatrix_eq([
1073        [1, 2, 3, 4, 5],
1074        [6, 7, 8, 9, 10],
1075        [11, 12, 13, 14, 15]
1076    ], &mut matrix);
1077    {
1078        let mut it = matrix.reborrow().row_iter();
1079        assert_eq!(&[6, 7, 8, 9, 10], it.by_ref().skip(1).next().unwrap());
1080        let row = it.next().unwrap();
1081        assert!(it.next().is_none());
1082        row[1] += 6;
1083        row[4] *= 2;
1084    }
1085    assert_submatrix_eq([
1086        [1, 2, 3, 4, 5],
1087        [6, 7, 8, 9, 10],
1088        [11, 18, 13, 14, 30]], 
1089        &mut matrix
1090    );
1091    let (mut left, mut right) = matrix.reborrow().split_cols(0..2, 3..4);
1092    {
1093        let mut it = left.reborrow().row_iter();
1094        let row1 = it.next().unwrap();
1095        let row2 = it.next().unwrap();
1096        assert!(it.next().is_some());
1097        assert!(it.next().is_none());
1098        assert_eq!(&[1, 2], row1);
1099        assert_eq!(&[6, 7], row2);
1100    }
1101    {
1102        let mut it = left.reborrow().row_iter();
1103        let row1 = it.next().unwrap();
1104        let row2 = it.next().unwrap();
1105        assert!(it.next().is_some());
1106        assert!(it.next().is_none());
1107        assert_eq!(&[1, 2], row1);
1108        assert_eq!(&[6, 7], row2);
1109        row2[1] += 1;
1110    }
1111    assert_submatrix_eq([[1, 2], [6, 8], [11, 18]], &mut left);
1112    {
1113        right = right.submatrix(1..3, 0..1);
1114        let mut it = right.reborrow().row_iter();
1115        let row1 = it.next().unwrap();
1116        let row2 = it.next().unwrap();
1117        assert_eq!(&[9], row1);
1118        assert_eq!(&[14], row2);
1119        row1[0] += 1;
1120    }
1121    assert_submatrix_eq([[10], [14]], &mut right);
1122}
1123
1124#[test]
1125fn test_submatrix_row_iter_wrapper() {
1126    with_testmatrix_vec(test_submatrix_row_iter);
1127    with_testmatrix_array(test_submatrix_row_iter);
1128    with_testmatrix_linmem(test_submatrix_row_iter);
1129    with_testmatrix_ndarray(test_submatrix_row_iter);
1130}
1131
1132#[cfg(test)]
1133fn test_submatrix_col_at<V: AsPointerToSlice<i64>>(mut matrix: SubmatrixMut<V, i64>) {
1134    assert_submatrix_eq([
1135        [1, 2, 3, 4, 5],
1136        [6, 7, 8, 9, 10],
1137        [11, 12, 13, 14, 15]
1138    ], &mut matrix);
1139    assert_eq!(&[2, 7, 12], &matrix.col_at(1).as_iter().copied().collect::<Vec<_>>()[..]);
1140    assert_eq!(&[2, 7, 12], &matrix.as_const().col_at(1).as_iter().copied().collect::<Vec<_>>()[..]);
1141    assert_eq!(&[5, 10, 15], &matrix.col_at(4).as_iter().copied().collect::<Vec<_>>()[..]);
1142    assert_eq!(&[5, 10, 15], &matrix.as_const().col_at(4).as_iter().copied().collect::<Vec<_>>()[..]);
1143
1144    {
1145        let (mut top, mut bottom) = matrix.reborrow().restrict_rows(0..2).split_rows(0..1, 1..2);
1146        assert_eq!(&[1], &top.col_mut_at(0).as_iter().copied().collect::<Vec<_>>()[..]);
1147        assert_eq!(&[1], &top.as_const().col_at(0).as_iter().copied().collect::<Vec<_>>()[..]);
1148        assert_eq!(&[1], &top.col_at(0).as_iter().copied().collect::<Vec<_>>()[..]);
1149        assert_eq!(&[5], &top.col_mut_at(4).as_iter().copied().collect::<Vec<_>>()[..]);
1150        assert_eq!(&[5], &top.as_const().col_at(4).as_iter().copied().collect::<Vec<_>>()[..]);
1151        assert_eq!(&[5], &top.col_at(4).as_iter().copied().collect::<Vec<_>>()[..]);
1152
1153        assert_eq!(&[6], &bottom.col_mut_at(0).as_iter().copied().collect::<Vec<_>>()[..]);
1154        assert_eq!(&[6], &bottom.as_const().col_at(0).as_iter().copied().collect::<Vec<_>>()[..]);
1155        assert_eq!(&[6], &bottom.col_at(0).as_iter().copied().collect::<Vec<_>>()[..]);
1156        assert_eq!(&[10], &bottom.col_mut_at(4).as_iter().copied().collect::<Vec<_>>()[..]);
1157        assert_eq!(&[10], &bottom.as_const().col_at(4).as_iter().copied().collect::<Vec<_>>()[..]);
1158        assert_eq!(&[10], &bottom.col_at(4).as_iter().copied().collect::<Vec<_>>()[..]);
1159    }
1160}
1161
1162#[test]
1163fn test_submatrix_col_at_wrapper() {
1164    with_testmatrix_vec(test_submatrix_col_at);
1165    with_testmatrix_array(test_submatrix_col_at);
1166    with_testmatrix_linmem(test_submatrix_col_at);
1167    with_testmatrix_ndarray(test_submatrix_col_at);
1168}
1169
1170#[cfg(test)]
1171fn test_submatrix_row_at<V: AsPointerToSlice<i64>>(mut matrix: SubmatrixMut<V, i64>) {
1172    assert_submatrix_eq([
1173        [1, 2, 3, 4, 5],
1174        [6, 7, 8, 9, 10],
1175        [11, 12, 13, 14, 15]
1176    ], &mut matrix);
1177    assert_eq!(&[2, 7, 12], &matrix.col_at(1).as_iter().copied().collect::<Vec<_>>()[..]);
1178    assert_eq!(&[2, 7, 12], &matrix.as_const().col_at(1).as_iter().copied().collect::<Vec<_>>()[..]);
1179    assert_eq!(&[5, 10, 15], &matrix.col_at(4).as_iter().copied().collect::<Vec<_>>()[..]);
1180    assert_eq!(&[5, 10, 15], &matrix.as_const().col_at(4).as_iter().copied().collect::<Vec<_>>()[..]);
1181
1182    {
1183        let (mut left, mut right) = matrix.reborrow().restrict_cols(1..5).split_cols(0..2, 2..4);
1184        assert_eq!(&[2, 3], left.row_mut_at(0));
1185        assert_eq!(&[4, 5], right.row_mut_at(0));
1186        assert_eq!(&[2, 3], left.as_const().row_at(0));
1187        assert_eq!(&[4, 5], right.as_const().row_at(0));
1188        assert_eq!(&[2, 3], left.row_at(0));
1189        assert_eq!(&[4, 5], right.row_at(0));
1190
1191        assert_eq!(&[7, 8], left.row_mut_at(1));
1192        assert_eq!(&[9, 10], right.row_mut_at(1));
1193        assert_eq!(&[7, 8], left.as_const().row_at(1));
1194        assert_eq!(&[9, 10], right.as_const().row_at(1));
1195        assert_eq!(&[7, 8], left.row_at(1));
1196        assert_eq!(&[9, 10], right.row_at(1));
1197    }
1198}
1199
1200#[test]
1201fn test_submatrix_row_at_wrapper() {
1202    with_testmatrix_vec(test_submatrix_row_at);
1203    with_testmatrix_array(test_submatrix_row_at);
1204    with_testmatrix_linmem(test_submatrix_row_at);
1205    with_testmatrix_ndarray(test_submatrix_row_at);
1206}