1use crate::BerError;
4use crate::length::decode_length;
5use crate::tag::{BOOLEAN, Class, ENUMERATED, INTEGER, OCTET_STRING, Tag};
6
7pub struct BerReader<'a> {
9 input: &'a [u8],
10 depth: u16,
11 max_depth: u16,
12 max_element_size: u32,
13}
14
15impl<'a> BerReader<'a> {
16 pub fn new(input: &'a [u8]) -> Self {
17 Self {
18 input,
19 depth: 0,
20 max_depth: 32,
21 max_element_size: 10 * 1024 * 1024,
22 }
23 }
24
25 pub fn with_max_depth(mut self, max: u16) -> Self {
26 self.max_depth = max;
27 self
28 }
29
30 pub fn with_max_element_size(mut self, max: u32) -> Self {
31 self.max_element_size = max;
32 self
33 }
34
35 pub fn is_empty(&self) -> bool {
36 self.input.is_empty()
37 }
38
39 pub fn remaining(&self) -> &'a [u8] {
40 self.input
41 }
42
43 pub fn peek_tag(&self) -> Result<Tag, BerError> {
45 if self.input.is_empty() {
46 return Err(BerError::Truncated { need: 1, have: 0 });
47 }
48 let (tag, _) = parse_tag(self.input)?;
49 Ok(tag)
50 }
51
52 pub fn read_element(&mut self) -> Result<(Tag, &'a [u8]), BerError> {
54 let (tag, tag_len) = parse_tag(self.input)?;
55
56 let rest = &self.input[tag_len..];
57 let (len_size, value_len) = decode_length(rest)?.ok_or(BerError::Truncated {
58 need: 1,
59 have: rest.len(),
60 })?;
61
62 if value_len as u64 > self.max_element_size as u64 {
63 return Err(BerError::ElementTooLarge {
64 size: value_len as u64,
65 max: self.max_element_size,
66 });
67 }
68
69 let header_len = tag_len + len_size;
70 let total = header_len + value_len;
71 if self.input.len() < total {
72 return Err(BerError::Truncated {
73 need: total,
74 have: self.input.len(),
75 });
76 }
77
78 let value = &self.input[header_len..total];
79 self.input = &self.input[total..];
80 Ok((tag, value))
81 }
82
83 pub fn read_sequence<F, T>(&mut self, expected_tag: Tag, f: F) -> Result<T, BerError>
86 where
87 F: FnOnce(&mut BerReader<'_>) -> Result<T, BerError>,
88 {
89 let (tag, value) = self.read_element()?;
90 if tag != expected_tag {
91 return Err(BerError::UnexpectedTag {
92 expected: expected_tag,
93 actual: tag,
94 });
95 }
96
97 if self.depth >= self.max_depth {
98 return Err(BerError::RecursionLimit {
99 max: self.max_depth,
100 });
101 }
102
103 let mut sub = BerReader {
104 input: value,
105 depth: self.depth + 1,
106 max_depth: self.max_depth,
107 max_element_size: self.max_element_size,
108 };
109 let result = f(&mut sub)?;
110 if !sub.input.is_empty() {
111 return Err(BerError::TrailingData {
112 remaining: sub.input.len(),
113 });
114 }
115 Ok(result)
116 }
117
118 pub fn read_sequence_lax<F, T>(&mut self, expected_tag: Tag, f: F) -> Result<T, BerError>
120 where
121 F: FnOnce(&mut BerReader<'_>) -> Result<T, BerError>,
122 {
123 let (tag, value) = self.read_element()?;
124 if tag != expected_tag {
125 return Err(BerError::UnexpectedTag {
126 expected: expected_tag,
127 actual: tag,
128 });
129 }
130
131 if self.depth >= self.max_depth {
132 return Err(BerError::RecursionLimit {
133 max: self.max_depth,
134 });
135 }
136
137 let mut sub = BerReader {
138 input: value,
139 depth: self.depth + 1,
140 max_depth: self.max_depth,
141 max_element_size: self.max_element_size,
142 };
143 f(&mut sub)
144 }
145
146 pub fn read_integer(&mut self) -> Result<i64, BerError> {
147 let (tag, value) = self.read_element()?;
148 if tag.number != INTEGER || tag.class != Class::Universal || tag.constructed {
149 return Err(BerError::UnexpectedTag {
150 expected: Tag::universal(INTEGER),
151 actual: tag,
152 });
153 }
154 decode_integer(value)
155 }
156
157 pub fn read_octet_string(&mut self) -> Result<&'a [u8], BerError> {
158 let (tag, value) = self.read_element()?;
159 if tag.number != OCTET_STRING || tag.class != Class::Universal {
160 return Err(BerError::UnexpectedTag {
161 expected: Tag::universal(OCTET_STRING),
162 actual: tag,
163 });
164 }
165 if tag.constructed {
166 return Err(BerError::ConstructedPrimitive);
167 }
168 Ok(value)
169 }
170
171 pub fn read_boolean(&mut self) -> Result<bool, BerError> {
172 let (tag, value) = self.read_element()?;
173 if tag.number != BOOLEAN || tag.class != Class::Universal || tag.constructed {
174 return Err(BerError::UnexpectedTag {
175 expected: Tag::universal(BOOLEAN),
176 actual: tag,
177 });
178 }
179 if value.len() != 1 {
180 return Err(BerError::InvalidBoolean);
181 }
182 Ok(value[0] != 0)
183 }
184
185 pub fn read_enumerated(&mut self) -> Result<i64, BerError> {
186 let (tag, value) = self.read_element()?;
187 if tag.number != ENUMERATED || tag.class != Class::Universal || tag.constructed {
188 return Err(BerError::UnexpectedTag {
189 expected: Tag::universal(ENUMERATED),
190 actual: tag,
191 });
192 }
193 decode_integer(value)
194 }
195
196 pub fn read_tagged_value(&mut self) -> Result<(Tag, &'a [u8]), BerError> {
198 self.read_element()
199 }
200
201 pub fn read_tagged_implicit_octet_string(
203 &mut self,
204 expected_number: u32,
205 ) -> Result<&'a [u8], BerError> {
206 let (tag, value) = self.read_element()?;
207 if tag.class != Class::Context || tag.number != expected_number {
208 return Err(BerError::UnexpectedTag {
209 expected: Tag::context(expected_number),
210 actual: tag,
211 });
212 }
213 Ok(value)
214 }
215}
216
217fn parse_tag(input: &[u8]) -> Result<(Tag, usize), BerError> {
218 if input.is_empty() {
219 return Err(BerError::Truncated { need: 1, have: 0 });
220 }
221
222 let first = input[0];
223 let class = Class::from_byte(first);
224 let constructed = (first & 0x20) != 0;
225 let tag_bits = first & 0x1F;
226
227 if tag_bits < 0x1F {
228 return Ok((
229 Tag {
230 class,
231 constructed,
232 number: tag_bits as u32,
233 },
234 1,
235 ));
236 }
237
238 let mut number: u32 = 0;
240 let mut i = 1;
241 loop {
242 if i >= input.len() {
243 return Err(BerError::Truncated {
244 need: i + 1,
245 have: input.len(),
246 });
247 }
248 if i > 5 {
249 return Err(BerError::TagOverflow);
250 }
251 let b = input[i];
252 number = number
253 .checked_shl(7)
254 .and_then(|n| n.checked_add((b & 0x7F) as u32))
255 .ok_or(BerError::TagOverflow)?;
256 i += 1;
257 if b & 0x80 == 0 {
258 break;
259 }
260 }
261
262 Ok((
263 Tag {
264 class,
265 constructed,
266 number,
267 },
268 i,
269 ))
270}
271
272fn decode_integer(bytes: &[u8]) -> Result<i64, BerError> {
273 if bytes.is_empty() || bytes.len() > 8 {
274 return Err(BerError::InvalidInteger);
275 }
276
277 let negative = bytes[0] & 0x80 != 0;
278 let mut result: i64 = if negative { -1 } else { 0 };
279
280 for &b in bytes {
281 result = (result << 8) | b as i64;
282 }
283 Ok(result)
284}