1use crate::*;
2use mech_core::*;
3#[cfg(feature = "matrix")]
4use mech_core::matrix::Matrix;
5
6macro_rules! mul_op {
9 ($lhs:expr, $rhs:expr, $out:expr) => {
10 unsafe { *$out = *$lhs * *$rhs; }
11 };}
12
13macro_rules! dot_op {
14 ($lhs:expr, $rhs:expr, $out:expr) => {
15 unsafe { *$out = (*$lhs).dot(&*$rhs); }
16 };}
17
18macro_rules! impl_dot {
19 ($name:ident, $type1:ty, $type2:ty, $out_type:ty) => {
20 impl_binop!($name, $type1, $type2, $out_type, dot_op, FeatureFlag::Builtin(FeatureKind::Dot));
21 register_fxn_descriptor!($name, u8, "u8", u16, "u16", u32, "u32", u64, "u64", u128, "u128", i8, "i8", i16, "i16", i32, "i32", i64, "i64", i128, "i128", F32, "f32", F64, "f64");
22 };
23}
24
25impl_binop!(DotScalar, T, T, T, mul_op, FeatureFlag::Builtin(FeatureKind::Dot));
26register_fxn_descriptor!(DotScalar, u8, "u8", u16, "u16", u32, "u32", u64, "u64", u128, "u128", i8, "i8", i16, "i16", i32, "i32", i64, "i64", i128, "i128", F32, "f32", F64, "f64");
27
28#[cfg(all(feature = "row_vector2", feature = "row_vector2"))]
29impl_dot!(DotR2R2, RowVector2<T>, RowVector2<T>, T);
30#[cfg(all(feature = "vector2", feature = "vector2"))]
31impl_dot!(DotV2V2, Vector2<T>, Vector2<T>, T);
32
33#[cfg(all(feature = "row_vector3", feature = "row_vector3"))]
34impl_dot!(DotR3R3, RowVector3<T>, RowVector3<T>, T);
35#[cfg(all(feature = "vector3", feature = "vector3"))]
36impl_dot!(DotV3V3, Vector3<T>, Vector3<T>, T);
37
38#[cfg(all(feature = "row_vector4", feature = "row_vector4"))]
39impl_dot!(DotR4R4, RowVector4<T>, RowVector4<T>, T);
40#[cfg(all(feature = "vector4", feature = "vector4"))]
41impl_dot!(DotV4V4, Vector4<T>, Vector4<T>, T);
42
43#[cfg(all(feature = "matrix1", feature = "matrix1"))]
44impl_dot!(DotM1M1, Matrix2<T>, Matrix2<T>, T);
45#[cfg(all(feature = "matrix2", feature = "matrix2"))]
46impl_dot!(DotM2M2, Matrix2<T>, Matrix2<T>, T);
47#[cfg(all(feature = "matrix3", feature = "matrix3"))]
48impl_dot!(DotM3M3, Matrix3<T>, Matrix3<T>, T);
49#[cfg(all(feature = "matrix4", feature = "matrix4"))]
50impl_dot!(DotM4M4, Matrix4<T>, Matrix4<T>, T);
51
52#[cfg(all(feature = "matrixd", feature = "matrixd"))]
53impl_dot!(DotMDMD, DMatrix<T>, DMatrix<T>, T);
54#[cfg(all(feature = "vectord", feature = "vectord"))]
55impl_dot!(DotVDVD, DVector<T>, DVector<T>, T);
56#[cfg(all(feature = "row_vectord", feature = "row_vectord"))]
57impl_dot!(DotRDRD, RowDVector<T>, RowDVector<T>, T);
58
59macro_rules! impl_dot_match_arms {
60 ($arg:expr, $($lhs_type:tt, $($matrix_kind:tt, $target_type:tt, $value_string:tt),+);+ $(;)?) => {
61 match $arg {
62 $(
63 $(
64 #[cfg(feature = $value_string)]
65 (Value::$lhs_type(lhs), Value::$lhs_type(rhs)) => Ok(Box::new(DotScalar { lhs: lhs.clone(), rhs: rhs.clone(), out: Ref::new($target_type::default()) })),
66
67 #[cfg(all(feature = $value_string, feature = "vector2", feature = "vector2"))]
68 (Value::$matrix_kind(Matrix::Vector2(lhs)), Value::$matrix_kind(Matrix::Vector2(rhs))) => Ok(Box::new(DotV2V2 { lhs: lhs.clone(), rhs: rhs.clone(), out: Ref::new($target_type::default()) })),
69 #[cfg(all(feature = $value_string, feature = "row_vector2", feature = "row_vector2"))]
70 (Value::$matrix_kind(Matrix::RowVector2(lhs)), Value::$matrix_kind(Matrix::RowVector2(rhs))) => Ok(Box::new(DotR2R2 { lhs: lhs.clone(), rhs: rhs.clone(), out: Ref::new($target_type::default()) })),
71
72 #[cfg(all(feature = $value_string, feature = "vector3", feature = "vector3"))]
73 (Value::$matrix_kind(Matrix::Vector3(lhs)), Value::$matrix_kind(Matrix::Vector3(rhs))) => Ok(Box::new(DotV3V3 { lhs: lhs.clone(), rhs: rhs.clone(), out: Ref::new($target_type::default()) })),
74 #[cfg(all(feature = $value_string, feature = "row_vector3", feature = "row_vector3"))]
75 (Value::$matrix_kind(Matrix::RowVector3(lhs)), Value::$matrix_kind(Matrix::RowVector3(rhs))) => Ok(Box::new(DotR3R3 { lhs: lhs.clone(), rhs: rhs.clone(), out: Ref::new($target_type::default()) })),
76
77 #[cfg(all(feature = $value_string, feature = "vector4", feature = "vector4"))]
78 (Value::$matrix_kind(Matrix::Vector4(lhs)), Value::$matrix_kind(Matrix::Vector4(rhs))) => Ok(Box::new(DotV4V4 { lhs: lhs.clone(), rhs: rhs.clone(), out: Ref::new($target_type::default()) })),
79 #[cfg(all(feature = $value_string, feature = "row_vector4", feature = "row_vector4"))]
80 (Value::$matrix_kind(Matrix::RowVector4(lhs)), Value::$matrix_kind(Matrix::RowVector4(rhs))) => Ok(Box::new(DotR4R4 { lhs: lhs.clone(), rhs: rhs.clone(), out: Ref::new($target_type::default()) })),
81
82 #[cfg(all(feature = $value_string, feature = "matrix1", feature = "matrix1"))]
83 (Value::$matrix_kind(Matrix::Matrix2(lhs)), Value::$matrix_kind(Matrix::Matrix2(rhs))) => Ok(Box::new(DotM1M1 { lhs: lhs.clone(), rhs: rhs.clone(), out: Ref::new($target_type::default()) })),
84 #[cfg(all(feature = $value_string, feature = "matrix2", feature = "matrix2"))]
85 (Value::$matrix_kind(Matrix::Matrix2(lhs)), Value::$matrix_kind(Matrix::Matrix2(rhs))) => Ok(Box::new(DotM2M2 { lhs: lhs.clone(), rhs: rhs.clone(), out: Ref::new($target_type::default()) })),
86 #[cfg(all(feature = $value_string, feature = "matrix3", feature = "matrix3"))]
87 (Value::$matrix_kind(Matrix::Matrix3(lhs)), Value::$matrix_kind(Matrix::Matrix3(rhs))) => Ok(Box::new(DotM3M3 { lhs: lhs.clone(), rhs: rhs.clone(), out: Ref::new($target_type::default()) })),
88 #[cfg(all(feature = $value_string, feature = "matrix4", feature = "matrix4"))]
89 (Value::$matrix_kind(Matrix::Matrix4(lhs)), Value::$matrix_kind(Matrix::Matrix4(rhs))) => Ok(Box::new(DotM4M4 { lhs: lhs.clone(), rhs: rhs.clone(), out: Ref::new($target_type::default()) })),
90 #[cfg(all(feature = $value_string, feature = "matrixd", feature = "matrixd"))]
91 (Value::$matrix_kind(Matrix::DMatrix(lhs)), Value::$matrix_kind(Matrix::DMatrix(rhs))) => {
92 let (lhs_rows,lhs_cols) = {lhs.borrow().shape()};
93 let (rhs_rows,rhs_cols) = {rhs.borrow().shape()};
94 if lhs_rows != rhs_rows || lhs_cols != rhs_cols {
95 return Err(MechError{file: file!().to_string(), tokens: vec![], msg: format!("Matrix dimensions must agree: lhs is {}x{}, rhs is {}x{}", lhs_rows, lhs_cols, rhs_rows, rhs_cols), id: line!(), kind: MechErrorKind::None });
96 }
97 Ok(Box::new(DotMDMD { lhs: lhs.clone(), rhs: rhs.clone(), out: Ref::new($target_type::default()) }))
98 },
99 #[cfg(all(feature = $value_string, feature = "vectord", feature = "vectord"))]
100 (Value::$matrix_kind(Matrix::DVector(lhs)), Value::$matrix_kind(Matrix::DVector(rhs))) => {
101 let lhs_len = {lhs.borrow().len()};
102 let rhs_len = {rhs.borrow().len()};
103 if lhs_len != rhs_len {
104 return Err(MechError{file: file!().to_string(), tokens: vec![], msg: format!("Vector dimensions must agree: lhs is {}, rhs is {}", lhs_len, rhs_len), id: line!(), kind: MechErrorKind::None });
105 }
106 Ok(Box::new(DotVDVD { lhs: lhs.clone(), rhs: rhs.clone(), out: Ref::new($target_type::default()) }))
107 },
108 #[cfg(all(feature = $value_string, feature = "row_vectord", feature = "row_vectord"))]
109 (Value::$matrix_kind(Matrix::RowDVector(lhs)), Value::$matrix_kind(Matrix::RowDVector(rhs))) => {
110 let lhs_len = {lhs.borrow().len()};
111 let rhs_len = {rhs.borrow().len()};
112 if lhs_len != rhs_len {
113 return Err(MechError{file: file!().to_string(), tokens: vec![], msg: format!("Vector dimensions must agree: lhs is {}, rhs is {}", lhs_len, rhs_len), id: line!(), kind: MechErrorKind::None });
114 }
115 Ok(Box::new(DotRDRD { lhs: lhs.clone(), rhs: rhs.clone(), out: Ref::new($target_type::default()) }))
116 },
117 )+
118 )+
119 x => Err(MechError{file: file!().to_string(), tokens: vec![], msg: format!("{:?}",x), id: line!(), kind: MechErrorKind::UnhandledFunctionArgumentKind }),
120 }
121 }
122}
123
124fn impl_dot_fxn(lhs_value: Value, rhs_value: Value) -> Result<Box<dyn MechFunction>, MechError> {
125 impl_dot_match_arms!(
126 (lhs_value, rhs_value),
127 I8, MatrixI8, i8, "i8";
128 I16, MatrixI16, i16, "i16";
129 I32, MatrixI32, i32, "i32";
130 I64, MatrixI64, i64, "i64";
131 I128, MatrixI128, i128, "i128";
132 U8, MatrixU8, u8, "u8";
133 U16, MatrixU16, u16, "u16";
134 U32, MatrixU32, u32, "u32";
135 U64, MatrixU64, u64, "u64";
136 U128, MatrixU128, u128, "u128";
137 F32, MatrixF32, F32, "f32";
138 F64, MatrixF64, F64, "f64";
139 R64, MatrixR64, R64, "rational";
140 C64, MatrixC64, C64, "complex";
141 )
142}
143
144impl_mech_binop_fxn!(MatrixDot,impl_dot_fxn,"matrix/dot");