1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
pub mod assign;
pub mod differential;
pub mod operations;
pub mod size;
pub mod tex_code;
pub mod variable;

pub use assign::*;
pub use differential::*;
use serde::{Deserialize, Serialize};
pub use size::*;
pub use tex_code::*;
pub use variable::*;

use crate::Expression;
use std::collections::{HashMap, HashSet};

#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub enum TensorExpression {
    KroneckerDeltas(Vec<[usize; 2]>),
    DotProduct {
        terms: Vec<Expression>,
        rank_combinations: Vec<HashMap<usize, String>>,
    },
    DirectProduct(Vec<Expression>),
}

impl Expression {
    pub fn tensor(self) -> Option<TensorExpression> {
        match self {
            Expression::Tensor(t) => Some(*t),
            _ => None,
        }
    }

    pub fn into_tensor(self) -> TensorExpression {
        match self {
            Expression::Tensor(t) => *t,
            _ => panic!("The expression is not a tensor expression."),
        }
    }
}

impl From<TensorExpression> for Expression {
    fn from(t: TensorExpression) -> Self {
        Expression::Tensor(t.into())
    }
}

#[cfg(test)]
mod tests {
    use std::{collections::HashMap, ops::Add};

    use opensrdk_linear_algebra::{sparse::SparseTensor, Matrix};

    use crate::{Expression, MatrixExpression};

    #[test]
    fn it_works() {
        let a = 5.0f64;
        let b = vec![a; 8];
        let mut hash = HashMap::new();
        hash.insert(vec![3usize; 8], 2.0);
        hash.insert(vec![1usize; 8], 3.0);
        hash.insert(vec![4usize; 8], 4.0);
        hash.insert(vec![5usize; 8], 2.0);
        let c = SparseTensor::from(vec![6usize; 8], hash).unwrap();

        let ec = Expression::from(c);
        // TODO: into_tensor is for internal use. don't use here.
        // TODO: At first it is needed to extract ConstantValue from ec, and then convert it to SparseTensor.
        // let tec = ec.into_tensor();
        // println!("{:?}", tec);
    }
}