opensrdk_linear_algebra/tensor/
mod.rs1pub mod matrix;
2pub mod sparse;
3
4use crate::Number;
5use rand::prelude::*;
6use std::{collections::HashMap, error::Error, fmt::Debug};
7
8pub type RankIndex = usize;
9pub type RankCombinationId = String;
10
11pub fn generate_rank_combination_id() -> RankCombinationId {
12 thread_rng().gen::<u32>().to_string()
13}
14
15pub fn generate_rank_combinations(
16 rank_pairs: &[[RankIndex; 2]],
17) -> [HashMap<RankIndex, String>; 2] {
18 let mut rank_combinations = [HashMap::new(), HashMap::new()];
19 for rank_pair in rank_pairs.iter() {
20 let id = generate_rank_combination_id();
21 rank_combinations[0].insert(rank_pair[0], id.to_string());
22 rank_combinations[1].insert(rank_pair[1], id.to_string());
23 }
24
25 rank_combinations
26}
27
28pub fn indices_cartesian_product(sizes: &[usize]) -> Vec<Vec<usize>> {
29 sizes
30 .iter()
31 .fold(Vec::<Vec<usize>>::new(), |accum, &next_size| {
32 if accum.is_empty() {
33 return (0..next_size).map(|i| vec![i]).collect::<Vec<_>>();
34 };
35 accum
36 .into_iter()
37 .flat_map(|acc| {
38 (0..next_size)
39 .map(|i| [&acc[..], &[i]].concat())
40 .collect::<Vec<_>>()
41 })
42 .collect()
43 })
44 .into_iter()
45 .collect()
46}
47
48pub trait Tensor<T>: Clone + Debug + PartialEq + Send + Sync
49where
50 T: Number,
51{
52 fn rank(&self) -> usize;
53 fn size(&self, rank: RankIndex) -> usize;
54 fn elem(&self, indices: &[usize]) -> T;
55 fn elem_mut(&mut self, indices: &[usize]) -> &mut T;
56}
57
58#[derive(thiserror::Error, Debug)]
59pub enum TensorError {
60 #[error("Dimension mismatch.")]
61 RankMismatch,
62 #[error("Out of range.")]
63 OutOfRange,
64 #[error("Others")]
65 Others(Box<dyn Error + Send + Sync>),
66}
67
68#[cfg(test)]
69mod tests {
70 #[test]
71 fn generate_rank_combinations() {
72 use super::generate_rank_combinations;
73 let a = generate_rank_combinations(&[[0, 0], [1, 1]]);
74 println!("a:{:?}", a);
75 let b = generate_rank_combinations(&[[0, 0], [1, 1], [2, 2]]);
76 println!("b:{:?}", b);
77 assert_eq!(a[0].get(&0).unwrap(), a[1].get(&0).unwrap());
78 assert_eq!(a[0].get(&1).unwrap(), a[1].get(&1).unwrap());
79 assert_eq!(b[0].get(&0).unwrap(), b[1].get(&0).unwrap());
80 assert_eq!(b[0].get(&1).unwrap(), b[1].get(&1).unwrap());
81 assert_eq!(b[0].get(&2).unwrap(), b[1].get(&2).unwrap());
82 }
83
84 #[test]
85 fn indices_cartesian_product() {
86 use super::indices_cartesian_product;
87 let a = indices_cartesian_product(&[2, 2]);
88 let b = indices_cartesian_product(&[2, 3, 4]);
89 assert_eq!(a, vec![vec![0, 0], vec![0, 1], vec![1, 0], vec![1, 1]]);
90
91 assert_eq!(
92 b,
93 vec![
94 vec![0, 0, 0],
95 vec![0, 0, 1],
96 vec![0, 0, 2],
97 vec![0, 0, 3],
98 vec![0, 1, 0],
99 vec![0, 1, 1],
100 vec![0, 1, 2],
101 vec![0, 1, 3],
102 vec![0, 2, 0],
103 vec![0, 2, 1],
104 vec![0, 2, 2],
105 vec![0, 2, 3],
106 vec![1, 0, 0],
107 vec![1, 0, 1],
108 vec![1, 0, 2],
109 vec![1, 0, 3],
110 vec![1, 1, 0],
111 vec![1, 1, 1],
112 vec![1, 1, 2],
113 vec![1, 1, 3],
114 vec![1, 2, 0],
115 vec![1, 2, 1],
116 vec![1, 2, 2],
117 vec![1, 2, 3],
118 ]
119 );
120 }
121}