1
2use std::hash::Hash;
3use std::iter::Sum;
4use sequence_trie::SequenceTrie;
5
6type Row<T, U> = (Vec<U>, T);
7type Rows<T, U> = Vec<Row<T, U>>;
8
9
10#[derive(Debug, Clone)]
11struct IteratorState {
12 next_index: usize
13}
14
15impl Default for IteratorState {
16 fn default() -> IteratorState {
17 IteratorState {
18 next_index: 0
19 }
20 }
21}
22
23#[derive(Debug, Clone)]
24pub struct CompressedSparseFiber<T, U> {
25 pub fptr: Vec<Vec<usize>>,
26 pub fids: Vec<Vec<U>>,
27 pub vals: Vec<T>,
28 _state: IteratorState,
29}
30
31impl<'a, T: 'a, U> CompressedSparseFiber<T, U>
32 where U: Clone {
33 pub fn new(fptr: Vec<Vec<usize>>,
34 fids: Vec<Vec<U>>,
35 vals: Vec<T>) -> CompressedSparseFiber<T, U> {
36 CompressedSparseFiber { fptr, fids, vals, _state: IteratorState { next_index: 0 } }
37 }
38
39 pub fn expand_row(self: &CompressedSparseFiber<T, U>, index: usize) -> Row<T, U>
40 where T: Copy,
41 U: Copy + Default {
42 let depth = self.fids.len();
43
44 let mut result: Vec<U> = vec![Default::default(); depth];
46 let last = self.fids[depth - 1][index];
47 result[depth-1] = last;
48 let mut current_index = index;
49 for level in (0..depth - 1).rev() {
50 let j = self.fptr[level].partition_point(|v| v <= ¤t_index);
51 result[level] = self.fids[level][j - 1];
52 current_index = j-1;
53 }
54 (result, self.vals[index])
55 }
56
57 fn weights(self: &CompressedSparseFiber<T, U>, col_index: usize) -> Vec<usize> {
58
59 fn combine(current_weights: Vec<usize>, fptr_row: Vec<usize>) -> Vec<usize> {
60 let tail = &fptr_row[1..];
61 fptr_row.iter()
62 .zip(tail)
63 .map(|(&i,&j)| ¤t_weights[i..j])
64 .map(|x|x.iter().sum::<usize>())
65 .collect::<Vec<_>>()
66 }
67
68 let fptr_row = &self.fptr[self.fptr.len() - 1];
69 let tail = &fptr_row[1..];
70 let initial = fptr_row.iter()
71 .zip(tail)
72 .map(|(&x, &y)| y - x)
73 .collect::<Vec<_>>();
74
75 (&self.fptr[col_index..self.fptr.len() - 1])
76 .iter()
77 .rfold(initial, |w, f| combine(w, f.to_vec()))
78 }
79
80 pub fn sum_column(self: &CompressedSparseFiber<T, U>, col_index: usize) -> U
81 where T: Copy,
82 U: Sum<U> + Copy {
83 let row = self.fids[col_index].clone();
84
85 if col_index == self.fptr.len() {
86 row.into_iter().sum::<U>()
87 } else {
88 let w = self.weights(col_index);
89 row.iter()
91 .zip(w)
92 .map(|(&x,y)| std::iter::repeat(x).take(y).sum())
93 .sum::<U>()
94 }
95 }
96}
97
98impl<T, U> From<&SequenceTrie<U, T>> for CompressedSparseFiber<T, U>
99 where T: Copy,
100 U: Clone + Eq + Hash + Ord + Copy {
101 fn from(trie: &SequenceTrie<U, T>) -> Self {
102 let mut i = vec![trie];
103 let mut fids: Vec<Vec<U>> = vec![];
104 let mut fptr = vec![];
105 let mut vals: Vec<T> = vec![];
106 let mut initial = true;
107
108 while !i.is_empty() {
109 let mut offset = 0;
110 let mut fptr_row = vec![0];
111 let (keys, children): (Vec<&U>, Vec<_>) = i.into_iter()
112 .flat_map(|y| {
113 let mut x = y.children_with_keys();
114 offset += x.len();
115 fptr_row.push(offset);
116 x.sort_by(|(a, _), (b, _)| a.cmp(b));
117 x
118 })
119 .unzip();
120 if !keys.is_empty() {
121 let row = keys.into_iter().map(|&f| f).collect();
122 fids.push(row);
123 if !initial {
124 fptr.push(fptr_row);
125 } else {
126 initial = false;
127 }
128 }
129
130 let mut values: Vec<T> = children.iter()
131 .filter_map(|x| x.value())
132 .map(|&x| x)
133 .collect::<Vec<_>>();
134 vals.append(&mut values);
135 i = children
136 }
137 CompressedSparseFiber::new(fptr, fids, vals)
138 }
139}
140
141impl<T, U> Iterator for CompressedSparseFiber<T, U>
142 where T: Copy,
143 U: Clone + Copy + Default + Default {
144 type Item = Row<T, U>;
145
146 fn next(&mut self) -> Option<Row<T, U>> {
147 self._state.next_index += 1;
148 if self._state.next_index < self.vals.len() {
149 Some(self.expand_row(self._state.next_index - 1))
150 } else {
151 None
152 }
153 }
154}
155
156impl<T, U> std::iter::FromIterator<Row<T, U>> for CompressedSparseFiber<T, U>
157 where T: Copy,
158 U: Clone + Copy + Hash + Ord {
159 fn from_iter<I: IntoIterator<Item = Row<T, U>>>(iter: I) -> Self {
160 let mut trie: SequenceTrie<U, T> = SequenceTrie::new();
161 for (row, x) in iter {
162 trie.insert(&row, x);
163 }
164 CompressedSparseFiber::<T, U>::from(&trie)
165 }
166}
167
168#[cfg(test)]
169mod tests {
170 use super::*;
171
172 fn sample_rows() -> Rows<f32, i32> {
173 vec![
174 (vec![1, 1, 1, 2], 1.0),
175 (vec![1, 1, 1, 3], 2.0),
176 (vec![1, 2, 1, 1], 3.0),
177 (vec![1, 2, 1, 3], 4.0),
178 (vec![1, 2, 2, 1], 5.0),
179 (vec![2, 2, 2, 1], 6.0),
180 (vec![2, 2, 2, 2], 7.0),
181 (vec![2, 2, 2, 3], 8.0),
182 ]
183 }
184
185 fn sample_csf() -> CompressedSparseFiber<f32, i32> {
186 CompressedSparseFiber::new(
187 vec![vec![0, 2, 3], vec![0, 1, 3, 4], vec![0, 2, 4, 5, 8]],
188 vec![vec![1, 2], vec![1, 2, 2], vec![1, 1, 2, 2], vec![2, 3, 1, 3, 1, 1, 2, 3]],
189 vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0],
190 )
191 }
192
193 #[test]
194 fn test_build() {
195 let x : CompressedSparseFiber<_, _> = sample_rows().into_iter().collect();
196 assert_eq!(x.fids[0], vec![1, 2]);
197 assert_eq!(x.fids[1], vec![1, 2, 2]);
198 assert_eq!(x.fids[2], vec![1, 1, 2, 2]);
199 assert_eq!(x.fids[3], vec![2, 3, 1, 3, 1, 1, 2, 3]);
200
201 assert_eq!(x.fptr[0], vec![0, 2, 3]);
202 assert_eq!(x.fptr[1], vec![0, 1, 3, 4]);
203 assert_eq!(x.fptr[2], vec![0, 2, 4, 5, 8]);
204
205 assert_eq!(x.vals, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
206 }
207
208
209 #[test]
210 fn test_expand_row() {
211 let x = sample_csf();
212
213 let (key, val) = x.expand_row(0);
214 assert_eq!(key, vec![1, 1, 1, 2]);
215 assert_eq!(val, 1.0);
216
217 let (key, val) = x.expand_row(4);
218 assert_eq!(key, vec![1, 2, 2, 1]);
219 assert_eq!(val, 5.0);
220
221 let (key, val) = x.expand_row(6);
222 assert_eq!(key, vec![2, 2, 2, 2]);
223 assert_eq!(val, 7.0);
224 }
225
226 #[test]
227 fn test_iterate() {
228 let x : CompressedSparseFiber<_, _> = sample_rows().into_iter().collect();
229
230 for (vec_out, val_out) in x {
231 let (_, value) = sample_rows().into_iter()
232 .find(|(vector, _)| vector == &vec_out).unwrap();
233 assert_eq!(value, val_out);
234 }
235 }
236
237 fn expected_sum(rows: &Rows<f32, i32>, col_index: usize) -> i32 {
238 let mut result = 0;
239 for (row, _) in rows {
240 result += row[col_index];
241 }
242 result
243 }
244
245 #[test]
246 fn test_sum() {
247 let x = sample_csf();
248 let rows = sample_rows();
249
250 assert_eq!(expected_sum(&rows, 0), x.sum_column(0));
251 assert_eq!(expected_sum(&rows, 1), x.sum_column(1));
252 assert_eq!(expected_sum(&rows, 2), x.sum_column(2));
253 assert_eq!(expected_sum(&rows, 3), x.sum_column(3));
254 }
255}