use crate::{
iterators::{
DenseTensorIterator, DenseTensorLinearIterator, FiberData, IteratableTensor,
SparseTensorIterator, SparseTensorLinearIterator, TensorStructureIndexIterator,
},
structure::{
OrderedStructure, PermutedStructure,
concrete_index::{ExpandedIndex, FlatIndex},
representation::{Euclidean, RepName},
},
tensors::data::{DenseTensor, SetTensorData, SparseTensor},
};
use std::collections::HashSet;
#[test]
fn test_dense_tensor_iterators() {
let rep = Euclidean {};
let structure: OrderedStructure<Euclidean> =
OrderedStructure::new(vec![rep.new_slot(2, 0), rep.new_slot(3, 0)]).structure;
let mut tensor = DenseTensor::<i32, _>::zero(structure);
for i in 0..6 {
tensor.data[i] = i as i32;
}
let expanded_iter = DenseTensorIterator::new(&tensor);
let expanded_results: Vec<(ExpandedIndex, &i32)> = expanded_iter.collect();
assert_eq!(expanded_results.len(), 6);
let linear_iter = DenseTensorLinearIterator::new(&tensor);
let linear_results: Vec<(FlatIndex, &i32)> = linear_iter.collect();
assert_eq!(linear_results.len(), 6);
for (i, &(_, &val)) in linear_results.iter().enumerate() {
assert_eq!(val, i as i32);
}
let into_iter_results: Vec<(ExpandedIndex, &i32)> = (&tensor).into_iter().collect();
assert_eq!(into_iter_results.len(), 6);
let fiber_data = [true, false].as_slice().into(); let fiber = tensor.fiber(fiber_data);
let fiber_iter = fiber.iter();
let fiber_results: Vec<(&i32, ())> = fiber_iter.collect();
assert_eq!(fiber_results.len(), 2); }
#[test]
fn test_sparse_tensor_iterators() {
let rep = Euclidean {};
let structure: OrderedStructure<Euclidean> =
OrderedStructure::new(vec![rep.new_slot(2, 0), rep.new_slot(3, 0)]).structure;
let mut tensor = SparseTensor::<i32, _>::empty(structure, 0);
tensor.set(&[0, 0], 1).unwrap();
tensor.set(&[1, 1], 2).unwrap();
tensor.set(&[1, 2], 3).unwrap();
let linear_iter = SparseTensorLinearIterator::new(&tensor);
let linear_results: Vec<(FlatIndex, &i32)> = linear_iter.collect();
assert_eq!(linear_results.len(), 3);
let expanded_iter = SparseTensorIterator::new(&tensor);
let expanded_results: Vec<(ExpandedIndex, &i32)> = expanded_iter.collect();
assert_eq!(expanded_results.len(), 3);
for &(ref indices, &val) in &expanded_results {
match indices.indices.as_slice() {
[0, 0] => assert_eq!(val, 1),
[1, 1] => assert_eq!(val, 2),
[1, 2] => assert_eq!(val, 3),
_ => panic!("Unexpected index"),
}
}
}
#[test]
fn test_fiber_iterators() {
let rep = Euclidean {};
let structure: PermutedStructure<OrderedStructure<Euclidean>> =
OrderedStructure::new(vec![rep.new_slot(2, 0), rep.new_slot(3, 0)]);
let mut tensor = DenseTensor::<i32, _>::zero((structure.structure).clone());
for i in 0..6 {
tensor.data[i] = i as i32;
}
let fiber = tensor.fiber(FiberData::BoolFilter(&[true, false]));
let fiber_iter = fiber.iter();
let results: Vec<(&i32, ())> = fiber_iter.collect();
assert_eq!(results.len(), 2);
assert_eq!(*results[0].0, 0); assert_eq!(*results[1].0, 3); }
#[test]
fn test_fiber_class_iterators() {
let rep = Euclidean {};
let structure: PermutedStructure<OrderedStructure<Euclidean>> =
OrderedStructure::new(vec![rep.new_slot(2, 0), rep.new_slot(3, 0)]);
let mut tensor = DenseTensor::<i32, _>::zero((structure.structure).clone());
for i in 0..6 {
tensor.data[i] = i as i32;
}
let fiber_class = tensor.fiber_class(FiberData::BoolFilter(&[false, true]));
let mut class_iter = fiber_class.iter();
let fiber1_opt = class_iter.next();
assert!(fiber1_opt.is_some());
let fiber2_opt = class_iter.next();
assert!(fiber2_opt.is_some());
let fiber3_opt = class_iter.next();
assert!(fiber3_opt.is_none());
let fiber1 = fiber1_opt.unwrap();
let results1: Vec<(&i32, ())> = fiber1.collect();
assert_eq!(results1.len(), 3); }
#[test]
fn test_tensor_structure_index_iterator() {
let rep = Euclidean {};
let structure: OrderedStructure<Euclidean> =
OrderedStructure::new(vec![rep.new_slot(2, 0), rep.new_slot(3, 0)]).structure;
let index_iter = TensorStructureIndexIterator::new(&structure);
let indices: Vec<ExpandedIndex> = index_iter.collect();
assert_eq!(indices.len(), 6);
let unique_indices: HashSet<_> = indices.iter().collect();
assert_eq!(unique_indices.len(), 6);
for idx in indices {
assert!(idx[0] < 2);
assert!(idx[1] < 3);
}
}