faer_ext/
lib.rs

1#![cfg_attr(docsrs, feature(doc_cfg))]
2
3#[cfg(feature = "ndarray")]
4use ndarray::{IntoDimension, ShapeArg};
5
6/// Conversions from external library matrix views into `faer` types.
7pub trait IntoFaer {
8    type Faer;
9    fn into_faer(self) -> Self::Faer;
10}
11
12#[cfg(feature = "nalgebra")]
13#[cfg_attr(docsrs, doc(cfg(feature = "nalgebra")))]
14/// Conversions from external library matrix views into `nalgebra` types.
15pub trait IntoNalgebra {
16    type Nalgebra;
17    fn into_nalgebra(self) -> Self::Nalgebra;
18}
19
20#[cfg(feature = "ndarray")]
21#[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))]
22/// Conversions from external library matrix views into `ndarray` types.
23pub trait IntoNdarray {
24    type Ndarray;
25    fn into_ndarray(self) -> Self::Ndarray;
26}
27
28/// Conversions from external library matrix views into complex `faer` types.
29pub trait IntoFaerComplex {
30    type Faer;
31    fn into_faer_complex(self) -> Self::Faer;
32}
33
34#[cfg(feature = "nalgebra")]
35#[cfg_attr(docsrs, doc(cfg(feature = "nalgebra")))]
36/// Conversions from external library matrix views into complex `nalgebra` types.
37pub trait IntoNalgebraComplex {
38    type Nalgebra;
39    fn into_nalgebra_complex(self) -> Self::Nalgebra;
40}
41
42#[cfg(feature = "ndarray")]
43#[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))]
44/// Conversions from external library matrix views into complex `ndarray` types.
45pub trait IntoNdarrayComplex {
46    type Ndarray;
47    fn into_ndarray_complex(self) -> Self::Ndarray;
48}
49
50#[cfg(feature = "nalgebra")]
51#[cfg_attr(docsrs, doc(cfg(feature = "nalgebra")))]
52const _: () = {
53    use faer::complex_native::*;
54    use faer::prelude::*;
55    use faer::SimpleEntity;
56    use nalgebra::{Dim, Dyn, MatrixView, MatrixViewMut, ViewStorage, ViewStorageMut};
57    use num_complex::{Complex32, Complex64};
58
59    impl<'a, T: SimpleEntity, R: Dim, C: Dim, RStride: Dim, CStride: Dim> IntoFaer
60        for MatrixView<'a, T, R, C, RStride, CStride>
61    {
62        type Faer = MatRef<'a, T>;
63
64        #[track_caller]
65        fn into_faer(self) -> Self::Faer {
66            let nrows = self.nrows();
67            let ncols = self.ncols();
68            let strides = self.strides();
69            let ptr = self.as_ptr();
70            unsafe {
71                faer::mat::from_raw_parts(
72                    ptr,
73                    nrows,
74                    ncols,
75                    strides.0.try_into().unwrap(),
76                    strides.1.try_into().unwrap(),
77                )
78            }
79        }
80    }
81
82    impl<'a, T: SimpleEntity, R: Dim, C: Dim, RStride: Dim, CStride: Dim> IntoFaer
83        for MatrixViewMut<'a, T, R, C, RStride, CStride>
84    {
85        type Faer = MatMut<'a, T>;
86
87        #[track_caller]
88        fn into_faer(self) -> Self::Faer {
89            let nrows = self.nrows();
90            let ncols = self.ncols();
91            let strides = self.strides();
92            let ptr = { self }.as_mut_ptr();
93            unsafe {
94                faer::mat::from_raw_parts_mut::<'_, T, _, _>(
95                    ptr,
96                    nrows,
97                    ncols,
98                    strides.0.try_into().unwrap(),
99                    strides.1.try_into().unwrap(),
100                )
101            }
102        }
103    }
104
105    impl<'a, T: SimpleEntity> IntoNalgebra for MatRef<'a, T> {
106        type Nalgebra = MatrixView<'a, T, Dyn, Dyn, Dyn, Dyn>;
107
108        #[track_caller]
109        fn into_nalgebra(self) -> Self::Nalgebra {
110            let nrows = self.nrows();
111            let ncols = self.ncols();
112            let row_stride = self.row_stride();
113            let col_stride = self.col_stride();
114            let ptr = self.as_ptr();
115            unsafe {
116                MatrixView::<'_, T, Dyn, Dyn, Dyn, Dyn>::from_data(ViewStorage::<
117                    '_,
118                    T,
119                    Dyn,
120                    Dyn,
121                    Dyn,
122                    Dyn,
123                >::from_raw_parts(
124                    ptr,
125                    (Dyn(nrows), Dyn(ncols)),
126                    (
127                        Dyn(row_stride.try_into().unwrap()),
128                        Dyn(col_stride.try_into().unwrap()),
129                    ),
130                ))
131            }
132        }
133    }
134
135    impl<'a, T: SimpleEntity> IntoNalgebra for MatMut<'a, T> {
136        type Nalgebra = MatrixViewMut<'a, T, Dyn, Dyn, Dyn, Dyn>;
137
138        #[track_caller]
139        fn into_nalgebra(self) -> Self::Nalgebra {
140            let nrows = self.nrows();
141            let ncols = self.ncols();
142            let row_stride = self.row_stride();
143            let col_stride = self.col_stride();
144            let ptr = self.as_ptr_mut();
145            unsafe {
146                MatrixViewMut::<'_, T, Dyn, Dyn, Dyn, Dyn>::from_data(ViewStorageMut::<
147                    '_,
148                    T,
149                    Dyn,
150                    Dyn,
151                    Dyn,
152                    Dyn,
153                >::from_raw_parts(
154                    ptr,
155                    (Dyn(nrows), Dyn(ncols)),
156                    (
157                        Dyn(row_stride.try_into().unwrap()),
158                        Dyn(col_stride.try_into().unwrap()),
159                    ),
160                ))
161            }
162        }
163    }
164
165    impl<'a, R: Dim, C: Dim, RStride: Dim, CStride: Dim> IntoFaerComplex
166        for MatrixView<'a, Complex32, R, C, RStride, CStride>
167    {
168        type Faer = MatRef<'a, c32>;
169
170        #[track_caller]
171        fn into_faer_complex(self) -> Self::Faer {
172            let nrows = self.nrows();
173            let ncols = self.ncols();
174            let strides = self.strides();
175            let ptr = self.as_ptr() as *const c32;
176            unsafe {
177                faer::mat::from_raw_parts(
178                    ptr,
179                    nrows,
180                    ncols,
181                    strides.0.try_into().unwrap(),
182                    strides.1.try_into().unwrap(),
183                )
184            }
185        }
186    }
187
188    impl<'a, R: Dim, C: Dim, RStride: Dim, CStride: Dim> IntoFaerComplex
189        for MatrixViewMut<'a, Complex32, R, C, RStride, CStride>
190    {
191        type Faer = MatMut<'a, c32>;
192
193        #[track_caller]
194        fn into_faer_complex(self) -> Self::Faer {
195            let nrows = self.nrows();
196            let ncols = self.ncols();
197            let strides = self.strides();
198            let ptr = { self }.as_mut_ptr() as *mut c32;
199            unsafe {
200                faer::mat::from_raw_parts_mut(
201                    ptr,
202                    nrows,
203                    ncols,
204                    strides.0.try_into().unwrap(),
205                    strides.1.try_into().unwrap(),
206                )
207            }
208        }
209    }
210
211    impl<'a> IntoNalgebraComplex for MatRef<'a, c32> {
212        type Nalgebra = MatrixView<'a, Complex32, Dyn, Dyn, Dyn, Dyn>;
213
214        #[track_caller]
215        fn into_nalgebra_complex(self) -> Self::Nalgebra {
216            let nrows = self.nrows();
217            let ncols = self.ncols();
218            let row_stride = self.row_stride();
219            let col_stride = self.col_stride();
220            let ptr = self.as_ptr() as *const Complex32;
221            unsafe {
222                MatrixView::<'_, Complex32, Dyn, Dyn, Dyn, Dyn>::from_data(ViewStorage::<
223                    '_,
224                    Complex32,
225                    Dyn,
226                    Dyn,
227                    Dyn,
228                    Dyn,
229                >::from_raw_parts(
230                    ptr,
231                    (Dyn(nrows), Dyn(ncols)),
232                    (
233                        Dyn(row_stride.try_into().unwrap()),
234                        Dyn(col_stride.try_into().unwrap()),
235                    ),
236                ))
237            }
238        }
239    }
240
241    impl<'a> IntoNalgebraComplex for MatMut<'a, c32> {
242        type Nalgebra = MatrixViewMut<'a, Complex32, Dyn, Dyn, Dyn, Dyn>;
243
244        #[track_caller]
245        fn into_nalgebra_complex(self) -> Self::Nalgebra {
246            let nrows = self.nrows();
247            let ncols = self.ncols();
248            let row_stride = self.row_stride();
249            let col_stride = self.col_stride();
250            let ptr = self.as_ptr_mut() as *mut Complex32;
251            unsafe {
252                MatrixViewMut::<'_, Complex32, Dyn, Dyn, Dyn, Dyn>::from_data(ViewStorageMut::<
253                    '_,
254                    Complex32,
255                    Dyn,
256                    Dyn,
257                    Dyn,
258                    Dyn,
259                >::from_raw_parts(
260                    ptr,
261                    (Dyn(nrows), Dyn(ncols)),
262                    (
263                        Dyn(row_stride.try_into().unwrap()),
264                        Dyn(col_stride.try_into().unwrap()),
265                    ),
266                ))
267            }
268        }
269    }
270
271    impl<'a, R: Dim, C: Dim, RStride: Dim, CStride: Dim> IntoFaerComplex
272        for MatrixView<'a, Complex64, R, C, RStride, CStride>
273    {
274        type Faer = MatRef<'a, c64>;
275
276        #[track_caller]
277        fn into_faer_complex(self) -> Self::Faer {
278            let nrows = self.nrows();
279            let ncols = self.ncols();
280            let strides = self.strides();
281            let ptr = self.as_ptr() as *const c64;
282            unsafe {
283                faer::mat::from_raw_parts(
284                    ptr,
285                    nrows,
286                    ncols,
287                    strides.0.try_into().unwrap(),
288                    strides.1.try_into().unwrap(),
289                )
290            }
291        }
292    }
293
294    impl<'a, R: Dim, C: Dim, RStride: Dim, CStride: Dim> IntoFaerComplex
295        for MatrixViewMut<'a, Complex64, R, C, RStride, CStride>
296    {
297        type Faer = MatMut<'a, c64>;
298
299        #[track_caller]
300        fn into_faer_complex(self) -> Self::Faer {
301            let nrows = self.nrows();
302            let ncols = self.ncols();
303            let strides = self.strides();
304            let ptr = { self }.as_mut_ptr() as *mut c64;
305            unsafe {
306                faer::mat::from_raw_parts_mut(
307                    ptr,
308                    nrows,
309                    ncols,
310                    strides.0.try_into().unwrap(),
311                    strides.1.try_into().unwrap(),
312                )
313            }
314        }
315    }
316
317    impl<'a> IntoNalgebraComplex for MatRef<'a, c64> {
318        type Nalgebra = MatrixView<'a, Complex64, Dyn, Dyn, Dyn, Dyn>;
319
320        #[track_caller]
321        fn into_nalgebra_complex(self) -> Self::Nalgebra {
322            let nrows = self.nrows();
323            let ncols = self.ncols();
324            let row_stride = self.row_stride();
325            let col_stride = self.col_stride();
326            let ptr = self.as_ptr() as *const Complex64;
327            unsafe {
328                MatrixView::<'_, Complex64, Dyn, Dyn, Dyn, Dyn>::from_data(ViewStorage::<
329                    '_,
330                    Complex64,
331                    Dyn,
332                    Dyn,
333                    Dyn,
334                    Dyn,
335                >::from_raw_parts(
336                    ptr,
337                    (Dyn(nrows), Dyn(ncols)),
338                    (
339                        Dyn(row_stride.try_into().unwrap()),
340                        Dyn(col_stride.try_into().unwrap()),
341                    ),
342                ))
343            }
344        }
345    }
346
347    impl<'a> IntoNalgebraComplex for MatMut<'a, c64> {
348        type Nalgebra = MatrixViewMut<'a, Complex64, Dyn, Dyn, Dyn, Dyn>;
349
350        #[track_caller]
351        fn into_nalgebra_complex(self) -> Self::Nalgebra {
352            let nrows = self.nrows();
353            let ncols = self.ncols();
354            let row_stride = self.row_stride();
355            let col_stride = self.col_stride();
356            let ptr = self.as_ptr_mut() as *mut Complex64;
357            unsafe {
358                MatrixViewMut::<'_, Complex64, Dyn, Dyn, Dyn, Dyn>::from_data(ViewStorageMut::<
359                    '_,
360                    Complex64,
361                    Dyn,
362                    Dyn,
363                    Dyn,
364                    Dyn,
365                >::from_raw_parts(
366                    ptr,
367                    (Dyn(nrows), Dyn(ncols)),
368                    (
369                        Dyn(row_stride.try_into().unwrap()),
370                        Dyn(col_stride.try_into().unwrap()),
371                    ),
372                ))
373            }
374        }
375    }
376};
377
378#[cfg(feature = "ndarray")]
379#[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))]
380const _: () = {
381    use faer::complex_native::*;
382    use faer::prelude::*;
383    use faer::SimpleEntity;
384    use ndarray::{ArrayView, ArrayViewMut, Ix2, ShapeBuilder};
385    use num_complex::{Complex32, Complex64};
386
387    impl<'a, T: SimpleEntity> IntoFaer for ArrayView<'a, T, Ix2> {
388        type Faer = MatRef<'a, T>;
389
390        #[track_caller]
391        fn into_faer(self) -> Self::Faer {
392            let nrows = self.nrows();
393            let ncols = self.ncols();
394            let strides: [isize; 2] = self.strides().try_into().unwrap();
395            let ptr = self.as_ptr();
396            unsafe { faer::mat::from_raw_parts(ptr, nrows, ncols, strides[0], strides[1]) }
397        }
398    }
399
400    impl<'a, T: SimpleEntity> IntoFaer for ArrayViewMut<'a, T, Ix2> {
401        type Faer = MatMut<'a, T>;
402
403        #[track_caller]
404        fn into_faer(self) -> Self::Faer {
405            let nrows = self.nrows();
406            let ncols = self.ncols();
407            let strides: [isize; 2] = self.strides().try_into().unwrap();
408            let ptr = { self }.as_mut_ptr();
409            unsafe {
410                faer::mat::from_raw_parts_mut::<'_, T, _, _>(
411                    ptr, nrows, ncols, strides[0], strides[1],
412                )
413            }
414        }
415    }
416
417    impl<'a, T: SimpleEntity> IntoNdarray for MatRef<'a, T> {
418        type Ndarray = ArrayView<'a, T, Ix2>;
419
420        #[track_caller]
421        fn into_ndarray(self) -> Self::Ndarray {
422            let nrows = self.nrows();
423            let ncols = self.ncols();
424            let row_stride: usize = self.row_stride().try_into().unwrap();
425            let col_stride: usize = self.col_stride().try_into().unwrap();
426            let ptr = self.as_ptr();
427            unsafe {
428                ArrayView::<'_, T, Ix2>::from_shape_ptr(
429                    (nrows, ncols).strides((row_stride, col_stride)),
430                    ptr,
431                )
432            }
433        }
434    }
435
436    impl<'a, T: SimpleEntity> IntoNdarray for MatMut<'a, T> {
437        type Ndarray = ArrayViewMut<'a, T, Ix2>;
438
439        #[track_caller]
440        fn into_ndarray(self) -> Self::Ndarray {
441            let nrows = self.nrows();
442            let ncols = self.ncols();
443            let row_stride: usize = self.row_stride().try_into().unwrap();
444            let col_stride: usize = self.col_stride().try_into().unwrap();
445            let ptr = self.as_ptr_mut();
446            unsafe {
447                ArrayViewMut::<'_, T, Ix2>::from_shape_ptr(
448                    (nrows, ncols).strides((row_stride, col_stride)),
449                    ptr,
450                )
451            }
452        }
453    }
454
455    impl<'a> IntoFaerComplex for ArrayView<'a, Complex32, Ix2> {
456        type Faer = MatRef<'a, c32>;
457
458        #[track_caller]
459        fn into_faer_complex(self) -> Self::Faer {
460            let nrows = self.nrows();
461            let ncols = self.ncols();
462            let strides: [isize; 2] = self.strides().try_into().unwrap();
463            let ptr = self.as_ptr() as *const c32;
464            unsafe { faer::mat::from_raw_parts(ptr, nrows, ncols, strides[0], strides[1]) }
465        }
466    }
467
468    impl<'a> IntoFaerComplex for ArrayViewMut<'a, Complex32, Ix2> {
469        type Faer = MatMut<'a, c32>;
470
471        #[track_caller]
472        fn into_faer_complex(self) -> Self::Faer {
473            let nrows = self.nrows();
474            let ncols = self.ncols();
475            let strides: [isize; 2] = self.strides().try_into().unwrap();
476            let ptr = { self }.as_mut_ptr() as *mut c32;
477            unsafe { faer::mat::from_raw_parts_mut(ptr, nrows, ncols, strides[0], strides[1]) }
478        }
479    }
480
481    impl<'a> IntoNdarrayComplex for MatRef<'a, c32> {
482        type Ndarray = ArrayView<'a, Complex32, Ix2>;
483
484        #[track_caller]
485        fn into_ndarray_complex(self) -> Self::Ndarray {
486            let nrows = self.nrows();
487            let ncols = self.ncols();
488            let row_stride: usize = self.row_stride().try_into().unwrap();
489            let col_stride: usize = self.col_stride().try_into().unwrap();
490            let ptr = self.as_ptr() as *const Complex32;
491            unsafe {
492                ArrayView::<'_, Complex32, Ix2>::from_shape_ptr(
493                    (nrows, ncols).strides((row_stride, col_stride)),
494                    ptr,
495                )
496            }
497        }
498    }
499
500    impl<'a> IntoNdarrayComplex for MatMut<'a, c32> {
501        type Ndarray = ArrayViewMut<'a, Complex32, Ix2>;
502
503        #[track_caller]
504        fn into_ndarray_complex(self) -> Self::Ndarray {
505            let nrows = self.nrows();
506            let ncols = self.ncols();
507            let row_stride: usize = self.row_stride().try_into().unwrap();
508            let col_stride: usize = self.col_stride().try_into().unwrap();
509            let ptr = self.as_ptr_mut() as *mut Complex32;
510            unsafe {
511                ArrayViewMut::<'_, Complex32, Ix2>::from_shape_ptr(
512                    (nrows, ncols).strides((row_stride, col_stride)),
513                    ptr,
514                )
515            }
516        }
517    }
518
519    impl<'a> IntoFaerComplex for ArrayView<'a, Complex64, Ix2> {
520        type Faer = MatRef<'a, c64>;
521
522        #[track_caller]
523        fn into_faer_complex(self) -> Self::Faer {
524            let nrows = self.nrows();
525            let ncols = self.ncols();
526            let strides: [isize; 2] = self.strides().try_into().unwrap();
527            let ptr = self.as_ptr() as *const c64;
528            unsafe { faer::mat::from_raw_parts(ptr, nrows, ncols, strides[0], strides[1]) }
529        }
530    }
531
532    impl<'a> IntoFaerComplex for ArrayViewMut<'a, Complex64, Ix2> {
533        type Faer = MatMut<'a, c64>;
534
535        #[track_caller]
536        fn into_faer_complex(self) -> Self::Faer {
537            let nrows = self.nrows();
538            let ncols = self.ncols();
539            let strides: [isize; 2] = self.strides().try_into().unwrap();
540            let ptr = { self }.as_mut_ptr() as *mut c64;
541            unsafe { faer::mat::from_raw_parts_mut(ptr, nrows, ncols, strides[0], strides[1]) }
542        }
543    }
544
545    impl<'a> IntoNdarrayComplex for MatRef<'a, c64> {
546        type Ndarray = ArrayView<'a, Complex64, Ix2>;
547
548        #[track_caller]
549        fn into_ndarray_complex(self) -> Self::Ndarray {
550            let nrows = self.nrows();
551            let ncols = self.ncols();
552            let row_stride: usize = self.row_stride().try_into().unwrap();
553            let col_stride: usize = self.col_stride().try_into().unwrap();
554            let ptr = self.as_ptr() as *const Complex64;
555            unsafe {
556                ArrayView::<'_, Complex64, Ix2>::from_shape_ptr(
557                    (nrows, ncols).strides((row_stride, col_stride)),
558                    ptr,
559                )
560            }
561        }
562    }
563
564    impl<'a> IntoNdarrayComplex for MatMut<'a, c64> {
565        type Ndarray = ArrayViewMut<'a, Complex64, Ix2>;
566
567        #[track_caller]
568        fn into_ndarray_complex(self) -> Self::Ndarray {
569            let nrows = self.nrows();
570            let ncols = self.ncols();
571            let row_stride: usize = self.row_stride().try_into().unwrap();
572            let col_stride: usize = self.col_stride().try_into().unwrap();
573            let ptr = self.as_ptr_mut() as *mut Complex64;
574            unsafe {
575                ArrayViewMut::<'_, Complex64, Ix2>::from_shape_ptr(
576                    (nrows, ncols).strides((row_stride, col_stride)),
577                    ptr,
578                )
579            }
580        }
581    }
582};
583
584#[cfg(all(feature = "nalgebra", feature = "ndarray"))]
585#[cfg_attr(docsrs, doc(cfg(all(feature = "nalgebra", feature = "ndarray"))))]
586const _: () =
587    {
588        use nalgebra::{Dim, Dyn, MatrixView, MatrixViewMut, ViewStorage, ViewStorageMut};
589        use ndarray::{ArrayView, ArrayViewMut, Ix2, ShapeBuilder};
590        use num_complex::Complex;
591
592        impl<'a, T> IntoNalgebra for ArrayView<'a, T, Ix2> {
593            type Nalgebra = MatrixView<'a, T, Dyn, Dyn, Dyn, Dyn>;
594
595            #[track_caller]
596            fn into_nalgebra(self) -> Self::Nalgebra {
597                let nrows = self.nrows();
598                let ncols = self.ncols();
599                let [row_stride, col_stride]: [isize; 2] = self.strides().try_into().unwrap();
600                let ptr = self.as_ptr();
601
602                unsafe {
603                    MatrixView::<'_, T, Dyn, Dyn, Dyn, Dyn>::from_data(ViewStorage::<
604                        '_,
605                        T,
606                        Dyn,
607                        Dyn,
608                        Dyn,
609                        Dyn,
610                    >::from_raw_parts(
611                        ptr,
612                        (Dyn(nrows), Dyn(ncols)),
613                        (
614                            Dyn(row_stride.try_into().unwrap()),
615                            Dyn(col_stride.try_into().unwrap()),
616                        ),
617                    ))
618                }
619            }
620        }
621        impl<'a, T> IntoNalgebra for ArrayViewMut<'a, T, Ix2> {
622            type Nalgebra = MatrixViewMut<'a, T, Dyn, Dyn, Dyn, Dyn>;
623
624            #[track_caller]
625            fn into_nalgebra(self) -> Self::Nalgebra {
626                let nrows = self.nrows();
627                let ncols = self.ncols();
628                let [row_stride, col_stride]: [isize; 2] = self.strides().try_into().unwrap();
629                let ptr = { self }.as_mut_ptr();
630
631                unsafe {
632                    MatrixViewMut::<'_, T, Dyn, Dyn, Dyn, Dyn>::from_data(ViewStorageMut::<
633                        '_,
634                        T,
635                        Dyn,
636                        Dyn,
637                        Dyn,
638                        Dyn,
639                    >::from_raw_parts(
640                        ptr,
641                        (Dyn(nrows), Dyn(ncols)),
642                        (
643                            Dyn(row_stride.try_into().unwrap()),
644                            Dyn(col_stride.try_into().unwrap()),
645                        ),
646                    ))
647                }
648            }
649        }
650
651        impl<'a, T, R: Dim, C: Dim, RStride: Dim, CStride: Dim> IntoNdarray
652            for MatrixView<'a, T, R, C, RStride, CStride>
653        {
654            type Ndarray = ArrayView<'a, T, Ix2>;
655
656            #[track_caller]
657            fn into_ndarray(self) -> Self::Ndarray {
658                let nrows = self.nrows();
659                let ncols = self.ncols();
660                let (row_stride, col_stride) = self.strides();
661                let ptr = self.as_ptr();
662
663                unsafe {
664                    ArrayView::<'_, T, Ix2>::from_shape_ptr(
665                        (nrows, ncols).strides((row_stride, col_stride)),
666                        ptr,
667                    )
668                }
669            }
670        }
671        impl<'a, T, R: Dim, C: Dim, RStride: Dim, CStride: Dim> IntoNdarray
672            for MatrixViewMut<'a, T, R, C, RStride, CStride>
673        {
674            type Ndarray = ArrayViewMut<'a, T, Ix2>;
675
676            #[track_caller]
677            fn into_ndarray(self) -> Self::Ndarray {
678                let nrows = self.nrows();
679                let ncols = self.ncols();
680                let (row_stride, col_stride) = self.strides();
681                let ptr = { self }.as_mut_ptr();
682
683                unsafe {
684                    ArrayViewMut::<'_, T, Ix2>::from_shape_ptr(
685                        (nrows, ncols).strides((row_stride, col_stride)),
686                        ptr,
687                    )
688                }
689            }
690        }
691
692        impl<'a, T> IntoNalgebraComplex for ArrayView<'a, Complex<T>, Ix2> {
693            type Nalgebra = MatrixView<'a, Complex<T>, Dyn, Dyn, Dyn, Dyn>;
694
695            #[track_caller]
696            fn into_nalgebra_complex(self) -> Self::Nalgebra {
697                let nrows = self.nrows();
698                let ncols = self.ncols();
699                let [row_stride, col_stride]: [isize; 2] = self.strides().try_into().unwrap();
700                let ptr = self.as_ptr();
701
702                unsafe {
703                    MatrixView::<'_, Complex<T>, Dyn, Dyn, Dyn, Dyn>::from_data(ViewStorage::<
704                        '_,
705                        Complex<T>,
706                        Dyn,
707                        Dyn,
708                        Dyn,
709                        Dyn,
710                    >::from_raw_parts(
711                        ptr,
712                        (Dyn(nrows), Dyn(ncols)),
713                        (
714                            Dyn(row_stride.try_into().unwrap()),
715                            Dyn(col_stride.try_into().unwrap()),
716                        ),
717                    ))
718                }
719            }
720        }
721        impl<'a, T> IntoNalgebraComplex for ArrayViewMut<'a, Complex<T>, Ix2> {
722            type Nalgebra = MatrixViewMut<'a, Complex<T>, Dyn, Dyn, Dyn, Dyn>;
723
724            #[track_caller]
725            fn into_nalgebra_complex(self) -> Self::Nalgebra {
726                let nrows = self.nrows();
727                let ncols = self.ncols();
728                let [row_stride, col_stride]: [isize; 2] = self.strides().try_into().unwrap();
729                let ptr = { self }.as_mut_ptr();
730
731                unsafe {
732                    MatrixViewMut::<'_, Complex<T>, Dyn, Dyn, Dyn, Dyn>::from_data(
733                        ViewStorageMut::<'_, Complex<T>, Dyn, Dyn, Dyn, Dyn>::from_raw_parts(
734                            ptr,
735                            (Dyn(nrows), Dyn(ncols)),
736                            (
737                                Dyn(row_stride.try_into().unwrap()),
738                                Dyn(col_stride.try_into().unwrap()),
739                            ),
740                        ),
741                    )
742                }
743            }
744        }
745
746        impl<'a, T, R: Dim, C: Dim, RStride: Dim, CStride: Dim> IntoNdarrayComplex
747            for MatrixView<'a, Complex<T>, R, C, RStride, CStride>
748        {
749            type Ndarray = ArrayView<'a, Complex<T>, Ix2>;
750
751            #[track_caller]
752            fn into_ndarray_complex(self) -> Self::Ndarray {
753                let nrows = self.nrows();
754                let ncols = self.ncols();
755                let (row_stride, col_stride) = self.strides();
756                let ptr = self.as_ptr();
757
758                unsafe {
759                    ArrayView::<'_, Complex<T>, Ix2>::from_shape_ptr(
760                        (nrows, ncols)
761                            .into_shape_and_order()
762                            .0
763                            .strides((row_stride, col_stride).into_dimension()),
764                        ptr,
765                    )
766                }
767            }
768        }
769        impl<'a, T, R: Dim, C: Dim, RStride: Dim, CStride: Dim> IntoNdarrayComplex
770            for MatrixViewMut<'a, Complex<T>, R, C, RStride, CStride>
771        {
772            type Ndarray = ArrayViewMut<'a, Complex<T>, Ix2>;
773
774            #[track_caller]
775            fn into_ndarray_complex(self) -> Self::Ndarray {
776                let nrows = self.nrows();
777                let ncols = self.ncols();
778                let (row_stride, col_stride) = self.strides();
779                let ptr = { self }.as_mut_ptr();
780
781                unsafe {
782                    ArrayViewMut::<'_, Complex<T>, Ix2>::from_shape_ptr(
783                        (nrows, ncols)
784                            .into_shape_and_order()
785                            .0
786                            .strides((row_stride, col_stride).into_dimension()),
787                        ptr,
788                    )
789                }
790            }
791        }
792    };
793
794#[cfg(feature = "polars")]
795#[cfg_attr(docsrs, doc(cfg(feature = "polars")))]
796pub mod polars {
797    use faer::Mat;
798    use polars::prelude::*;
799
800    pub trait Frame {
801        fn is_valid(self) -> PolarsResult<LazyFrame>;
802    }
803
804    impl Frame for LazyFrame {
805        fn is_valid(self) -> PolarsResult<LazyFrame> {
806            let test_dtypes: bool = self
807                .clone()
808                .limit(0)
809                .collect()?
810                .dtypes()
811                .into_iter()
812                .map(|e| {
813                    matches!(
814                        e,
815                        DataType::UInt8
816                            | DataType::UInt16
817                            | DataType::UInt32
818                            | DataType::UInt64
819                            | DataType::Int8
820                            | DataType::Int16
821                            | DataType::Int32
822                            | DataType::Int64
823                            | DataType::Float32
824                            | DataType::Float64
825                    )
826                })
827                .all(|e| e);
828            let test_no_nulls: bool = self
829                .clone()
830                .null_count()
831                .cast_all(DataType::UInt64, true)
832                .with_column(
833                    fold_exprs(
834                        lit(0).cast(DataType::UInt64),
835                        |acc, x| Ok(Some(acc + x)),
836                        [col("*")],
837                    )
838                    .alias("sum"),
839                )
840                .select(&[col("sum")])
841                .collect()?
842                .column("sum")?
843                .u64()?
844                .into_iter()
845                .map(|e| e.eq(&Some(0u64)))
846                .collect::<Vec<_>>()[0];
847            match (test_dtypes, test_no_nulls) {
848                (true, true) => Ok(self),
849                (false, true) => Err(PolarsError::InvalidOperation(
850                    "frame contains non-numerical data".into(),
851                )),
852                (true, false) => Err(PolarsError::InvalidOperation(
853                    "frame contains null entries".into(),
854                )),
855                (false, false) => Err(PolarsError::InvalidOperation(
856                    "frame contains non-numerical data and null entries".into(),
857                )),
858            }
859        }
860    }
861
862    macro_rules! polars_impl {
863        ($ty: ident, $dtype: ident, $fn_name: ident) => {
864            /// Converts a `polars` lazyframe into a [`Mat`].
865            ///
866            /// Note that this function expects that the frame passed "looks like"
867            /// a numerical array and all values will be cast to either f32 or f64
868            /// prior to building [`Mat`].
869            ///
870            /// Passing a frame with either non-numerical column data or null
871            /// entries will result in a error. Users are expected to reolve
872            /// these issues in `polars` prior calling this function.
873            #[cfg(feature = "polars")]
874            #[cfg_attr(docsrs, doc(cfg(feature = "polars")))]
875            pub fn $fn_name(
876                frame: impl Frame,
877            ) -> PolarsResult<Mat<$ty>> {
878                use core::{iter::zip, mem::MaybeUninit};
879
880                fn implementation(
881                    lf: LazyFrame,
882                ) -> PolarsResult<Mat<$ty>> {
883                    let df = lf
884                        .select(&[col("*").cast(DataType::$dtype)])
885                        .collect()?;
886
887                    let nrows = df.height();
888                    let ncols = df.get_column_names().len();
889
890                    let mut out = Mat::<$ty>::with_capacity(df.height(), df.get_column_names().len());
891
892                    df.get_column_names().iter()
893                        .enumerate()
894                        .try_for_each(|(j, col)| -> PolarsResult<()> {
895                            let mut row_start = 0usize;
896
897                            // SAFETY: this is safe since we allocated enough space for `ncols` columns and
898                            // `nrows` rows
899                            let out_col = unsafe {
900                                core::slice::from_raw_parts_mut(
901                                    out.as_mut().ptr_at_mut(0, j) as *mut MaybeUninit<$ty>,
902                                    nrows,
903                                )
904                            };
905
906                            df.column(col)?.$ty()?.downcast_iter().try_for_each(
907                                |chunk| -> PolarsResult<()> {
908                                    let len = chunk.len();
909                                    if len == 0 {
910                                        return Ok(());
911                                    }
912
913                                    match row_start.checked_add(len) {
914                                        Some(next_row_start) => {
915                                            if next_row_start <= nrows {
916                                                let mut out_slice = &mut out_col[row_start..next_row_start];
917                                                let mut values = chunk.values_iter().as_slice();
918                                                let validity = chunk.validity();
919
920                                                assert_eq!(values.len(), len);
921
922                                                match validity {
923                                                    Some(bitmap) => {
924                                                        let (mut bytes, offset, bitmap_len) = bitmap.as_slice();
925                                                        assert_eq!(bitmap_len, len);
926                                                        const BITS_PER_BYTE: usize = 8;
927
928                                                        if offset > 0 {
929                                                            let first_byte_len = Ord::min(len, 8 - offset);
930
931                                                            let (out_prefix, out_suffix) = out_slice.split_at_mut(first_byte_len);
932                                                            let (values_prefix, values_suffix) = values.split_at(first_byte_len);
933
934                                                            for (out_elem, value_elem) in zip(
935                                                                out_prefix,
936                                                                values_prefix,
937                                                            ) {
938                                                                *out_elem = MaybeUninit::new(*value_elem)
939                                                            }
940
941                                                            bytes = &bytes[1..];
942                                                            values = values_suffix;
943                                                            out_slice = out_suffix;
944                                                        }
945
946                                                        if bytes.len() > 0 {
947                                                            for (out_slice8, values8) in zip(
948                                                                out_slice.chunks_exact_mut(BITS_PER_BYTE),
949                                                                values.chunks_exact(BITS_PER_BYTE),
950                                                            ) {
951                                                                for (out_elem, value_elem) in zip(out_slice8, values8) {
952                                                                    *out_elem = MaybeUninit::new(*value_elem);
953                                                                }
954                                                            }
955
956                                                            for (out_elem, value_elem) in zip(
957                                                                out_slice.chunks_exact_mut(BITS_PER_BYTE).into_remainder(),
958                                                                values.chunks_exact(BITS_PER_BYTE).remainder(),
959                                                            ) {
960                                                                *out_elem = MaybeUninit::new(*value_elem);
961                                                            }
962                                                        }
963                                                    }
964                                                    None => {
965                                                        // SAFETY: T and MaybeUninit<T> have the same layout
966                                                        // NOTE: This state should not be reachable
967                                                        let values = unsafe {
968                                                            core::slice::from_raw_parts(
969                                                                values.as_ptr() as *const MaybeUninit<$ty>,
970                                                                values.len(),
971                                                            )
972                                                        };
973                                                        out_slice.copy_from_slice(values);
974                                                    }
975                                                }
976
977                                                row_start = next_row_start;
978                                                Ok(())
979                                            } else {
980                                                Err(PolarsError::ShapeMismatch(
981                                                    format!("too many values in column {col}").into(),
982                                                ))
983                                            }
984                                        }
985                                        None => Err(PolarsError::ShapeMismatch(
986                                            format!("too many values in column {col}").into(),
987                                        )),
988                                    }
989                                },
990                            )?;
991
992                            if row_start < nrows {
993                                Err(PolarsError::ShapeMismatch(
994                                    format!("not enough values in column {col} (column has {row_start} values, while dataframe has {nrows} rows)").into(),
995                                ))
996                            } else {
997                                Ok(())
998                            }
999                        })?;
1000
1001                    // SAFETY: we initialized every `ncols` columns, and each one was initialized with `nrows`
1002                    // elements
1003                    unsafe { out.set_dims(nrows, ncols) };
1004
1005                    Ok(out)
1006                }
1007
1008                implementation(frame.is_valid()?)
1009            }
1010        };
1011    }
1012
1013    polars_impl!(f32, Float32, polars_to_faer_f32);
1014    polars_impl!(f64, Float64, polars_to_faer_f64);
1015}
1016
1017#[cfg(test)]
1018mod tests {
1019    #![allow(unused_imports)]
1020    #![allow(non_snake_case)]
1021
1022    use super::*;
1023    use faer::mat;
1024    use faer::prelude::*;
1025
1026    #[cfg(feature = "ndarray")]
1027    #[test]
1028    fn test_ext_ndarray() {
1029        let mut I_faer = Mat::<f32>::identity(8, 7);
1030        let mut I_ndarray = ndarray::Array2::<f32>::zeros([8, 7]);
1031        I_ndarray.diag_mut().fill(1.0);
1032
1033        assert_eq!(I_ndarray.view().into_faer(), I_faer);
1034        assert_eq!(I_faer.as_ref().into_ndarray(), I_ndarray);
1035
1036        assert_eq!(I_ndarray.view_mut().into_faer(), I_faer);
1037        assert_eq!(I_faer.as_mut().into_ndarray(), I_ndarray);
1038    }
1039
1040    #[cfg(feature = "nalgebra")]
1041    #[test]
1042    fn test_ext_nalgebra() {
1043        let mut I_faer = Mat::<f32>::identity(8, 7);
1044        let mut I_nalgebra = nalgebra::DMatrix::<f32>::identity(8, 7);
1045
1046        assert_eq!(I_nalgebra.view_range(.., ..).into_faer(), I_faer);
1047        assert_eq!(I_faer.as_ref().into_nalgebra(), I_nalgebra);
1048
1049        assert_eq!(I_nalgebra.view_range_mut(.., ..).into_faer(), I_faer);
1050        assert_eq!(I_faer.as_mut().into_nalgebra(), I_nalgebra);
1051    }
1052
1053    #[cfg(feature = "polars")]
1054    #[test]
1055    fn test_polars_pos() {
1056        use crate::polars::{polars_to_faer_f32, polars_to_faer_f64};
1057        #[rustfmt::skip]
1058        use ::polars::prelude::*;
1059
1060        let s0: Series = Series::new("a", [1, 2, 3]);
1061        let s1: Series = Series::new("b", [10, 11, 12]);
1062
1063        let lf = DataFrame::new(vec![s0, s1]).unwrap().lazy();
1064
1065        let arr_32 = polars_to_faer_f32(lf.clone()).unwrap();
1066        let arr_64 = polars_to_faer_f64(lf).unwrap();
1067
1068        let expected_32 = mat![[1f32, 10f32], [2f32, 11f32], [3f32, 12f32]];
1069        let expected_64 = mat![[1f64, 10f64], [2f64, 11f64], [3f64, 12f64]];
1070
1071        assert_eq!(arr_32, expected_32);
1072        assert_eq!(arr_64, expected_64);
1073    }
1074
1075    #[cfg(feature = "polars")]
1076    #[test]
1077    #[should_panic(expected = "frame contains null entries")]
1078    fn test_polars_neg_32_null() {
1079        use crate::polars::polars_to_faer_f32;
1080        #[rustfmt::skip]
1081        use ::polars::prelude::*;
1082
1083        let s0: Series = Series::new("a", [1, 2, 3]);
1084        let s1: Series = Series::new("b", [Some(10), Some(11), None]);
1085
1086        let lf = DataFrame::new(vec![s0, s1]).unwrap().lazy();
1087
1088        polars_to_faer_f32(lf).unwrap();
1089    }
1090
1091    #[cfg(feature = "polars")]
1092    #[test]
1093    #[should_panic(expected = "frame contains non-numerical data")]
1094    fn test_polars_neg_32_strl() {
1095        use crate::polars::polars_to_faer_f32;
1096        #[rustfmt::skip]
1097        use ::polars::prelude::*;
1098
1099        let s0: Series = Series::new("a", [1, 2, 3]);
1100        let s1: Series = Series::new("b", ["fish", "dog", "crocodile"]);
1101
1102        let lf = DataFrame::new(vec![s0, s1]).unwrap().lazy();
1103
1104        polars_to_faer_f32(lf).unwrap();
1105    }
1106
1107    #[cfg(feature = "polars")]
1108    #[test]
1109    #[should_panic(expected = "frame contains non-numerical data and null entries")]
1110    fn test_polars_neg_32_combo() {
1111        use crate::polars::polars_to_faer_f32;
1112        #[rustfmt::skip]
1113        use ::polars::prelude::*;
1114
1115        let s0: Series = Series::new("a", [1, 2, 3]);
1116        let s1: Series = Series::new("b", [Some(10), Some(11), None]);
1117        let s2: Series = Series::new("c", [Some("fish"), Some("dog"), None]);
1118
1119        let lf = DataFrame::new(vec![s0, s1, s2]).unwrap().lazy();
1120
1121        polars_to_faer_f32(lf).unwrap();
1122    }
1123
1124    #[cfg(feature = "polars")]
1125    #[test]
1126    #[should_panic(expected = "frame contains null entries")]
1127    fn test_polars_neg_64_null() {
1128        use crate::polars::polars_to_faer_f64;
1129        #[rustfmt::skip]
1130        use ::polars::prelude::*;
1131
1132        let s0: Series = Series::new("a", [1, 2, 3]);
1133        let s1: Series = Series::new("b", [Some(10), Some(11), None]);
1134
1135        let lf = DataFrame::new(vec![s0, s1]).unwrap().lazy();
1136
1137        polars_to_faer_f64(lf).unwrap();
1138    }
1139
1140    #[cfg(feature = "polars")]
1141    #[test]
1142    #[should_panic(expected = "frame contains non-numerical data")]
1143    fn test_polars_neg_64_strl() {
1144        use crate::polars::polars_to_faer_f64;
1145        #[rustfmt::skip]
1146        use ::polars::prelude::*;
1147
1148        let s0: Series = Series::new("a", [1, 2, 3]);
1149        let s1: Series = Series::new("b", ["fish", "dog", "crocodile"]);
1150
1151        let lf = DataFrame::new(vec![s0, s1]).unwrap().lazy();
1152
1153        polars_to_faer_f64(lf).unwrap();
1154    }
1155
1156    #[cfg(feature = "polars")]
1157    #[test]
1158    #[should_panic(expected = "frame contains non-numerical data and null entries")]
1159    fn test_polars_neg_64_combo() {
1160        use crate::polars::polars_to_faer_f64;
1161        #[rustfmt::skip]
1162        use ::polars::prelude::*;
1163
1164        let s0: Series = Series::new("a", [1, 2, 3]);
1165        let s1: Series = Series::new("b", [Some(10), Some(11), None]);
1166        let s2: Series = Series::new("c", [Some("fish"), Some("dog"), None]);
1167
1168        let lf = DataFrame::new(vec![s0, s1, s2]).unwrap().lazy();
1169
1170        polars_to_faer_f64(lf).unwrap();
1171    }
1172}