1use std::{collections::HashMap, convert::TryFrom, error, fmt, mem};
2
3use crate::bitcode::{BlockInfo, Payload, Record, Signature};
4use crate::bits::{self, Bits, Cursor};
5use crate::bitstream::{Abbreviation, BlockInfoCode, BuiltinAbbreviationId, Operand};
6use crate::visitor::BitStreamVisitor;
7
8#[derive(Debug, Clone)]
10pub enum Error {
11 InvalidSignature(u32),
12 InvalidAbbrev,
13 NestedBlockInBlockInfo,
14 MissingSetBid,
15 InvalidBlockInfoRecord(u64),
16 AbbrevWidthTooSmall(usize),
17 NoSuchAbbrev { block_id: u64, abbrev_id: usize },
18 MissingEndBlock(u64),
19 ReadBits(bits::Error),
20}
21
22impl fmt::Display for Error {
23 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
24 match self {
25 Error::InvalidSignature(sig) => {
26 write!(f, "invalid signature (magic number): 0x{:x}", sig)
27 }
28 Error::InvalidAbbrev => write!(f, "invalid abbreviation"),
29 Error::NestedBlockInBlockInfo => write!(f, "nested block in block info"),
30 Error::MissingSetBid => write!(f, "missing SETBID"),
31 Error::InvalidBlockInfoRecord(record_id) => {
32 write!(f, "invalid block info record `{}`", record_id)
33 }
34 Error::AbbrevWidthTooSmall(width) => {
35 write!(f, "abbreviation width `{}` is too small", width)
36 }
37 Error::NoSuchAbbrev {
38 block_id,
39 abbrev_id,
40 } => write!(
41 f,
42 "no such abbreviation `{}` in block `{}`",
43 abbrev_id, block_id
44 ),
45 Error::MissingEndBlock(block_id) => write!(f, "missing end block for `{}`", block_id),
46 Error::ReadBits(err) => err.fmt(f),
47 }
48 }
49}
50
51impl error::Error for Error {}
52
53impl From<bits::Error> for Error {
54 fn from(err: bits::Error) -> Self {
55 Self::ReadBits(err)
56 }
57}
58
59#[derive(Debug, Clone)]
61pub struct BitStreamReader<'a> {
62 cursor: Cursor<'a>,
63 pub(crate) block_info: HashMap<u64, BlockInfo>,
65 global_abbrevs: HashMap<u64, Vec<Abbreviation>>,
66}
67
68impl<'a> BitStreamReader<'a> {
69 pub const TOP_LEVEL_BLOCK_ID: u64 = u64::MAX;
71
72 pub fn new(buffer: &'a [u8]) -> Self {
74 let cursor = Cursor::new(Bits::new(buffer));
75 Self {
76 cursor,
77 block_info: HashMap::new(),
78 global_abbrevs: HashMap::new(),
79 }
80 }
81
82 pub fn read_signature(&mut self) -> Result<Signature, Error> {
84 assert!(self.cursor.is_at_start());
85 let bits = self.cursor.read(mem::size_of::<u32>() * 8)? as u32;
86 Ok(Signature::new(bits))
87 }
88
89 pub fn read_abbrev_op(&mut self) -> Result<Operand, Error> {
91 let is_literal = self.cursor.read(1)?;
92 if is_literal == 1 {
93 return Ok(Operand::Literal(self.cursor.read_vbr(8)?));
94 }
95 let op_type = self.cursor.read(3)?;
96 let op = match op_type {
97 1 => Operand::Fixed(self.cursor.read_vbr(5)? as u8),
98 2 => Operand::Vbr(self.cursor.read_vbr(5)? as u8),
99 3 => Operand::Array(Box::new(self.read_abbrev_op()?)),
100 4 => Operand::Char6,
101 5 => Operand::Blob,
102 _ => return Err(Error::InvalidAbbrev),
103 };
104 Ok(op)
105 }
106
107 pub fn read_abbrev(&mut self, num_ops: usize) -> Result<Abbreviation, Error> {
109 if num_ops == 0 {
110 return Err(Error::InvalidAbbrev);
111 }
112 let mut operands = Vec::new();
113 for i in 0..num_ops {
114 let op = self.read_abbrev_op()?;
115 let is_array = op.is_array();
116 let is_blob = op.is_blob();
117 operands.push(op);
118 if is_array {
119 if i == num_ops - 2 {
120 break;
121 } else {
122 return Err(Error::InvalidAbbrev);
123 }
124 } else if is_blob {
125 if i != num_ops - 1 {
126 return Err(Error::InvalidAbbrev);
127 }
128 }
129 }
130 Ok(Abbreviation { operands })
131 }
132
133 fn read_single_abbreviated_record_operand(&mut self, operand: &Operand) -> Result<u64, Error> {
134 match operand {
135 Operand::Char6 => {
136 let value = self.cursor.read(6)?;
137 return match value {
138 0..=25 => Ok(value + u64::from('a' as u32)),
139 26..=51 => Ok(value + u64::from('A' as u32) - 26),
140 52..=61 => Ok(value + u64::from('0' as u32) - 52),
141 62 => Ok(u64::from('.' as u32)),
142 63 => Ok(u64::from('_' as u32)),
143 _ => Err(Error::InvalidAbbrev),
144 };
145 }
146 Operand::Literal(value) => Ok(*value),
147 Operand::Fixed(width) => Ok(self.cursor.read(*width as usize)?),
148 Operand::Vbr(width) => Ok(self.cursor.read_vbr(*width as usize)?),
149 Operand::Array(_) | Operand::Blob => Err(Error::InvalidAbbrev),
150 }
151 }
152
153 pub fn read_abbreviated_record(&mut self, abbrev: &Abbreviation) -> Result<Record, Error> {
155 let code =
156 self.read_single_abbreviated_record_operand(&abbrev.operands.first().unwrap())?;
157 let last_operand = abbrev.operands.last().unwrap();
158 let last_regular_operand_index =
159 abbrev.operands.len() - (if last_operand.is_payload() { 1 } else { 0 });
160 let mut fields = Vec::new();
161 for op in &abbrev.operands[1..last_regular_operand_index] {
162 fields.push(self.read_single_abbreviated_record_operand(op)?);
163 }
164 let payload = if last_operand.is_payload() {
165 match last_operand {
166 Operand::Array(element) => {
167 let length = self.cursor.read_vbr(6)? as usize;
168 let mut elements = Vec::with_capacity(length);
169 for _ in 0..length {
170 elements.push(self.read_single_abbreviated_record_operand(element)?);
171 }
172 if matches!(**element, Operand::Char6) {
173 let s: String = elements
174 .into_iter()
175 .map(|x| std::char::from_u32(x as u32).unwrap())
176 .collect();
177 Some(Payload::Char6String(s))
178 } else {
179 Some(Payload::Array(elements))
180 }
181 }
182 Operand::Blob => {
183 let length = self.cursor.read_vbr(6)? as usize;
184 self.cursor.advance(32)?;
185 let data = self.cursor.read_bytes(length)?;
186 self.cursor.advance(32)?;
187 Some(Payload::Blob(data))
188 }
189 _ => unreachable!(),
190 }
191 } else {
192 None
193 };
194 Ok(Record {
195 id: code,
196 fields,
197 payload,
198 })
199 }
200
201 pub fn read_block_info_block(&mut self, abbrev_width: usize) -> Result<(), Error> {
203 use BuiltinAbbreviationId::*;
204
205 let mut current_block_id = None;
206 loop {
207 let abbrev_id = self.cursor.read(abbrev_width)?;
208 match BuiltinAbbreviationId::try_from(abbrev_id).map_err(|_| Error::NoSuchAbbrev {
209 block_id: 0,
210 abbrev_id: abbrev_id as usize,
211 })? {
212 EndBlock => {
213 self.cursor.advance(32)?;
214 return Ok(());
215 }
216 EnterSubBlock => {
217 return Err(Error::NestedBlockInBlockInfo);
218 }
219 DefineAbbreviation => {
220 if let Some(block_id) = current_block_id {
221 let num_ops = self.cursor.read_vbr(5)? as usize;
222 let abbrev = self.read_abbrev(num_ops)?;
223 let abbrevs = self
224 .global_abbrevs
225 .entry(block_id)
226 .or_insert_with(|| Vec::new());
227 abbrevs.push(abbrev);
228 } else {
229 return Err(Error::MissingSetBid);
230 }
231 }
232 UnabbreviatedRecord => {
233 let code = self.cursor.read_vbr(6)?;
234 let num_ops = self.cursor.read_vbr(6)? as usize;
235 let mut operands = Vec::with_capacity(num_ops);
236 for _ in 0..num_ops {
237 operands.push(self.cursor.read_vbr(6)?);
238 }
239 match BlockInfoCode::try_from(
240 u8::try_from(code).map_err(|_| Error::InvalidBlockInfoRecord(code))?,
241 )
242 .map_err(|_| Error::InvalidBlockInfoRecord(code))?
243 {
244 BlockInfoCode::SetBid => {
245 if operands.len() != 1 {
246 return Err(Error::InvalidBlockInfoRecord(code));
247 }
248 current_block_id = operands.first().cloned();
249 }
250 BlockInfoCode::BlockName => {
251 if let Some(block_id) = current_block_id {
252 let block_info = self
253 .block_info
254 .entry(block_id)
255 .or_insert_with(|| BlockInfo::default());
256 let name = String::from_utf8(
257 operands.into_iter().map(|x| x as u8).collect::<Vec<u8>>(),
258 )
259 .unwrap_or_else(|_| "<invalid>".to_string());
260 block_info.name = name;
261 } else {
262 return Err(Error::MissingSetBid);
263 }
264 }
265 BlockInfoCode::SetRecordName => {
266 if let Some(block_id) = current_block_id {
267 if let Some(record_id) = operands.first().cloned() {
268 let block_info = self
269 .block_info
270 .entry(block_id)
271 .or_insert_with(|| BlockInfo::default());
272 let name = String::from_utf8(
273 operands
274 .into_iter()
275 .skip(1)
276 .map(|x| x as u8)
277 .collect::<Vec<u8>>(),
278 )
279 .unwrap_or_else(|_| "<invalid>".to_string());
280 block_info.record_names.insert(record_id, name);
281 } else {
282 return Err(Error::InvalidBlockInfoRecord(code));
283 }
284 } else {
285 return Err(Error::MissingSetBid);
286 }
287 }
288 }
289 }
290 }
291 }
292 }
293
294 pub fn read_block<V: BitStreamVisitor>(
296 &mut self,
297 id: u64,
298 abbrev_width: usize,
299 visitor: &mut V,
300 ) -> Result<(), Error> {
301 use BuiltinAbbreviationId::*;
302
303 while !self.cursor.is_at_end() {
304 let abbrev_id = self.cursor.read(abbrev_width)?;
305 match BuiltinAbbreviationId::try_from(abbrev_id) {
306 Ok(abbrev_id) => match abbrev_id {
307 EndBlock => {
308 self.cursor.advance(32)?;
309 visitor.did_exit_block();
310 return Ok(());
311 }
312 EnterSubBlock => {
313 let block_id = self.cursor.read_vbr(8)?;
314 let new_abbrev_width = self.cursor.read_vbr(4)? as usize;
315 self.cursor.advance(32)?;
316 let block_length = self.cursor.read(32)? as usize * 4;
317 match block_id {
318 0 => self.read_block_info_block(new_abbrev_width)?,
319 _ => {
320 if !visitor.should_enter_block(block_id) {
321 self.cursor.skip_bytes(block_length)?;
322 break;
323 }
324 self.read_block(block_id, new_abbrev_width, visitor)?;
325 }
326 }
327 }
328 DefineAbbreviation => {
329 let num_ops = self.cursor.read_vbr(5)? as usize;
330 let abbrev = self.read_abbrev(num_ops)?;
331 let abbrev_info =
332 self.global_abbrevs.entry(id).or_insert_with(|| Vec::new());
333 abbrev_info.push(abbrev);
334 }
335 UnabbreviatedRecord => {
336 let code = self.cursor.read_vbr(6)?;
337 let num_ops = self.cursor.read_vbr(6)? as usize;
338 let mut operands = Vec::with_capacity(num_ops);
339 for _ in 0..num_ops {
340 operands.push(self.cursor.read_vbr(6)?);
341 }
342 visitor.visit(Record {
343 id: code,
344 fields: operands,
345 payload: None,
346 });
347 }
348 },
349 Err(_) => {
350 if let Some(abbrev_info) = self.global_abbrevs.get(&id).cloned() {
351 let abbrev_id = abbrev_id as usize;
352 if abbrev_id - 4 < abbrev_info.len() {
353 visitor
354 .visit(self.read_abbreviated_record(&abbrev_info[abbrev_id - 4])?);
355 continue;
356 }
357 }
358 return Err(Error::NoSuchAbbrev {
359 block_id: id,
360 abbrev_id: abbrev_id as usize,
361 });
362 }
363 }
364 }
365 if id != Self::TOP_LEVEL_BLOCK_ID {
366 return Err(Error::MissingEndBlock(id));
367 }
368 Ok(())
369 }
370}