dolang_bytecode/
varint.rs1use std::{
2 hint::unreachable_unchecked,
3 io,
4 ptr::{self, NonNull},
5 slice,
6};
7
8use super::{DecResult, Decode, EncResult, Encode, UnsafeDecode};
9
10pub(crate) type IVar = i64;
13pub(crate) type UVar = u64;
14
15const MAX_LEN: usize = 9;
16
17const fn zag(x: IVar) -> UVar {
18 ((x << 1) ^ (x >> (IVar::BITS - 1))) as UVar
19}
20
21const fn zig(x: UVar) -> IVar {
22 (x >> 1) as IVar ^ -(x as IVar & 1)
23}
24
25fn encoded_len(x: UVar) -> usize {
26 match x.leading_zeros() {
27 57..65 => 1,
28 50..57 => 2,
29 43..50 => 3,
30 36..43 => 4,
31 29..36 => 5,
32 22..29 => 6,
33 15..22 => 7,
34 8..15 => 8,
35 0..8 => 9,
36 _ => unsafe { unreachable_unchecked() },
37 }
38}
39
40impl Encode for UVar {
41 fn encode(&self, w: &mut impl io::Write) -> EncResult<()> {
42 let len = encoded_len(*self);
43 if len == 9 {
44 let mut res = [0u8; MAX_LEN];
45 res[1..].copy_from_slice(&self.to_le_bytes());
46 w.write_all(&res)?;
47 } else {
48 let res = ((*self << 1 | 1) << (len - 1)).to_le_bytes();
49 w.write_all(&res[..len])?;
50 }
51 Ok(())
52 }
53}
54
55impl Decode for UVar {
56 fn decode<R: io::Read + io::Seek>(r: &mut R) -> DecResult<Self> {
57 let mut header = 0u8;
58 let mut data = [0u8; UVar::BITS as usize / 8];
59 r.read_exact(slice::from_mut(&mut header))?;
60 let trailing = header.trailing_zeros() as usize;
61
62 if trailing == 8 {
63 r.read_exact(&mut data)?;
64 Ok(UVar::from_le_bytes(data))
65 } else if trailing == 0 {
66 Ok(header as UVar >> 1)
67 } else {
68 data[0] = header;
69 r.read_exact(&mut data[1..trailing + 1])?;
70 Ok(UVar::from_le_bytes(data) >> (1 + trailing))
71 }
72 }
73}
74
75impl UnsafeDecode for UVar {
76 unsafe fn decode(r: &mut NonNull<u8>) -> Self {
77 unsafe {
78 let header = *r.as_ptr();
79 *r = r.add(1);
80 let mut data = [0u8; UVar::BITS as usize / 8];
81 let trailing = header.trailing_zeros() as usize;
82 if trailing == 8 {
83 ptr::copy_nonoverlapping(r.as_ptr(), data.as_mut_ptr(), size_of_val(&data));
84 *r = r.add(size_of_val(&data));
85 UVar::from_le_bytes(data)
86 } else {
87 data[0] = header;
88 ptr::copy_nonoverlapping(r.as_ptr(), data[1..].as_mut_ptr(), trailing);
89 *r = r.add(trailing);
90 UVar::from_le_bytes(data) >> (1 + trailing)
91 }
92 }
93 }
94}
95
96impl Encode for IVar {
97 fn encode(&self, w: &mut impl io::Write) -> EncResult<()> {
98 zag(*self).encode(w)
99 }
100}
101
102impl Decode for IVar {
103 fn decode<R: io::Read + io::Seek>(r: &mut R) -> DecResult<Self> {
104 Decode::decode(r).map(zig)
105 }
106}
107
108impl UnsafeDecode for IVar {
109 unsafe fn decode(r: &mut NonNull<u8>) -> Self {
110 zig(unsafe { UnsafeDecode::decode(r) })
111 }
112}
113
114#[cfg(test)]
115mod test {
116 use super::*;
117
118 fn push_unique<T: PartialEq>(items: &mut Vec<T>, value: T) {
119 if !items.contains(&value) {
120 items.push(value);
121 }
122 }
123
124 fn len_bounds(len: usize) -> (UVar, UVar) {
125 match len {
126 1 => (0, (1 << 7) - 1),
127 2 => (1 << 7, (1 << 14) - 1),
128 3 => (1 << 14, (1 << 21) - 1),
129 4 => (1 << 21, (1 << 28) - 1),
130 5 => (1 << 28, (1 << 35) - 1),
131 6 => (1 << 35, (1 << 42) - 1),
132 7 => (1 << 42, (1 << 49) - 1),
133 8 => (1 << 49, (1 << 56) - 1),
134 9 => (1 << 56, UVar::MAX),
135 _ => panic!("invalid encoding length"),
136 }
137 }
138
139 fn spread_uvars() -> Vec<UVar> {
140 let mut out = Vec::new();
141
142 for len in 1..=MAX_LEN {
143 let (lo, hi) = len_bounds(len);
144 for value in [
145 lo,
146 lo.saturating_add(1),
147 lo.saturating_add((hi - lo) / 2),
148 hi.saturating_sub(1),
149 hi,
150 ] {
151 push_unique(&mut out, value);
152 }
153 }
154
155 for shift in (0..UVar::BITS).step_by(3) {
156 if let Some(value) = 1u64.checked_shl(shift) {
157 push_unique(&mut out, value);
158 push_unique(&mut out, value.saturating_sub(1));
159 }
160 push_unique(&mut out, UVar::MAX >> shift);
161 }
162
163 out
164 }
165
166 fn spread_ivars() -> Vec<IVar> {
167 let mut out = vec![IVar::MIN, IVar::MIN + 1, -1, 0, 1, IVar::MAX - 1, IVar::MAX];
168
169 for value in spread_uvars() {
170 push_unique(&mut out, zig(value));
171 }
172
173 out
174 }
175
176 #[test]
177 #[cfg_attr(miri, ignore)]
178 fn zigzag() {
179 for value in spread_ivars() {
180 assert_eq!(value, zig(zag(value)));
181 }
182 }
183
184 #[test]
185 #[cfg_attr(miri, ignore)]
186 fn encode_decode_uvar() {
187 for value in spread_uvars() {
188 let mut buf = Vec::new();
189 value.encode(&mut buf).unwrap();
190 assert_eq!(buf.len(), encoded_len(value));
191 assert_eq!(
192 value,
193 Decode::decode(&mut io::Cursor::new(&buf[..])).unwrap()
194 );
195 assert_eq!(value, unsafe {
196 UnsafeDecode::decode(&mut NonNull::from_ref(&buf[0]))
197 });
198 }
199 }
200
201 #[test]
202 #[cfg_attr(miri, ignore)]
203 fn encode_decode_ivar() {
204 for value in spread_ivars() {
205 let mut buf = Vec::new();
206 value.encode(&mut buf).unwrap();
207 assert_eq!(buf.len(), encoded_len(zag(value)));
208 assert_eq!(
209 value,
210 Decode::decode(&mut io::Cursor::new(&buf[..])).unwrap()
211 );
212 assert_eq!(value, unsafe {
213 UnsafeDecode::decode(&mut NonNull::from_ref(&buf[0]))
214 });
215 }
216 }
217}