opensrdk_symbolic_computation/expression/tensor_expression/
mod.rs

1pub mod assign;
2pub mod differential;
3pub mod operations;
4pub mod size;
5pub mod tex_code;
6pub mod variable;
7
8pub use assign::*;
9pub use differential::*;
10use serde::{Deserialize, Serialize};
11pub use size::*;
12pub use tex_code::*;
13pub use variable::*;
14
15use crate::Expression;
16use std::collections::{HashMap, HashSet};
17
18#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
19pub enum TensorExpression {
20    KroneckerDeltas(Vec<[usize; 2]>),
21    DotProduct {
22        terms: Vec<Expression>,
23        rank_combinations: Vec<HashMap<usize, String>>,
24    },
25    DirectProduct(Vec<Expression>),
26}
27
28impl Expression {
29    pub fn tensor(self) -> Option<TensorExpression> {
30        match self {
31            Expression::Tensor(t) => Some(*t),
32            _ => None,
33        }
34    }
35
36    pub fn into_tensor(self) -> TensorExpression {
37        match self {
38            Expression::Tensor(t) => *t,
39            _ => panic!("The expression is not a tensor expression."),
40        }
41    }
42}
43
44impl From<TensorExpression> for Expression {
45    fn from(t: TensorExpression) -> Self {
46        Expression::Tensor(t.into())
47    }
48}
49
50#[cfg(test)]
51mod tests {
52    use std::{collections::HashMap, ops::Add};
53
54    use opensrdk_linear_algebra::{sparse::SparseTensor, Matrix};
55
56    use crate::{Expression, MatrixExpression};
57
58    #[test]
59    fn it_works() {
60        let a = 5.0f64;
61        let b = vec![a; 8];
62        let mut hash = HashMap::new();
63        hash.insert(vec![3usize; 8], 2.0);
64        hash.insert(vec![1usize; 8], 3.0);
65        hash.insert(vec![4usize; 8], 4.0);
66        hash.insert(vec![5usize; 8], 2.0);
67        let c = SparseTensor::from(vec![6usize; 8], hash).unwrap();
68
69        let ec = Expression::from(c);
70        // TODO: into_tensor is for internal use. don't use here.
71        // TODO: At first it is needed to extract ConstantValue from ec, and then convert it to SparseTensor.
72        // let tec = ec.into_tensor();
73        // println!("{:?}", tec);
74    }
75}