opensrdk_symbolic_computation/expression/matrix_expression/operations/
t.rs

1use std::collections::HashMap;
2
3use opensrdk_linear_algebra::Matrix;
4
5use crate::{BracketsLevel, ConstantValue, Expression, MatrixExpression, TensorExpression};
6
7impl Expression {
8    pub fn t(self) -> Expression {
9        if let Expression::Constant(v) = &self {
10            let t = |v: &Matrix| v.t().into();
11            return match v {
12                ConstantValue::Scalar(v) => v.abs().into(),
13                ConstantValue::Tensor(v) => t(&v.reduce_1dimension_rank().to_mat()),
14                ConstantValue::Matrix(v) => return t(v),
15            };
16        }
17
18        MatrixExpression::T(self.into()).into()
19    }
20}
21
22impl MatrixExpression {
23    pub(crate) fn diff_t(v: &Expression, symbols: &[&str]) -> Vec<Expression> {
24        let delta_01: Expression = TensorExpression::KroneckerDeltas(vec![[0, 1]]).into();
25        let tensor = delta_01
26            .clone()
27            .dot(v.clone(), &[[0, 1]])
28            .dot(delta_01, &[[0, 1]]);
29
30        tensor.differential(symbols)
31    }
32
33    pub(crate) fn tex_code_t(v: &Expression, symbols: &HashMap<&str, &str>) -> String {
34        format!(
35            r"{}^\top",
36            v._tex_code(symbols, BracketsLevel::ForOperation)
37        )
38    }
39}
40
41#[cfg(test)]
42mod tests {
43    use std::collections::HashSet;
44
45    use crate::{new_variable_tensor, Expression, MatrixExpression, Size};
46
47    #[test]
48    fn it_works() {
49        let id = "x";
50        let a = HashSet::from([id; 1]);
51        let ea = new_variable_tensor((id).to_string(), vec![Size::Many, Size::Many]);
52
53        let ea_t = ea.clone().t();
54
55        let id2 = "g";
56        let diff_ea_t = MatrixExpression::diff_t(&ea, &[id]);
57        let tex_symbols = vec![("x", "y")].into_iter().collect();
58        println!("{:?}", diff_ea_t);
59        let tex_ea_t = ea_t.tex_code(&tex_symbols);
60        let tex_diff_ea_t = diff_ea_t[0].tex_code(&tex_symbols);
61        println!("{:?}", tex_ea_t);
62        println!("{:?}", tex_diff_ea_t);
63    }
64}