fst_no_std/raw/
ops.rs

1#[cfg(feature = "alloc")]
2use crate::raw::Output;
3#[cfg(feature = "alloc")]
4use crate::stream::{IntoStreamer, Streamer};
5#[cfg(feature = "alloc")]
6use alloc::{boxed::Box, collections::BinaryHeap};
7#[cfg(feature = "alloc")]
8use alloc::{vec, vec::Vec};
9#[cfg(feature = "alloc")]
10use core::cmp;
11#[cfg(feature = "alloc")]
12use core::iter::FromIterator;
13
14/// Permits stream operations to be hetergeneous with respect to streams.
15#[cfg(feature = "alloc")]
16type BoxedStream<'f> =
17    Box<dyn for<'a> Streamer<'a, Item = (&'a [u8], Output)> + 'f>;
18
19/// A value indexed by a stream.
20///
21/// Indexed values are used to indicate the presence of a key in multiple
22/// streams during a set operation. Namely, the index corresponds to the stream
23/// (by the order in which it was added to the operation, starting at `0`)
24/// and the value corresponds to the value associated with a particular key
25/// in that stream.
26#[derive(Copy, Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
27pub struct IndexedValue {
28    /// The index of the stream that produced this value (starting at `0`).
29    pub index: usize,
30    /// The value.
31    pub value: u64,
32}
33
34/// A builder for collecting fst streams on which to perform set operations
35/// on the keys of fsts.
36///
37/// Set operations include intersection, union, difference and symmetric
38/// difference. The result of each set operation is itself a stream that emits
39/// pairs of keys and a sequence of each occurrence of that key in the
40/// participating streams. This information allows one to perform set
41/// operations on fsts and customize how conflicting output values are handled.
42///
43/// All set operations work efficiently on an arbitrary number of
44/// streams with memory proportional to the number of streams.
45///
46/// The algorithmic complexity of all set operations is `O(n1 + n2 + n3 + ...)`
47/// where `n1, n2, n3, ...` correspond to the number of elements in each
48/// stream.
49///
50/// The `'f` lifetime parameter refers to the lifetime of the underlying set.
51#[cfg(feature = "alloc")]
52pub struct OpBuilder<'f> {
53    streams: Vec<BoxedStream<'f>>,
54}
55
56#[cfg(feature = "alloc")]
57impl<'f> Default for OpBuilder<'f> {
58    fn default() -> Self {
59        Self::new()
60    }
61}
62
63#[cfg(feature = "alloc")]
64impl<'f> OpBuilder<'f> {
65    /// Create a new set operation builder.
66    #[inline]
67    #[must_use]
68    pub fn new() -> OpBuilder<'f> {
69        OpBuilder { streams: vec![] }
70    }
71
72    /// Add a stream to this set operation.
73    ///
74    /// This is useful for a chaining style pattern, e.g.,
75    /// `builder.add(stream1).add(stream2).union()`.
76    ///
77    /// The stream must emit a lexicographically ordered sequence of key-value
78    /// pairs.
79    pub fn add<I, S>(mut self, stream: I) -> OpBuilder<'f>
80    where
81        I: for<'a> IntoStreamer<'a, Into = S, Item = (&'a [u8], Output)>,
82        S: 'f + for<'a> Streamer<'a, Item = (&'a [u8], Output)>,
83    {
84        self.push(stream);
85        self
86    }
87
88    /// Add a stream to this set operation.
89    ///
90    /// The stream must emit a lexicographically ordered sequence of key-value
91    /// pairs.
92    pub fn push<I, S>(&mut self, stream: I)
93    where
94        I: for<'a> IntoStreamer<'a, Into = S, Item = (&'a [u8], Output)>,
95        S: 'f + for<'a> Streamer<'a, Item = (&'a [u8], Output)>,
96    {
97        self.streams.push(Box::new(stream.into_stream()));
98    }
99
100    /// Performs a union operation on all streams that have been added.
101    ///
102    /// Note that this returns a stream of `(&[u8], &[IndexedValue])`. The
103    /// first element of the tuple is the byte string key. The second element
104    /// of the tuple is a list of all occurrences of that key in participating
105    /// streams. The `IndexedValue` contains an index and the value associated
106    /// with that key in that stream. The index uniquely identifies each
107    /// stream, which is an integer that is auto-incremented when a stream
108    /// is added to this operation (starting at `0`).
109    #[inline]
110    #[must_use]
111    pub fn union(self) -> Union<'f> {
112        Union {
113            heap: StreamHeap::new(self.streams),
114            outs: vec![],
115            cur_slot: None,
116        }
117    }
118
119    /// Performs an intersection operation on all streams that have been added.
120    ///
121    /// Note that this returns a stream of `(&[u8], &[IndexedValue])`. The
122    /// first element of the tuple is the byte string key. The second element
123    /// of the tuple is a list of all occurrences of that key in participating
124    /// streams. The `IndexedValue` contains an index and the value associated
125    /// with that key in that stream. The index uniquely identifies each
126    /// stream, which is an integer that is auto-incremented when a stream
127    /// is added to this operation (starting at `0`).
128    #[inline]
129    #[must_use]
130    pub fn intersection(self) -> Intersection<'f> {
131        Intersection {
132            heap: StreamHeap::new(self.streams),
133            outs: vec![],
134            cur_slot: None,
135        }
136    }
137
138    /// Performs a difference operation with respect to the first stream added.
139    /// That is, this returns a stream of all elements in the first stream
140    /// that don't exist in any other stream that has been added.
141    ///
142    /// Note that this returns a stream of `(&[u8], &[IndexedValue])`. The
143    /// first element of the tuple is the byte string key. The second element
144    /// of the tuple is a list of all occurrences of that key in participating
145    /// streams. The `IndexedValue` contains an index and the value associated
146    /// with that key in that stream. The index uniquely identifies each
147    /// stream, which is an integer that is auto-incremented when a stream
148    /// is added to this operation (starting at `0`).
149    ///
150    /// The interface is the same for all the operations, but due to the nature
151    /// of `difference`, each yielded key contains exactly one `IndexValue` with
152    /// `index` set to 0.
153    #[inline]
154    #[must_use]
155    pub fn difference(mut self) -> Difference<'f> {
156        let first = self.streams.swap_remove(0);
157        Difference {
158            set: first,
159            key: vec![],
160            heap: StreamHeap::new(self.streams),
161            outs: vec![],
162        }
163    }
164
165    /// Performs a symmetric difference operation on all of the streams that
166    /// have been added.
167    ///
168    /// When there are only two streams, then the keys returned correspond to
169    /// keys that are in either stream but *not* in both streams.
170    ///
171    /// More generally, for any number of streams, keys that occur in an odd
172    /// number of streams are returned.
173    ///
174    /// Note that this returns a stream of `(&[u8], &[IndexedValue])`. The
175    /// first element of the tuple is the byte string key. The second element
176    /// of the tuple is a list of all occurrences of that key in participating
177    /// streams. The `IndexedValue` contains an index and the value associated
178    /// with that key in that stream. The index uniquely identifies each
179    /// stream, which is an integer that is auto-incremented when a stream
180    /// is added to this operation (starting at `0`).
181    #[inline]
182    #[must_use]
183    pub fn symmetric_difference(self) -> SymmetricDifference<'f> {
184        SymmetricDifference {
185            heap: StreamHeap::new(self.streams),
186            outs: vec![],
187            cur_slot: None,
188        }
189    }
190}
191
192#[cfg(feature = "alloc")]
193impl<'f, I, S> Extend<I> for OpBuilder<'f>
194where
195    I: for<'a> IntoStreamer<'a, Into = S, Item = (&'a [u8], Output)>,
196    S: 'f + for<'a> Streamer<'a, Item = (&'a [u8], Output)>,
197{
198    fn extend<T>(&mut self, it: T)
199    where
200        T: IntoIterator<Item = I>,
201    {
202        for stream in it {
203            self.push(stream);
204        }
205    }
206}
207
208#[cfg(feature = "alloc")]
209impl<'f, I, S> FromIterator<I> for OpBuilder<'f>
210where
211    I: for<'a> IntoStreamer<'a, Into = S, Item = (&'a [u8], Output)>,
212    S: 'f + for<'a> Streamer<'a, Item = (&'a [u8], Output)>,
213{
214    fn from_iter<T>(it: T) -> OpBuilder<'f>
215    where
216        T: IntoIterator<Item = I>,
217    {
218        let mut op = OpBuilder::new();
219        op.extend(it);
220        op
221    }
222}
223
224/// A stream of set union over multiple fst streams in lexicographic order.
225///
226/// The `'f` lifetime parameter refers to the lifetime of the underlying map.
227#[cfg(feature = "alloc")]
228pub struct Union<'f> {
229    heap: StreamHeap<'f>,
230    outs: Vec<IndexedValue>,
231    cur_slot: Option<Slot>,
232}
233
234#[cfg(feature = "alloc")]
235impl<'a, 'f> Streamer<'a> for Union<'f> {
236    type Item = (&'a [u8], &'a [IndexedValue]);
237
238    fn next(&'a mut self) -> Option<(&'a [u8], &'a [IndexedValue])> {
239        if let Some(slot) = self.cur_slot.take() {
240            self.heap.refill(slot);
241        }
242        let slot = match self.heap.pop() {
243            None => return None,
244            Some(slot) => {
245                self.cur_slot = Some(slot);
246                self.cur_slot.as_ref().unwrap()
247            }
248        };
249        self.outs.clear();
250        self.outs.push(slot.indexed_value());
251        while let Some(slot2) = self.heap.pop_if_equal(slot.input()) {
252            self.outs.push(slot2.indexed_value());
253            self.heap.refill(slot2);
254        }
255        Some((slot.input(), &self.outs))
256    }
257}
258
259/// A stream of set intersection over multiple fst streams in lexicographic
260/// order.
261///
262/// The `'f` lifetime parameter refers to the lifetime of the underlying fst.
263#[cfg(feature = "alloc")]
264pub struct Intersection<'f> {
265    heap: StreamHeap<'f>,
266    outs: Vec<IndexedValue>,
267    cur_slot: Option<Slot>,
268}
269
270#[cfg(feature = "alloc")]
271impl<'a, 'f> Streamer<'a> for Intersection<'f> {
272    type Item = (&'a [u8], &'a [IndexedValue]);
273
274    fn next(&'a mut self) -> Option<(&'a [u8], &'a [IndexedValue])> {
275        if let Some(slot) = self.cur_slot.take() {
276            self.heap.refill(slot);
277        }
278        loop {
279            let slot = match self.heap.pop() {
280                None => return None,
281                Some(slot) => slot,
282            };
283            self.outs.clear();
284            self.outs.push(slot.indexed_value());
285            let mut popped: usize = 1;
286            while let Some(slot2) = self.heap.pop_if_equal(slot.input()) {
287                self.outs.push(slot2.indexed_value());
288                self.heap.refill(slot2);
289                popped += 1;
290            }
291            if popped < self.heap.num_slots() {
292                self.heap.refill(slot);
293            } else {
294                self.cur_slot = Some(slot);
295                let key = self.cur_slot.as_ref().unwrap().input();
296                return Some((key, &self.outs));
297            }
298        }
299    }
300}
301
302/// A stream of set difference over multiple fst streams in lexicographic
303/// order.
304///
305/// The difference operation is taken with respect to the first stream and the
306/// rest of the streams. i.e., All elements in the first stream that do not
307/// appear in any other streams.
308///
309/// The `'f` lifetime parameter refers to the lifetime of the underlying fst.
310#[cfg(feature = "alloc")]
311pub struct Difference<'f> {
312    set: BoxedStream<'f>,
313    key: Vec<u8>,
314    heap: StreamHeap<'f>,
315    outs: Vec<IndexedValue>,
316}
317
318#[cfg(feature = "alloc")]
319impl<'a, 'f> Streamer<'a> for Difference<'f> {
320    type Item = (&'a [u8], &'a [IndexedValue]);
321
322    fn next(&'a mut self) -> Option<(&'a [u8], &'a [IndexedValue])> {
323        loop {
324            match self.set.next() {
325                None => return None,
326                Some((key, out)) => {
327                    self.key.clear();
328                    self.key.extend(key);
329                    self.outs.clear();
330                    self.outs
331                        .push(IndexedValue { index: 0, value: out.value() });
332                }
333            };
334            let mut unique = true;
335            while let Some(slot) = self.heap.pop_if_le(&self.key) {
336                if slot.input() == &*self.key {
337                    unique = false;
338                }
339                self.heap.refill(slot);
340            }
341            if unique {
342                return Some((&self.key, &self.outs));
343            }
344        }
345    }
346}
347
348/// A stream of set symmetric difference over multiple fst streams in
349/// lexicographic order.
350///
351/// The `'f` lifetime parameter refers to the lifetime of the underlying fst.
352#[cfg(feature = "alloc")]
353pub struct SymmetricDifference<'f> {
354    heap: StreamHeap<'f>,
355    outs: Vec<IndexedValue>,
356    cur_slot: Option<Slot>,
357}
358
359#[cfg(feature = "alloc")]
360impl<'a, 'f> Streamer<'a> for SymmetricDifference<'f> {
361    type Item = (&'a [u8], &'a [IndexedValue]);
362
363    fn next(&'a mut self) -> Option<(&'a [u8], &'a [IndexedValue])> {
364        if let Some(slot) = self.cur_slot.take() {
365            self.heap.refill(slot);
366        }
367        loop {
368            let slot = match self.heap.pop() {
369                None => return None,
370                Some(slot) => slot,
371            };
372            self.outs.clear();
373            self.outs.push(slot.indexed_value());
374            let mut popped: usize = 1;
375            while let Some(slot2) = self.heap.pop_if_equal(slot.input()) {
376                self.outs.push(slot2.indexed_value());
377                self.heap.refill(slot2);
378                popped += 1;
379            }
380            // This key is in the symmetric difference if and only if it
381            // appears in an odd number of sets.
382            if popped % 2 == 0 {
383                self.heap.refill(slot);
384            } else {
385                self.cur_slot = Some(slot);
386                let key = self.cur_slot.as_ref().unwrap().input();
387                return Some((key, &self.outs));
388            }
389        }
390    }
391}
392
393#[cfg(feature = "alloc")]
394struct StreamHeap<'f> {
395    rdrs: Vec<BoxedStream<'f>>,
396    heap: BinaryHeap<Slot>,
397}
398
399#[cfg(feature = "alloc")]
400impl<'f> StreamHeap<'f> {
401    fn new(streams: Vec<BoxedStream<'f>>) -> StreamHeap<'f> {
402        let mut u = StreamHeap { rdrs: streams, heap: BinaryHeap::new() };
403        for i in 0..u.rdrs.len() {
404            u.refill(Slot::new(i));
405        }
406        u
407    }
408
409    fn pop(&mut self) -> Option<Slot> {
410        self.heap.pop()
411    }
412
413    fn peek_is_duplicate(&self, key: &[u8]) -> bool {
414        self.heap.peek().is_some_and(|s| s.input() == key)
415    }
416
417    fn pop_if_equal(&mut self, key: &[u8]) -> Option<Slot> {
418        if self.peek_is_duplicate(key) {
419            self.pop()
420        } else {
421            None
422        }
423    }
424
425    fn pop_if_le(&mut self, key: &[u8]) -> Option<Slot> {
426        if self.heap.peek().is_some_and(|s| s.input() <= key) {
427            self.pop()
428        } else {
429            None
430        }
431    }
432
433    fn num_slots(&self) -> usize {
434        self.rdrs.len()
435    }
436
437    fn refill(&mut self, mut slot: Slot) {
438        if let Some((input, output)) = self.rdrs[slot.idx].next() {
439            slot.set_input(input);
440            slot.set_output(output);
441            self.heap.push(slot);
442        }
443    }
444}
445
446#[cfg(feature = "alloc")]
447#[derive(Debug, Eq, PartialEq)]
448struct Slot {
449    idx: usize,
450    input: Vec<u8>,
451    output: Output,
452}
453
454#[cfg(feature = "alloc")]
455impl Slot {
456    fn new(rdr_idx: usize) -> Slot {
457        Slot {
458            idx: rdr_idx,
459            input: Vec::with_capacity(64),
460            output: Output::zero(),
461        }
462    }
463
464    fn indexed_value(&self) -> IndexedValue {
465        IndexedValue { index: self.idx, value: self.output.value() }
466    }
467
468    fn input(&self) -> &[u8] {
469        &self.input
470    }
471
472    fn set_input(&mut self, input: &[u8]) {
473        self.input.clear();
474        self.input.extend(input);
475    }
476
477    fn set_output(&mut self, output: Output) {
478        self.output = output;
479    }
480}
481
482#[cfg(feature = "alloc")]
483#[allow(clippy::non_canonical_partial_ord_impl)]
484impl PartialOrd for Slot {
485    fn partial_cmp(&self, other: &Slot) -> Option<cmp::Ordering> {
486        (&self.input, self.output)
487            .partial_cmp(&(&other.input, other.output))
488            .map(core::cmp::Ordering::reverse)
489    }
490}
491
492#[cfg(feature = "alloc")]
493impl Ord for Slot {
494    fn cmp(&self, other: &Slot) -> cmp::Ordering {
495        self.partial_cmp(other).unwrap()
496    }
497}
498
499#[cfg(test)]
500mod tests {
501    use crate::raw::tests::{fst_map, fst_set};
502    use crate::raw::Fst;
503    use crate::stream::{IntoStreamer, Streamer};
504
505    use super::OpBuilder;
506
507    fn s(string: &str) -> String {
508        string.to_owned()
509    }
510
511    macro_rules! create_set_op {
512        ($name:ident, $op:ident) => {
513            fn $name(sets: Vec<Vec<&str>>) -> Vec<String> {
514                let fsts: Vec<Fst<_>> =
515                    sets.into_iter().map(fst_set).collect();
516                let op: OpBuilder = fsts.iter().collect();
517                let mut stream = op.$op().into_stream();
518                let mut keys = vec![];
519                while let Some((key, _)) = stream.next() {
520                    keys.push(String::from_utf8(key.to_vec()).unwrap());
521                }
522                keys
523            }
524        };
525    }
526
527    macro_rules! create_map_op {
528        ($name:ident, $op:ident) => {
529            fn $name(sets: Vec<Vec<(&str, u64)>>) -> Vec<(String, u64)> {
530                let fsts: Vec<Fst<_>> =
531                    sets.into_iter().map(fst_map).collect();
532                let op: OpBuilder = fsts.iter().collect();
533                let mut stream = op.$op().into_stream();
534                let mut keys = vec![];
535                while let Some((key, outs)) = stream.next() {
536                    let merged = outs.iter().fold(0, |a, b| a + b.value);
537                    let s = String::from_utf8(key.to_vec()).unwrap();
538                    keys.push((s, merged));
539                }
540                keys
541            }
542        };
543    }
544
545    create_set_op!(fst_union, union);
546    create_set_op!(fst_intersection, intersection);
547    create_set_op!(fst_symmetric_difference, symmetric_difference);
548    create_set_op!(fst_difference, difference);
549    create_map_op!(fst_union_map, union);
550    create_map_op!(fst_intersection_map, intersection);
551    create_map_op!(fst_symmetric_difference_map, symmetric_difference);
552    create_map_op!(fst_difference_map, difference);
553
554    #[test]
555    fn union_set() {
556        let v = fst_union(vec![vec!["a", "b", "c"], vec!["x", "y", "z"]]);
557        assert_eq!(v, vec!["a", "b", "c", "x", "y", "z"]);
558    }
559
560    #[test]
561    fn union_set_dupes() {
562        let v = fst_union(vec![vec!["aa", "b", "cc"], vec!["b", "cc", "z"]]);
563        assert_eq!(v, vec!["aa", "b", "cc", "z"]);
564    }
565
566    #[test]
567    fn union_map() {
568        let v = fst_union_map(vec![
569            vec![("a", 1), ("b", 2), ("c", 3)],
570            vec![("x", 1), ("y", 2), ("z", 3)],
571        ]);
572        assert_eq!(
573            v,
574            vec![
575                (s("a"), 1),
576                (s("b"), 2),
577                (s("c"), 3),
578                (s("x"), 1),
579                (s("y"), 2),
580                (s("z"), 3),
581            ]
582        );
583    }
584
585    #[test]
586    fn union_map_dupes() {
587        let v = fst_union_map(vec![
588            vec![("aa", 1), ("b", 2), ("cc", 3)],
589            vec![("b", 1), ("cc", 2), ("z", 3)],
590            vec![("b", 1)],
591        ]);
592        assert_eq!(
593            v,
594            vec![(s("aa"), 1), (s("b"), 4), (s("cc"), 5), (s("z"), 3),]
595        );
596    }
597
598    #[test]
599    fn intersect_set() {
600        let v =
601            fst_intersection(vec![vec!["a", "b", "c"], vec!["x", "y", "z"]]);
602        assert_eq!(v, Vec::<String>::new());
603    }
604
605    #[test]
606    fn intersect_set_dupes() {
607        let v = fst_intersection(vec![
608            vec!["aa", "b", "cc"],
609            vec!["b", "cc", "z"],
610        ]);
611        assert_eq!(v, vec!["b", "cc"]);
612    }
613
614    #[test]
615    fn intersect_map() {
616        let v = fst_intersection_map(vec![
617            vec![("a", 1), ("b", 2), ("c", 3)],
618            vec![("x", 1), ("y", 2), ("z", 3)],
619        ]);
620        assert_eq!(v, Vec::<(String, u64)>::new());
621    }
622
623    #[test]
624    fn intersect_map_dupes() {
625        let v = fst_intersection_map(vec![
626            vec![("aa", 1), ("b", 2), ("cc", 3)],
627            vec![("b", 1), ("cc", 2), ("z", 3)],
628            vec![("b", 1)],
629        ]);
630        assert_eq!(v, vec![(s("b"), 4)]);
631    }
632
633    #[test]
634    fn symmetric_difference() {
635        let v = fst_symmetric_difference(vec![
636            vec!["a", "b", "c"],
637            vec!["a", "b"],
638            vec!["a"],
639        ]);
640        assert_eq!(v, vec!["a", "c"]);
641    }
642
643    #[test]
644    fn symmetric_difference_map() {
645        let v = fst_symmetric_difference_map(vec![
646            vec![("a", 1), ("b", 2), ("c", 3)],
647            vec![("a", 1), ("b", 2)],
648            vec![("a", 1)],
649        ]);
650        assert_eq!(v, vec![(s("a"), 3), (s("c"), 3)]);
651    }
652
653    #[test]
654    fn difference() {
655        let v = fst_difference(vec![
656            vec!["a", "b", "c"],
657            vec!["a", "b"],
658            vec!["a"],
659        ]);
660        assert_eq!(v, vec!["c"]);
661    }
662
663    #[test]
664    fn difference2() {
665        // Regression test: https://github.com/BurntSushi/fst/issues/19
666        let v = fst_difference(vec![vec!["a", "c"], vec!["b", "c"]]);
667        assert_eq!(v, vec!["a"]);
668        let v = fst_difference(vec![vec!["bar", "foo"], vec!["baz", "foo"]]);
669        assert_eq!(v, vec!["bar"]);
670    }
671
672    #[test]
673    fn difference_map() {
674        let v = fst_difference_map(vec![
675            vec![("a", 1), ("b", 2), ("c", 3)],
676            vec![("a", 1), ("b", 2)],
677            vec![("a", 1)],
678        ]);
679        assert_eq!(v, vec![(s("c"), 3)]);
680    }
681}