Skip to main content

dolang_bytecode/
varint.rs

1use std::{
2    hint::unreachable_unchecked,
3    io,
4    ptr::{self, NonNull},
5    slice,
6};
7
8use super::{DecResult, Decode, EncResult, Encode, UnsafeDecode};
9
10// Implementation of vint64 encoding by Tony Arcieri: https://crates.io/crates/vint64
11
12pub(crate) type IVar = i64;
13pub(crate) type UVar = u64;
14
15const MAX_LEN: usize = 9;
16
17const fn zag(x: IVar) -> UVar {
18    ((x << 1) ^ (x >> (IVar::BITS - 1))) as UVar
19}
20
21const fn zig(x: UVar) -> IVar {
22    (x >> 1) as IVar ^ -(x as IVar & 1)
23}
24
25fn encoded_len(x: UVar) -> usize {
26    match x.leading_zeros() {
27        57..65 => 1,
28        50..57 => 2,
29        43..50 => 3,
30        36..43 => 4,
31        29..36 => 5,
32        22..29 => 6,
33        15..22 => 7,
34        8..15 => 8,
35        0..8 => 9,
36        _ => unsafe { unreachable_unchecked() },
37    }
38}
39
40impl Encode for UVar {
41    fn encode(&self, w: &mut impl io::Write) -> EncResult<()> {
42        let len = encoded_len(*self);
43        if len == 9 {
44            let mut res = [0u8; MAX_LEN];
45            res[1..].copy_from_slice(&self.to_le_bytes());
46            w.write_all(&res)?;
47        } else {
48            let res = ((*self << 1 | 1) << (len - 1)).to_le_bytes();
49            w.write_all(&res[..len])?;
50        }
51        Ok(())
52    }
53}
54
55impl Decode for UVar {
56    fn decode<R: io::Read + io::Seek>(r: &mut R) -> DecResult<Self> {
57        let mut header = 0u8;
58        let mut data = [0u8; UVar::BITS as usize / 8];
59        r.read_exact(slice::from_mut(&mut header))?;
60        let trailing = header.trailing_zeros() as usize;
61
62        if trailing == 8 {
63            r.read_exact(&mut data)?;
64            Ok(UVar::from_le_bytes(data))
65        } else if trailing == 0 {
66            Ok(header as UVar >> 1)
67        } else {
68            data[0] = header;
69            r.read_exact(&mut data[1..trailing + 1])?;
70            Ok(UVar::from_le_bytes(data) >> (1 + trailing))
71        }
72    }
73}
74
75impl UnsafeDecode for UVar {
76    unsafe fn decode(r: &mut NonNull<u8>) -> Self {
77        unsafe {
78            let header = *r.as_ptr();
79            *r = r.add(1);
80            let mut data = [0u8; UVar::BITS as usize / 8];
81            let trailing = header.trailing_zeros() as usize;
82            if trailing == 8 {
83                ptr::copy_nonoverlapping(r.as_ptr(), data.as_mut_ptr(), size_of_val(&data));
84                *r = r.add(size_of_val(&data));
85                UVar::from_le_bytes(data)
86            } else {
87                data[0] = header;
88                ptr::copy_nonoverlapping(r.as_ptr(), data[1..].as_mut_ptr(), trailing);
89                *r = r.add(trailing);
90                UVar::from_le_bytes(data) >> (1 + trailing)
91            }
92        }
93    }
94}
95
96impl Encode for IVar {
97    fn encode(&self, w: &mut impl io::Write) -> EncResult<()> {
98        zag(*self).encode(w)
99    }
100}
101
102impl Decode for IVar {
103    fn decode<R: io::Read + io::Seek>(r: &mut R) -> DecResult<Self> {
104        Decode::decode(r).map(zig)
105    }
106}
107
108impl UnsafeDecode for IVar {
109    unsafe fn decode(r: &mut NonNull<u8>) -> Self {
110        zig(unsafe { UnsafeDecode::decode(r) })
111    }
112}
113
114#[cfg(test)]
115mod test {
116    use super::*;
117
118    fn push_unique<T: PartialEq>(items: &mut Vec<T>, value: T) {
119        if !items.contains(&value) {
120            items.push(value);
121        }
122    }
123
124    fn len_bounds(len: usize) -> (UVar, UVar) {
125        match len {
126            1 => (0, (1 << 7) - 1),
127            2 => (1 << 7, (1 << 14) - 1),
128            3 => (1 << 14, (1 << 21) - 1),
129            4 => (1 << 21, (1 << 28) - 1),
130            5 => (1 << 28, (1 << 35) - 1),
131            6 => (1 << 35, (1 << 42) - 1),
132            7 => (1 << 42, (1 << 49) - 1),
133            8 => (1 << 49, (1 << 56) - 1),
134            9 => (1 << 56, UVar::MAX),
135            _ => panic!("invalid encoding length"),
136        }
137    }
138
139    fn spread_uvars() -> Vec<UVar> {
140        let mut out = Vec::new();
141
142        for len in 1..=MAX_LEN {
143            let (lo, hi) = len_bounds(len);
144            for value in [
145                lo,
146                lo.saturating_add(1),
147                lo.saturating_add((hi - lo) / 2),
148                hi.saturating_sub(1),
149                hi,
150            ] {
151                push_unique(&mut out, value);
152            }
153        }
154
155        for shift in (0..UVar::BITS).step_by(3) {
156            if let Some(value) = 1u64.checked_shl(shift) {
157                push_unique(&mut out, value);
158                push_unique(&mut out, value.saturating_sub(1));
159            }
160            push_unique(&mut out, UVar::MAX >> shift);
161        }
162
163        out
164    }
165
166    fn spread_ivars() -> Vec<IVar> {
167        let mut out = vec![IVar::MIN, IVar::MIN + 1, -1, 0, 1, IVar::MAX - 1, IVar::MAX];
168
169        for value in spread_uvars() {
170            push_unique(&mut out, zig(value));
171        }
172
173        out
174    }
175
176    #[test]
177    #[cfg_attr(miri, ignore)]
178    fn zigzag() {
179        for value in spread_ivars() {
180            assert_eq!(value, zig(zag(value)));
181        }
182    }
183
184    #[test]
185    #[cfg_attr(miri, ignore)]
186    fn encode_decode_uvar() {
187        for value in spread_uvars() {
188            let mut buf = Vec::new();
189            value.encode(&mut buf).unwrap();
190            assert_eq!(buf.len(), encoded_len(value));
191            assert_eq!(
192                value,
193                Decode::decode(&mut io::Cursor::new(&buf[..])).unwrap()
194            );
195            assert_eq!(value, unsafe {
196                UnsafeDecode::decode(&mut NonNull::from_ref(&buf[0]))
197            });
198        }
199    }
200
201    #[test]
202    #[cfg_attr(miri, ignore)]
203    fn encode_decode_ivar() {
204        for value in spread_ivars() {
205            let mut buf = Vec::new();
206            value.encode(&mut buf).unwrap();
207            assert_eq!(buf.len(), encoded_len(zag(value)));
208            assert_eq!(
209                value,
210                Decode::decode(&mut io::Cursor::new(&buf[..])).unwrap()
211            );
212            assert_eq!(value, unsafe {
213                UnsafeDecode::decode(&mut NonNull::from_ref(&buf[0]))
214            });
215        }
216    }
217}