use crate::num::{One, Zero};
use crate::Matrix;
use core::mem;
use core::mem::MaybeUninit;
use core::ptr;
impl<const M: usize, const N: usize, T> Matrix<M, N, T> {
#[doc(hidden)]
#[inline]
pub const fn from_column_major_order(data: [[T; M]; N]) -> Self {
Self { data }
}
}
impl<const M: usize, const N: usize, T> Matrix<M, N, T>
where
T: Zero + Copy,
{
#[doc(hidden)]
#[inline]
pub fn zeros() -> Self {
Self::from_column_major_order([[T::zero(); M]; N])
}
}
impl<const M: usize, const N: usize, T> Matrix<M, N, T>
where
T: One + Copy,
{
#[doc(hidden)]
#[inline]
pub fn ones() -> Self {
Self::from_column_major_order([[T::one(); M]; N])
}
}
impl<const D: usize, T> Matrix<D, D, T>
where
T: Zero + One + Copy,
{
#[doc(hidden)]
#[inline]
pub fn eye() -> Self {
let mut m = Self::from_column_major_order([[T::zero(); D]; D]);
for i in 0..D {
m[(i, i)] = T::one();
}
m
}
}
#[macro_export]
macro_rules! matrix {
($($data:tt)*) => {
$crate::Matrix::from_column_major_order($crate::proc_macro::matrix!($($data)*))
};
}
#[macro_export]
macro_rules! vector {
($($data:tt)*) => {
$crate::Matrix::from_column_major_order($crate::proc_macro::matrix!($($data)*))
};
}
#[macro_export]
macro_rules! zeros {
($cols:expr) => {
$crate::Matrix::<$cols, $cols>::zeros()
};
($rows:expr, $cols:expr) => {{
$crate::Matrix::<$rows, $cols>::zeros()
}};
($rows:expr, $cols:expr, $ty:ty) => {{
$crate::Matrix::<$rows, $cols, $ty>::zeros()
}};
}
#[macro_export]
macro_rules! ones {
($cols:expr) => {
$crate::Matrix::<$cols, $cols>::ones()
};
($rows:expr, $cols:expr) => {{
$crate::Matrix::<$rows, $cols>::ones()
}};
($rows:expr, $cols:expr, $ty:ty) => {{
$crate::Matrix::<$rows, $cols, $ty>::ones()
}};
}
#[macro_export]
macro_rules! eye {
($dim:expr) => {
$crate::Matrix::<$dim, $dim>::eye()
};
($dim:expr, $ty:ty) => {{
$crate::Matrix::<$dim, $dim, $ty>::eye()
}};
}
#[macro_export]
macro_rules! diag {
($d1:expr, $d2:expr) => {{
let mut m = $crate::Matrix::<2, 2>::zeros();
m[(0, 0)] = $d1;
m[(1, 1)] = $d2;
m
}};
($d1:expr, $d2:expr, $d3:expr) => {{
let mut m = $crate::Matrix::<3, 3>::zeros();
m[(0, 0)] = $d1;
m[(1, 1)] = $d2;
m[(2, 2)] = $d3;
m
}};
($d1:expr, $d2:expr, $d3:expr, $d4:expr) => {{
let mut m = $crate::Matrix::<4, 4>::zeros();
m[(0, 0)] = $d1;
m[(1, 1)] = $d2;
m[(2, 2)] = $d3;
m[(3, 3)] = $d4;
m
}};
($d1:expr, $d2:expr, $d3:expr, $d4:expr, $d5:expr) => {{
let mut m = $crate::Matrix::<5, 5>::zeros();
m[(0, 0)] = $d1;
m[(1, 1)] = $d2;
m[(2, 2)] = $d3;
m[(3, 3)] = $d4;
m[(4, 4)] = $d5;
m
}};
($d1:expr, $d2:expr, $d3:expr, $d4:expr, $d5:expr, $d6:expr) => {{
let mut m = $crate::Matrix::<6, 6>::zeros();
m[(0, 0)] = $d1;
m[(1, 1)] = $d2;
m[(2, 2)] = $d3;
m[(3, 3)] = $d4;
m[(4, 4)] = $d5;
m[(5, 5)] = $d6;
m
}};
}
#[inline]
pub unsafe fn transmute_unchecked<A, B>(a: A) -> B {
let b = unsafe { ptr::read(&a as *const A as *const B) };
mem::forget(a);
b
}
impl<T, const M: usize, const N: usize> Matrix<M, N, MaybeUninit<T>> {
#[inline]
pub(crate) fn uninit() -> Self {
let matrix = MaybeUninit::uninit();
unsafe { matrix.assume_init() }
}
#[inline]
pub(crate) unsafe fn assume_init(self) -> Matrix<M, N, T> {
unsafe { transmute_unchecked(self) }
}
}
pub fn collect<I, T, const M: usize, const N: usize>(mut iter: I) -> Result<Matrix<M, N, T>, usize>
where
I: Iterator<Item = T>,
{
struct Guard<'a, T, const M: usize, const N: usize> {
matrix: &'a mut Matrix<M, N, MaybeUninit<T>>,
init: usize,
}
impl<T, const M: usize, const N: usize> Drop for Guard<'_, T, M, N> {
fn drop(&mut self) {
for elem in &mut self.matrix.as_mut_slice()[..self.init] {
unsafe { ptr::drop_in_place(elem.as_mut_ptr()) };
}
}
}
let mut matrix: Matrix<M, N, MaybeUninit<T>> = Matrix::uninit();
let mut guard = Guard {
matrix: &mut matrix,
init: 0,
};
for _ in 0..(M * N) {
match iter.next() {
Some(item) => {
unsafe { guard.matrix.get_unchecked_mut(guard.init).write(item) };
guard.init += 1;
}
None => {
return Err(guard.init);
}
}
}
mem::forget(guard);
Ok(unsafe { matrix.assume_init() })
}
impl<T, const M: usize, const N: usize> FromIterator<T> for Matrix<M, N, T> {
#[inline]
fn from_iter<I>(iter: I) -> Self
where
I: IntoIterator<Item = T>,
{
collect(iter.into_iter()).unwrap_or_else(|len| collect_panic::<M, N>(len))
}
}
#[cold]
fn collect_panic<const M: usize, const N: usize>(len: usize) -> ! {
if N == 1 {
panic!("collect iterator of length {} into `Vector<_, {}>`", len, M);
} else if M == 1 {
panic!(
"collect iterator of length {} into `RowVector<_, {}>`",
len, N
);
} else {
panic!(
"collect iterator of length {} into `Matrix<_, {}, {}>`",
len, M, N
);
}
}
#[cfg(test)]
mod new_test {
use approx::assert_relative_eq;
#[test]
fn diag() {
let d = diag!(0.1, 0.2);
let e = matrix![
0.1, 0.0;
0.0, 0.2;
];
assert_relative_eq!(d, e, max_relative = 1e-6);
let d = diag!(0.1, 0.2, 0.3);
let e = matrix![
0.1, 0.0, 0.0;
0.0, 0.2, 0.0;
0.0, 0.0, 0.3;
];
assert_relative_eq!(d, e, max_relative = 1e-6);
let d = diag!(0.1, 0.2, 0.3, 0.4);
let e = matrix![
0.1, 0.0, 0.0, 0.0;
0.0, 0.2, 0.0, 0.0;
0.0, 0.0, 0.3, 0.0;
0.0, 0.0, 0.0, 0.4;
];
assert_relative_eq!(d, e, max_relative = 1e-6);
let d = diag!(0.1, 0.2, 0.3, 0.4, 0.5);
let e = matrix![
0.1, 0.0, 0.0, 0.0, 0.0;
0.0, 0.2, 0.0, 0.0, 0.0;
0.0, 0.0, 0.3, 0.0, 0.0;
0.0, 0.0, 0.0, 0.4, 0.0;
0.0, 0.0, 0.0, 0.0, 0.5;
];
assert_relative_eq!(d, e, max_relative = 1e-6);
let d = diag!(0.1, 0.2, 0.3, 0.4, 0.5, 0.6);
let e = matrix![
0.1, 0.0, 0.0, 0.0, 0.0, 0.0;
0.0, 0.2, 0.0, 0.0, 0.0, 0.0;
0.0, 0.0, 0.3, 0.0, 0.0, 0.0;
0.0, 0.0, 0.0, 0.4, 0.0, 0.0;
0.0, 0.0, 0.0, 0.0, 0.5, 0.0;
0.0, 0.0, 0.0, 0.0, 0.0, 0.6;
];
assert_relative_eq!(d, e, max_relative = 1e-6);
}
}