1use crate::*;
2use mech_core::*;
3#[cfg(feature = "matrix")]
4use mech_core::matrix::Matrix;
5use nalgebra::ComplexField;
6
7#[macro_export]
10macro_rules! impl_binop_solve {
11 ($struct_name:ident, $arg1_type:ty, $arg2_type:ty, $out_type:ty, $op:ident, $feature_flag:expr) => {
12 #[derive(Debug)]
13 pub struct $struct_name<T> {
14 pub lhs: Ref<$arg1_type>,
15 pub rhs: Ref<$arg2_type>,
16 pub out: Ref<$out_type>,
17 }
18 impl<T> MechFunctionFactory for $struct_name<T>
19 where
20 #[cfg(feature = "compiler")] T: Copy + Debug + Display + Clone + Sync + Send + 'static + PartialEq + PartialOrd + ComplexField + AsValueKind + Add<Output = T> + AddAssign +Sub<Output = T> + SubAssign +Mul<Output = T> + MulAssign +Div<Output = T> + DivAssign +Zero + One + ConstElem + CompileConst + AsValueKind,
21 #[cfg(not(feature = "compiler"))] T: Copy + Debug + Display + Clone + Sync + Send + 'static + PartialEq + PartialOrd + ComplexField + AsValueKind + Add<Output = T> + AddAssign +Sub<Output = T> + SubAssign +Mul<Output = T> + MulAssign +Div<Output = T> + DivAssign +Zero + One,
22 Ref<$out_type>: ToValue,
23 {
24 fn new(args: FunctionArgs) -> MResult<Box<dyn MechFunction>> {
25 match args {
26 FunctionArgs::Binary(out, arg1, arg2) => {
27 let lhs: Ref<$arg1_type> = unsafe { arg1.as_unchecked() }.clone();
28 let rhs: Ref<$arg2_type> = unsafe { arg2.as_unchecked() }.clone();
29 let out: Ref<$out_type> = unsafe { out.as_unchecked() }.clone();
30 Ok(Box::new(Self {lhs, rhs, out }))
31 },
32 _ => Err(MechError2::new(
33 IncorrectNumberOfArguments { expected: 2, found: args.len() },
34 None
35 ).with_compiler_loc()
36 ),
37 }
38 }
39 }
40 impl<T> MechFunctionImpl for $struct_name<T>
41 where
42 T: Copy + Debug + Display + Clone + Sync + Send + 'static +
43 PartialEq + PartialOrd + ComplexField +
44 Add<Output = T> + AddAssign +
45 Sub<Output = T> + SubAssign +
46 Mul<Output = T> + MulAssign +
47 Div<Output = T> + DivAssign +
48 Zero + One,
49 Ref<$out_type>: ToValue
50 {
51 fn solve(&self) {
52 let lhs_ptr = self.lhs.as_ptr();
53 let rhs_ptr = self.rhs.as_ptr();
54 let out_ptr = self.out.as_mut_ptr();
55 $op!(lhs_ptr,rhs_ptr,out_ptr);
56 }
57 fn out(&self) -> Value { self.out.to_value() }
58 fn to_string(&self) -> String { format!("{:#?}", self) }
59 }
60 #[cfg(feature = "compiler")]
61 impl<T> MechFunctionCompiler for $struct_name<T>
62 where
63 T: ConstElem + CompileConst + AsValueKind
64 {
65 fn compile(&self, ctx: &mut CompileCtx) -> MResult<Register> {
66 let name = format!("{}<{}>", stringify!($struct_name), T::as_value_kind());
67 compile_binop!(name, self.out, self.lhs, self.rhs, ctx, $feature_flag);
68 }
69 }
70 };
71}
72
73macro_rules! solve_op {
74 ($a:expr, $b:expr, $out:expr) => {
75 unsafe { *$out = (*$a).clone().lu().solve(&*$b).unwrap(); }
76 };}
77
78macro_rules! impl_solve {
79 ($name:ident, $type1:ty, $type2:ty, $out_type:ty) => {
80 impl_binop_solve!($name, $type1, $type2, $out_type, solve_op, FeatureFlag::Builtin(FeatureKind::Solve));
81 register_fxn_descriptor!($name, f64, "f64");
82 };
83}
84
85#[cfg(all(feature = "matrixd", feature = "vectord"))]
86impl_solve!(MatrixSolveMDVD, DMatrix<T>, DVector<T>, DVector<T>);
87
88macro_rules! impl_solve_match_arms {
89 ($arg:expr, $($($matrix_kind:tt, $target_type:tt, $value_string:tt),+);+ $(;)?) => {
90 match $arg {
91 $(
92 $(
93 #[cfg(all(feature = $value_string, feature = "matrixd", feature = "vectord"))]
94 (Value::$matrix_kind(Matrix::DMatrix(lhs)), Value::$matrix_kind(Matrix::DVector(rhs))) => {
95 let (a_rows, a_cols) = lhs.borrow().shape();
96 let (b_rows, b_cols) = rhs.borrow().shape();
97 if b_cols != 1 {
98 return Err(MechError2::new(
99 DimensionMismatch { dims: vec![a_rows, a_cols, b_rows, b_cols] },
100 Some("Right-hand side must be a vector (1 column)".to_string())
101 ).with_compiler_loc());
102 }
103 if a_rows != b_rows {
104 return Err(MechError2::new(
105 DimensionMismatch { dims: vec![a_rows, a_cols, b_rows, b_cols] },
106 Some("Matrix rows must match vector rows".to_string())
107 ).with_compiler_loc());
108 }
109 Ok(Box::new(MatrixSolveMDVD { lhs: lhs.clone(), rhs: rhs.clone(), out: Ref::new(DVector::from_element(a_rows, $target_type::zero())) }))
110 },
111 #[cfg(feature = $value_string)]
112 (Value::$matrix_kind(lhs), Value::$matrix_kind(rhs)) => {
113 let lhs_shape = lhs.shape();
114 let rhs_shape = rhs.shape();
115 return Err(MechError2::new(
116 DimensionMismatch { dims: vec![lhs_shape[0], lhs_shape[1], rhs_shape[0], rhs_shape[1]] },
117 Some("Matrix multiplication is only implemented for `matrixd` and `vectord` types".to_string())
118 ).with_compiler_loc());
119 }
120 )+
121 )+
122 (arg1,arg2) => Err(MechError2::new(
123 UnhandledFunctionArgumentKind2 { arg: (arg1.kind(),arg2.kind()), fxn_name: stringify!($fxn).to_string() },
124 Some("Unsupported types for matrix multiplication".to_string())
125 ).with_compiler_loc()),
126 }
127 }
128}
129
130fn impl_solve_fxn(lhs_value: Value, rhs_value: Value) -> MResult<Box<dyn MechFunction>> {
131 impl_solve_match_arms!(
132 (lhs_value, rhs_value),
133 MatrixF32, f32, "f32";
134 MatrixF64, f64, "f64";
135 )
138}
139
140impl_mech_binop_fxn!(MatrixSolve, impl_solve_fxn, "matrix/solve");