lmutils/
matrix.rs

1#![allow(mutable_transmutes)]
2use std::{
3    collections::{HashMap, HashSet},
4    f64,
5    mem::MaybeUninit,
6    ops::{Deref, DerefMut},
7    str::FromStr,
8    sync::Arc,
9};
10
11#[cfg(feature = "r")]
12use extendr_api::{
13    io::Load, scalar::Scalar, single_threaded, wrapper, AsStrIter, Attributes, Conversions,
14    IntoRobj, MatrixConversions, RMatrix, Rinternals, Robj, Rtype,
15};
16use faer::{c64, linalg::qr, ColMut, Mat, MatMut, MatRef, RowMut};
17use rand_distr::Distribution;
18use rayon::prelude::*;
19use regex::Regex;
20use tracing::{debug, error, info, trace};
21
22use crate::{file::File, mean, standardize_column, standardize_row, Error};
23
24#[derive(Debug, Clone, Copy)]
25pub enum Join {
26    /// Inner join, only rows that are present in both matrices are kept
27    Inner = 0,
28    /// Left join, all rows from the left matrix must be matched
29    Left = 1,
30    /// Right join, all rows from the right matrix must be matched
31    Right = 2,
32}
33
34const INVALID_JOIN_TYPE: &str = "invalid join type, must be one of 0, 1, or 2";
35
36#[cfg(feature = "r")]
37impl TryFrom<Robj> for Join {
38    type Error = &'static str;
39
40    #[cfg_attr(coverage_nightly, coverage(off))]
41    fn try_from(obj: Robj) -> Result<Self, Self::Error> {
42        let val = if obj.is_integer() {
43            obj.as_integer().unwrap()
44        } else if obj.is_real() {
45            obj.as_real().unwrap() as i32
46        } else if obj.is_logical() {
47            obj.as_logical().unwrap().inner()
48        } else {
49            return Err(INVALID_JOIN_TYPE);
50        };
51        match val {
52            0 => Ok(Join::Inner),
53            1 => Ok(Join::Left),
54            2 => Ok(Join::Right),
55            _ => Err(INVALID_JOIN_TYPE),
56        }
57    }
58}
59
60impl std::fmt::Display for Join {
61    #[cfg_attr(coverage_nightly, coverage(off))]
62    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
63        match self {
64            Join::Inner => write!(f, "inner"),
65            Join::Left => write!(f, "left"),
66            Join::Right => write!(f, "right"),
67        }
68    }
69}
70
71#[doc(hidden)]
72pub trait DerefMatrix: Deref<Target = Matrix> + DerefMut + std::fmt::Debug {}
73
74impl<T> DerefMatrix for T where T: Deref<Target = Matrix> + DerefMut + std::fmt::Debug {}
75
76pub enum Matrix {
77    #[cfg(feature = "r")]
78    R(RMatrix<f64>),
79    Owned(OwnedMatrix),
80    File(File),
81    Dyn(Box<dyn DerefMatrix>),
82    Transform(
83        #[allow(clippy::type_complexity)]
84        Vec<Arc<dyn for<'a> Fn(&'a mut Matrix) -> Result<&'a mut Matrix, Error>>>,
85        Box<Matrix>,
86    ),
87}
88
89impl PartialEq for Matrix {
90    #[cfg_attr(coverage_nightly, coverage(off))]
91    fn eq(&self, other: &Self) -> bool {
92        match (self, other) {
93            #[cfg(feature = "r")]
94            (Matrix::R(a), Matrix::R(b)) => a == b,
95            (Matrix::Owned(a), Matrix::Owned(b)) => a == b,
96            (Matrix::File(a), Matrix::File(b)) => a == b,
97            (Matrix::Dyn(a), Matrix::Dyn(b)) => ***a == ***b,
98            (Matrix::Transform(_, a), Matrix::Transform(_, b)) => a == b,
99            _ => false,
100        }
101    }
102}
103
104impl std::fmt::Debug for Matrix {
105    #[cfg_attr(coverage_nightly, coverage(off))]
106    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
107        match self {
108            #[cfg(feature = "r")]
109            Matrix::R(m) => write!(f, "Matrix::R({:?})", m),
110            Matrix::Owned(m) => write!(f, "Matrix::Owned({:?})", m),
111            Matrix::File(m) => write!(f, "Matrix::File({:?})", m),
112            Matrix::Dyn(m) => write!(f, "Matrix::Dyn({:?})", m),
113            Matrix::Transform(t, m) => write!(f, "Matrix::Transform({:?}, {:?})", t.len(), m),
114        }
115    }
116}
117
118// SAFETY: This is always safe except when calling into R, for example by
119// loading an .RData file
120unsafe impl Send for Matrix {}
121unsafe impl Sync for Matrix {}
122
123impl Matrix {
124    #[cfg_attr(coverage_nightly, coverage(off))]
125    pub fn as_mat_ref(&mut self) -> Result<MatRef<'_, f64>, crate::Error> {
126        Ok(match self {
127            m @ (Matrix::File(_) | Matrix::Transform(..)) => m.into_owned()?,
128            m => m,
129        }
130        .as_mat_ref_loaded())
131    }
132
133    #[cfg_attr(coverage_nightly, coverage(off))]
134    pub fn as_mat_ref_loaded(&self) -> MatRef<'_, f64> {
135        match self {
136            #[cfg(feature = "r")]
137            Matrix::R(m) => MatRef::from_column_major_slice(m.data(), m.nrows(), m.ncols()),
138            Matrix::Owned(m) => {
139                MatRef::from_column_major_slice(m.data.as_slice(), m.nrows, m.ncols)
140            },
141            Matrix::File(_) => panic!("cannot call this function on a file"),
142            Matrix::Dyn(m) => m.as_mat_ref_loaded(),
143            Matrix::Transform(..) => panic!("cannot call this function on a transform"),
144        }
145    }
146
147    #[cfg_attr(coverage_nightly, coverage(off))]
148    pub fn as_mat_mut(&mut self) -> Result<MatMut<'_, f64>, crate::Error> {
149        Ok(match self {
150            // SAFETY: We know that the data is valid
151            #[cfg(feature = "r")]
152            Matrix::R(m) => unsafe {
153                MatMut::from_raw_parts_mut(
154                    m.data().as_ptr().cast_mut(),
155                    m.nrows(),
156                    m.ncols(),
157                    1,
158                    m.nrows() as isize,
159                )
160            },
161            Matrix::Owned(m) => {
162                MatMut::from_column_major_slice_mut(m.data.as_mut(), m.nrows, m.ncols)
163            },
164            m @ (Matrix::File(_) | Matrix::Transform(..)) => m.into_owned()?.as_mat_mut()?,
165            Matrix::Dyn(m) => m.as_mat_mut()?,
166        })
167    }
168
169    #[cfg_attr(coverage_nightly, coverage(off))]
170    pub fn as_owned_ref(&mut self) -> Result<&OwnedMatrix, crate::Error> {
171        self.into_owned()?;
172        match self {
173            Matrix::Owned(m) => Ok(&*m),
174            Matrix::Dyn(m) => m.as_owned_ref(),
175            _ => unreachable!(),
176        }
177    }
178
179    #[cfg_attr(coverage_nightly, coverage(off))]
180    pub fn as_owned_mut(&mut self) -> Result<&mut OwnedMatrix, crate::Error> {
181        self.into_owned()?;
182        match self {
183            Matrix::Owned(m) => Ok(m),
184            Matrix::Dyn(m) => m.as_owned_mut(),
185            _ => unreachable!(),
186        }
187    }
188
189    #[cfg_attr(coverage_nightly, coverage(off))]
190    pub fn is_loaded(&self) -> bool {
191        match self {
192            #[cfg(feature = "r")]
193            Matrix::R(_) => true,
194            Matrix::Owned(_) => true,
195            Matrix::File(_) | Matrix::Transform(..) => false,
196            Matrix::Dyn(m) => m.is_loaded(),
197        }
198    }
199
200    #[cfg_attr(coverage_nightly, coverage(off))]
201    #[cfg(feature = "r")]
202    pub fn to_rmatrix(&mut self) -> Result<RMatrix<f64>, crate::Error> {
203        Ok(match self {
204            #[cfg(feature = "r")]
205            Matrix::R(m) => m.clone().into_robj().clone().as_matrix().unwrap(),
206            Matrix::Owned(m) => {
207                use extendr_api::prelude::*;
208
209                let mut mat = RMatrix::new_matrix(
210                    m.nrows,
211                    m.ncols,
212                    #[inline(always)]
213                    |r, c| m.data[c * m.nrows + r],
214                );
215                let mut dimnames = List::from_values([NULL, NULL]);
216                if let Some(colnames) = &m.colnames {
217                    dimnames.set_elt(1, colnames.into_robj()).unwrap();
218                }
219                mat.set_attrib(wrapper::symbol::dimnames_symbol(), dimnames)
220                    .unwrap();
221                mat
222            },
223            m @ (Matrix::File(_) | Matrix::Transform(..)) => m.into_owned()?.to_rmatrix()?,
224            Matrix::Dyn(m) => m.to_rmatrix()?,
225        })
226    }
227
228    #[cfg_attr(coverage_nightly, coverage(off))]
229    #[cfg(feature = "r")]
230    pub fn into_robj(&mut self) -> Result<Robj, crate::Error> {
231        Ok(self.to_rmatrix().into_robj())
232    }
233
234    #[cfg_attr(coverage_nightly, coverage(off))]
235    #[cfg(feature = "r")]
236    pub fn from_rdata(mut reader: impl std::io::Read) -> Result<Self, crate::Error> {
237        let mut buf = [0; 5];
238        reader.read_exact(&mut buf)?;
239        if buf != *b"RDX3\n" {
240            return Err(crate::Error::InvalidRdataFile);
241        }
242        let obj = Robj::from_reader(
243            &mut reader,
244            extendr_api::io::PstreamFormat::R_pstream_xdr_format,
245            None,
246        )?;
247        let mat = obj
248            .as_pairlist()
249            .ok_or(crate::Error::InvalidRdataFile)?
250            .into_iter()
251            .next()
252            .ok_or(crate::Error::InvalidRdataFile)?
253            .1;
254        Matrix::from_robj(mat)
255    }
256
257    #[cfg_attr(coverage_nightly, coverage(off))]
258    #[cfg(feature = "r")]
259    pub fn from_rds(mut reader: impl std::io::Read) -> Result<Self, crate::Error> {
260        let obj = Robj::from_reader(
261            &mut reader,
262            extendr_api::io::PstreamFormat::R_pstream_xdr_format,
263            None,
264        )?;
265        Matrix::from_robj(obj)
266    }
267
268    #[cfg_attr(coverage_nightly, coverage(off))]
269    #[cfg(feature = "r")]
270    pub fn from_robj(r: Robj) -> Result<Self, crate::Error> {
271        fn r_int_to_f64(r: i32) -> f64 {
272            if r == i32::MIN {
273                f64::NAN
274            } else {
275                r as f64
276            }
277        }
278
279        if r.is_matrix() {
280            let float = RMatrix::<f64>::try_from(r);
281            match float {
282                Ok(float) => Ok(float.into()),
283                Err(extendr_api::Error::TypeMismatch(r)) => {
284                    Ok(RMatrix::<i32>::try_from(r)?.into_matrix())
285                },
286                Err(e) => Err(e.into()),
287            }
288        } else if r.is_string() {
289            Ok(File::from_str(r.as_str().expect("i is a string"))?.into())
290        } else if r.is_integer() {
291            let v = r
292                .as_integer_slice()
293                .expect("data is an integer vector")
294                .iter()
295                .map(|i| r_int_to_f64(*i))
296                .collect::<Vec<_>>();
297            Ok(Matrix::Owned(OwnedMatrix::new(v.len(), 1, v, None)))
298        } else if r.is_real() {
299            let v = r.as_real_vector().expect("data is a real vector");
300            Ok(Matrix::Owned(OwnedMatrix::new(v.len(), 1, v, None)))
301        } else if r.is_list() {
302            if r.class()
303                .map(|x| x.into_iter().any(|c| c == "data.frame"))
304                .unwrap_or(false)
305            {
306                use extendr_api::prelude::*;
307
308                let df = r.as_list().expect("data is a list");
309                struct Par(pub Robj);
310                unsafe impl Send for Par {}
311                let mut names = df.iter().map(|(n, r)| (n, Par(r))).collect::<Vec<_>>();
312                let (names, data) = names.into_iter().unzip::<_, _, Vec<_>, Vec<_>>();
313                let data = data
314                    .into_par_iter()
315                    .map(|Par(r)| {
316                        if r.is_string() {
317                            Ok(r.as_str_iter()
318                                .unwrap()
319                                .map(|x| x.parse::<f64>())
320                                .collect::<std::result::Result<Vec<_>, _>>()?)
321                        } else if r.is_integer() {
322                            Ok(r.as_integer_slice()
323                                .unwrap()
324                                .iter()
325                                .map(|x| r_int_to_f64(*x))
326                                .collect())
327                        } else if r.is_real() {
328                            Ok(r.as_real_vector().unwrap().to_vec())
329                        } else if r.is_logical() {
330                            Ok(r.as_logical_slice()
331                                .unwrap()
332                                .iter()
333                                .map(|x| r_int_to_f64(x.inner()))
334                                .collect())
335                        } else {
336                            Err(crate::Error::InvalidItemType)
337                        }
338                    })
339                    .collect::<std::result::Result<Vec<_>, crate::Error>>()?;
340                let ncols = data.len();
341                let nrows = data[0].len();
342                for i in data.iter().skip(1) {
343                    if i.len() != nrows {
344                        return Err(crate::Error::UnequalColumnLengths);
345                    }
346                }
347                Ok(Matrix::Owned(OwnedMatrix::new(
348                    nrows,
349                    ncols,
350                    data.concat(),
351                    Some(names.into_iter().map(|x| x.to_string()).collect()),
352                )))
353            } else {
354                Err(crate::Error::InvalidItemType)
355            }
356        } else {
357            Err(crate::Error::InvalidItemType)
358        }
359    }
360
361    #[cfg_attr(coverage_nightly, coverage(off))]
362    #[doc(hidden)]
363    pub fn to_owned_loaded(self) -> OwnedMatrix {
364        if let Matrix::Owned(m) = self {
365            m
366        } else {
367            panic!("matrix is not owned");
368        }
369    }
370
371    #[tracing::instrument(skip(self))]
372    #[cfg_attr(coverage_nightly, coverage(off))]
373    pub fn into_owned(&mut self) -> Result<&mut Self, crate::Error> {
374        match self {
375            #[cfg(feature = "r")]
376            Matrix::R(_) => {
377                let colnames = self
378                    .colnames()?
379                    .map(|x| x.into_iter().map(|x| x.to_string()).collect());
380                let Matrix::R(m) = self else { unreachable!() };
381                *self = Matrix::Owned(OwnedMatrix::new(
382                    m.nrows(),
383                    m.ncols(),
384                    m.data().to_vec(),
385                    colnames,
386                ));
387                Ok(self)
388            },
389            Matrix::Owned(_) => Ok(self),
390            Matrix::File(m) => {
391                *self = m.read()?;
392                Ok(self)
393            },
394            Matrix::Dyn(m) => {
395                m.into_owned()?;
396                Ok(self)
397            },
398            Matrix::Transform(..) => self.transform(),
399        }
400    }
401
402    #[tracing::instrument(skip(self))]
403    #[cfg_attr(coverage_nightly, coverage(off))]
404    pub fn to_owned(&mut self) -> Result<Matrix, crate::Error> {
405        match self {
406            #[cfg(feature = "r")]
407            Matrix::R(_) => {
408                let colnames = self
409                    .colnames()?
410                    .map(|x| x.into_iter().map(|x| x.to_string()).collect());
411                let Matrix::R(m) = self else { unreachable!() };
412                Ok(Matrix::Owned(OwnedMatrix::new(
413                    m.nrows(),
414                    m.ncols(),
415                    m.data().to_vec(),
416                    colnames,
417                )))
418            },
419            Matrix::Owned(mat) => Ok(Matrix::Owned(mat.clone())),
420            Matrix::File(m) => Ok(m.read()?),
421            Matrix::Dyn(m) => m.to_owned(),
422            Matrix::Transform(fns, mat) => {
423                let mut mat = Matrix::Transform(fns.clone(), Box::new(mat.to_owned()?));
424                mat.into_owned()?;
425                Ok(Matrix::Owned(mat.to_owned_loaded()))
426            },
427        }
428    }
429
430    #[cfg_attr(coverage_nightly, coverage(off))]
431    pub fn colnames(&mut self) -> Result<Option<Vec<&str>>, crate::Error> {
432        Ok(match self {
433            #[cfg(feature = "r")]
434            Matrix::R(m) => m.dimnames().and_then(|mut dimnames| {
435                dimnames
436                    .nth(1)
437                    .unwrap()
438                    .as_str_iter()
439                    .map(|x| x.collect::<Vec<_>>())
440            }),
441            Matrix::Owned(m) => m
442                .colnames
443                .as_deref()
444                .map(|x| x.iter().map(|x| x.as_str()).collect()),
445            m @ (Matrix::File(_) | Matrix::Transform(..)) => m.into_owned()?.colnames()?,
446            Matrix::Dyn(_) => None,
447        })
448    }
449
450    #[cfg_attr(coverage_nightly, coverage(off))]
451    pub fn colnames_loaded(&self) -> Option<Vec<&str>> {
452        match self {
453            #[cfg(feature = "r")]
454            Matrix::R(m) => m.dimnames().and_then(|mut dimnames| {
455                dimnames
456                    .nth(1)
457                    .unwrap()
458                    .as_str_iter()
459                    .map(|x| x.collect::<Vec<_>>())
460            }),
461            Matrix::Owned(m) => m
462                .colnames
463                .as_deref()
464                .map(|x| x.iter().map(|x| x.as_str()).collect()),
465            Matrix::File(_) => None,
466            Matrix::Dyn(_) => None,
467            Matrix::Transform(..) => None,
468        }
469    }
470
471    #[cfg_attr(coverage_nightly, coverage(off))]
472    pub fn set_colnames(&mut self, colnames: Vec<String>) -> Result<&mut Self, crate::Error> {
473        if colnames.len() != self.ncols()? {
474            return Err(crate::Error::ColumnNamesMismatch);
475        }
476        match self {
477            #[cfg(feature = "r")]
478            Matrix::R(m) => {
479                use extendr_api::prelude::*;
480
481                let mut dimnames = List::from_values([NULL, NULL]);
482                dimnames.set_elt(1, colnames.into_robj()).unwrap();
483                m.set_attrib(wrapper::symbol::dimnames_symbol(), dimnames)
484                    .unwrap();
485                Ok(self)
486            },
487            Matrix::Owned(m) => {
488                m.colnames = Some(colnames);
489                Ok(self)
490            },
491            m @ (Matrix::File(_) | Matrix::Transform(..)) => m.into_owned()?.set_colnames(colnames),
492            Matrix::Dyn(m) => m.set_colnames(colnames),
493        }
494    }
495
496    #[cfg_attr(coverage_nightly, coverage(off))]
497    pub fn from_slice(data: &[f64], rows: usize, cols: usize) -> Self {
498        Matrix::Owned(OwnedMatrix::new(rows, cols, data.to_vec(), None))
499    }
500
501    #[cfg_attr(coverage_nightly, coverage(off))]
502    #[cfg(feature = "r")]
503    pub fn from_rmatrix(r: RMatrix<f64>) -> Self {
504        Matrix::R(r)
505    }
506
507    #[cfg_attr(coverage_nightly, coverage(off))]
508    pub fn from_owned(m: OwnedMatrix) -> Self {
509        Matrix::Owned(m)
510    }
511
512    #[cfg_attr(coverage_nightly, coverage(off))]
513    pub fn from_file(f: File) -> Self {
514        Matrix::File(f)
515    }
516
517    #[cfg_attr(coverage_nightly, coverage(off))]
518    pub fn from_deref(m: impl DerefMatrix + 'static) -> Matrix {
519        Matrix::Dyn(Box::new(m))
520    }
521
522    #[cfg_attr(coverage_nightly, coverage(off))]
523    pub fn from_mat_ref(m: faer::MatRef<'_, f64>) -> Self {
524        let data = vec![MaybeUninit::<f64>::uninit(); m.nrows() * m.ncols()];
525        m.par_col_chunks(1).enumerate().for_each(|(i, c)| {
526            let col = c.col(0);
527            let slice = unsafe {
528                std::slice::from_raw_parts_mut(
529                    data.as_ptr().add(i * m.nrows()).cast::<f64>().cast_mut(),
530                    m.nrows(),
531                )
532            };
533            for (i, x) in col.iter().enumerate() {
534                slice[i] = *x;
535            }
536        });
537        Matrix::Owned(OwnedMatrix::new(
538            m.nrows(),
539            m.ncols(),
540            // SAFETY: The data is initialized now
541            unsafe {
542                std::mem::transmute::<std::vec::Vec<std::mem::MaybeUninit<f64>>, std::vec::Vec<f64>>(
543                    data,
544                )
545            },
546            None,
547        ))
548    }
549
550    pub fn generate_normal_matrix(rows: usize, cols: usize, mean: f64, std_dev: f64) -> Self {
551        let data = rand_distr::Normal::new(mean, std_dev)
552            .unwrap()
553            .sample_iter(rand::thread_rng())
554            .take(rows * cols)
555            .collect::<Vec<_>>();
556        Matrix::Owned(OwnedMatrix::new(rows, cols, data, None))
557    }
558
559    pub fn generate_standard_normal_matrix(rows: usize, cols: usize) -> Self {
560        Self::generate_normal_matrix(rows, cols, 0.0, 1.0)
561    }
562
563    pub fn eigen(&mut self, symmetric: Option<bool>) -> Result<Eigen, crate::Error> {
564        Eigen::new(self, symmetric)
565    }
566
567    pub fn is_symmetric(&mut self) -> Result<bool, crate::Error> {
568        let m = self.as_mat_ref()?;
569        if m.nrows() != m.ncols() {
570            return Err(crate::Error::MatrixDimensionsMismatch);
571        }
572        let mut is_symmetric = true;
573        'outer: for i in 1..m.nrows() {
574            for j in 0..i {
575                if (m.get(i, j) - m.get(j, i)).abs() >= 100.0 * f64::EPSILON {
576                    is_symmetric = false;
577                    break 'outer;
578                }
579            }
580        }
581        Ok(is_symmetric)
582    }
583}
584
585pub enum Eigen {
586    Real { values: Vec<f64>, vectors: Vec<f64> },
587    Complex { values: Vec<c64>, vectors: Vec<c64> },
588}
589
590impl Eigen {
591    pub fn new(m: &mut Matrix, symmetric: Option<bool>) -> Result<Self, crate::Error> {
592        enum E {
593            Generic(faer::linalg::solvers::Eigen<f64>),
594            SelfAdjoint(faer::linalg::solvers::SelfAdjointEigen<f64>),
595        }
596        let symmetric = match symmetric {
597            Some(s) => s,
598            None => m.is_symmetric()?,
599        };
600        let m = m.as_mat_ref()?;
601        if m.nrows() != m.ncols() {
602            return Err(crate::Error::MatrixDimensionsMismatch);
603        }
604        if symmetric {
605            let eigen = m.self_adjoint_eigen(faer::Side::Lower)?;
606            let s = eigen.S();
607            let u = eigen.U();
608            let mut values = Vec::with_capacity(m.nrows());
609            for i in 0..m.nrows() {
610                values.push(s[i]);
611            }
612            let mut zero = 0.0;
613            let mut vectors: Vec<f64> = vec![zero; m.nrows() * m.nrows()];
614            u.par_col_chunks(1).enumerate().for_each(|(i, c)| {
615                let col = c.col(0);
616                let mut vector = unsafe {
617                    std::slice::from_raw_parts_mut(
618                        vectors.as_ptr().cast_mut().add(i * m.nrows()),
619                        m.nrows(),
620                    )
621                };
622                for (j, x) in col.iter().enumerate() {
623                    vector[j] = *x;
624                }
625            });
626            Ok(Eigen::Real { values, vectors })
627        } else {
628            let eigen = m.eigen()?;
629            let s = eigen.S();
630            let u = eigen.U();
631            let complex = (0..m.nrows()).into_par_iter().any(|i| {
632                if s[i].im != 0.0 {
633                    // is complex
634                    return true;
635                }
636                let col = u.col(i);
637                for j in 0..m.nrows() {
638                    if col[j].im != 0.0 {
639                        return true;
640                    }
641                }
642                false
643            });
644            if complex {
645                let mut values = Vec::with_capacity(m.nrows());
646                for i in 0..m.nrows() {
647                    values.push(s[i]);
648                }
649                let mut zero = c64::new(0.0, 0.0);
650                let mut vectors: Vec<c64> = vec![zero; m.nrows() * m.nrows()];
651                u.par_col_chunks(1).enumerate().for_each(|(i, c)| {
652                    let col = c.col(0);
653                    let mut vector = unsafe {
654                        std::slice::from_raw_parts_mut(
655                            vectors.as_ptr().cast_mut().add(i * m.nrows()),
656                            m.nrows(),
657                        )
658                    };
659                    for (j, x) in col.iter().enumerate() {
660                        vector[j] = *x;
661                    }
662                });
663                Ok(Eigen::Complex { values, vectors })
664            } else {
665                let mut values = Vec::with_capacity(m.nrows());
666                for i in 0..m.nrows() {
667                    values.push(s[i].re);
668                }
669                let mut zero = 0.0;
670                let mut vectors: Vec<f64> = vec![zero; m.nrows() * m.nrows()];
671                u.par_col_chunks(1).enumerate().for_each(|(i, c)| {
672                    let col = c.col(0);
673                    let mut vector = unsafe {
674                        std::slice::from_raw_parts_mut(
675                            vectors.as_ptr().cast_mut().add(i * m.nrows()),
676                            m.nrows(),
677                        )
678                    };
679                    for (j, x) in col.iter().enumerate() {
680                        vector[j] = x.re;
681                    }
682                });
683                Ok(Eigen::Real { values, vectors })
684            }
685        }
686    }
687}
688
689impl Matrix {
690    pub fn transform(&mut self) -> Result<&mut Self, crate::Error> {
691        if let Matrix::Transform(fns, mat) = self {
692            let slf = std::mem::replace(self, Matrix::Owned(OwnedMatrix::new(0, 0, vec![], None)));
693            if let Matrix::Transform(fns, mat) = slf {
694                let mut mat = *mat;
695                for f in fns {
696                    f(&mut mat)?;
697                }
698                *self = mat;
699            }
700        }
701        Ok(self)
702    }
703
704    fn add_transformation(
705        &mut self,
706        f: impl for<'a> Fn(&'a mut Matrix) -> Result<&'a mut Matrix, Error> + 'static,
707    ) -> &mut Self {
708        match self {
709            Matrix::Transform(fns, _) => fns.push(Arc::new(f)),
710            _ => {
711                let m =
712                    std::mem::replace(self, Matrix::Owned(OwnedMatrix::new(0, 0, vec![], None)));
713                *self = Matrix::Transform(vec![], Box::new(m));
714                self.add_transformation(f);
715            },
716        }
717        self
718    }
719
720    pub fn t_combine_columns(&mut self, mut others: Vec<Self>) -> &mut Self {
721        self.add_transformation(move |m| {
722            let others =
723                unsafe { std::mem::transmute::<&[Self], &mut [Matrix]>(others.as_slice()) };
724            m.combine_columns(others)
725        })
726    }
727
728    #[tracing::instrument(skip(self, others))]
729    pub fn combine_columns(&mut self, others: &mut [Self]) -> Result<&mut Self, crate::Error> {
730        if others.is_empty() {
731            return Ok(self);
732        }
733
734        // Only retain the column names if they are present in all matrices
735        let colnames = if self.colnames()?.is_some()
736            && others
737                .iter_mut()
738                .map(|i| i.colnames())
739                .collect::<Result<Vec<Option<Vec<&str>>>, _>>()?
740                .into_iter()
741                .all(|i| i.is_some())
742        {
743            let mut colnames = self
744                .colnames()?
745                .unwrap()
746                .into_iter()
747                .map(|x| x.to_string())
748                .collect::<Vec<_>>();
749            for i in others.iter_mut() {
750                colnames.extend(i.colnames()?.unwrap().iter().map(|x| x.to_string()));
751            }
752            Some(colnames)
753        } else {
754            None
755        };
756
757        // Ensure that all the matrices are properly sized
758        let others = others
759            .iter_mut()
760            .map(|x| x.as_mat_ref())
761            .collect::<Result<Vec<_>, _>>()?;
762        let nrows = self.nrows()?;
763        if others.iter().any(|i| i.nrows() != nrows) {
764            return Err(crate::Error::MatrixDimensionsMismatch);
765        }
766        let ncols = self.ncols()? + others.iter().map(|i| i.ncols()).sum::<usize>();
767        let data = vec![MaybeUninit::<f64>::uninit(); nrows * ncols];
768        debug!("nrows: {}, ncols: {}", nrows, ncols);
769
770        // Combine the matrices
771        let mats = [&[self.as_mat_ref_loaded()], others.as_slice()].concat();
772        mats.par_iter().enumerate().for_each(|(i, m)| {
773            let cols_before = mats.iter().take(i).map(|m| m.ncols()).sum::<usize>();
774            // SAFETY: No two threads will write to the same location
775            (0..m.ncols()).into_par_iter().for_each(|c| unsafe {
776                let src = m
777                    .get_unchecked(.., c)
778                    .try_as_col_major()
779                    .expect("could not get slice")
780                    .as_slice();
781                let dst = data
782                    .as_ptr()
783                    .add(m.nrows() * (c + cols_before))
784                    .cast::<f64>()
785                    .cast_mut();
786                let slice = std::slice::from_raw_parts_mut(dst, m.nrows());
787                slice.copy_from_slice(src);
788            });
789        });
790        *self = Matrix::Owned(OwnedMatrix::new(
791            nrows,
792            ncols,
793            // SAFETY: The data is initialized now
794            unsafe {
795                std::mem::transmute::<std::vec::Vec<std::mem::MaybeUninit<f64>>, std::vec::Vec<f64>>(
796                    data,
797                )
798            },
799            colnames,
800        ));
801        Ok(self)
802    }
803
804    pub fn t_combine_rows(&mut self, mut others: Vec<Self>) -> &mut Self {
805        self.add_transformation(move |m| {
806            let others = unsafe { std::mem::transmute::<&[Self], &mut [Self]>(others.as_slice()) };
807            m.combine_rows(others)
808        })
809    }
810
811    #[tracing::instrument(skip(self, others))]
812    pub fn combine_rows(&mut self, others: &mut [Self]) -> Result<&mut Self, crate::Error> {
813        if others.is_empty() {
814            return Ok(self);
815        }
816        let colnames = self.colnames()?;
817        if others
818            .iter_mut()
819            .map(|i| i.colnames())
820            .collect::<Result<Vec<_>, _>>()?
821            .iter()
822            .any(|i| *i != colnames)
823        {
824            return Err(crate::Error::ColumnNamesMismatch);
825        }
826        let colnames = colnames.map(|x| x.iter().map(|x| x.to_string()).collect());
827        let ncols = self.ncols()?;
828        let others = others
829            .iter_mut()
830            .map(|x| x.as_mat_ref())
831            .collect::<Result<Vec<_>, _>>()?;
832        if others.iter().any(|i| i.ncols() != ncols) {
833            return Err(crate::Error::MatrixDimensionsMismatch);
834        }
835        let nrows = self.nrows()? + others.iter().map(|i| i.nrows()).sum::<usize>();
836        debug!("nrows: {}, ncols: {}", nrows, ncols);
837        let mats = [&[self.as_mat_ref_loaded()], others.as_slice()].concat();
838        let data = vec![MaybeUninit::<f64>::uninit(); nrows * ncols];
839        mats.par_iter().enumerate().for_each(|(i, m)| {
840            let rows_before = mats.iter().take(i).map(|m| m.nrows()).sum::<usize>();
841            (0..ncols).into_par_iter().for_each(|c| unsafe {
842                let src = m
843                    .get_unchecked(.., c)
844                    .try_as_col_major()
845                    .expect("could not get slice")
846                    .as_slice();
847                debug!("{i} {c} src: {:?}", src);
848                let dst = data
849                    .as_ptr()
850                    .add(nrows * c + rows_before)
851                    .cast::<f64>()
852                    .cast_mut();
853                debug!("{i} {c} dst: {:?}", dst);
854                let slice = std::slice::from_raw_parts_mut(dst, m.nrows());
855                debug!("{i} {c} slice: {:?}", slice);
856                slice.copy_from_slice(src);
857                debug!("{i} {c} slice: {:?}", slice);
858            });
859        });
860        *self = Matrix::Owned(OwnedMatrix::new(
861            nrows,
862            ncols,
863            // SAFETY: The data is initialized now
864            unsafe {
865                std::mem::transmute::<std::vec::Vec<std::mem::MaybeUninit<f64>>, std::vec::Vec<f64>>(
866                    data,
867                )
868            },
869            colnames,
870        ));
871        Ok(self)
872    }
873
874    pub fn t_remove_rows(&mut self, removing: HashSet<usize>) -> &mut Self {
875        self.add_transformation(move |m| m.remove_rows(&removing))
876    }
877
878    #[tracing::instrument(skip(self, removing))]
879    pub fn remove_rows(&mut self, removing: &HashSet<usize>) -> Result<&mut Self, crate::Error> {
880        if removing.is_empty() {
881            return Ok(self);
882        }
883        for i in removing.iter() {
884            if *i >= self.nrows()? {
885                return Err(crate::Error::RowIndexOutOfBounds(*i));
886            }
887        }
888        let new_nrows = self.nrows()? - removing.len();
889        let m = self.as_mat_ref()?;
890        let data = vec![MaybeUninit::<f64>::uninit(); new_nrows * m.ncols()];
891        (0..m.ncols()).into_par_iter().for_each(|c| {
892            let mut j = 0;
893            for i in 0..m.nrows() {
894                if removing.contains(&i) {
895                    continue;
896                }
897                let src = m.get(i, c);
898                unsafe {
899                    let dst = data
900                        .as_ptr()
901                        .add(c * new_nrows + j)
902                        .cast::<f64>()
903                        .cast_mut();
904                    dst.write(*src);
905                }
906                j += 1;
907            }
908        });
909        *self = Matrix::Owned(OwnedMatrix::new(
910            new_nrows,
911            self.ncols()?,
912            // SAFETY: The data is initialized now
913            unsafe {
914                std::mem::transmute::<std::vec::Vec<std::mem::MaybeUninit<f64>>, std::vec::Vec<f64>>(
915                    data,
916                )
917            },
918            self.colnames()?
919                .map(|x| x.iter().map(|x| x.to_string()).collect()),
920        ));
921        Ok(self)
922    }
923
924    pub fn t_remove_columns(&mut self, removing: HashSet<usize>) -> &mut Self {
925        self.add_transformation(move |m| m.remove_columns(&removing))
926    }
927
928    #[tracing::instrument(skip(self, removing))]
929    pub fn remove_columns(&mut self, removing: &HashSet<usize>) -> Result<&mut Self, crate::Error> {
930        if removing.is_empty() {
931            return Ok(self);
932        }
933        let m = self.as_mat_ref()?;
934        for i in removing.iter() {
935            if *i >= m.ncols() {
936                return Err(crate::Error::ColumnIndexOutOfBounds(*i));
937            }
938        }
939        let new_ncols = m.ncols() - removing.len();
940        let data = vec![MaybeUninit::<f64>::uninit(); m.nrows() * new_ncols];
941        debug!("nrows: {}, ncols: {}", m.nrows(), new_ncols);
942        (0..m.ncols())
943            .filter(|x| !removing.contains(x))
944            .collect::<Vec<_>>()
945            .into_par_iter()
946            .enumerate()
947            // SAFETY: No two threads will write to the same location
948            .for_each(|(n, o)| unsafe {
949                let src = m.get_unchecked(.., o).try_as_col_major().expect("could not get slice").as_slice();
950                let dst = data.as_ptr().add(n * m.nrows()).cast::<f64>().cast_mut();
951                let slice = std::slice::from_raw_parts_mut(dst, m.nrows());
952                slice.copy_from_slice(src);
953            });
954        *self = Matrix::Owned(OwnedMatrix::new(
955            m.nrows(),
956            new_ncols,
957            // SAFETY: The data is initialized now
958            unsafe {
959                std::mem::transmute::<std::vec::Vec<std::mem::MaybeUninit<f64>>, std::vec::Vec<f64>>(
960                    data,
961                )
962            },
963            self.colnames()?.map(|x| {
964                x.iter()
965                    .enumerate()
966                    .filter_map(|(i, x)| {
967                        if removing.contains(&i) {
968                            None
969                        } else {
970                            Some(x.to_string())
971                        }
972                    })
973                    .collect()
974            }),
975        ));
976        Ok(self)
977    }
978
979    pub fn t_remove_column_by_name(&mut self, name: &str) -> &mut Self {
980        let name = name.to_string();
981        self.add_transformation(move |m| m.remove_column_by_name(&name))
982    }
983
984    #[tracing::instrument(skip(self))]
985    pub fn remove_column_by_name(&mut self, name: &str) -> Result<&mut Self, crate::Error> {
986        let colnames = self.colnames()?;
987        if colnames.is_none() {
988            return Err(crate::Error::MissingColumnNames);
989        }
990        let exists = colnames
991            .expect("colnames should be present")
992            .iter()
993            .position(|x| *x == name);
994        if let Some(i) = exists {
995            self.remove_columns(&[i].iter().copied().collect())?;
996        } else {
997            return Err(crate::Error::ColumnNameNotFound(name.to_string()));
998        }
999        Ok(self)
1000    }
1001
1002    pub fn t_remove_columns_by_name(&mut self, names: HashSet<String>) -> &mut Self {
1003        self.add_transformation(move |m| m.remove_columns_by_name(&names))
1004    }
1005
1006    #[tracing::instrument(skip(self, names))]
1007    pub fn remove_columns_by_name(
1008        &mut self,
1009        names: &HashSet<String>,
1010    ) -> Result<&mut Self, crate::Error> {
1011        let names = HashSet::<&str>::from_iter(names.iter().map(|x| x.as_str()));
1012        let colnames = self.colnames()?;
1013        if colnames.is_none() {
1014            return Err(crate::Error::MissingColumnNames);
1015        }
1016        let colnames = colnames.expect("colnames should be present");
1017        let removing = colnames
1018            .iter()
1019            .enumerate()
1020            .filter_map(|(i, x)| if names.contains(x) { Some(i) } else { None })
1021            .collect();
1022        self.remove_columns(&removing)
1023    }
1024
1025    pub fn t_remove_column_by_name_if_exists(&mut self, name: &str) -> &mut Self {
1026        let name = name.to_string();
1027        self.add_transformation(move |m| m.remove_column_by_name_if_exists(&name))
1028    }
1029
1030    #[tracing::instrument(skip(self))]
1031    pub fn remove_column_by_name_if_exists(
1032        &mut self,
1033        name: &str,
1034    ) -> Result<&mut Self, crate::Error> {
1035        match self.remove_column_by_name(name) {
1036            Ok(_) => Ok(self),
1037            Err(crate::Error::ColumnNameNotFound(_) | crate::Error::MissingColumnNames) => Ok(self),
1038            Err(e) => Err(e),
1039        }
1040    }
1041
1042    pub fn t_transpose(&mut self) -> &mut Self {
1043        self.add_transformation(|m| m.transpose())
1044    }
1045
1046    #[tracing::instrument(skip(self))]
1047    pub fn transpose(&mut self) -> Result<&mut Self, crate::Error> {
1048        let m = self.as_mat_ref()?;
1049        let new_data = vec![MaybeUninit::<f64>::uninit(); m.nrows() * m.ncols()];
1050        m.par_col_chunks(1).enumerate().for_each(|(new_row, c)| {
1051            c.col(0).iter().enumerate().for_each(|(new_col, x)| {
1052                let i = new_col * m.ncols() + new_row;
1053                unsafe {
1054                    new_data.as_ptr().add(i).cast_mut().cast::<f64>().write(*x);
1055                };
1056            });
1057        });
1058        *self = Matrix::Owned(OwnedMatrix::new(
1059            m.ncols(),
1060            m.nrows(),
1061            // SAFETY: The data is initialized now
1062            unsafe {
1063                std::mem::transmute::<std::vec::Vec<std::mem::MaybeUninit<f64>>, std::vec::Vec<f64>>(
1064                    new_data,
1065                )
1066            },
1067            None,
1068        ));
1069        Ok(self)
1070    }
1071
1072    pub fn t_sort_by_column(&mut self, by: usize) -> &mut Self {
1073        self.add_transformation(move |m| m.sort_by_column(by))
1074    }
1075
1076    #[tracing::instrument(skip(self))]
1077    pub fn sort_by_column(&mut self, by: usize) -> Result<&mut Self, crate::Error> {
1078        if by >= self.ncols()? {
1079            return Err(crate::Error::ColumnIndexOutOfBounds(by));
1080        }
1081        let col = self.col(by)?.unwrap();
1082        let mut order = col.iter().copied().enumerate().collect::<Vec<_>>();
1083        order.sort_by(|a, b| a.1.partial_cmp(&b.1).expect("could not compare"));
1084        self.sort_by_order(
1085            order
1086                .into_iter()
1087                .map(|(i, _)| i)
1088                .collect::<Vec<_>>()
1089                .as_slice(),
1090        )
1091    }
1092
1093    pub fn t_sort_by_column_name(&mut self, by: &str) -> &mut Self {
1094        let by = by.to_string();
1095        self.add_transformation(move |m| m.sort_by_column_name(&by))
1096    }
1097
1098    #[tracing::instrument(skip(self))]
1099    pub fn sort_by_column_name(&mut self, by: &str) -> Result<&mut Self, crate::Error> {
1100        let colnames = self.colnames()?;
1101        if colnames.is_none() {
1102            return Err(crate::Error::MissingColumnNames);
1103        }
1104        let by_col_idx = colnames
1105            .expect("colnames should be present")
1106            .iter()
1107            .position(|x| *x == by);
1108        if let Some(i) = by_col_idx {
1109            self.sort_by_column(i)?;
1110            Ok(self)
1111        } else {
1112            Err(crate::Error::ColumnNameNotFound(by.to_string()))
1113        }
1114    }
1115
1116    pub fn t_sort_by_order(&mut self, order: Vec<usize>) -> &mut Self {
1117        self.add_transformation(move |m| m.sort_by_order(&order))
1118    }
1119
1120    #[tracing::instrument(skip(self, order))]
1121    pub fn sort_by_order(&mut self, order: &[usize]) -> Result<&mut Self, crate::Error> {
1122        if order.len() != self.nrows()? {
1123            return Err(crate::Error::OrderLengthMismatch(order.len()));
1124        }
1125        for i in order.iter() {
1126            if *i >= self.nrows()? {
1127                return Err(crate::Error::RowIndexOutOfBounds(*i));
1128            }
1129        }
1130        let data = vec![MaybeUninit::<f64>::uninit(); self.as_mat_ref()?.nrows() * self.ncols()?];
1131        let m = self.as_mat_ref()?;
1132        m.par_col_chunks(1).enumerate().for_each(|(i, c)| {
1133            let col = c.col(0);
1134            let slice = unsafe {
1135                std::slice::from_raw_parts_mut(
1136                    data.as_ptr().add(i * m.nrows()).cast::<f64>().cast_mut(),
1137                    m.nrows(),
1138                )
1139            };
1140            for (i, o) in order.iter().enumerate() {
1141                slice[i] = col[*o];
1142            }
1143        });
1144        *self = Matrix::Owned(OwnedMatrix::new(
1145            m.nrows(),
1146            m.ncols(),
1147            // SAFETY: The data is initialized now
1148            unsafe {
1149                std::mem::transmute::<std::vec::Vec<std::mem::MaybeUninit<f64>>, std::vec::Vec<f64>>(
1150                    data,
1151                )
1152            },
1153            self.colnames()?
1154                .map(|x| x.iter().map(|x| x.to_string()).collect()),
1155        ));
1156        Ok(self)
1157    }
1158
1159    pub fn t_dedup_by_column(&mut self, by: usize) -> &mut Self {
1160        self.add_transformation(move |m| m.dedup_by_column(by))
1161    }
1162
1163    #[tracing::instrument(skip(self))]
1164    pub fn dedup_by_column(&mut self, by: usize) -> Result<&mut Self, crate::Error> {
1165        if by >= self.ncols()? {
1166            return Err(crate::Error::ColumnIndexOutOfBounds(by));
1167        }
1168        let mut col = self.col(by)?.unwrap().to_vec();
1169        col.sort_by(|a, b| a.partial_cmp(b).expect("could not compare"));
1170        let mut removing = HashSet::new();
1171        for i in 1..col.len() {
1172            if col[i] == col[i - 1] {
1173                removing.insert(i);
1174            }
1175        }
1176        self
1177            .remove_rows(&removing)
1178            // by doing this we avoid nesting another error that's impossible to occur inside
1179            // crate::Error
1180            .expect("all indices should be valid");
1181        Ok(self)
1182    }
1183
1184    pub fn t_dedup_by_column_name(&mut self, by: &str) -> &mut Self {
1185        let by = by.to_string();
1186        self.add_transformation(move |m| m.dedup_by_column_name(&by))
1187    }
1188
1189    #[tracing::instrument(skip(self))]
1190    pub fn dedup_by_column_name(&mut self, by: &str) -> Result<&mut Self, crate::Error> {
1191        let colnames = self.colnames()?;
1192        if colnames.is_none() {
1193            return Err(crate::Error::MissingColumnNames);
1194        }
1195        let by_col_idx = colnames
1196            .expect("colnames should be present")
1197            .iter()
1198            .position(|x| *x == by);
1199        if let Some(i) = by_col_idx {
1200            self.dedup_by_column(i)?;
1201            Ok(self)
1202        } else {
1203            Err(crate::Error::ColumnNameNotFound(by.to_string()))
1204        }
1205    }
1206
1207    pub fn t_match_to(&mut self, with: Vec<f64>, by: usize, join: Join) -> &mut Self {
1208        self.add_transformation(move |m| m.match_to(&with, by, join))
1209    }
1210
1211    #[tracing::instrument(skip(self, with))]
1212    pub fn match_to(
1213        &mut self,
1214        with: &[f64],
1215        by: usize,
1216        join: Join,
1217    ) -> Result<&mut Self, crate::Error> {
1218        if by >= self.ncols()? {
1219            return Err(crate::Error::ColumnIndexOutOfBounds(by));
1220        }
1221        let mut col = self
1222            .col(by)?
1223            .unwrap()
1224            .iter()
1225            .enumerate()
1226            .collect::<Vec<_>>();
1227        col.sort_by(|a, b| a.1.partial_cmp(b.1).expect("could not compare"));
1228        let mut other = with.iter().enumerate().collect::<Vec<_>>();
1229        other.sort_by(|a, b| a.1.partial_cmp(b.1).expect("could not compare"));
1230        let mut order = Vec::with_capacity(match join {
1231            Join::Inner => col.len().min(other.len()),
1232            Join::Left => col.len(),
1233            Join::Right => other.len(),
1234        });
1235        let mut i = 0;
1236        let mut j = 0;
1237        while i < col.len() && j < other.len() {
1238            match col[i].1.partial_cmp(other[j].1) {
1239                Some(std::cmp::Ordering::Less) => i += 1,
1240                Some(std::cmp::Ordering::Equal) => {
1241                    order.push((other[j].0, col[i].0));
1242                    i += 1;
1243                    j += 1;
1244                },
1245                Some(std::cmp::Ordering::Greater) => j += 1,
1246                None => panic!("could not compare"),
1247            }
1248        }
1249        match join {
1250            Join::Inner => (),
1251            Join::Left => {
1252                if order.len() < col.len() {
1253                    return Err(crate::Error::NotAllRowsMatched(join));
1254                }
1255            },
1256            Join::Right => {
1257                if order.len() < other.len() {
1258                    return Err(crate::Error::NotAllRowsMatched(join));
1259                }
1260            },
1261        }
1262        let m = self.as_mat_ref()?;
1263        let data = vec![MaybeUninit::<f64>::uninit(); m.ncols() * order.len()];
1264        order.sort_by(|a, b| a.0.partial_cmp(&b.0).expect("could not compare"));
1265        let order = order.into_iter().map(|(_, o)| o).collect::<Vec<_>>();
1266        m.par_col_chunks(1).enumerate().for_each(|(i, c)| {
1267            let col = c.col(0);
1268            // SAFETY: No two threads will write to the same location
1269            let slice = unsafe {
1270                std::slice::from_raw_parts_mut(
1271                    data.as_ptr().add(i * order.len()).cast::<f64>().cast_mut(),
1272                    m.nrows(),
1273                )
1274            };
1275            for (i, o) in order.iter().enumerate() {
1276                slice[i] = col[*o];
1277            }
1278        });
1279        *self = Matrix::Owned(OwnedMatrix::new(
1280            order.len(),
1281            m.ncols(),
1282            // SAFETY: The data is initialized now
1283            unsafe {
1284                std::mem::transmute::<std::vec::Vec<std::mem::MaybeUninit<f64>>, std::vec::Vec<f64>>(
1285                    data,
1286                )
1287            },
1288            self.colnames()?
1289                .map(|x| x.iter().map(|x| x.to_string()).collect()),
1290        ));
1291
1292        // if self is sorted, then we can just go by pairs otherwise binary search
1293        Ok(self)
1294    }
1295
1296    pub fn t_match_to_by_column_name(
1297        &mut self,
1298        other: Vec<f64>,
1299        col: &str,
1300        join: Join,
1301    ) -> &mut Self {
1302        let col = col.to_string();
1303        self.add_transformation(move |m| m.match_to_by_column_name(&other, &col, join))
1304    }
1305
1306    #[tracing::instrument(skip(self, other))]
1307    pub fn match_to_by_column_name(
1308        &mut self,
1309        other: &[f64],
1310        col: &str,
1311        join: Join,
1312    ) -> Result<&mut Self, crate::Error> {
1313        let colnames = self.colnames()?;
1314        if colnames.is_none() {
1315            return Err(crate::Error::MissingColumnNames);
1316        }
1317        let by_col_idx = colnames
1318            .expect("colnames should be present")
1319            .iter()
1320            .position(|x| *x == col);
1321        if let Some(i) = by_col_idx {
1322            self.match_to(other, i, join)?;
1323            Ok(self)
1324        } else {
1325            Err(crate::Error::ColumnNameNotFound(col.to_string()))
1326        }
1327    }
1328
1329    pub fn t_join(
1330        &mut self,
1331        mut other: Self,
1332        self_by: usize,
1333        other_by: usize,
1334        join: Join,
1335    ) -> &mut Self {
1336        self.add_transformation(move |m| {
1337            let other = unsafe { std::mem::transmute::<&Self, &mut Self>(&other) };
1338            m.join(other, self_by, other_by, join)
1339        })
1340    }
1341
1342    #[tracing::instrument(skip(self, other))]
1343    pub fn join(
1344        &mut self,
1345        other: &mut Matrix,
1346        self_by: usize,
1347        other_by: usize,
1348        join: Join,
1349    ) -> Result<&mut Self, crate::Error> {
1350        if self_by >= self.ncols()? {
1351            return Err(crate::Error::ColumnIndexOutOfBounds(self_by));
1352        }
1353        if other_by >= other.ncols()? {
1354            return Err(crate::Error::ColumnIndexOutOfBounds(other_by));
1355        }
1356        let mut self_col = self
1357            .col(self_by)?
1358            .unwrap()
1359            .iter()
1360            .enumerate()
1361            .collect::<Vec<_>>();
1362        self_col.sort_by(|a, b| a.1.partial_cmp(b.1).expect("could not compare"));
1363        let mut other_col = other
1364            .col(other_by)?
1365            .unwrap()
1366            .iter()
1367            .enumerate()
1368            .collect::<Vec<_>>();
1369        other_col.sort_by(|a, b| a.1.partial_cmp(b.1).expect("could not compare"));
1370        let mut order = Vec::with_capacity(match join {
1371            Join::Inner => self_col.len().min(other_col.len()),
1372            Join::Left => self_col.len(),
1373            Join::Right => other_col.len(),
1374        });
1375        let mut i = 0;
1376        let mut j = 0;
1377        while i < self_col.len() && j < other_col.len() {
1378            match self_col[i].1.partial_cmp(other_col[j].1) {
1379                Some(std::cmp::Ordering::Less) => i += 1,
1380                Some(std::cmp::Ordering::Equal) => {
1381                    order.push((self_col[i].0, other_col[j].0));
1382                    i += 1;
1383                    j += 1;
1384                },
1385                Some(std::cmp::Ordering::Greater) => j += 1,
1386                None => panic!("could not compare"),
1387            }
1388        }
1389        order.sort_by_key(|x| x.0);
1390        trace!("order: {:?}", order);
1391        match join {
1392            Join::Inner => (),
1393            Join::Left => {
1394                if order.len() < self_col.len() {
1395                    return Err(crate::Error::NotAllRowsMatched(join));
1396                }
1397            },
1398            Join::Right => {
1399                if order.len() < other_col.len() {
1400                    return Err(crate::Error::NotAllRowsMatched(join));
1401                }
1402            },
1403        }
1404        let self_m = self.as_mat_ref()?;
1405        let other_m = other.as_mat_ref()?;
1406        let ncols = self_m.ncols() + other_m.ncols() - 1;
1407        let data = vec![MaybeUninit::<f64>::uninit(); ncols * order.len()];
1408        debug!("nrows: {}, ncols: {}", order.len(), ncols);
1409        let self_cols = (0..self_m.ncols()).collect::<Vec<_>>();
1410        let mut other_cols = (0..other_m.ncols()).collect::<Vec<_>>();
1411        other_cols.remove(other_by);
1412        rayon::scope(|s| {
1413            s.spawn(|_| {
1414                self_cols.par_iter().enumerate().for_each(|(i, &c)| {
1415                    let col = self_m.col(c);
1416                    let slice = unsafe {
1417                        std::slice::from_raw_parts_mut(
1418                            data.as_ptr().add(i * order.len()).cast::<f64>().cast_mut(),
1419                            self_m.nrows(),
1420                        )
1421                    };
1422                    for (i, o) in order.iter().enumerate() {
1423                        slice[i] = col[o.0];
1424                    }
1425                });
1426            });
1427            s.spawn(|_| {
1428                other_cols.par_iter().enumerate().for_each(|(i, &c)| {
1429                    let col = other_m.col(c);
1430                    let slice = unsafe {
1431                        std::slice::from_raw_parts_mut(
1432                            data.as_ptr()
1433                                .add((i + self_m.ncols()) * order.len())
1434                                .cast::<f64>()
1435                                .cast_mut(),
1436                            other_m.nrows(),
1437                        )
1438                    };
1439                    for (i, o) in order.iter().enumerate() {
1440                        slice[i] = col[o.1];
1441                    }
1442                });
1443            });
1444        });
1445        let mut self_colnames = self
1446            .colnames()?
1447            .map(|x| x.iter().map(|x| x.to_string()).collect::<Vec<_>>());
1448        let other_colnames = other.colnames()?;
1449        if let (Some(self_colnames), Some(mut other_colnames)) =
1450            (&mut self_colnames, other_colnames)
1451        {
1452            other_colnames.remove(other_by);
1453            self_colnames.extend(other_colnames.into_iter().map(|x| x.to_string()));
1454        }
1455        *self = Matrix::Owned(OwnedMatrix::new(
1456            order.len(),
1457            ncols,
1458            // SAFETY: The data is initialized now
1459            unsafe {
1460                std::mem::transmute::<std::vec::Vec<std::mem::MaybeUninit<f64>>, std::vec::Vec<f64>>(
1461                    data,
1462                )
1463            },
1464            self_colnames,
1465        ));
1466
1467        Ok(self)
1468    }
1469
1470    pub fn t_join_by_column_name(&mut self, mut other: Matrix, by: &str, join: Join) -> &mut Self {
1471        let by = by.to_string();
1472        self.add_transformation(move |m| {
1473            let other = unsafe { std::mem::transmute::<&Self, &mut Self>(&other) };
1474            m.join_by_column_name(other, &by, join)
1475        })
1476    }
1477
1478    #[tracing::instrument(skip(self, other))]
1479    pub fn join_by_column_name(
1480        &mut self,
1481        other: &mut Matrix,
1482        by: &str,
1483        join: Join,
1484    ) -> Result<&mut Self, crate::Error> {
1485        let self_colnames = self.colnames()?;
1486        let other_colnames = other.colnames()?;
1487        if self_colnames.is_none() || other_colnames.is_none() {
1488            if self_colnames.is_none() {
1489                debug!("self colnames are missing");
1490            } else {
1491                debug!("other colnames are missing");
1492            }
1493            return Err(crate::Error::MissingColumnNames);
1494        }
1495        let self_by_col_idx = self_colnames
1496            .expect("colnames should be present")
1497            .iter()
1498            .position(|x| *x == by);
1499        let other_by_col_idx = other_colnames
1500            .expect("colnames should be present")
1501            .iter()
1502            .position(|x| *x == by);
1503        if let (Some(i), Some(j)) = (self_by_col_idx, other_by_col_idx) {
1504            self.join(other, i, j, join)?;
1505            Ok(self)
1506        } else {
1507            Err(crate::Error::ColumnNameNotFound(by.to_string()))
1508        }
1509    }
1510
1511    pub fn t_standardize_columns(&mut self) -> &mut Self {
1512        self.add_transformation(|m| m.standardize_columns())
1513    }
1514
1515    #[tracing::instrument(skip(self))]
1516    pub fn standardize_columns(&mut self) -> Result<&mut Self, crate::Error> {
1517        debug!("Standardizing matrix");
1518        self.as_mat_mut()?
1519            .par_col_chunks_mut(1)
1520            .for_each(|c| standardize_column(c.col_mut(0)));
1521        debug!("Standardized matrix");
1522        Ok(self)
1523    }
1524
1525    pub fn t_standardize_rows(&mut self) -> &mut Self {
1526        self.add_transformation(|m| m.standardize_rows())
1527    }
1528
1529    #[tracing::instrument(skip(self))]
1530    pub fn standardize_rows(&mut self) -> Result<&mut Self, crate::Error> {
1531        debug!("Standardizing matrix");
1532        self.as_mat_mut()?
1533            .par_row_chunks_mut(1)
1534            .for_each(|r| standardize_row(r.row_mut(0)));
1535        debug!("Standardized matrix");
1536        Ok(self)
1537    }
1538
1539    pub fn t_remove_nan_rows(&mut self) -> &mut Self {
1540        self.add_transformation(|m| m.remove_nan_rows())
1541    }
1542
1543    #[tracing::instrument(skip(self))]
1544    pub fn remove_nan_rows(&mut self) -> Result<&mut Self, crate::Error> {
1545        let removing = self
1546            .as_mat_ref()?
1547            .par_row_chunks(1)
1548            .enumerate()
1549            .filter(|(_, row)| !row.is_all_finite())
1550            .map(|(i, _)| i)
1551            .collect::<HashSet<_>>();
1552        debug!("Removed {} rows with NaN values", removing.len());
1553        self.remove_rows(&removing)
1554    }
1555
1556    pub fn t_remove_nan_columns(&mut self) -> &mut Self {
1557        self.add_transformation(|m| m.remove_nan_columns())
1558    }
1559
1560    #[tracing::instrument(skip(self))]
1561    pub fn remove_nan_columns(&mut self) -> Result<&mut Self, crate::Error> {
1562        let removing = self
1563            .as_mat_ref()?
1564            .par_col_chunks(1)
1565            .enumerate()
1566            .filter(|(_, col)| !col.is_all_finite())
1567            .map(|(i, _)| i)
1568            .collect::<HashSet<_>>();
1569        debug!("Removed {} columns with NaN values", removing.len());
1570        self.remove_columns(&removing)
1571    }
1572
1573    pub fn t_nan_to_value(&mut self, val: f64) -> &mut Self {
1574        self.add_transformation(move |m| m.nan_to_value(val))
1575    }
1576
1577    #[tracing::instrument(skip(self))]
1578    pub fn nan_to_value(&mut self, val: f64) -> Result<&mut Self, crate::Error> {
1579        self.as_mat_mut()?.par_col_chunks_mut(1).for_each(|c| {
1580            c.col_mut(0).iter_mut().for_each(|x| {
1581                if !x.is_finite() {
1582                    *x = val;
1583                }
1584            })
1585        });
1586        Ok(self)
1587    }
1588
1589    pub fn t_nan_to_column_mean(&mut self) -> &mut Self {
1590        self.add_transformation(move |m| m.nan_to_column_mean())
1591    }
1592
1593    #[tracing::instrument(skip(self))]
1594    pub fn nan_to_column_mean(&mut self) -> Result<&mut Self, crate::Error> {
1595        self.as_mat_mut()?.par_col_chunks_mut(1).for_each(|c| {
1596            let col = c.col_mut(0);
1597            let m = mean::mean(
1598                col.as_ref()
1599                    .try_as_col_major()
1600                    .expect("could not get slice")
1601                    .as_slice(),
1602            );
1603            col.iter_mut().for_each(|x| {
1604                if !x.is_finite() {
1605                    *x = m;
1606                }
1607            })
1608        });
1609        Ok(self)
1610    }
1611
1612    pub fn t_nan_to_row_mean(&mut self) -> &mut Self {
1613        self.add_transformation(move |m| m.nan_to_row_mean())
1614    }
1615
1616    #[tracing::instrument(skip(self))]
1617    pub fn nan_to_row_mean(&mut self) -> Result<&mut Self, crate::Error> {
1618        self.as_mat_mut()?.par_row_chunks_mut(1).for_each(|r| {
1619            let row = r.row_mut(0);
1620            let mut m = 0.0;
1621            faer::stats::col_mean(
1622                ColMut::from_mut(&mut m),
1623                row.as_ref().as_mat(),
1624                faer::stats::NanHandling::Ignore,
1625            );
1626            row.iter_mut().for_each(|x| {
1627                if !x.is_finite() {
1628                    *x = m;
1629                }
1630            })
1631        });
1632        Ok(self)
1633    }
1634
1635    pub fn t_min_column_sum(&mut self, sum: f64) -> &mut Self {
1636        self.add_transformation(move |m| m.min_column_sum(sum))
1637    }
1638
1639    #[tracing::instrument(skip(self))]
1640    pub fn min_column_sum(&mut self, sum: f64) -> Result<&mut Self, crate::Error> {
1641        let removing = self
1642            .as_mat_mut()?
1643            .par_col_chunks_mut(1)
1644            .enumerate()
1645            .filter(|(_, c)| {
1646                crate::sum(c.as_ref().col(0).try_as_col_major().unwrap().as_slice()) < sum
1647            })
1648            .map(|(i, _)| i)
1649            .collect::<HashSet<_>>();
1650        debug!("Removed {} columns with sum < {}", removing.len(), sum);
1651        self.remove_columns(&removing)
1652    }
1653
1654    pub fn t_max_column_sum(&mut self, sum: f64) -> &mut Self {
1655        self.add_transformation(move |m| m.max_column_sum(sum))
1656    }
1657
1658    #[tracing::instrument(skip(self))]
1659    pub fn max_column_sum(&mut self, sum: f64) -> Result<&mut Self, crate::Error> {
1660        let removing = self
1661            .as_mat_mut()?
1662            .par_col_chunks_mut(1)
1663            .enumerate()
1664            .filter(|(_, c)| {
1665                crate::sum(c.as_ref().col(0).try_as_col_major().unwrap().as_slice()) > sum
1666            })
1667            .map(|(i, _)| i)
1668            .collect();
1669        self.remove_columns(&removing)
1670    }
1671
1672    pub fn t_min_row_sum(&mut self, sum: f64) -> &mut Self {
1673        self.add_transformation(move |m| m.min_row_sum(sum))
1674    }
1675
1676    #[tracing::instrument(skip(self))]
1677    pub fn min_row_sum(&mut self, sum: f64) -> Result<&mut Self, crate::Error> {
1678        let removing = self
1679            .as_mat_mut()?
1680            .par_row_chunks_mut(1)
1681            .enumerate()
1682            .filter(|(_, r)| r.sum() < sum)
1683            .map(|(i, _)| i)
1684            .collect();
1685        self.remove_rows(&removing)
1686    }
1687
1688    pub fn t_max_row_sum(&mut self, sum: f64) -> &mut Self {
1689        self.add_transformation(move |m| m.max_row_sum(sum))
1690    }
1691
1692    #[tracing::instrument(skip(self))]
1693    pub fn max_row_sum(&mut self, sum: f64) -> Result<&mut Self, crate::Error> {
1694        let removing = self
1695            .as_mat_mut()?
1696            .par_row_chunks_mut(1)
1697            .enumerate()
1698            .filter(|(_, r)| r.sum() > sum)
1699            .map(|(i, _)| i)
1700            .collect();
1701        self.remove_rows(&removing)
1702    }
1703
1704    #[cfg_attr(coverage_nightly, coverage(off))]
1705    pub fn t_rename_column(&mut self, old: &str, new: &str) -> &mut Self {
1706        let old = old.to_string();
1707        let new = new.to_string();
1708        self.add_transformation(move |m| m.rename_column(&old, &new))
1709    }
1710
1711    #[tracing::instrument(skip(self))]
1712    #[cfg_attr(coverage_nightly, coverage(off))]
1713    pub fn rename_column(&mut self, old: &str, new: &str) -> Result<&mut Self, crate::Error> {
1714        let colnames = self.colnames()?;
1715        if colnames.is_none() {
1716            return Err(crate::Error::MissingColumnNames);
1717        }
1718        let idx = colnames
1719            .as_ref()
1720            .expect("colnames should be present")
1721            .iter()
1722            .position(|x| *x == old);
1723        if let Some(i) = idx {
1724            let mut colnames = colnames
1725                .expect("colnames should be present")
1726                .into_iter()
1727                .map(|x| {
1728                    if x == old {
1729                        new.to_string()
1730                    } else {
1731                        x.to_string()
1732                    }
1733                })
1734                .collect::<Vec<_>>();
1735            self.into_owned()?;
1736            self.as_owned_mut()?.colnames = Some(colnames);
1737            Ok(self)
1738        } else {
1739            Err(crate::Error::ColumnNameNotFound(old.to_string()))
1740        }
1741    }
1742
1743    #[cfg_attr(coverage_nightly, coverage(off))]
1744    pub fn t_rename_column_if_exists(&mut self, old: &str, new: &str) -> &mut Self {
1745        let old = old.to_string();
1746        let new = new.to_string();
1747        self.add_transformation(move |m| m.rename_column_if_exists(&old, &new))
1748    }
1749
1750    #[tracing::instrument(skip(self))]
1751    #[cfg_attr(coverage_nightly, coverage(off))]
1752    pub fn rename_column_if_exists(
1753        &mut self,
1754        old: &str,
1755        new: &str,
1756    ) -> Result<&mut Self, crate::Error> {
1757        match self.rename_column(old, new) {
1758            Ok(_) => Ok(self),
1759            Err(crate::Error::ColumnNameNotFound(_) | crate::Error::MissingColumnNames) => Ok(self),
1760            Err(e) => Err(e),
1761        }
1762    }
1763
1764    pub fn t_remove_duplicate_columns(&mut self) -> &mut Self {
1765        self.add_transformation(move |m| m.remove_duplicate_columns())
1766    }
1767
1768    #[tracing::instrument(skip(self))]
1769    pub fn remove_duplicate_columns(&mut self) -> Result<&mut Self, crate::Error> {
1770        let ncols = self.ncols()?;
1771        let cols = (0..ncols)
1772            .into_par_iter()
1773            .map(|x| self.col_loaded(x))
1774            .collect::<Vec<_>>();
1775        let cols = (0..ncols)
1776            .into_par_iter()
1777            .flat_map(|i| ((i + 1)..ncols).into_par_iter().map(move |j| (i, j)))
1778            .filter_map(|(i, j)| if cols[i] == cols[j] { Some(j) } else { None })
1779            .collect::<HashSet<_>>();
1780        self.remove_columns(&cols)
1781    }
1782
1783    pub fn t_remove_identical_columns(&mut self) -> &mut Self {
1784        self.add_transformation(move |m| m.remove_identical_columns())
1785    }
1786
1787    #[tracing::instrument(skip(self))]
1788    pub fn remove_identical_columns(&mut self) -> Result<&mut Self, crate::Error> {
1789        let ncols = self.ncols()?;
1790        let cols = (0..ncols)
1791            .into_par_iter()
1792            .map(|x| self.col_loaded(x).unwrap())
1793            .collect::<Vec<_>>();
1794        let cols = (0..ncols)
1795            .into_par_iter()
1796            .filter_map(|i| {
1797                let col = cols[i];
1798                let first = col[0];
1799                for i in col.iter().skip(1) {
1800                    if *i != first {
1801                        return None;
1802                    }
1803                }
1804                Some(i)
1805            })
1806            .collect::<HashSet<_>>();
1807        self.remove_columns(&cols)
1808    }
1809
1810    pub fn t_min_non_nan(&mut self, val: usize) -> &mut Self {
1811        self.add_transformation(move |m| m.min_non_nan(val))
1812    }
1813
1814    #[tracing::instrument(skip(self))]
1815    pub fn min_non_nan(&mut self, val: usize) -> Result<&mut Self, crate::Error> {
1816        let removing = self
1817            .as_mat_mut()?
1818            .par_col_chunks_mut(1)
1819            .enumerate()
1820            .filter(|(_, c)| c.as_ref().col(0).iter().filter(|x| x.is_finite()).count() < val)
1821            .map(|(i, _)| i)
1822            .collect();
1823        self.remove_columns(&removing)
1824    }
1825
1826    pub fn t_max_non_nan(&mut self, val: usize) -> &mut Self {
1827        self.add_transformation(move |m| m.max_non_nan(val))
1828    }
1829
1830    #[tracing::instrument(skip(self))]
1831    pub fn max_non_nan(&mut self, val: usize) -> Result<&mut Self, crate::Error> {
1832        let removing = self
1833            .as_mat_mut()?
1834            .par_col_chunks_mut(1)
1835            .enumerate()
1836            .filter(|(_, c)| c.as_ref().col(0).iter().filter(|x| x.is_finite()).count() > val)
1837            .map(|(i, _)| i)
1838            .collect();
1839        self.remove_columns(&removing)
1840    }
1841
1842    pub fn t_subset_columns(&mut self, cols: HashSet<usize>) -> &mut Self {
1843        self.add_transformation(move |m| m.subset_columns(&cols))
1844    }
1845
1846    #[tracing::instrument(skip(self))]
1847    pub fn subset_columns(&mut self, cols: &HashSet<usize>) -> Result<&mut Self, crate::Error> {
1848        let ncols = self.ncols()?;
1849        let removing = (0..ncols)
1850            .into_par_iter()
1851            .filter(|x| !cols.contains(x))
1852            .collect::<HashSet<_>>();
1853        self.remove_columns(&removing)
1854    }
1855
1856    pub fn t_subset_columns_by_name(&mut self, cols: HashSet<String>) -> &mut Self {
1857        self.add_transformation(move |m| m.subset_columns_by_name(&cols))
1858    }
1859
1860    #[tracing::instrument(skip(self))]
1861    pub fn subset_columns_by_name(
1862        &mut self,
1863        cols: &HashSet<String>,
1864    ) -> Result<&mut Self, crate::Error> {
1865        let colnames = self.colnames()?;
1866        if colnames.is_none() {
1867            return Err(crate::Error::MissingColumnNames);
1868        }
1869        let colnames = colnames
1870            .expect("colnames should be present")
1871            .into_iter()
1872            .map(|x| x.to_string())
1873            .collect::<Vec<_>>();
1874        let cols = colnames
1875            .iter()
1876            .enumerate()
1877            .filter_map(|(i, x)| if cols.contains(x) { Some(i) } else { None })
1878            .collect::<HashSet<_>>();
1879        self.subset_columns(&cols)
1880    }
1881
1882    pub fn t_rename_columns_with_regex(&mut self, regex: &str, replacement: &str) -> &mut Self {
1883        let regex = regex.to_string();
1884        let replacement = replacement.to_string();
1885        self.add_transformation(move |m| m.rename_columns_with_regex(&regex, &replacement))
1886    }
1887
1888    #[tracing::instrument(skip(self))]
1889    pub fn rename_columns_with_regex(
1890        &mut self,
1891        regex: &str,
1892        replacement: &str,
1893    ) -> Result<&mut Self, crate::Error> {
1894        let colnames = self.colnames()?;
1895        if colnames.is_none() {
1896            return Err(crate::Error::MissingColumnNames);
1897        }
1898        let re = Regex::new(regex)?;
1899        let colnames = colnames
1900            .expect("colnames should be present")
1901            .iter()
1902            .map(|x| re.replace_all(x, replacement).to_string())
1903            .collect::<Vec<_>>();
1904        self.set_colnames(colnames)?;
1905        Ok(self)
1906    }
1907
1908    pub fn t_scale_columns(&mut self, scale: Vec<f64>) -> &mut Self {
1909        self.add_transformation(move |m| m.scale_columns(&scale))
1910    }
1911
1912    #[tracing::instrument(skip(self))]
1913    pub fn scale_columns(&mut self, scale: &[f64]) -> Result<&mut Self, crate::Error> {
1914        if scale.len() == 1 {
1915            let mut mat = self.as_mat_mut()?;
1916            let scale = scale[0];
1917            for i in 0..mat.ncols() {
1918                for j in 0..mat.nrows() {
1919                    mat[(j, i)] *= scale;
1920                }
1921            }
1922        } else if scale.len() == self.ncols()? {
1923            let mut mat = self.as_mat_mut()?;
1924            for i in 0..mat.ncols() {
1925                let scale = scale[i];
1926                for j in 0..mat.nrows() {
1927                    mat[(j, i)] *= scale;
1928                }
1929            }
1930        } else {
1931            return Err(crate::Error::InvalidScaleLength(scale.len()));
1932        }
1933
1934        Ok(self)
1935    }
1936
1937    pub fn t_scale_rows(&mut self, scale: Vec<f64>) -> &mut Self {
1938        self.add_transformation(move |m| m.scale_rows(&scale))
1939    }
1940
1941    #[tracing::instrument(skip(self))]
1942    pub fn scale_rows(&mut self, scale: &[f64]) -> Result<&mut Self, crate::Error> {
1943        if scale.len() == 1 {
1944            let mut mat = self.as_mat_mut()?;
1945            let scale = scale[0];
1946            for i in 0..mat.nrows() {
1947                for j in 0..mat.ncols() {
1948                    mat[(i, j)] *= scale;
1949                }
1950            }
1951        } else if scale.len() == self.nrows()? {
1952            let mut mat = self.as_mat_mut()?;
1953            for i in 0..mat.nrows() {
1954                let scale = scale[i];
1955                for j in 0..mat.ncols() {
1956                    mat[(i, j)] *= scale;
1957                }
1958            }
1959        } else {
1960            return Err(crate::Error::InvalidScaleLength(scale.len()));
1961        }
1962
1963        Ok(self)
1964    }
1965}
1966
1967impl Matrix {
1968    #[cfg_attr(coverage_nightly, coverage(off))]
1969    pub fn nrows(&mut self) -> Result<usize, crate::Error> {
1970        self.as_mat_ref().map(|x| x.nrows())
1971    }
1972
1973    #[cfg_attr(coverage_nightly, coverage(off))]
1974    pub fn nrows_loaded(&self) -> usize {
1975        self.as_mat_ref_loaded().nrows()
1976    }
1977
1978    #[cfg_attr(coverage_nightly, coverage(off))]
1979    pub fn ncols(&mut self) -> Result<usize, crate::Error> {
1980        self.as_mat_ref().map(|x| x.ncols())
1981    }
1982
1983    #[cfg_attr(coverage_nightly, coverage(off))]
1984    pub fn ncols_loaded(&self) -> usize {
1985        self.as_mat_ref_loaded().ncols()
1986    }
1987
1988    #[cfg_attr(coverage_nightly, coverage(off))]
1989    pub fn data(&mut self) -> Result<&[f64], crate::Error> {
1990        self.as_owned_ref().map(|x| x.data.as_slice())
1991    }
1992
1993    #[cfg_attr(coverage_nightly, coverage(off))]
1994    pub fn as_mut_slice(&mut self) -> Result<&mut [f64], crate::Error> {
1995        match self {
1996            Matrix::Owned(m) => Ok(&mut m.data),
1997            #[cfg(feature = "r")]
1998            Matrix::R(m) => Ok(unsafe {
1999                std::slice::from_raw_parts_mut(m.data().as_ptr().cast_mut(), m.data().len())
2000            }),
2001            ref m => self.into_owned()?.as_mut_slice(),
2002        }
2003    }
2004
2005    #[cfg_attr(coverage_nightly, coverage(off))]
2006    pub fn col(&mut self, col: usize) -> Result<Option<&[f64]>, crate::Error> {
2007        if col >= self.ncols()? {
2008            return Ok(None);
2009        }
2010        self.as_mat_ref().map(|x| {
2011            Some(unsafe {
2012                x.get_unchecked(.., col)
2013                    .try_as_col_major()
2014                    .expect("could not get slice")
2015                    .as_slice()
2016            })
2017        })
2018    }
2019
2020    #[cfg_attr(coverage_nightly, coverage(off))]
2021    pub fn col_loaded(&self, col: usize) -> Option<&[f64]> {
2022        if col >= self.ncols_loaded() {
2023            return None;
2024        }
2025        Some(unsafe {
2026            self.as_mat_ref_loaded()
2027                .get_unchecked(.., col)
2028                .try_as_col_major()
2029                .expect("could not get slice")
2030                .as_slice()
2031        })
2032    }
2033
2034    #[cfg_attr(coverage_nightly, coverage(off))]
2035    pub fn get(&mut self, row: usize, col: usize) -> Result<Option<f64>, crate::Error> {
2036        let nrows = self.nrows()?;
2037        let ncols = self.ncols()?;
2038        self.as_mat_ref().map(|x| {
2039            if row >= nrows || col > ncols {
2040                None
2041            } else {
2042                Some(unsafe { *x.get_unchecked(row, col) })
2043            }
2044        })
2045    }
2046
2047    #[cfg_attr(coverage_nightly, coverage(off))]
2048    pub fn get_loaded(&self, row: usize, col: usize) -> Option<f64> {
2049        let nrows = self.nrows_loaded();
2050        let ncols = self.ncols_loaded();
2051        if row >= nrows || col > ncols {
2052            None
2053        } else {
2054            Some(unsafe { *self.as_mat_ref_loaded().get_unchecked(row, col) })
2055        }
2056    }
2057
2058    #[cfg_attr(coverage_nightly, coverage(off))]
2059    pub fn column_index(&mut self, name: &str) -> Result<usize, crate::Error> {
2060        let colnames = self.colnames()?;
2061        if colnames.is_none() {
2062            return Err(crate::Error::MissingColumnNames);
2063        }
2064        let idx = colnames
2065            .expect("colnames should be present")
2066            .iter()
2067            .position(|x| *x == name);
2068        match idx {
2069            Some(i) => Ok(i),
2070            None => Err(crate::Error::ColumnNameNotFound(name.to_string())),
2071        }
2072    }
2073
2074    #[cfg_attr(coverage_nightly, coverage(off))]
2075    pub fn has_column(&mut self, name: &str) -> Result<bool, crate::Error> {
2076        self.column_index(name).map(|_| true).or_else(|e| match e {
2077            crate::Error::ColumnNameNotFound(_) => Ok(false),
2078            e => Err(e),
2079        })
2080    }
2081
2082    #[cfg_attr(coverage_nightly, coverage(off))]
2083    pub fn has_column_loaded(&self, name: &str) -> bool {
2084        self.colnames_loaded()
2085            .map(|x| x.contains(&name))
2086            .unwrap_or(false)
2087    }
2088
2089    #[cfg_attr(coverage_nightly, coverage(off))]
2090    pub fn column_by_name(&mut self, name: &str) -> Result<Option<&[f64]>, crate::Error> {
2091        let col = self.column_index(name)?;
2092        self.col(col)
2093    }
2094}
2095
2096impl FromStr for Matrix {
2097    type Err = crate::Error;
2098
2099    #[cfg_attr(coverage_nightly, coverage(off))]
2100    fn from_str(s: &str) -> Result<Self, Self::Err> {
2101        Ok(Matrix::File(s.parse()?))
2102    }
2103}
2104
2105#[derive(
2106    Clone,
2107    Debug,
2108    serde::Serialize,
2109    serde::Deserialize,
2110    rkyv::Archive,
2111    rkyv::Serialize,
2112    rkyv::Deserialize,
2113)]
2114#[archive(check_bytes)]
2115pub struct OwnedMatrix {
2116    pub(crate) nrows: usize,
2117    pub(crate) ncols: usize,
2118    pub(crate) colnames: Option<Vec<String>>,
2119    pub(crate) data: Vec<f64>,
2120}
2121
2122impl PartialEq for OwnedMatrix {
2123    #[cfg_attr(coverage_nightly, coverage(off))]
2124    fn eq(&self, other: &Self) -> bool {
2125        self.nrows == other.nrows
2126            && self.ncols == other.ncols
2127            && self.colnames == other.colnames
2128            && self.data.len() == other.data.len()
2129            && self
2130                .data
2131                .iter()
2132                .zip(other.data.iter())
2133                .all(|(a, b)| a.to_bits() == b.to_bits())
2134    }
2135}
2136
2137impl OwnedMatrix {
2138    pub fn new(rows: usize, cols: usize, data: Vec<f64>, colnames: Option<Vec<String>>) -> Self {
2139        assert!(rows * cols == data.len());
2140        if let Some(colnames) = &colnames {
2141            assert_eq!(cols, colnames.len());
2142        }
2143        Self {
2144            nrows: rows,
2145            ncols: cols,
2146            data,
2147            colnames,
2148        }
2149    }
2150
2151    #[cfg_attr(coverage_nightly, coverage(off))]
2152    #[doc(hidden)]
2153    pub fn into_data(self) -> Vec<f64> {
2154        self.data
2155    }
2156}
2157
2158pub trait IntoMatrix {
2159    fn into_matrix(self) -> Matrix;
2160}
2161
2162#[cfg(feature = "r")]
2163impl IntoMatrix for RMatrix<f64> {
2164    #[cfg_attr(coverage_nightly, coverage(off))]
2165    fn into_matrix(self) -> Matrix {
2166        Matrix::R(self)
2167    }
2168}
2169
2170#[cfg(feature = "r")]
2171impl IntoMatrix for RMatrix<i32> {
2172    #[cfg_attr(coverage_nightly, coverage(off))]
2173    fn into_matrix(self) -> Matrix {
2174        Matrix::from_robj(self.into_robj()).unwrap()
2175    }
2176}
2177
2178impl IntoMatrix for OwnedMatrix {
2179    #[cfg_attr(coverage_nightly, coverage(off))]
2180    fn into_matrix(self) -> Matrix {
2181        Matrix::Owned(self)
2182    }
2183}
2184
2185impl IntoMatrix for File {
2186    #[cfg_attr(coverage_nightly, coverage(off))]
2187    fn into_matrix(self) -> Matrix {
2188        Matrix::File(self)
2189    }
2190}
2191
2192impl IntoMatrix for MatRef<'_, f64> {
2193    #[cfg_attr(coverage_nightly, coverage(off))]
2194    fn into_matrix(self) -> Matrix {
2195        Matrix::from_mat_ref(self)
2196    }
2197}
2198
2199impl IntoMatrix for MatMut<'_, f64> {
2200    #[cfg_attr(coverage_nightly, coverage(off))]
2201    fn into_matrix(self) -> Matrix {
2202        Matrix::from_mat_ref(self.as_ref())
2203    }
2204}
2205
2206impl IntoMatrix for Mat<f64> {
2207    #[cfg_attr(coverage_nightly, coverage(off))]
2208    fn into_matrix(self) -> Matrix {
2209        Matrix::from_mat_ref(self.as_ref())
2210    }
2211}
2212
2213pub trait TryIntoMatrix {
2214    type Err;
2215
2216    fn try_into_matrix(self) -> Result<Matrix, Self::Err>;
2217}
2218
2219#[cfg(feature = "r")]
2220impl TryIntoMatrix for Robj {
2221    type Err = crate::Error;
2222
2223    #[cfg_attr(coverage_nightly, coverage(off))]
2224    fn try_into_matrix(self) -> Result<Matrix, Self::Err> {
2225        Matrix::from_robj(self)
2226    }
2227}
2228
2229impl TryIntoMatrix for &str {
2230    type Err = crate::Error;
2231
2232    #[cfg_attr(coverage_nightly, coverage(off))]
2233    fn try_into_matrix(self) -> Result<Matrix, Self::Err> {
2234        Ok(Matrix::File(self.parse()?))
2235    }
2236}
2237
2238impl<T> TryIntoMatrix for T
2239where
2240    T: IntoMatrix,
2241{
2242    type Err = ();
2243
2244    #[cfg_attr(coverage_nightly, coverage(off))]
2245    fn try_into_matrix(self) -> Result<Matrix, Self::Err> {
2246        Ok(self.into_matrix())
2247    }
2248}
2249
2250impl<T> From<T> for Matrix
2251where
2252    T: IntoMatrix,
2253{
2254    #[cfg_attr(coverage_nightly, coverage(off))]
2255    fn from(t: T) -> Self {
2256        t.into_matrix()
2257    }
2258}
2259
2260#[cfg(test)]
2261mod tests {
2262    use faer::traits::pulp::num_complex::Complex;
2263    use test_log::test;
2264
2265    use super::*;
2266
2267    macro_rules! assert_float_eq {
2268        ($a:expr, $b:expr, $tol:expr) => {
2269            assert!(($a - $b).abs() < $tol, "{:.22} != {:.22}", $a, $b);
2270        };
2271    }
2272
2273    macro_rules! float_eq {
2274        ($a:expr, $b:expr) => {
2275            assert_float_eq!($a, $b, 1e-12);
2276        };
2277    }
2278
2279    macro_rules! rough_eq {
2280        ($a:expr, $b:expr) => {
2281            assert_float_eq!($a, $b, 1e-3);
2282        };
2283    }
2284
2285    #[test]
2286    fn test_combine_columns_success() {
2287        let mut m1 = OwnedMatrix::new(
2288            3,
2289            2,
2290            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
2291            Some(vec!["a".to_string(), "b".to_string()]),
2292        )
2293        .into_matrix();
2294        let m2 = OwnedMatrix::new(
2295            3,
2296            2,
2297            vec![7.0, 8.0, 9.0, 10.0, 11.0, 12.0],
2298            Some(vec!["c".to_string(), "d".to_string()]),
2299        )
2300        .into_matrix();
2301        let m3 = OwnedMatrix::new(
2302            3,
2303            2,
2304            vec![13.0, 14.0, 15.0, 16.0, 17.0, 18.0],
2305            Some(vec!["e".to_string(), "f".to_string()]),
2306        )
2307        .into_matrix();
2308        let m = m1.t_combine_columns(vec![m2, m3]);
2309        assert_eq!(
2310            m.data().unwrap(),
2311            &[
2312                1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0,
2313                16.0, 17.0, 18.0
2314            ],
2315        );
2316        assert_eq!(
2317            m.colnames().unwrap().unwrap(),
2318            &[
2319                "a".to_string(),
2320                "b".to_string(),
2321                "c".to_string(),
2322                "d".to_string(),
2323                "e".to_string(),
2324                "f".to_string()
2325            ]
2326        );
2327        assert_eq!(m.nrows().unwrap(), 3);
2328        assert_eq!(m.ncols().unwrap(), 6);
2329    }
2330
2331    #[test]
2332    fn test_combine_columns_dimensions_mismatch() {
2333        let mut m1 = OwnedMatrix::new(3, 2, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], None).into_matrix();
2334        let m2 = OwnedMatrix::new(2, 2, vec![19.0, 20.0, 21.0, 22.0], None).into_matrix();
2335        let res = m1.combine_columns(&mut [m2]).unwrap_err();
2336        assert!(matches!(res, Error::MatrixDimensionsMismatch));
2337    }
2338
2339    #[test]
2340    fn test_combine_columns_no_colnames() {
2341        let mut m1 = OwnedMatrix::new(
2342            3,
2343            2,
2344            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
2345            Some(vec!["a".to_string(), "b".to_string()]),
2346        )
2347        .into_matrix();
2348        let m2 =
2349            OwnedMatrix::new(3, 2, vec![19.0, 20.0, 21.0, 22.0, 23.0, 24.0], None).into_matrix();
2350        let m = m1.t_combine_columns(vec![m2]);
2351        assert_eq!(
2352            m.data().unwrap(),
2353            &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0]
2354        );
2355        assert!(m.colnames().unwrap().is_none());
2356        assert_eq!(m.nrows().unwrap(), 3);
2357        assert_eq!(m.ncols().unwrap(), 4);
2358    }
2359
2360    #[test]
2361    fn test_combine_rows_success() {
2362        let mut m1 = OwnedMatrix::new(
2363            3,
2364            2,
2365            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
2366            Some(vec!["a".to_string(), "b".to_string()]),
2367        )
2368        .into_matrix();
2369        let m2 = OwnedMatrix::new(
2370            3,
2371            2,
2372            vec![7.0, 8.0, 9.0, 10.0, 11.0, 12.0],
2373            Some(vec!["a".to_string(), "b".to_string()]),
2374        )
2375        .into_matrix();
2376        let m3 = OwnedMatrix::new(
2377            3,
2378            2,
2379            vec![13.0, 14.0, 15.0, 16.0, 17.0, 18.0],
2380            Some(vec!["a".to_string(), "b".to_string()]),
2381        )
2382        .into_matrix();
2383        let m = m1.t_combine_rows(vec![m2, m3]);
2384        assert_eq!(
2385            m.data().unwrap(),
2386            &[
2387                1.0, 2.0, 3.0, 7.0, 8.0, 9.0, 13.0, 14.0, 15.0, 4.0, 5.0, 6.0, 10.0, 11.0, 12.0,
2388                16.0, 17.0, 18.0
2389            ],
2390        );
2391        assert_eq!(
2392            m.colnames().unwrap().unwrap(),
2393            &["a".to_string(), "b".to_string()]
2394        );
2395        assert_eq!(m.nrows().unwrap(), 9);
2396        assert_eq!(m.ncols().unwrap(), 2);
2397    }
2398
2399    #[test]
2400    fn test_combine_rows_dimensions_mismatch() {
2401        let mut m1 = OwnedMatrix::new(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], None).into_matrix();
2402        let m2 = OwnedMatrix::new(2, 2, vec![19.0, 20.0, 21.0, 22.0], None).into_matrix();
2403        let res = m1.combine_rows(&mut [m2]).unwrap_err();
2404        assert!(matches!(res, Error::MatrixDimensionsMismatch));
2405    }
2406
2407    #[test]
2408    fn test_combine_rows_column_names_mismatch() {
2409        let mut m1 = OwnedMatrix::new(
2410            3,
2411            2,
2412            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
2413            Some(vec!["a".to_string(), "b".to_string()]),
2414        )
2415        .into_matrix();
2416        let m2 = OwnedMatrix::new(
2417            3,
2418            2,
2419            vec![19.0, 20.0, 21.0, 22.0, 23.0, 24.0],
2420            Some(vec!["c".to_string(), "d".to_string()]),
2421        )
2422        .into_matrix();
2423        let m = m1.combine_rows(&mut [m2]).unwrap_err();
2424        assert!(matches!(m, Error::ColumnNamesMismatch));
2425    }
2426
2427    #[test]
2428    fn test_remove_rows_success() {
2429        let mut m = OwnedMatrix::new(
2430            3,
2431            2,
2432            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
2433            Some(vec!["a".to_string(), "b".to_string()]),
2434        )
2435        .into_matrix();
2436        let mut removing = HashSet::new();
2437        removing.insert(1);
2438        let m = m.t_remove_rows(removing);
2439        assert_eq!(m.data().unwrap(), &[1.0, 3.0, 4.0, 6.0]);
2440        assert_eq!(
2441            m.colnames().unwrap().unwrap(),
2442            &["a".to_string(), "b".to_string()]
2443        );
2444        assert_eq!(m.nrows().unwrap(), 2);
2445        assert_eq!(m.ncols().unwrap(), 2);
2446    }
2447
2448    #[test]
2449    fn test_remove_rows_empty() {
2450        let mut m = OwnedMatrix::new(3, 2, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], None).into_matrix();
2451        let removing = HashSet::new();
2452        let m = m.t_remove_rows(removing);
2453        assert_eq!(m.data().unwrap(), &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
2454        assert!(m.colnames().unwrap().is_none());
2455        assert_eq!(m.nrows().unwrap(), 3);
2456        assert_eq!(m.ncols().unwrap(), 2);
2457    }
2458
2459    #[test]
2460    fn test_remove_rows_index_out_of_bounds() {
2461        let mut m = OwnedMatrix::new(
2462            3,
2463            2,
2464            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
2465            Some(vec!["a".to_string(), "b".to_string()]),
2466        )
2467        .into_matrix();
2468        let mut removing = HashSet::new();
2469        removing.insert(3);
2470        let m = m.remove_rows(&removing).unwrap_err();
2471        assert!(matches!(m, Error::RowIndexOutOfBounds(3)));
2472    }
2473
2474    #[test]
2475    fn test_remove_columns_success() {
2476        let mut m = OwnedMatrix::new(
2477            3,
2478            2,
2479            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
2480            Some(vec!["a".to_string(), "b".to_string()]),
2481        )
2482        .into_matrix();
2483        let mut removing = HashSet::new();
2484        removing.insert(1);
2485        let m = m.t_remove_columns(removing);
2486        assert_eq!(m.data().unwrap(), &[1.0, 2.0, 3.0]);
2487        assert_eq!(m.colnames().unwrap().unwrap(), &["a".to_string()]);
2488        assert_eq!(m.nrows().unwrap(), 3);
2489        assert_eq!(m.ncols().unwrap(), 1);
2490    }
2491
2492    #[test]
2493    fn test_remove_columns_empty() {
2494        let mut m = OwnedMatrix::new(3, 2, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], None).into_matrix();
2495        let removing = HashSet::new();
2496        let m = m.t_remove_columns(removing);
2497        assert_eq!(m.data().unwrap(), &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
2498        assert!(m.colnames().unwrap().is_none());
2499        assert_eq!(m.nrows().unwrap(), 3);
2500        assert_eq!(m.ncols().unwrap(), 2);
2501    }
2502
2503    #[test]
2504    fn test_remove_columns_index_out_of_bounds() {
2505        let mut m = OwnedMatrix::new(
2506            3,
2507            2,
2508            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
2509            Some(vec!["a".to_string(), "b".to_string()]),
2510        )
2511        .into_matrix();
2512        let mut removing = HashSet::new();
2513        removing.insert(2);
2514        let m = m.remove_columns(&removing).unwrap_err();
2515        assert!(matches!(m, Error::ColumnIndexOutOfBounds(2)));
2516    }
2517
2518    #[test]
2519    fn test_remove_column_by_name_success() {
2520        let mut m = OwnedMatrix::new(
2521            3,
2522            2,
2523            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
2524            Some(vec!["a".to_string(), "b".to_string()]),
2525        )
2526        .into_matrix();
2527        let m = m.t_remove_column_by_name("a");
2528        assert_eq!(m.data().unwrap(), &[4.0, 5.0, 6.0]);
2529        assert_eq!(m.colnames().unwrap().unwrap(), &["b".to_string()]);
2530        assert_eq!(m.nrows().unwrap(), 3);
2531        assert_eq!(m.ncols().unwrap(), 1);
2532    }
2533
2534    #[test]
2535    fn test_remove_column_by_name_column_not_found() {
2536        let mut m = OwnedMatrix::new(
2537            3,
2538            2,
2539            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
2540            Some(vec!["a".to_string(), "b".to_string()]),
2541        )
2542        .into_matrix();
2543        let m = m.remove_column_by_name("c").unwrap_err();
2544        assert!(matches!(m, Error::ColumnNameNotFound(_)));
2545    }
2546
2547    #[test]
2548    fn test_remove_column_by_name_if_exists_success() {
2549        let mut m = OwnedMatrix::new(
2550            3,
2551            2,
2552            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
2553            Some(vec!["a".to_string(), "b".to_string()]),
2554        )
2555        .into_matrix();
2556        let m = m.t_remove_column_by_name_if_exists("a");
2557        assert_eq!(m.data().unwrap(), &[4.0, 5.0, 6.0]);
2558        assert_eq!(m.colnames().unwrap().unwrap(), &["b".to_string()]);
2559        assert_eq!(m.nrows().unwrap(), 3);
2560        assert_eq!(m.ncols().unwrap(), 1);
2561    }
2562
2563    #[test]
2564    fn test_remove_column_by_name_if_exists_column_not_found() {
2565        let mut m = OwnedMatrix::new(
2566            3,
2567            2,
2568            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
2569            Some(vec!["a".to_string(), "b".to_string()]),
2570        )
2571        .into_matrix();
2572        let m = m.t_remove_column_by_name_if_exists("c");
2573        assert_eq!(m.data().unwrap(), &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
2574        assert_eq!(
2575            m.colnames().unwrap().unwrap(),
2576            &["a".to_string(), "b".to_string()]
2577        );
2578        assert_eq!(m.nrows().unwrap(), 3);
2579        assert_eq!(m.ncols().unwrap(), 2);
2580    }
2581
2582    #[test]
2583    fn test_remove_columns_by_name_success() {
2584        let mut m = OwnedMatrix::new(
2585            3,
2586            3,
2587            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0],
2588            Some(vec!["a".to_string(), "b".to_string(), "c".to_string()]),
2589        )
2590        .into_matrix();
2591        let m = m
2592            .t_remove_columns_by_name(HashSet::from_iter(["a", "c"].iter().map(|x| x.to_string())));
2593        assert_eq!(m.data().unwrap(), &[4.0, 5.0, 6.0]);
2594        assert_eq!(m.colnames().unwrap().unwrap(), &["b".to_string()]);
2595        assert_eq!(m.nrows().unwrap(), 3);
2596        assert_eq!(m.ncols().unwrap(), 1);
2597    }
2598
2599    #[test]
2600    fn test_transpose() {
2601        let mut m = OwnedMatrix::new(
2602            3,
2603            2,
2604            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
2605            Some(vec!["a".to_string(), "b".to_string()]),
2606        )
2607        .into_matrix();
2608        let m = m.t_transpose();
2609        assert_eq!(m.data().unwrap(), &[1.0, 4.0, 2.0, 5.0, 3.0, 6.0]);
2610        assert_eq!(m.nrows().unwrap(), 2);
2611        assert_eq!(m.ncols().unwrap(), 3);
2612        assert!(m.colnames().unwrap().is_none(),);
2613    }
2614
2615    #[test]
2616    fn test_sort_by_column_success() {
2617        let mut m = OwnedMatrix::new(
2618            3,
2619            2,
2620            vec![3.0, 2.0, 1.0, 6.0, 5.0, 4.0],
2621            Some(vec!["a".to_string(), "b".to_string()]),
2622        )
2623        .into_matrix();
2624        let m = m.sort_by_column(0).unwrap();
2625        assert_eq!(m.data().unwrap(), &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
2626        assert_eq!(m.nrows().unwrap(), 3);
2627        assert_eq!(m.ncols().unwrap(), 2);
2628        assert_eq!(
2629            m.colnames().unwrap().unwrap(),
2630            &["a".to_string(), "b".to_string()]
2631        );
2632    }
2633
2634    #[test]
2635    fn test_sort_by_column_already_sorted() {
2636        let mut m = OwnedMatrix::new(
2637            3,
2638            2,
2639            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
2640            Some(vec!["a".to_string(), "b".to_string()]),
2641        )
2642        .into_matrix();
2643        let m = m.t_sort_by_column(0);
2644        assert_eq!(m.data().unwrap(), &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
2645        assert_eq!(m.nrows().unwrap(), 3);
2646        assert_eq!(m.ncols().unwrap(), 2);
2647        assert_eq!(
2648            m.colnames().unwrap().unwrap(),
2649            &["a".to_string(), "b".to_string()]
2650        );
2651    }
2652
2653    #[test]
2654    fn test_sort_by_column_index_out_of_bounds() {
2655        let mut m = OwnedMatrix::new(
2656            3,
2657            2,
2658            vec![3.0, 2.0, 1.0, 6.0, 5.0, 4.0],
2659            Some(vec!["a".to_string(), "b".to_string()]),
2660        )
2661        .into_matrix();
2662        let m = m.sort_by_column(2).unwrap_err();
2663        assert!(matches!(m, Error::ColumnIndexOutOfBounds(2)));
2664    }
2665
2666    #[test]
2667    fn test_sort_by_column_name_success() {
2668        let mut m = OwnedMatrix::new(
2669            3,
2670            2,
2671            vec![3.0, 2.0, 1.0, 6.0, 5.0, 4.0],
2672            Some(vec!["a".to_string(), "b".to_string()]),
2673        )
2674        .into_matrix();
2675        let m = m.t_sort_by_column_name("a");
2676        assert_eq!(m.data().unwrap(), &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
2677        assert_eq!(m.nrows().unwrap(), 3);
2678        assert_eq!(m.ncols().unwrap(), 2);
2679        assert_eq!(
2680            m.colnames().unwrap().unwrap(),
2681            &["a".to_string(), "b".to_string()]
2682        );
2683    }
2684
2685    #[test]
2686    fn test_sort_by_column_name_no_colnames() {
2687        let mut m = OwnedMatrix::new(3, 2, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], None).into_matrix();
2688        let m = m.sort_by_column_name("a").unwrap_err();
2689        assert!(matches!(m, Error::MissingColumnNames));
2690    }
2691
2692    #[test]
2693    fn test_sort_by_column_name_not_found() {
2694        let mut m = OwnedMatrix::new(
2695            3,
2696            2,
2697            vec![3.0, 2.0, 1.0, 6.0, 5.0, 4.0],
2698            Some(vec!["a".to_string(), "b".to_string()]),
2699        )
2700        .into_matrix();
2701        let m = m.sort_by_column_name("c").unwrap_err();
2702        assert!(matches!(m, Error::ColumnNameNotFound(_)));
2703    }
2704
2705    #[test]
2706    fn test_sort_by_order() {
2707        let mut m = OwnedMatrix::new(
2708            3,
2709            2,
2710            vec![3.0, 2.0, 1.0, 6.0, 5.0, 4.0],
2711            Some(vec!["a".to_string(), "b".to_string()]),
2712        )
2713        .into_matrix();
2714        let m = m.t_sort_by_order(vec![2, 0, 1]);
2715        assert_eq!(m.data().unwrap(), &[1.0, 3.0, 2.0, 4.0, 6.0, 5.0]);
2716        assert_eq!(m.nrows().unwrap(), 3);
2717        assert_eq!(m.ncols().unwrap(), 2);
2718        assert_eq!(
2719            m.colnames().unwrap().unwrap(),
2720            &["a".to_string(), "b".to_string()]
2721        );
2722    }
2723
2724    #[test]
2725    fn test_sort_by_order_out_of_bounds() {
2726        let mut m = OwnedMatrix::new(
2727            3,
2728            2,
2729            vec![3.0, 2.0, 1.0, 6.0, 5.0, 4.0],
2730            Some(vec!["a".to_string(), "b".to_string()]),
2731        )
2732        .into_matrix();
2733        let m = m.sort_by_order(&[2, 0, 3]).unwrap_err();
2734        assert!(matches!(m, Error::RowIndexOutOfBounds(3)));
2735    }
2736
2737    #[test]
2738    fn test_sort_by_order_length_mismatch() {
2739        let mut m = OwnedMatrix::new(
2740            3,
2741            2,
2742            vec![3.0, 2.0, 1.0, 6.0, 5.0, 4.0],
2743            Some(vec!["a".to_string(), "b".to_string()]),
2744        )
2745        .into_matrix();
2746        let m = m.sort_by_order(&[2, 0]).unwrap_err();
2747        assert!(matches!(m, Error::OrderLengthMismatch(2)));
2748    }
2749
2750    #[test]
2751    fn test_dedup_by_column_success() {
2752        let mut m = OwnedMatrix::new(
2753            3,
2754            2,
2755            vec![1.0, 2.0, 2.0, 4.0, 5.0, 6.0],
2756            Some(vec!["a".to_string(), "b".to_string()]),
2757        )
2758        .into_matrix();
2759        let m = m.t_dedup_by_column(0);
2760        assert_eq!(m.data().unwrap(), &[1.0, 2.0, 4.0, 5.0]);
2761        assert_eq!(m.nrows().unwrap(), 2);
2762        assert_eq!(m.ncols().unwrap(), 2);
2763        assert_eq!(
2764            m.colnames().unwrap().unwrap(),
2765            &["a".to_string(), "b".to_string()]
2766        );
2767    }
2768
2769    #[test]
2770    fn test_dedup_by_column_index_out_of_bounds() {
2771        let mut m = OwnedMatrix::new(
2772            3,
2773            2,
2774            vec![1.0, 2.0, 2.0, 4.0, 5.0, 6.0],
2775            Some(vec!["a".to_string(), "b".to_string()]),
2776        )
2777        .into_matrix();
2778        let m = m.dedup_by_column(2).unwrap_err();
2779        assert!(matches!(m, Error::ColumnIndexOutOfBounds(2)));
2780    }
2781
2782    #[test]
2783    fn test_dedup_by_column_name_success() {
2784        let mut m = OwnedMatrix::new(
2785            3,
2786            2,
2787            vec![1.0, 2.0, 2.0, 4.0, 5.0, 6.0],
2788            Some(vec!["a".to_string(), "b".to_string()]),
2789        )
2790        .into_matrix();
2791        let m = m.t_dedup_by_column_name("a");
2792        assert_eq!(m.data().unwrap(), &[1.0, 2.0, 4.0, 5.0]);
2793        assert_eq!(m.nrows().unwrap(), 2);
2794        assert_eq!(m.ncols().unwrap(), 2);
2795        assert_eq!(
2796            m.colnames().unwrap().unwrap(),
2797            &["a".to_string(), "b".to_string()]
2798        );
2799    }
2800
2801    #[test]
2802    fn test_dedup_by_column_name_no_colnames() {
2803        let mut m = OwnedMatrix::new(3, 2, vec![1.0, 2.0, 2.0, 4.0, 5.0, 6.0], None).into_matrix();
2804        let m = m.dedup_by_column_name("a").unwrap_err();
2805        assert!(matches!(m, Error::MissingColumnNames));
2806    }
2807
2808    #[test]
2809    fn test_dedup_by_column_name_not_found() {
2810        let mut m = OwnedMatrix::new(
2811            3,
2812            2,
2813            vec![1.0, 2.0, 2.0, 4.0, 5.0, 6.0],
2814            Some(vec!["a".to_string(), "b".to_string()]),
2815        )
2816        .into_matrix();
2817        let m = m.dedup_by_column_name("c").unwrap_err();
2818        assert!(matches!(m, Error::ColumnNameNotFound(_)));
2819    }
2820
2821    #[test]
2822    fn test_match_to_success_inner() {
2823        let mut m1 = OwnedMatrix::new(
2824            5,
2825            2,
2826            vec![1.0, 2.0, 3.0, 4.0, 5.0, 1.0, 2.0, 3.0, 4.0, 5.0],
2827            Some(vec!["a".to_string(), "b".to_string()]),
2828        )
2829        .into_matrix();
2830        let other = vec![5.0, 8.0, 1.0, 2.0, 7.0];
2831        let m = m1.t_match_to(other, 0, Join::Inner);
2832        assert_eq!(m.data().unwrap(), &[5.0, 1.0, 2.0, 5.0, 1.0, 2.0]);
2833        assert_eq!(m.nrows().unwrap(), 3);
2834        assert_eq!(m.ncols().unwrap(), 2);
2835        assert_eq!(
2836            m.colnames().unwrap().unwrap(),
2837            &["a".to_string(), "b".to_string()]
2838        );
2839    }
2840
2841    #[test]
2842    fn test_match_to_success_not_sorted() {
2843        let mut m1 = OwnedMatrix::new(
2844            5,
2845            2,
2846            vec![5.0, 2.0, 1.0, 4.0, 3.0, 5.0, 2.0, 1.0, 4.0, 3.0],
2847            Some(vec!["a".to_string(), "b".to_string()]),
2848        )
2849        .into_matrix();
2850        let other = vec![5.0, 1.0, 6.0, 2.0, 3.0, 4.0, 5.0, 7.0];
2851        let m = m1.t_match_to(other, 0, Join::Inner);
2852        assert_eq!(
2853            m.data().unwrap(),
2854            &[5.0, 1.0, 2.0, 3.0, 4.0, 5.0, 1.0, 2.0, 3.0, 4.0]
2855        );
2856        assert_eq!(m.nrows().unwrap(), 5);
2857        assert_eq!(m.ncols().unwrap(), 2);
2858        assert_eq!(
2859            m.colnames().unwrap().unwrap(),
2860            &["a".to_string(), "b".to_string()]
2861        );
2862    }
2863
2864    #[test]
2865    fn test_match_to_success_left() {
2866        let mut m1 = OwnedMatrix::new(
2867            5,
2868            2,
2869            vec![1.0, 2.0, 3.0, 4.0, 5.0, 1.0, 2.0, 3.0, 4.0, 5.0],
2870            Some(vec!["a".to_string(), "b".to_string()]),
2871        )
2872        .into_matrix();
2873        let other = vec![5.0, 1.0, 6.0, 2.0, 3.0, 4.0, 5.0, 7.0];
2874        let m = m1.t_match_to(other, 0, Join::Left);
2875        assert_eq!(
2876            m.data().unwrap(),
2877            &[5.0, 1.0, 2.0, 3.0, 4.0, 5.0, 1.0, 2.0, 3.0, 4.0],
2878        );
2879        assert_eq!(m.nrows().unwrap(), 5);
2880        assert_eq!(m.ncols().unwrap(), 2);
2881        assert_eq!(
2882            m.colnames().unwrap().unwrap(),
2883            &["a".to_string(), "b".to_string()]
2884        );
2885    }
2886
2887    #[test]
2888    fn test_match_to_success_right() {
2889        let mut m1 = OwnedMatrix::new(
2890            5,
2891            2,
2892            vec![1.0, 2.0, 3.0, 4.0, 5.0, 1.0, 2.0, 3.0, 4.0, 5.0],
2893            Some(vec!["a".to_string(), "b".to_string()]),
2894        )
2895        .into_matrix();
2896        let other = vec![5.0, 1.0, 2.0];
2897        let m = m1.t_match_to(other, 0, Join::Right);
2898        assert_eq!(m.data().unwrap(), &[5.0, 1.0, 2.0, 5.0, 1.0, 2.0],);
2899        assert_eq!(m.nrows().unwrap(), 3);
2900        assert_eq!(m.ncols().unwrap(), 2);
2901        assert_eq!(
2902            m.colnames().unwrap().unwrap(),
2903            &["a".to_string(), "b".to_string()]
2904        );
2905    }
2906
2907    #[test]
2908    fn test_match_to_success_empty_other() {
2909        let mut m1 = OwnedMatrix::new(
2910            5,
2911            2,
2912            vec![1.0, 2.0, 3.0, 4.0, 5.0, 1.0, 2.0, 3.0, 4.0, 5.0],
2913            Some(vec!["a".to_string(), "b".to_string()]),
2914        )
2915        .into_matrix();
2916        let other = vec![];
2917        let m = m1.t_match_to(other, 0, Join::Inner);
2918        assert!(m.data().unwrap().is_empty());
2919        assert_eq!(m.nrows().unwrap(), 0);
2920        assert_eq!(m.ncols().unwrap(), 2);
2921        assert_eq!(
2922            m.colnames().unwrap().unwrap(),
2923            &["a".to_string(), "b".to_string()]
2924        );
2925    }
2926
2927    #[test]
2928    fn test_match_to_index_out_of_bounds() {
2929        let mut m1 = OwnedMatrix::new(
2930            5,
2931            2,
2932            vec![1.0, 2.0, 3.0, 4.0, 5.0, 1.0, 2.0, 3.0, 4.0, 5.0],
2933            Some(vec!["a".to_string(), "b".to_string()]),
2934        )
2935        .into_matrix();
2936        let other = [5.0, 1.0, 6.0, 2.0, 3.0, 4.0, 5.0, 7.0];
2937        let m = m1.match_to(&other, 2, Join::Left).unwrap_err();
2938        assert!(matches!(m, Error::ColumnIndexOutOfBounds(2)));
2939    }
2940
2941    #[test]
2942    fn test_match_to_not_all_rows_matched_left() {
2943        let mut m1 = OwnedMatrix::new(
2944            5,
2945            2,
2946            vec![1.0, 2.0, 3.0, 4.0, 5.0, 1.0, 2.0, 3.0, 4.0, 5.0],
2947            Some(vec!["a".to_string(), "b".to_string()]),
2948        )
2949        .into_matrix();
2950        let other = [5.0, 6.0, 2.0, 3.0, 4.0, 5.0];
2951        let m = m1.match_to(&other, 0, Join::Left).unwrap_err();
2952        assert!(matches!(m, Error::NotAllRowsMatched(Join::Left)));
2953    }
2954
2955    #[test]
2956    fn test_match_to_not_all_rows_matched_right() {
2957        let mut m1 = OwnedMatrix::new(
2958            5,
2959            2,
2960            vec![1.0, 2.0, 3.0, 4.0, 5.0, 1.0, 2.0, 3.0, 4.0, 5.0],
2961            Some(vec!["a".to_string(), "b".to_string()]),
2962        )
2963        .into_matrix();
2964        let other = [5.0, 6.0, 2.0, 3.0, 4.0, 5.0];
2965        let m = m1.match_to(&other, 0, Join::Right).unwrap_err();
2966        assert!(matches!(m, Error::NotAllRowsMatched(Join::Right)));
2967    }
2968
2969    #[test]
2970    fn test_match_to_by_column_name_success_inner() {
2971        let mut m1 = OwnedMatrix::new(
2972            5,
2973            2,
2974            vec![1.0, 2.0, 3.0, 4.0, 5.0, 1.0, 2.0, 3.0, 4.0, 5.0],
2975            Some(vec!["a".to_string(), "b".to_string()]),
2976        )
2977        .into_matrix();
2978        let other = vec![5.0, 8.0, 1.0, 2.0, 7.0];
2979        let m = m1.t_match_to_by_column_name(other, "a", Join::Inner);
2980        assert_eq!(m.data().unwrap(), &[5.0, 1.0, 2.0, 5.0, 1.0, 2.0]);
2981        assert_eq!(m.nrows().unwrap(), 3);
2982        assert_eq!(m.ncols().unwrap(), 2);
2983        assert_eq!(
2984            m.colnames().unwrap().unwrap(),
2985            &["a".to_string(), "b".to_string()]
2986        );
2987    }
2988
2989    #[test]
2990    fn test_match_to_by_column_name_success_left() {
2991        let mut m1 = OwnedMatrix::new(
2992            5,
2993            2,
2994            vec![1.0, 2.0, 3.0, 4.0, 5.0, 1.0, 2.0, 3.0, 4.0, 5.0],
2995            Some(vec!["a".to_string(), "b".to_string()]),
2996        )
2997        .into_matrix();
2998        let other = vec![5.0, 1.0, 6.0, 2.0, 3.0, 4.0, 5.0, 7.0];
2999        let m = m1.t_match_to_by_column_name(other, "a", Join::Left);
3000        assert_eq!(
3001            m.data().unwrap(),
3002            &[5.0, 1.0, 2.0, 3.0, 4.0, 5.0, 1.0, 2.0, 3.0, 4.0],
3003        );
3004        assert_eq!(m.nrows().unwrap(), 5);
3005        assert_eq!(m.ncols().unwrap(), 2);
3006        assert_eq!(
3007            m.colnames().unwrap().unwrap(),
3008            &["a".to_string(), "b".to_string()]
3009        );
3010    }
3011
3012    #[test]
3013    fn test_match_to_by_column_name_success_right() {
3014        let mut m1 = OwnedMatrix::new(
3015            5,
3016            2,
3017            vec![1.0, 2.0, 3.0, 4.0, 5.0, 1.0, 2.0, 3.0, 4.0, 5.0],
3018            Some(vec!["a".to_string(), "b".to_string()]),
3019        )
3020        .into_matrix();
3021        let other = vec![5.0, 1.0, 2.0];
3022        let m = m1.t_match_to_by_column_name(other, "a", Join::Right);
3023        assert_eq!(m.data().unwrap(), &[5.0, 1.0, 2.0, 5.0, 1.0, 2.0],);
3024        assert_eq!(m.nrows().unwrap(), 3);
3025        assert_eq!(m.ncols().unwrap(), 2);
3026        assert_eq!(
3027            m.colnames().unwrap().unwrap(),
3028            &["a".to_string(), "b".to_string()]
3029        );
3030    }
3031
3032    #[test]
3033    fn test_match_to_by_column_name_success_empty_other() {
3034        let mut m1 = OwnedMatrix::new(
3035            5,
3036            2,
3037            vec![1.0, 2.0, 3.0, 4.0, 5.0, 1.0, 2.0, 3.0, 4.0, 5.0],
3038            Some(vec!["a".to_string(), "b".to_string()]),
3039        )
3040        .into_matrix();
3041        let other = vec![];
3042        let m = m1.t_match_to_by_column_name(other, "a", Join::Inner);
3043        assert!(m.data().unwrap().is_empty());
3044        assert_eq!(m.nrows().unwrap(), 0);
3045        assert_eq!(m.ncols().unwrap(), 2);
3046        assert_eq!(
3047            m.colnames().unwrap().unwrap(),
3048            &["a".to_string(), "b".to_string()]
3049        );
3050    }
3051
3052    #[test]
3053    fn test_match_to_by_column_name_column_name_not_found() {
3054        let mut m1 = OwnedMatrix::new(
3055            5,
3056            2,
3057            vec![1.0, 2.0, 3.0, 4.0, 5.0, 1.0, 2.0, 3.0, 4.0, 5.0],
3058            Some(vec!["a".to_string(), "b".to_string()]),
3059        )
3060        .into_matrix();
3061        let other = [5.0, 8.0, 1.0, 2.0, 7.0];
3062        let m = m1
3063            .match_to_by_column_name(&other, "c", Join::Inner)
3064            .unwrap_err();
3065        assert!(matches!(m, Error::ColumnNameNotFound(_)));
3066    }
3067
3068    #[test]
3069    fn test_match_to_by_column_name_no_colnames() {
3070        let mut m1 = OwnedMatrix::new(3, 2, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], None).into_matrix();
3071        let other = [5.0, 8.0, 1.0, 2.0, 7.0];
3072        let m = m1
3073            .match_to_by_column_name(&other, "a", Join::Inner)
3074            .unwrap_err();
3075        assert!(matches!(m, Error::MissingColumnNames));
3076    }
3077
3078    #[test]
3079    fn test_join_success_inner() {
3080        let mut m1 = OwnedMatrix::new(
3081            5,
3082            2,
3083            vec![6.0, 2.0, 3.0, 4.0, 5.0, 6.0, 2.0, 3.0, 4.0, 5.0],
3084            Some(vec!["a".to_string(), "b".to_string()]),
3085        )
3086        .into_matrix();
3087        let mut m2 = OwnedMatrix::new(
3088            3,
3089            2,
3090            vec![5.0, 6.0, 2.0, 7.0, 3.0, 4.0],
3091            Some(vec!["a".to_string(), "c".to_string()]),
3092        )
3093        .into_matrix();
3094        let m = m1.t_join(m2, 0, 0, Join::Inner);
3095        assert_eq!(
3096            m.data().unwrap(),
3097            &[6.0, 2.0, 5.0, 6.0, 2.0, 5.0, 3.0, 4.0, 7.0]
3098        );
3099        assert_eq!(m.nrows().unwrap(), 3);
3100        assert_eq!(m.ncols().unwrap(), 3);
3101        assert_eq!(
3102            m.colnames().unwrap().unwrap(),
3103            &["a".to_string(), "b".to_string(), "c".to_string()]
3104        );
3105    }
3106
3107    #[test]
3108    fn test_join_success_left() {
3109        let mut m1 = OwnedMatrix::new(
3110            5,
3111            2,
3112            vec![6.0, 2.0, 3.0, 4.0, 5.0, 6.0, 2.0, 3.0, 4.0, 5.0],
3113            Some(vec!["a".to_string(), "b".to_string()]),
3114        )
3115        .into_matrix();
3116        let mut m2 = OwnedMatrix::new(
3117            7,
3118            2,
3119            vec![
3120                5.0, 6.0, 2.0, 7.0, 3.0, 4.0, 1.0, 5.0, 6.0, 2.0, 7.0, 3.0, 4.0, 1.0,
3121            ],
3122            Some(vec!["a".to_string(), "c".to_string()]),
3123        )
3124        .into_matrix();
3125        let m = m1.t_join(m2, 0, 0, Join::Left);
3126        assert_eq!(
3127            m.data().unwrap(),
3128            &[6.0, 2.0, 3.0, 4.0, 5.0, 6.0, 2.0, 3.0, 4.0, 5.0, 6.0, 2.0, 3.0, 4.0, 5.0]
3129        );
3130        assert_eq!(m.nrows().unwrap(), 5);
3131        assert_eq!(m.ncols().unwrap(), 3);
3132        assert_eq!(
3133            m.colnames().unwrap().unwrap(),
3134            &["a".to_string(), "b".to_string(), "c".to_string()]
3135        );
3136    }
3137
3138    #[test]
3139    fn test_join_success_right() {
3140        let mut m1 = OwnedMatrix::new(
3141            5,
3142            2,
3143            vec![6.0, 2.0, 3.0, 4.0, 5.0, 6.0, 2.0, 3.0, 4.0, 5.0],
3144            Some(vec!["a".to_string(), "b".to_string()]),
3145        )
3146        .into_matrix();
3147        let mut m2 = OwnedMatrix::new(
3148            3,
3149            2,
3150            vec![5.0, 6.0, 2.0, 7.0, 3.0, 4.0],
3151            Some(vec!["a".to_string(), "c".to_string()]),
3152        )
3153        .into_matrix();
3154        let m = m1.t_join(m2, 0, 0, Join::Right);
3155        assert_eq!(
3156            m.data().unwrap(),
3157            &[6.0, 2.0, 5.0, 6.0, 2.0, 5.0, 3.0, 4.0, 7.0]
3158        );
3159        assert_eq!(m.nrows().unwrap(), 3);
3160        assert_eq!(m.ncols().unwrap(), 3);
3161        assert_eq!(
3162            m.colnames().unwrap().unwrap(),
3163            &["a".to_string(), "b".to_string(), "c".to_string()]
3164        );
3165    }
3166
3167    #[test]
3168    fn test_join_index_out_of_bounds() {
3169        let mut m1 = OwnedMatrix::new(
3170            5,
3171            2,
3172            vec![6.0, 2.0, 3.0, 4.0, 5.0, 6.0, 2.0, 3.0, 4.0, 5.0],
3173            Some(vec!["a".to_string(), "b".to_string()]),
3174        )
3175        .into_matrix();
3176        let mut m2 = OwnedMatrix::new(
3177            3,
3178            2,
3179            vec![5.0, 6.0, 2.0, 7.0, 3.0, 4.0],
3180            Some(vec!["a".to_string(), "c".to_string()]),
3181        )
3182        .into_matrix();
3183        let m = m1.join(&mut m2, 0, 2, Join::Inner).unwrap_err();
3184        assert!(matches!(m, Error::ColumnIndexOutOfBounds(2)));
3185    }
3186
3187    #[test]
3188    fn test_join_not_all_rows_matched_left() {
3189        let mut m1 = OwnedMatrix::new(
3190            5,
3191            2,
3192            vec![8.0, 2.0, 3.0, 4.0, 5.0, 6.0, 2.0, 3.0, 4.0, 5.0],
3193            Some(vec!["a".to_string(), "b".to_string()]),
3194        )
3195        .into_matrix();
3196        let mut m2 = OwnedMatrix::new(
3197            7,
3198            2,
3199            vec![
3200                5.0, 6.0, 2.0, 7.0, 3.0, 4.0, 1.0, 5.0, 6.0, 2.0, 7.0, 3.0, 4.0, 1.0,
3201            ],
3202            Some(vec!["a".to_string(), "c".to_string()]),
3203        )
3204        .into_matrix();
3205        let m = m1.join(&mut m2, 0, 0, Join::Left).unwrap_err();
3206        assert!(matches!(m, Error::NotAllRowsMatched(Join::Left)));
3207    }
3208
3209    #[test]
3210    fn test_join_not_all_rows_matched_right() {
3211        let mut m1 = OwnedMatrix::new(
3212            5,
3213            2,
3214            vec![6.0, 2.0, 3.0, 4.0, 5.0, 6.0, 2.0, 3.0, 4.0, 5.0],
3215            Some(vec!["a".to_string(), "b".to_string()]),
3216        )
3217        .into_matrix();
3218        let mut m2 = OwnedMatrix::new(
3219            3,
3220            2,
3221            vec![8.0, 6.0, 2.0, 7.0, 3.0, 4.0],
3222            Some(vec!["a".to_string(), "c".to_string()]),
3223        )
3224        .into_matrix();
3225        let m = m1.join(&mut m2, 0, 0, Join::Right).unwrap_err();
3226        assert!(matches!(m, Error::NotAllRowsMatched(Join::Right)));
3227    }
3228
3229    #[test]
3230    fn test_join_by_column_name_success_inner() {
3231        let mut m1 = OwnedMatrix::new(
3232            5,
3233            2,
3234            vec![6.0, 2.0, 3.0, 4.0, 5.0, 6.0, 2.0, 3.0, 4.0, 5.0],
3235            Some(vec!["a".to_string(), "b".to_string()]),
3236        )
3237        .into_matrix();
3238        let mut m2 = OwnedMatrix::new(
3239            3,
3240            2,
3241            vec![5.0, 6.0, 2.0, 7.0, 3.0, 4.0],
3242            Some(vec!["a".to_string(), "c".to_string()]),
3243        )
3244        .into_matrix();
3245        let m = m1.t_join_by_column_name(m2, "a", Join::Inner);
3246        assert_eq!(
3247            m.data().unwrap(),
3248            &[6.0, 2.0, 5.0, 6.0, 2.0, 5.0, 3.0, 4.0, 7.0]
3249        );
3250        assert_eq!(m.nrows().unwrap(), 3);
3251        assert_eq!(m.ncols().unwrap(), 3);
3252        assert_eq!(
3253            m.colnames().unwrap().unwrap(),
3254            &["a".to_string(), "b".to_string(), "c".to_string()]
3255        );
3256    }
3257
3258    #[test]
3259    fn test_join_by_column_name_success_left() {
3260        let mut m1 = OwnedMatrix::new(
3261            5,
3262            2,
3263            vec![6.0, 2.0, 3.0, 4.0, 5.0, 6.0, 2.0, 3.0, 4.0, 5.0],
3264            Some(vec!["a".to_string(), "b".to_string()]),
3265        )
3266        .into_matrix();
3267        let mut m2 = OwnedMatrix::new(
3268            7,
3269            2,
3270            vec![
3271                5.0, 6.0, 2.0, 7.0, 3.0, 4.0, 1.0, 5.0, 6.0, 2.0, 7.0, 3.0, 4.0, 1.0,
3272            ],
3273            Some(vec!["a".to_string(), "c".to_string()]),
3274        )
3275        .into_matrix();
3276        let m = m1.t_join_by_column_name(m2, "a", Join::Left);
3277        assert_eq!(
3278            m.data().unwrap(),
3279            &[6.0, 2.0, 3.0, 4.0, 5.0, 6.0, 2.0, 3.0, 4.0, 5.0, 6.0, 2.0, 3.0, 4.0, 5.0]
3280        );
3281        assert_eq!(m.nrows().unwrap(), 5);
3282        assert_eq!(m.ncols().unwrap(), 3);
3283        assert_eq!(
3284            m.colnames().unwrap().unwrap(),
3285            &["a".to_string(), "b".to_string(), "c".to_string()]
3286        );
3287    }
3288
3289    #[test]
3290    fn test_join_by_column_name_success_right() {
3291        let mut m1 = OwnedMatrix::new(
3292            5,
3293            2,
3294            vec![6.0, 2.0, 3.0, 4.0, 5.0, 6.0, 2.0, 3.0, 4.0, 5.0],
3295            Some(vec!["a".to_string(), "b".to_string()]),
3296        )
3297        .into_matrix();
3298        let mut m2 = OwnedMatrix::new(
3299            3,
3300            2,
3301            vec![5.0, 6.0, 2.0, 7.0, 3.0, 4.0],
3302            Some(vec!["a".to_string(), "c".to_string()]),
3303        )
3304        .into_matrix();
3305        let m = m1.t_join_by_column_name(m2, "a", Join::Right);
3306        assert_eq!(
3307            m.data().unwrap(),
3308            &[6.0, 2.0, 5.0, 6.0, 2.0, 5.0, 3.0, 4.0, 7.0]
3309        );
3310        assert_eq!(m.nrows().unwrap(), 3);
3311        assert_eq!(m.ncols().unwrap(), 3);
3312        assert_eq!(
3313            m.colnames().unwrap().unwrap(),
3314            &["a".to_string(), "b".to_string(), "c".to_string()]
3315        );
3316    }
3317
3318    #[test]
3319    fn test_standardize_columns() {
3320        let mut m = OwnedMatrix::new(
3321            3,
3322            2,
3323            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
3324            Some(vec!["a".to_string(), "b".to_string()]),
3325        )
3326        .into_matrix();
3327        let m = m.t_standardize_columns();
3328        assert_eq!(m.data().unwrap(), &[-1.0, 0.0, 1.0, -1.0, 0.0, 1.0]);
3329        assert_eq!(m.nrows().unwrap(), 3);
3330        assert_eq!(m.ncols().unwrap(), 2);
3331        assert_eq!(
3332            m.colnames().unwrap().unwrap(),
3333            &["a".to_string(), "b".to_string()]
3334        );
3335    }
3336
3337    #[test]
3338    fn test_standardize_rows() {
3339        let mut m = OwnedMatrix::new(
3340            2,
3341            3,
3342            vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0],
3343            Some(vec!["a".to_string(), "b".to_string(), "c".to_string()]),
3344        )
3345        .into_matrix();
3346        let m = m.t_standardize_rows();
3347        assert_eq!(m.data().unwrap(), &[-1.0, -1.0, 0.0, 0.0, 1.0, 1.0]);
3348        assert_eq!(m.nrows().unwrap(), 2);
3349        assert_eq!(m.ncols().unwrap(), 3);
3350        assert_eq!(
3351            m.colnames().unwrap().unwrap(),
3352            &["a".to_string(), "b".to_string(), "c".to_string()]
3353        );
3354    }
3355
3356    #[test]
3357    fn test_remove_nan_rows() {
3358        let mut m = OwnedMatrix::new(
3359            3,
3360            2,
3361            vec![1.0, 2.0, f64::NAN, 4.0, 5.0, 6.0],
3362            Some(vec!["a".to_string(), "b".to_string()]),
3363        )
3364        .into_matrix();
3365        let m = m.t_remove_nan_rows();
3366        assert_eq!(m.data().unwrap(), &[1.0, 2.0, 4.0, 5.0]);
3367        assert_eq!(m.nrows().unwrap(), 2);
3368        assert_eq!(m.ncols().unwrap(), 2);
3369        assert_eq!(
3370            m.colnames().unwrap().unwrap(),
3371            &["a".to_string(), "b".to_string()]
3372        );
3373    }
3374
3375    #[test]
3376    fn test_remove_nan_columns() {
3377        let mut m = OwnedMatrix::new(
3378            3,
3379            2,
3380            vec![1.0, 2.0, f64::NAN, 4.0, 5.0, 6.0],
3381            Some(vec!["a".to_string(), "b".to_string()]),
3382        )
3383        .into_matrix();
3384        let m = m.t_remove_nan_columns();
3385        assert_eq!(m.data().unwrap(), &[4.0, 5.0, 6.0]);
3386        assert_eq!(m.nrows().unwrap(), 3);
3387        assert_eq!(m.ncols().unwrap(), 1);
3388        assert_eq!(m.colnames().unwrap().unwrap(), &["b".to_string()]);
3389    }
3390
3391    #[test]
3392    fn test_nan_to_value() {
3393        let mut m = OwnedMatrix::new(
3394            3,
3395            2,
3396            vec![1.0, 2.0, f64::NAN, 4.0, 5.0, 6.0],
3397            Some(vec!["a".to_string(), "b".to_string()]),
3398        )
3399        .into_matrix();
3400        let m = m.t_nan_to_value(0.0);
3401        assert_eq!(m.data().unwrap(), &[1.0, 2.0, 0.0, 4.0, 5.0, 6.0]);
3402        assert_eq!(m.nrows().unwrap(), 3);
3403        assert_eq!(m.ncols().unwrap(), 2);
3404        assert_eq!(
3405            m.colnames().unwrap().unwrap(),
3406            &["a".to_string(), "b".to_string()]
3407        );
3408    }
3409
3410    #[test]
3411    fn test_nan_to_column_mean() {
3412        let mut m = OwnedMatrix::new(
3413            3,
3414            2,
3415            vec![1.0, 2.0, f64::NAN, 4.0, f64::NAN, 6.0],
3416            Some(vec!["a".to_string(), "b".to_string()]),
3417        )
3418        .into_matrix();
3419        let m = m.t_nan_to_column_mean();
3420        assert_eq!(m.data().unwrap(), &[1.0, 2.0, 1.5, 4.0, 5.0, 6.0]);
3421        assert_eq!(m.nrows().unwrap(), 3);
3422        assert_eq!(m.ncols().unwrap(), 2);
3423        assert_eq!(
3424            m.colnames().unwrap().unwrap(),
3425            &["a".to_string(), "b".to_string()]
3426        );
3427    }
3428
3429    #[test]
3430    fn test_nan_to_column_mean_all_nan() {
3431        let mut m = OwnedMatrix::new(3, 2, vec![f64::NAN; 6], None).into_matrix();
3432        let m = m.t_nan_to_column_mean();
3433        assert_eq!(m.data().unwrap(), &[0.0; 6]);
3434        assert_eq!(m.nrows().unwrap(), 3);
3435        assert_eq!(m.ncols().unwrap(), 2);
3436        assert!(m.colnames().unwrap().is_none());
3437    }
3438
3439    #[test]
3440    fn test_nan_to_row_mean() {
3441        let mut m = OwnedMatrix::new(
3442            3,
3443            4,
3444            vec![
3445                1.0,
3446                2.0,
3447                f64::NAN,
3448                4.0,
3449                5.0,
3450                6.0,
3451                7.0,
3452                f64::NAN,
3453                9.0,
3454                11.0,
3455                11.0,
3456                12.0,
3457            ],
3458            Some(vec![
3459                "a".to_string(),
3460                "b".to_string(),
3461                "c".to_string(),
3462                "d".to_string(),
3463            ]),
3464        )
3465        .into_matrix();
3466        let m = m.t_nan_to_row_mean();
3467        assert_eq!(
3468            m.data().unwrap(),
3469            &[1.0, 2.0, 9.0, 4.0, 5.0, 6.0, 7.0, 6.0, 9.0, 11.0, 11.0, 12.0]
3470        );
3471        assert_eq!(m.nrows().unwrap(), 3);
3472        assert_eq!(m.ncols().unwrap(), 4);
3473        assert_eq!(
3474            m.colnames().unwrap().unwrap(),
3475            &[
3476                "a".to_string(),
3477                "b".to_string(),
3478                "c".to_string(),
3479                "d".to_string()
3480            ]
3481        );
3482    }
3483
3484    #[test]
3485    fn test_min_column_sum() {
3486        let mut m = OwnedMatrix::new(
3487            3,
3488            2,
3489            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
3490            Some(vec!["a".to_string(), "b".to_string()]),
3491        )
3492        .into_matrix();
3493        let m = m.t_min_column_sum(7.0);
3494        assert_eq!(m.data().unwrap(), &[4.0, 5.0, 6.0]);
3495        assert_eq!(m.nrows().unwrap(), 3);
3496        assert_eq!(m.ncols().unwrap(), 1);
3497        assert_eq!(m.colnames().unwrap().unwrap(), &["b".to_string()]);
3498    }
3499
3500    #[test]
3501    fn test_max_column_sum() {
3502        let mut m = OwnedMatrix::new(
3503            3,
3504            2,
3505            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
3506            Some(vec!["a".to_string(), "b".to_string()]),
3507        )
3508        .into_matrix();
3509        let m = m.t_max_column_sum(7.0);
3510        assert_eq!(m.data().unwrap(), &[1.0, 2.0, 3.0]);
3511        assert_eq!(m.nrows().unwrap(), 3);
3512        assert_eq!(m.ncols().unwrap(), 1);
3513        assert_eq!(m.colnames().unwrap().unwrap(), &["a".to_string()]);
3514    }
3515
3516    #[test]
3517    fn test_min_row_sum() {
3518        let mut m = OwnedMatrix::new(
3519            3,
3520            2,
3521            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
3522            Some(vec!["a".to_string(), "b".to_string()]),
3523        )
3524        .into_matrix();
3525        let m = m.t_min_row_sum(7.0);
3526        assert_eq!(m.data().unwrap(), &[2.0, 3.0, 5.0, 6.0]);
3527        assert_eq!(m.nrows().unwrap(), 2);
3528        assert_eq!(m.ncols().unwrap(), 2);
3529        assert_eq!(
3530            m.colnames().unwrap().unwrap(),
3531            &["a".to_string(), "b".to_string()]
3532        );
3533    }
3534
3535    #[test]
3536    fn test_max_row_sum() {
3537        let mut m = OwnedMatrix::new(
3538            3,
3539            2,
3540            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
3541            Some(vec!["a".to_string(), "b".to_string()]),
3542        )
3543        .into_matrix();
3544        let m = m.t_max_row_sum(7.0);
3545        assert_eq!(m.data().unwrap(), &[1.0, 2.0, 4.0, 5.0]);
3546        assert_eq!(m.nrows().unwrap(), 2);
3547        assert_eq!(m.ncols().unwrap(), 2);
3548        assert_eq!(
3549            m.colnames().unwrap().unwrap(),
3550            &["a".to_string(), "b".to_string()]
3551        );
3552    }
3553
3554    #[test]
3555    fn test_remove_duplicate_columns() {
3556        let mut m = OwnedMatrix::new(
3557            3,
3558            3,
3559            vec![1.0, 2.0, 1.0, 4.0, 5.0, 6.0, 1.0, 2.0, 1.0],
3560            Some(vec!["a".to_string(), "b".to_string(), "c".to_string()]),
3561        )
3562        .into_matrix();
3563        let m = m.t_remove_duplicate_columns();
3564        assert_eq!(m.data().unwrap(), &[1.0, 2.0, 1.0, 4.0, 5.0, 6.0]);
3565        assert_eq!(m.nrows().unwrap(), 3);
3566        assert_eq!(m.ncols().unwrap(), 2);
3567        assert_eq!(
3568            m.colnames().unwrap().unwrap(),
3569            &["a".to_string(), "b".to_string()]
3570        );
3571    }
3572
3573    #[test]
3574    fn test_remove_identical_columns() {
3575        let mut m = OwnedMatrix::new(
3576            3,
3577            3,
3578            vec![1.0, 1.0, 1.0, 4.0, 5.0, 6.0, 1.0, 2.0, 1.0],
3579            Some(vec!["a".to_string(), "b".to_string(), "c".to_string()]),
3580        )
3581        .into_matrix();
3582        let m = m.t_remove_identical_columns();
3583        assert_eq!(m.data().unwrap(), &[4.0, 5.0, 6.0, 1.0, 2.0, 1.0]);
3584        assert_eq!(m.nrows().unwrap(), 3);
3585        assert_eq!(m.ncols().unwrap(), 2);
3586        assert_eq!(
3587            m.colnames().unwrap().unwrap(),
3588            &["b".to_string(), "c".to_string()]
3589        );
3590    }
3591
3592    #[test]
3593    fn test_min_non_nan() {
3594        let mut m = OwnedMatrix::new(
3595            3,
3596            2,
3597            vec![1.0, f64::NAN, f64::NAN, 4.0, 5.0, 6.0],
3598            Some(vec!["a".to_string(), "b".to_string()]),
3599        )
3600        .into_matrix();
3601        let m = m.t_min_non_nan(2);
3602        assert_eq!(m.data().unwrap(), &[4.0, 5.0, 6.0]);
3603        assert_eq!(m.nrows().unwrap(), 3);
3604        assert_eq!(m.ncols().unwrap(), 1);
3605        assert_eq!(m.colnames().unwrap().unwrap(), &["b".to_string()]);
3606    }
3607
3608    #[test]
3609    fn test_max_non_nan() {
3610        let mut m = OwnedMatrix::new(
3611            3,
3612            2,
3613            vec![1.0, f64::NAN, f64::NAN, 4.0, 5.0, 6.0],
3614            Some(vec!["a".to_string(), "b".to_string()]),
3615        )
3616        .into_matrix();
3617        let m = m.t_max_non_nan(2);
3618        assert_eq!(m.nrows().unwrap(), 3);
3619        assert_eq!(m.ncols().unwrap(), 1);
3620        assert_eq!(m.colnames().unwrap().unwrap(), &["a".to_string()]);
3621    }
3622
3623    #[test]
3624    fn test_subset_columns() {
3625        let mut m = OwnedMatrix::new(
3626            3,
3627            3,
3628            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0],
3629            Some(vec!["a".to_string(), "b".to_string(), "c".to_string()]),
3630        )
3631        .into_matrix();
3632        let m = m.t_subset_columns([0, 2].into_iter().collect::<HashSet<_>>());
3633        assert_eq!(m.data().unwrap(), &[1.0, 2.0, 3.0, 7.0, 8.0, 9.0]);
3634        assert_eq!(m.nrows().unwrap(), 3);
3635        assert_eq!(m.ncols().unwrap(), 2);
3636        assert_eq!(
3637            m.colnames().unwrap().unwrap(),
3638            &["a".to_string(), "c".to_string()]
3639        );
3640    }
3641
3642    #[test]
3643    fn test_subset_columns_by_name() {
3644        let mut m = OwnedMatrix::new(
3645            3,
3646            3,
3647            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0],
3648            Some(vec!["a".to_string(), "b".to_string(), "c".to_string()]),
3649        )
3650        .into_matrix();
3651        let m = m.t_subset_columns_by_name(["a", "c", "d"].iter().map(|s| s.to_string()).collect());
3652        assert_eq!(m.data().unwrap(), &[1.0, 2.0, 3.0, 7.0, 8.0, 9.0]);
3653        assert_eq!(m.nrows().unwrap(), 3);
3654        assert_eq!(m.ncols().unwrap(), 2);
3655        assert_eq!(
3656            m.colnames().unwrap().unwrap(),
3657            &["a".to_string(), "c".to_string()]
3658        );
3659    }
3660
3661    #[test]
3662    fn test_rename_columns_with_regex() {
3663        let mut m = OwnedMatrix::new(
3664            3,
3665            3,
3666            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0],
3667            Some(vec!["a".to_string(), "b".to_string(), "c".to_string()]),
3668        )
3669        .into_matrix();
3670        let m = m.t_rename_columns_with_regex("a", "x");
3671        assert_eq!(
3672            m.data().unwrap(),
3673            &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]
3674        );
3675        assert_eq!(m.nrows().unwrap(), 3);
3676        assert_eq!(m.ncols().unwrap(), 3);
3677        assert_eq!(
3678            m.colnames().unwrap().unwrap(),
3679            &["x".to_string(), "b".to_string(), "c".to_string()]
3680        );
3681    }
3682
3683    #[test]
3684    fn test_eigen_symmetric_real() {
3685        let mut m = OwnedMatrix::new(
3686            3,
3687            3,
3688            vec![1.0, 2.0, 3.0, 2.0, 5.0, 6.0, 3.0, 6.0, 9.0],
3689            Some(vec!["a".to_string(), "b".to_string(), "c".to_string()]),
3690        )
3691        .into_matrix();
3692        let Eigen::Real { values, vectors } = m.eigen(None).unwrap() else {
3693            panic!("Expected real decomposition");
3694        };
3695        assert_eq!(values.len(), 3);
3696        float_eq!(values[0], -4.3296658397194226e-16);
3697        float_eq!(values[1], 0.6992647456322797);
3698        float_eq!(values[2], 14.300735254367696);
3699        assert_eq!(vectors.len(), 9);
3700        let expected = [
3701            0.9486832980505138,
3702            -1.3877787807814457e-15,
3703            -0.3162277660168371,
3704            0.17781910596911185,
3705            -0.8269242138935418,
3706            0.5334573179073411,
3707            0.26149639682478465,
3708            0.5623133863572407,
3709            0.7844891904743533,
3710        ];
3711        for (i, &v) in vectors.iter().enumerate() {
3712            float_eq!(v, expected[i]);
3713        }
3714    }
3715
3716    #[test]
3717    fn test_eigen_not_symmetric_real() {
3718        let mut m = OwnedMatrix::new(
3719            3,
3720            3,
3721            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0],
3722            Some(vec!["a".to_string(), "b".to_string(), "c".to_string()]),
3723        )
3724        .into_matrix();
3725        let Eigen::Real { values, vectors } = m.eigen(None).unwrap() else {
3726            panic!("Expected real decomposition");
3727        };
3728        assert_eq!(values.len(), 3);
3729        float_eq!(values[0], 16.116843969807025);
3730        float_eq!(values[1], -1.116843969807056);
3731        float_eq!(values[2], 0.0);
3732        assert_eq!(vectors.len(), 9);
3733        let expected = [
3734            -0.46454727338767027,
3735            -0.5707955312285774,
3736            -0.6770437890694855,
3737            -0.9178859873651294,
3738            -0.24901002745731335,
3739            0.4198659324505014,
3740            0.4082482904638624,
3741            -0.8164965809277261,
3742            0.40824829046386313,
3743        ];
3744        for (i, &v) in vectors.iter().enumerate() {
3745            float_eq!(v, expected[i]);
3746        }
3747    }
3748
3749    #[test]
3750    fn test_is_symmetric() {
3751        let mut m = OwnedMatrix::new(
3752            3,
3753            3,
3754            vec![1.0, 2.0, 3.0, 2.0, 5.0, 6.0, 3.0, 6.0, 9.0],
3755            Some(vec!["a".to_string(), "b".to_string(), "c".to_string()]),
3756        )
3757        .into_matrix();
3758        assert!(m.is_symmetric().unwrap());
3759    }
3760
3761    #[test]
3762    fn test_is_not_symmetric() {
3763        let mut m = OwnedMatrix::new(
3764            3,
3765            3,
3766            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0],
3767            Some(vec!["a".to_string(), "b".to_string(), "c".to_string()]),
3768        )
3769        .into_matrix();
3770        assert!(!m.is_symmetric().unwrap());
3771    }
3772
3773    #[test]
3774    fn test_scale_columns_scalar() {
3775        let mut m = OwnedMatrix::new(
3776            3,
3777            3,
3778            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0],
3779            Some(vec!["a".to_string(), "b".to_string(), "c".to_string()]),
3780        )
3781        .into_matrix();
3782        let m = m.t_scale_columns(vec![2.0]);
3783        assert_eq!(
3784            m.data().unwrap(),
3785            &[2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0, 18.0]
3786        );
3787        assert_eq!(m.nrows().unwrap(), 3);
3788        assert_eq!(m.ncols().unwrap(), 3);
3789        assert_eq!(
3790            m.colnames().unwrap().unwrap(),
3791            &["a".to_string(), "b".to_string(), "c".to_string()]
3792        );
3793    }
3794
3795    #[test]
3796    fn test_scale_columns_vector() {
3797        let mut m = OwnedMatrix::new(
3798            3,
3799            3,
3800            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0],
3801            Some(vec!["a".to_string(), "b".to_string(), "c".to_string()]),
3802        )
3803        .into_matrix();
3804        let m = m.t_scale_columns(vec![2.0, 3.0, 4.0]);
3805        assert_eq!(
3806            m.data().unwrap(),
3807            &[2.0, 4.0, 6.0, 12.0, 15.0, 18.0, 28.0, 32.0, 36.0]
3808        );
3809        assert_eq!(m.nrows().unwrap(), 3);
3810        assert_eq!(m.ncols().unwrap(), 3);
3811        assert_eq!(
3812            m.colnames().unwrap().unwrap(),
3813            &["a".to_string(), "b".to_string(), "c".to_string()]
3814        );
3815    }
3816
3817    #[test]
3818    fn test_scale_rows_scalar() {
3819        let mut m = OwnedMatrix::new(
3820            3,
3821            3,
3822            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0],
3823            Some(vec!["a".to_string(), "b".to_string(), "c".to_string()]),
3824        )
3825        .into_matrix();
3826        let m = m.t_scale_rows(vec![2.0]);
3827        assert_eq!(
3828            m.data().unwrap(),
3829            &[2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0, 18.0]
3830        );
3831        assert_eq!(m.nrows().unwrap(), 3);
3832        assert_eq!(m.ncols().unwrap(), 3);
3833        assert_eq!(
3834            m.colnames().unwrap().unwrap(),
3835            &["a".to_string(), "b".to_string(), "c".to_string()]
3836        );
3837    }
3838
3839    #[test]
3840    fn test_scale_rows_vector() {
3841        let mut m = OwnedMatrix::new(
3842            3,
3843            3,
3844            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0],
3845            Some(vec!["a".to_string(), "b".to_string(), "c".to_string()]),
3846        )
3847        .into_matrix();
3848        let m = m.t_scale_rows(vec![2.0, 3.0, 4.0]);
3849        assert_eq!(
3850            m.data().unwrap(),
3851            &[2.0, 6.0, 12.0, 8.0, 15.0, 24.0, 14.0, 24.0, 36.0]
3852        );
3853        assert_eq!(m.nrows().unwrap(), 3);
3854        assert_eq!(m.ncols().unwrap(), 3);
3855        assert_eq!(
3856            m.colnames().unwrap().unwrap(),
3857            &["a".to_string(), "b".to_string(), "c".to_string()]
3858        );
3859    }
3860}