opensrdk_symbolic_computation/expression/tensor_expression/
mod.rs1pub mod assign;
2pub mod differential;
3pub mod operations;
4pub mod size;
5pub mod tex_code;
6pub mod variable;
7
8pub use assign::*;
9pub use differential::*;
10use serde::{Deserialize, Serialize};
11pub use size::*;
12pub use tex_code::*;
13pub use variable::*;
14
15use crate::Expression;
16use std::collections::{HashMap, HashSet};
17
18#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
19pub enum TensorExpression {
20 KroneckerDeltas(Vec<[usize; 2]>),
21 DotProduct {
22 terms: Vec<Expression>,
23 rank_combinations: Vec<HashMap<usize, String>>,
24 },
25 DirectProduct(Vec<Expression>),
26}
27
28impl Expression {
29 pub fn tensor(self) -> Option<TensorExpression> {
30 match self {
31 Expression::Tensor(t) => Some(*t),
32 _ => None,
33 }
34 }
35
36 pub fn into_tensor(self) -> TensorExpression {
37 match self {
38 Expression::Tensor(t) => *t,
39 _ => panic!("The expression is not a tensor expression."),
40 }
41 }
42}
43
44impl From<TensorExpression> for Expression {
45 fn from(t: TensorExpression) -> Self {
46 Expression::Tensor(t.into())
47 }
48}
49
50#[cfg(test)]
51mod tests {
52 use std::{collections::HashMap, ops::Add};
53
54 use opensrdk_linear_algebra::{sparse::SparseTensor, Matrix};
55
56 use crate::{Expression, MatrixExpression};
57
58 #[test]
59 fn it_works() {
60 let a = 5.0f64;
61 let b = vec![a; 8];
62 let mut hash = HashMap::new();
63 hash.insert(vec![3usize; 8], 2.0);
64 hash.insert(vec![1usize; 8], 3.0);
65 hash.insert(vec![4usize; 8], 4.0);
66 hash.insert(vec![5usize; 8], 2.0);
67 let c = SparseTensor::from(vec![6usize; 8], hash).unwrap();
68
69 let ec = Expression::from(c);
70 }
75}