fst_map/
lib.rs

1use std::collections::BTreeMap;
2use std::marker::PhantomData;
3use std::mem;
4
5use fst::{Map, MapBuilder};
6use zerocopy::{FromBytes, Immutable, IntoBytes, KnownLayout, Unaligned};
7
8#[derive(Clone)]
9pub struct FstMap<V> {
10    buf: Vec<u8>,
11    _pd: PhantomData<V>,
12}
13
14#[derive(
15    Debug, Clone, Copy, PartialEq, Eq, FromBytes, KnownLayout, Immutable, IntoBytes, Unaligned,
16)]
17#[repr(C)]
18struct Header {
19    value_start_offset: [u8; 8],
20}
21
22impl Header {
23    fn value_start_offset(&self) -> usize {
24        usize::try_from(u64::from_le_bytes(self.value_start_offset)).unwrap()
25    }
26}
27
28/// Build an FstMap from a sorted map of byte-viewable keys and values.
29///
30/// Serializes the map as:
31/// [Header, Fst from key to offset, array of values]
32impl<K, V> From<&BTreeMap<K, V>> for FstMap<V>
33where
34    K: AsRef<[u8]>,
35    V: Immutable + IntoBytes + KnownLayout + FromBytes + Unaligned,
36{
37    fn from(map: &BTreeMap<K, V>) -> FstMap<V> {
38        let mut buf = vec![0; mem::size_of::<Header>()];
39        let mut map_builder = MapBuilder::new(&mut buf).unwrap();
40        let mut value_array: Vec<u8> = vec![];
41
42        for (i, (k, v)) in map.iter().enumerate() {
43            let value_bytes: &[u8] = v.as_bytes();
44            value_array.extend_from_slice(value_bytes);
45            map_builder.insert(k.as_ref(), i as u64).unwrap();
46        }
47
48        map_builder.finish().unwrap();
49
50        let value_start_offset = (buf.len() as u64).to_le_bytes();
51
52        let header = Header { value_start_offset };
53        buf[0..mem::size_of::<Header>()].copy_from_slice(header.as_bytes());
54
55        // take the bytes from value_array and append them to
56        // the end of the key buffer.
57        buf.append(&mut value_array);
58
59        FstMap {
60            buf: buf,
61            _pd: PhantomData,
62        }
63    }
64}
65
66impl<V> FstMap<V>
67where
68    V: Immutable + IntoBytes + KnownLayout + FromBytes + Unaligned,
69{
70    fn header(&self) -> &Header {
71        let header_len = size_of::<Header>();
72
73        let header: &Header = Header::ref_from_bytes(&self.buf[..header_len]).unwrap();
74
75        header
76    }
77
78    fn map(&self) -> Map<&[u8]> {
79        let header_len = size_of::<Header>();
80        let header = self.header();
81        let map = Map::new(&self.buf[header_len..header.value_start_offset()]).unwrap();
82        map
83    }
84
85    fn value_array(&self) -> &[u8] {
86        let header = self.header();
87        &self.buf[header.value_start_offset()..]
88    }
89
90    pub fn get<K>(&self, key: K) -> Option<&V>
91    where
92        K: AsRef<[u8]>,
93    {
94        let map = self.map();
95
96        // early-exit if our key map doesn't contain what we're looking for
97        let value_index_u64: u64 = map.get(key.as_ref())?;
98        let value_index: usize = value_index_u64.try_into().unwrap();
99
100        let value_start = value_index * mem::size_of::<V>();
101        let value_end = (value_index + 1) * mem::size_of::<V>();
102
103        let value_bytes = &self.value_array()[value_start..value_end];
104
105        let value = V::ref_from_bytes(value_bytes).unwrap();
106
107        Some(value)
108    }
109}
110
111#[test]
112fn smoke_fst_map() {
113    use rand::{thread_rng, Rng};
114
115    const N_TESTS: usize = 1024;
116
117    const TEST_SIZE: usize = 1024;
118
119    let mut rng = thread_rng();
120
121    let before = std::time::Instant::now();
122
123    for _ in 0..N_TESTS {
124        let model: BTreeMap<Vec<u8>, Header> = (0..TEST_SIZE)
125            .map(|_| {
126                let k: u64 = rng.gen();
127                let k_buf = k.to_le_bytes();
128
129                (
130                    k_buf.to_vec(),
131                    Header {
132                        value_start_offset: k.to_le_bytes(),
133                    },
134                )
135            })
136            .collect();
137
138        let map = FstMap::from(&model);
139
140        for (k_a, v_a) in &model {
141            let v_b = map.get(k_a).unwrap();
142            assert_eq!(v_a, v_b);
143        }
144    }
145
146    let wps = (N_TESTS * TEST_SIZE) as u128 * 1000 / before.elapsed().as_millis();
147    dbg!(wps);
148}