opensrdk_linear_algebra/tensor/
mod.rs

1pub mod matrix;
2pub mod sparse;
3
4use crate::Number;
5use rand::prelude::*;
6use std::{collections::HashMap, error::Error, fmt::Debug};
7
8pub type RankIndex = usize;
9pub type RankCombinationId = String;
10
11pub fn generate_rank_combination_id() -> RankCombinationId {
12    thread_rng().gen::<u32>().to_string()
13}
14
15pub fn generate_rank_combinations(
16    rank_pairs: &[[RankIndex; 2]],
17) -> [HashMap<RankIndex, String>; 2] {
18    let mut rank_combinations = [HashMap::new(), HashMap::new()];
19    for rank_pair in rank_pairs.iter() {
20        let id = generate_rank_combination_id();
21        rank_combinations[0].insert(rank_pair[0], id.to_string());
22        rank_combinations[1].insert(rank_pair[1], id.to_string());
23    }
24
25    rank_combinations
26}
27
28pub fn indices_cartesian_product(sizes: &[usize]) -> Vec<Vec<usize>> {
29    sizes
30        .iter()
31        .fold(Vec::<Vec<usize>>::new(), |accum, &next_size| {
32            if accum.is_empty() {
33                return (0..next_size).map(|i| vec![i]).collect::<Vec<_>>();
34            };
35            accum
36                .into_iter()
37                .flat_map(|acc| {
38                    (0..next_size)
39                        .map(|i| [&acc[..], &[i]].concat())
40                        .collect::<Vec<_>>()
41                })
42                .collect()
43        })
44        .into_iter()
45        .collect()
46}
47
48pub trait Tensor<T>: Clone + Debug + PartialEq + Send + Sync
49where
50    T: Number,
51{
52    fn rank(&self) -> usize;
53    fn size(&self, rank: RankIndex) -> usize;
54    fn elem(&self, indices: &[usize]) -> T;
55    fn elem_mut(&mut self, indices: &[usize]) -> &mut T;
56}
57
58#[derive(thiserror::Error, Debug)]
59pub enum TensorError {
60    #[error("Dimension mismatch.")]
61    RankMismatch,
62    #[error("Out of range.")]
63    OutOfRange,
64    #[error("Others")]
65    Others(Box<dyn Error + Send + Sync>),
66}
67
68#[cfg(test)]
69mod tests {
70    #[test]
71    fn generate_rank_combinations() {
72        use super::generate_rank_combinations;
73        let a = generate_rank_combinations(&[[0, 0], [1, 1]]);
74        println!("a:{:?}", a);
75        let b = generate_rank_combinations(&[[0, 0], [1, 1], [2, 2]]);
76        println!("b:{:?}", b);
77        assert_eq!(a[0].get(&0).unwrap(), a[1].get(&0).unwrap());
78        assert_eq!(a[0].get(&1).unwrap(), a[1].get(&1).unwrap());
79        assert_eq!(b[0].get(&0).unwrap(), b[1].get(&0).unwrap());
80        assert_eq!(b[0].get(&1).unwrap(), b[1].get(&1).unwrap());
81        assert_eq!(b[0].get(&2).unwrap(), b[1].get(&2).unwrap());
82    }
83
84    #[test]
85    fn indices_cartesian_product() {
86        use super::indices_cartesian_product;
87        let a = indices_cartesian_product(&[2, 2]);
88        let b = indices_cartesian_product(&[2, 3, 4]);
89        assert_eq!(a, vec![vec![0, 0], vec![0, 1], vec![1, 0], vec![1, 1]]);
90
91        assert_eq!(
92            b,
93            vec![
94                vec![0, 0, 0],
95                vec![0, 0, 1],
96                vec![0, 0, 2],
97                vec![0, 0, 3],
98                vec![0, 1, 0],
99                vec![0, 1, 1],
100                vec![0, 1, 2],
101                vec![0, 1, 3],
102                vec![0, 2, 0],
103                vec![0, 2, 1],
104                vec![0, 2, 2],
105                vec![0, 2, 3],
106                vec![1, 0, 0],
107                vec![1, 0, 1],
108                vec![1, 0, 2],
109                vec![1, 0, 3],
110                vec![1, 1, 0],
111                vec![1, 1, 1],
112                vec![1, 1, 2],
113                vec![1, 1, 3],
114                vec![1, 2, 0],
115                vec![1, 2, 1],
116                vec![1, 2, 2],
117                vec![1, 2, 3],
118            ]
119        );
120    }
121}