opensrdk_symbolic_computation/expression/tensor_expression/
variable.rs1use crate::{Expression, Size, TensorExpression};
2use std::collections::HashSet;
3
4pub fn new_variable_tensor(id: String, sizes: Vec<Size>) -> Expression {
5 Expression::Variable(id, sizes)
6}
7
8impl TensorExpression {
9 pub fn variable_ids(&self) -> HashSet<&str> {
10 match self {
11 TensorExpression::KroneckerDeltas(_) => HashSet::new(),
12 TensorExpression::DotProduct {
13 terms,
14 rank_combinations: _,
15 } => terms.iter().map(|t| t.variable_ids()).flatten().collect(),
16 TensorExpression::DirectProduct(terms) => {
17 terms.iter().map(|t| t.variable_ids()).flatten().collect()
18 }
19 }
20 }
21}
22
23#[cfg(test)]
24mod tests {
25 use std::collections::HashSet;
26
27 use crate::{new_variable_tensor, size, MatrixExpression, Size};
28
29 #[test]
30 fn it_works() {
31 let id = "x";
32 let a = HashSet::from([id; 1]);
33 let ea = new_variable_tensor((id).to_string(), vec![Size::Many, Size::Many, Size::Many]);
34 println!("{:?}", ea);
35 }
36}