Skip to main content

irox_tools/codec/
code_dictionary.rs

1// SPDX-License-Identifier: MIT
2// Copyright 2025 IROX Contributors
3//
4extern crate alloc;
5
6use crate::codec::{DecodeGroupVarintFrom, EncodeGroupVarintTo};
7use alloc::sync::Arc;
8use core::hash::Hash;
9use core::ops::DerefMut;
10use core::sync::atomic::{AtomicU64, Ordering};
11use irox_bits::{
12    Bits, BitsWrapper, Error, MutBits, ReadFromBEBits, SharedROCounter, WriteToBEBits,
13};
14use std::collections::HashMap;
15
16///
17/// Simple auto-incrementing dictionary indexed by hash value.  Creates new codes
18/// for new values when first seen.  Increments in order of query.
19#[derive(Debug)]
20pub struct CodeDictionary<T: Eq + Hash> {
21    dictionary: HashMap<T, u32>,
22    inverse: HashMap<u32, T>,
23    counter: Arc<AtomicU64>,
24}
25impl<T: Eq + Hash> Default for CodeDictionary<T> {
26    fn default() -> CodeDictionary<T> {
27        Self {
28            dictionary: HashMap::new(),
29            inverse: HashMap::new(),
30            counter: Arc::new(AtomicU64::new(1)),
31        }
32    }
33}
34impl<T: Eq + Hash + Default> CodeDictionary<T> {
35    pub fn new() -> CodeDictionary<T> {
36        Default::default()
37    }
38}
39impl<T: Eq + Hash + Clone> CodeDictionary<T> {
40    ///
41    /// Looks up a code for a specific value
42    pub fn lookup_value(&self, value: &T) -> Option<u32> {
43        self.dictionary.get(value).copied()
44    }
45
46    ///
47    /// Returns the code for the specified value and if a new code was generated
48    /// for the value (first time seeing the value).
49    pub fn get_code(&mut self, value: &T) -> (bool, u32) {
50        let mut new_code = false;
51        let code = self.dictionary.entry(value.clone()).or_insert_with(|| {
52            new_code = true;
53            let ctr = self.counter.fetch_add(1, Ordering::SeqCst) as u32;
54            self.inverse.insert(ctr, value.clone());
55            ctr
56        });
57        (new_code, *code)
58    }
59    pub fn read_code<F: FnOnce() -> Result<T, E>, E>(
60        &mut self,
61        code: u32,
62        value_producer: F,
63    ) -> Result<T, E> {
64        if let Some(val) = self.inverse.get(&code) {
65            return Ok(val.clone());
66        }
67        let val = value_producer()?;
68        self.inverse.insert(code, val.clone());
69        self.dictionary.insert(val.clone(), code);
70        Ok(val)
71    }
72}
73
74///
75/// Converts values into codes using [`CodeDictionary`], then uses [`GroupVarintCodeEncoder`]
76/// to encode a sequence of 4 codes to the stream.  If a code hasn't been written before,
77/// we immediately follow the group varint block with the specific coded value(s) (up to 4).
78///
79/// Block format: `[control byte][4..=16 code bytes][0..=4 code-mapped-values]`
80pub struct GroupVarintCodeEncoder<'a, T: Eq + Hash, B: MutBits> {
81    inner: BitsWrapper<'a, B>,
82    dict: CodeDictionary<T>,
83}
84impl<'a, T: Eq + Hash + Default, B: MutBits> GroupVarintCodeEncoder<'a, T, B> {
85    pub fn new(inner: BitsWrapper<'a, B>) -> Self {
86        Self {
87            inner,
88            dict: CodeDictionary::new(),
89        }
90    }
91}
92impl<T: Eq + Hash + Default + Clone + WriteToBEBits, B: MutBits> GroupVarintCodeEncoder<'_, T, B> {
93    pub fn encode_4(&mut self, vals: &[T; 4]) -> Result<usize, Error> {
94        let [a, b, c, d] = vals;
95        let ea = self.dict.get_code(a);
96        let eb = self.dict.get_code(b);
97        let ec = self.dict.get_code(c);
98        let ed = self.dict.get_code(d);
99
100        let codes = [ea.1, eb.1, ec.1, ed.1];
101        let mut used = codes.encode_group_varint_to(self.inner.deref_mut())?;
102        if ea.0 {
103            used += a.write_be_to(self.inner.deref_mut())?;
104        }
105        if eb.0 {
106            used += b.write_be_to(self.inner.deref_mut())?;
107        }
108        if ec.0 {
109            used += c.write_be_to(self.inner.deref_mut())?;
110        }
111        if ed.0 {
112            used += d.write_be_to(self.inner.deref_mut())?;
113        }
114
115        Ok(used)
116    }
117
118    pub fn counter(&self) -> SharedROCounter {
119        SharedROCounter::new(self.dict.counter.clone())
120    }
121
122    pub fn flush(&mut self) -> Result<(), Error> {
123        self.inner.flush()
124    }
125}
126
127///
128/// Wraps [`CodeDictionary`] in an `Arc<RwLock>>` for shared access.
129#[derive(Debug, Default, Clone)]
130pub struct SharedCodeDictionary<T: Eq + Hash> {
131    inner: Arc<std::sync::RwLock<CodeDictionary<T>>>,
132}
133impl<T: Eq + Hash + Default> SharedCodeDictionary<T> {
134    pub fn new() -> SharedCodeDictionary<T> {
135        Default::default()
136    }
137}
138impl<T: Eq + Hash + Copy + Default> SharedCodeDictionary<T> {
139    ///
140    /// Looks up a code for a specific value
141    pub fn lookup_value(&self, value: &T) -> Option<u32> {
142        if let Ok(lock) = self.inner.read() {
143            if let Some(code) = lock.lookup_value(value) {
144                return Some(code);
145            }
146        }
147
148        None
149    }
150
151    ///
152    /// Returns the code for the specified value and if a new code was generated
153    /// for the value (first time seeing the value).
154    pub fn get_code(&mut self, value: &T) -> (bool, u32) {
155        if let Ok(lock) = self.inner.read() {
156            if let Some(code) = lock.lookup_value(value) {
157                return (false, code);
158            }
159        }
160        if let Ok(mut lock) = self.inner.write() {
161            return lock.get_code(value);
162        }
163        (false, 0)
164    }
165    pub fn read_code<F: FnOnce() -> Result<T, E>, E>(
166        &mut self,
167        code: u32,
168        value_producer: F,
169    ) -> Result<T, E> {
170        if let Ok(lock) = self.inner.read() {
171            if let Some(val) = lock.inverse.get(&code) {
172                return Ok(*val);
173            }
174        }
175        if let Ok(mut lock) = self.inner.write() {
176            let val = value_producer()?;
177            lock.inverse.insert(code, val);
178            lock.dictionary.insert(val, code);
179            return Ok(val);
180        }
181        Ok(T::default())
182    }
183}
184
185///
186/// Converts values into codes using [`CodeDictionary`], then uses [`GroupVarintCodeEncoder`]
187/// to encode a sequence of 4 codes to the stream.  If a code hasn't been written before,
188/// we immediately follow the group varint block with the specific coded value(s) (up to 4).
189///
190/// Block format: `[control byte][4..=16 code bytes][0..=4 code-mapped-values]`
191///
192/// Must provide a shared dictionary to use this struct.  Decoding MUST be performed in the
193/// exact same order as encoding or else the mapped values won't align correctly.
194pub struct SharedGroupVarintCodeEncoder<'a, T: Eq + Hash, B: MutBits> {
195    inner: BitsWrapper<'a, B>,
196    dict: SharedCodeDictionary<T>,
197}
198impl<'a, T: Eq + Hash + Default, B: MutBits> SharedGroupVarintCodeEncoder<'a, T, B> {
199    pub fn new(inner: BitsWrapper<'a, B>, dict: SharedCodeDictionary<T>) -> Self {
200        Self { inner, dict }
201    }
202}
203impl<T: Eq + Hash + Default + Copy + WriteToBEBits, B: MutBits>
204    SharedGroupVarintCodeEncoder<'_, T, B>
205{
206    pub fn encode_4(&mut self, vals: &[T; 4]) -> Result<usize, Error> {
207        let [a, b, c, d] = vals;
208        let ea = self.dict.get_code(a);
209        let eb = self.dict.get_code(b);
210        let ec = self.dict.get_code(c);
211        let ed = self.dict.get_code(d);
212
213        let codes = [ea.1, eb.1, ec.1, ed.1];
214        let mut used = codes.encode_group_varint_to(self.inner.deref_mut())?;
215        if ea.0 {
216            used += a.write_be_to(self.inner.deref_mut())?;
217        }
218        if eb.0 {
219            used += b.write_be_to(self.inner.deref_mut())?;
220        }
221        if ec.0 {
222            used += c.write_be_to(self.inner.deref_mut())?;
223        }
224        if ed.0 {
225            used += d.write_be_to(self.inner.deref_mut())?;
226        }
227
228        Ok(used)
229    }
230}
231
232pub struct GroupVarintCodeDecoder<'a, T: Hash + Eq, B: Bits> {
233    inner: BitsWrapper<'a, B>,
234    dict: CodeDictionary<T>,
235}
236impl<'a, T: Hash + Eq + Default, B: Bits> GroupVarintCodeDecoder<'a, T, B> {
237    pub fn new(inner: BitsWrapper<'a, B>) -> Self {
238        Self {
239            inner,
240            dict: CodeDictionary::new(),
241        }
242    }
243}
244impl<T: Hash + Eq + Default + ReadFromBEBits + Clone, B: Bits> GroupVarintCodeDecoder<'_, T, B> {
245    fn decode_1(&mut self, code: u32) -> Result<T, Error> {
246        self.dict
247            .read_code(code, || T::read_from_be_bits(self.inner.deref_mut()))
248    }
249
250    pub fn decode_4(&mut self) -> Result<Option<[T; 4]>, Error> {
251        let Some(val) = u32::decode_group_varint_from(self.inner.deref_mut())? else {
252            return Ok(None);
253        };
254        let [a, b, c, d] = val;
255
256        Ok(Some([
257            self.decode_1(a)?,
258            self.decode_1(b)?,
259            self.decode_1(c)?,
260            self.decode_1(d)?,
261        ]))
262    }
263}
264
265#[cfg(test)]
266mod test {
267    use crate::buf::{Buffer, FixedU8Buf, RoundU8Buffer};
268    use crate::codec::{GroupVarintCodeDecoder, GroupVarintCodeEncoder};
269    use crate::hex::HexDump;
270    use irox_bits::{BitsWrapper, Error};
271
272    #[test]
273    pub fn test_encoder() -> Result<(), Error> {
274        let mut buf = FixedU8Buf::<48>::new();
275        {
276            let mut codec = GroupVarintCodeEncoder::<u32, _>::new(BitsWrapper::Borrowed(&mut buf));
277            let used = codec.encode_4(&[0xAAAA, 0xBBBBBB, 0xCC, 0xDDDDDDDD])?;
278            assert_eq!(used, 5 + 16);
279            let used = codec.encode_4(&[0xAAAA, 0xBBBBBB, 0xCC, 0xDDDDDDDD])?;
280            assert_eq!(used, 5);
281        }
282        buf.as_ref_used().hexdump();
283
284        assert_eq!(5 + 16 + 5, buf.len());
285        assert_eq_hex_slice!(
286            &[
287                0x00, // control char for first code block
288                0x01, 0x02, 0x03, 0x04, // first 4 codes in code block
289                0x00, 0x00, 0xAA, 0xAA, // first coded value,
290                0x00, 0xBB, 0xBB, 0xBB, // second coded value,
291                0x00, 0x00, 0x00, 0xCC, // third coded value,
292                0xDD, 0xDD, 0xDD, 0xDD, // fourth coded value
293                0x00, // control char for second code block
294                0x01, 0x02, 0x03, 0x04, // second 4 code in code block
295            ],
296            buf.as_ref_used()
297        );
298        Ok(())
299    }
300
301    #[test]
302    pub fn test_decoder() -> Result<(), Error> {
303        let mut buf = RoundU8Buffer::from([
304            0x00, // control char for first code block
305            0x01, 0x02, 0x03, 0x04, // first 4 codes in code block
306            0x00, 0x00, 0xAA, 0xAA, // first coded value,
307            0x00, 0xBB, 0xBB, 0xBB, // second coded value,
308            0x00, 0x00, 0x00, 0xCC, // third coded value,
309            0xDD, 0xDD, 0xDD, 0xDD, // fourth coded value
310            0x00, // control char for second code block
311            0x01, 0x02, 0x03, 0x04, // second 4 code in code block
312        ]);
313        let mut dec = GroupVarintCodeDecoder::<u32, _>::new(BitsWrapper::Borrowed(&mut buf));
314        let block1 = dec.decode_4()?;
315        assert!(block1.is_some());
316        if let Some(block1) = block1 {
317            assert_eq_hex_slice!(&[0xAAAA, 0xBBBBBB, 0xCC, 0xDDDDDDDD], block1.as_ref())
318        }
319        let block2 = dec.decode_4()?;
320        assert!(block2.is_some());
321        if let Some(block2) = block2 {
322            assert_eq_hex_slice!(&[0xAAAA, 0xBBBBBB, 0xCC, 0xDDDDDDDD], block2.as_ref())
323        }
324        let block3 = dec.decode_4()?;
325        assert!(block3.is_none());
326        assert_eq!(0, buf.len());
327        Ok(())
328    }
329}