mech_matrix/
solve.rs

1use crate::*;
2use mech_core::*;
3#[cfg(feature = "matrix")]
4use mech_core::matrix::Matrix;
5use nalgebra::ComplexField;
6
7// Solve  ------------------------------------------------------------------
8
9#[macro_export]
10macro_rules! impl_binop_solve {
11  ($struct_name:ident, $arg1_type:ty, $arg2_type:ty, $out_type:ty, $op:ident, $feature_flag:expr) => {
12    #[derive(Debug)]
13    pub struct $struct_name<T> {
14      pub lhs: Ref<$arg1_type>,
15      pub rhs: Ref<$arg2_type>,
16      pub out: Ref<$out_type>,
17    }
18    impl<T> MechFunctionFactory for $struct_name<T> 
19    where
20      #[cfg(feature = "compiler")]      T: Copy + Debug + Display + Clone + Sync + Send + 'static + PartialEq + PartialOrd + ComplexField + AsValueKind + Add<Output = T> + AddAssign +Sub<Output = T> + SubAssign +Mul<Output = T> + MulAssign +Div<Output = T> + DivAssign +Zero + One + ConstElem + CompileConst + AsValueKind,
21      #[cfg(not(feature = "compiler"))] T: Copy + Debug + Display + Clone + Sync + Send + 'static + PartialEq + PartialOrd + ComplexField + AsValueKind + Add<Output = T> + AddAssign +Sub<Output = T> + SubAssign +Mul<Output = T> + MulAssign +Div<Output = T> + DivAssign +Zero + One,
22      Ref<$out_type>: ToValue,
23    {
24      fn new(args: FunctionArgs) -> MResult<Box<dyn MechFunction>> {
25        match args {
26          FunctionArgs::Binary(out, arg1, arg2) => {
27            let lhs: Ref<$arg1_type> = unsafe { arg1.as_unchecked() }.clone();
28            let rhs: Ref<$arg2_type> = unsafe { arg2.as_unchecked() }.clone();
29            let out: Ref<$out_type> = unsafe { out.as_unchecked() }.clone();
30            Ok(Box::new(Self {lhs, rhs, out }))
31          },
32          _ => Err(MechError2::new(
33              IncorrectNumberOfArguments { expected: 2, found: args.len() }, 
34              None
35            ).with_compiler_loc()
36          ),
37        }
38      }
39    }
40    impl<T> MechFunctionImpl for $struct_name<T>
41    where
42      T: Copy + Debug + Display + Clone + Sync + Send + 'static + 
43      PartialEq + PartialOrd + ComplexField +
44      Add<Output = T> + AddAssign +
45      Sub<Output = T> + SubAssign +
46      Mul<Output = T> + MulAssign +
47      Div<Output = T> + DivAssign +
48      Zero + One,
49      Ref<$out_type>: ToValue
50    {
51      fn solve(&self) {
52          let lhs_ptr = self.lhs.as_ptr();
53          let rhs_ptr = self.rhs.as_ptr();
54          let out_ptr = self.out.as_mut_ptr();
55          $op!(lhs_ptr,rhs_ptr,out_ptr);
56      }
57      fn out(&self) -> Value { self.out.to_value() }
58      fn to_string(&self) -> String { format!("{:#?}", self) }
59    }   
60    #[cfg(feature = "compiler")]
61    impl<T> MechFunctionCompiler for $struct_name<T> 
62    where
63      T: ConstElem + CompileConst + AsValueKind
64    {
65      fn compile(&self, ctx: &mut CompileCtx) -> MResult<Register> {
66        let name = format!("{}<{}>", stringify!($struct_name), T::as_value_kind());
67        compile_binop!(name, self.out, self.lhs, self.rhs, ctx, $feature_flag);
68      }
69    }
70  };
71}
72
73macro_rules! solve_op {
74  ($a:expr, $b:expr, $out:expr) => {
75    unsafe { *$out = (*$a).clone().lu().solve(&*$b).unwrap(); }
76  };}
77
78macro_rules! impl_solve {
79  ($name:ident, $type1:ty, $type2:ty, $out_type:ty) => {
80    impl_binop_solve!($name, $type1, $type2, $out_type, solve_op, FeatureFlag::Builtin(FeatureKind::Solve));
81    register_fxn_descriptor!($name, f64, "f64");
82  };
83}
84
85#[cfg(all(feature = "matrixd", feature = "vectord"))]
86impl_solve!(MatrixSolveMDVD, DMatrix<T>, DVector<T>, DVector<T>);
87
88macro_rules! impl_solve_match_arms {
89  ($arg:expr, $($($matrix_kind:tt, $target_type:tt, $value_string:tt),+);+ $(;)?) => {
90    match $arg {
91      $(
92        $(
93          #[cfg(all(feature = $value_string, feature = "matrixd", feature = "vectord"))]
94          (Value::$matrix_kind(Matrix::DMatrix(lhs)), Value::$matrix_kind(Matrix::DVector(rhs))) => {
95            let (a_rows, a_cols) = lhs.borrow().shape();
96            let (b_rows, b_cols) = rhs.borrow().shape();
97            if b_cols != 1 {
98              return Err(MechError2::new(
99                DimensionMismatch { dims: vec![a_rows, a_cols, b_rows, b_cols] },
100                Some("Right-hand side must be a vector (1 column)".to_string())
101              ).with_compiler_loc());
102            }
103            if a_rows != b_rows {
104              return Err(MechError2::new(
105                DimensionMismatch { dims: vec![a_rows, a_cols, b_rows, b_cols] },
106                Some("Matrix rows must match vector rows".to_string())
107              ).with_compiler_loc());
108            }
109            Ok(Box::new(MatrixSolveMDVD { lhs: lhs.clone(), rhs: rhs.clone(), out: Ref::new(DVector::from_element(a_rows, $target_type::zero())) }))
110          },
111          #[cfg(feature = $value_string)]
112          (Value::$matrix_kind(lhs), Value::$matrix_kind(rhs)) => {
113            let lhs_shape = lhs.shape();
114            let rhs_shape = rhs.shape();
115            return Err(MechError2::new(
116              DimensionMismatch { dims: vec![lhs_shape[0], lhs_shape[1], rhs_shape[0], rhs_shape[1]] },
117              Some("Matrix multiplication is only implemented for `matrixd` and `vectord` types".to_string())
118            ).with_compiler_loc());
119          }
120        )+
121      )+
122      (arg1,arg2) => Err(MechError2::new(
123        UnhandledFunctionArgumentKind2 { arg: (arg1.kind(),arg2.kind()), fxn_name: stringify!($fxn).to_string() },
124        Some("Unsupported types for matrix multiplication".to_string())
125      ).with_compiler_loc()),
126    }
127  }
128}
129
130fn impl_solve_fxn(lhs_value: Value, rhs_value: Value) -> MResult<Box<dyn MechFunction>> {
131  impl_solve_match_arms!(
132    (lhs_value, rhs_value),
133    MatrixF32,  f32,  "f32";
134    MatrixF64,  f64,  "f64";
135    //R64, MatrixR64, R64, "rational";
136    //C64, MatrixC64, C64, "complex";
137  )
138}
139
140impl_mech_binop_fxn!(MatrixSolve, impl_solve_fxn, "matrix/solve");