Skip to main content

compressed_sparse_fiber/
lib.rs

1
2use std::hash::Hash;
3use std::iter::Sum;
4use sequence_trie::SequenceTrie;
5
6type Row<T, U> = (Vec<U>, T);
7type Rows<T, U> = Vec<Row<T, U>>;
8
9
10#[derive(Debug, Clone)]
11struct IteratorState {
12    next_index: usize
13}
14
15impl Default for IteratorState {
16    fn default() -> IteratorState {
17        IteratorState {
18            next_index: 0
19        }
20    }
21}
22
23#[derive(Debug, Clone)]
24pub struct CompressedSparseFiber<T, U> {
25    pub fptr: Vec<Vec<usize>>,
26    pub fids: Vec<Vec<U>>,
27    pub vals: Vec<T>,
28    _state: IteratorState,
29}
30
31impl<'a, T: 'a, U> CompressedSparseFiber<T, U>
32    where U: Clone {
33    pub fn new(fptr: Vec<Vec<usize>>,
34           fids: Vec<Vec<U>>,
35           vals: Vec<T>) -> CompressedSparseFiber<T, U> {
36        CompressedSparseFiber { fptr, fids, vals, _state: IteratorState { next_index: 0 } }
37    }
38
39    pub fn expand_row(self: &CompressedSparseFiber<T, U>, index: usize) -> Row<T, U>
40        where T: Copy,
41              U: Copy + Default {
42        let depth = self.fids.len();
43
44        // The last row has the same length as vals
45        let mut result: Vec<U> = vec![Default::default(); depth];
46        let last = self.fids[depth - 1][index];
47        result[depth-1] = last;
48        let mut current_index = index;
49        for level in (0..depth - 1).rev() {
50            let j = self.fptr[level].partition_point(|v| v <= &current_index);
51            result[level] = self.fids[level][j - 1];
52            current_index = j-1;
53        }
54        (result, self.vals[index])
55    }
56
57    fn weights(self: &CompressedSparseFiber<T, U>, col_index: usize) -> Vec<usize> {
58
59        fn combine(current_weights: Vec<usize>, fptr_row: Vec<usize>) -> Vec<usize> {
60            let tail = &fptr_row[1..];
61            fptr_row.iter()
62                .zip(tail)
63                .map(|(&i,&j)| &current_weights[i..j])
64                .map(|x|x.iter().sum::<usize>())
65                .collect::<Vec<_>>()
66        }
67
68        let fptr_row = &self.fptr[self.fptr.len() - 1];
69        let tail = &fptr_row[1..];
70        let initial = fptr_row.iter()
71            .zip(tail)
72            .map(|(&x, &y)| y - x)
73            .collect::<Vec<_>>();
74
75        (&self.fptr[col_index..self.fptr.len() - 1])
76            .iter()
77            .rfold(initial, |w, f| combine(w, f.to_vec()))
78    }
79
80    pub fn sum_column(self: &CompressedSparseFiber<T, U>, col_index: usize) -> U
81        where T: Copy,
82              U: Sum<U> + Copy {
83        let row = self.fids[col_index].clone();
84
85        if col_index == self.fptr.len() {
86            row.into_iter().sum::<U>()
87        } else {
88            let w = self.weights(col_index);
89            // Repeat and take to avoid Mul<usize, Output = U> constraint
90            row.iter()
91                .zip(w)
92                .map(|(&x,y)| std::iter::repeat(x).take(y).sum())
93                .sum::<U>()
94        }
95    }
96}
97
98impl<T, U> From<&SequenceTrie<U, T>> for CompressedSparseFiber<T, U>
99    where T: Copy,
100          U: Clone + Eq + Hash + Ord + Copy {
101   fn from(trie: &SequenceTrie<U, T>) -> Self {
102        let mut i = vec![trie];
103        let mut fids: Vec<Vec<U>> = vec![];
104        let mut fptr = vec![];
105        let mut vals: Vec<T> = vec![];
106        let mut initial = true;
107
108        while !i.is_empty() {
109            let mut offset = 0;
110            let mut fptr_row = vec![0];
111            let (keys, children): (Vec<&U>, Vec<_>) = i.into_iter()
112                .flat_map(|y| {
113                    let mut x = y.children_with_keys();
114                    offset += x.len();
115                    fptr_row.push(offset);
116                    x.sort_by(|(a, _), (b, _)| a.cmp(b));
117                    x
118                })
119                .unzip();
120            if !keys.is_empty() {
121                let row = keys.into_iter().map(|&f| f).collect();
122                fids.push(row);
123                if !initial {
124                    fptr.push(fptr_row);
125                } else {
126                    initial = false;
127                }
128            }
129
130            let mut values: Vec<T> = children.iter()
131                .filter_map(|x| x.value())
132                .map(|&x| x)
133                .collect::<Vec<_>>();
134            vals.append(&mut values);
135            i = children
136        }
137        CompressedSparseFiber::new(fptr, fids, vals)
138    }
139}
140
141impl<T, U> Iterator for CompressedSparseFiber<T, U>
142    where T: Copy,
143          U: Clone + Copy + Default + Default {
144    type Item = Row<T, U>;
145
146    fn next(&mut self) -> Option<Row<T, U>> {
147        self._state.next_index += 1;
148        if self._state.next_index < self.vals.len() {
149            Some(self.expand_row(self._state.next_index - 1))
150        } else {
151            None
152        }
153    }
154}
155
156impl<T, U> std::iter::FromIterator<Row<T, U>> for CompressedSparseFiber<T, U>
157    where T: Copy,
158          U: Clone + Copy + Hash + Ord {
159    fn from_iter<I: IntoIterator<Item = Row<T, U>>>(iter: I) -> Self {
160        let mut trie: SequenceTrie<U, T> = SequenceTrie::new();
161        for (row, x) in iter {
162            trie.insert(&row, x);
163        }
164        CompressedSparseFiber::<T, U>::from(&trie)
165    }
166}
167
168#[cfg(test)]
169mod tests {
170    use super::*;
171
172    fn sample_rows() -> Rows<f32, i32> {
173        vec![
174            (vec![1, 1, 1, 2], 1.0),
175            (vec![1, 1, 1, 3], 2.0),
176            (vec![1, 2, 1, 1], 3.0),
177            (vec![1, 2, 1, 3], 4.0),
178            (vec![1, 2, 2, 1], 5.0),
179            (vec![2, 2, 2, 1], 6.0),
180            (vec![2, 2, 2, 2], 7.0),
181            (vec![2, 2, 2, 3], 8.0),
182        ]
183    }
184
185    fn sample_csf() -> CompressedSparseFiber<f32, i32> {
186        CompressedSparseFiber::new(
187            vec![vec![0, 2, 3], vec![0, 1, 3, 4], vec![0, 2, 4, 5, 8]],
188            vec![vec![1, 2], vec![1, 2, 2], vec![1, 1, 2, 2], vec![2, 3, 1, 3, 1, 1, 2, 3]],
189            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0],
190        )
191    }
192
193    #[test]
194    fn test_build() {
195        let x : CompressedSparseFiber<_, _> = sample_rows().into_iter().collect();
196        assert_eq!(x.fids[0], vec![1, 2]);
197        assert_eq!(x.fids[1], vec![1, 2, 2]);
198        assert_eq!(x.fids[2], vec![1, 1, 2, 2]);
199        assert_eq!(x.fids[3], vec![2, 3, 1, 3, 1, 1, 2, 3]);
200
201        assert_eq!(x.fptr[0], vec![0, 2, 3]);
202        assert_eq!(x.fptr[1], vec![0, 1, 3, 4]);
203        assert_eq!(x.fptr[2], vec![0, 2, 4, 5, 8]);
204
205        assert_eq!(x.vals, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
206    }
207
208
209    #[test]
210    fn test_expand_row() {
211        let x = sample_csf();
212
213        let (key, val) = x.expand_row(0);
214        assert_eq!(key, vec![1, 1, 1, 2]);
215        assert_eq!(val, 1.0);
216
217        let (key, val) = x.expand_row(4);
218        assert_eq!(key, vec![1, 2, 2, 1]);
219        assert_eq!(val, 5.0);
220
221        let (key, val) = x.expand_row(6);
222        assert_eq!(key, vec![2, 2, 2, 2]);
223        assert_eq!(val, 7.0);
224    }
225
226    #[test]
227    fn test_iterate() {
228        let x : CompressedSparseFiber<_, _> = sample_rows().into_iter().collect();
229
230        for (vec_out, val_out) in x {
231            let (_, value) = sample_rows().into_iter()
232                .find(|(vector, _)| vector == &vec_out).unwrap();
233            assert_eq!(value, val_out);
234        }
235    }
236
237    fn expected_sum(rows: &Rows<f32, i32>, col_index: usize) -> i32 {
238        let mut result = 0;
239        for (row, _) in rows {
240            result += row[col_index];
241        }
242        result
243    }
244
245    #[test]
246    fn test_sum() {
247        let x = sample_csf();
248        let rows = sample_rows();
249
250        assert_eq!(expected_sum(&rows, 0), x.sum_column(0));
251        assert_eq!(expected_sum(&rows, 1), x.sum_column(1));
252        assert_eq!(expected_sum(&rows, 2), x.sum_column(2));
253        assert_eq!(expected_sum(&rows, 3), x.sum_column(3));
254    }
255}