Skip to main content

ldap_client_ber/
reader.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2
3use crate::BerError;
4use crate::length::decode_length;
5use crate::tag::{BOOLEAN, Class, ENUMERATED, INTEGER, OCTET_STRING, Tag};
6
7/// Zero-copy BER decoder over a byte slice.
8pub struct BerReader<'a> {
9    input: &'a [u8],
10    depth: u16,
11    max_depth: u16,
12    max_element_size: u32,
13}
14
15impl<'a> BerReader<'a> {
16    pub fn new(input: &'a [u8]) -> Self {
17        Self {
18            input,
19            depth: 0,
20            max_depth: 32,
21            max_element_size: 10 * 1024 * 1024,
22        }
23    }
24
25    pub fn with_max_depth(mut self, max: u16) -> Self {
26        self.max_depth = max;
27        self
28    }
29
30    pub fn with_max_element_size(mut self, max: u32) -> Self {
31        self.max_element_size = max;
32        self
33    }
34
35    pub fn is_empty(&self) -> bool {
36        self.input.is_empty()
37    }
38
39    pub fn remaining(&self) -> &'a [u8] {
40        self.input
41    }
42
43    /// Peek at the tag of the next element without consuming it.
44    pub fn peek_tag(&self) -> Result<Tag, BerError> {
45        if self.input.is_empty() {
46            return Err(BerError::Truncated { need: 1, have: 0 });
47        }
48        let (tag, _) = parse_tag(self.input)?;
49        Ok(tag)
50    }
51
52    /// Read the next TLV element, returning `(tag, value_bytes)`.
53    pub fn read_element(&mut self) -> Result<(Tag, &'a [u8]), BerError> {
54        let (tag, tag_len) = parse_tag(self.input)?;
55
56        let rest = &self.input[tag_len..];
57        let (len_size, value_len) = decode_length(rest)?.ok_or(BerError::Truncated {
58            need: 1,
59            have: rest.len(),
60        })?;
61
62        if value_len as u64 > self.max_element_size as u64 {
63            return Err(BerError::ElementTooLarge {
64                size: value_len as u64,
65                max: self.max_element_size,
66            });
67        }
68
69        let header_len = tag_len + len_size;
70        let total = header_len + value_len;
71        if self.input.len() < total {
72            return Err(BerError::Truncated {
73                need: total,
74                have: self.input.len(),
75            });
76        }
77
78        let value = &self.input[header_len..total];
79        self.input = &self.input[total..];
80        Ok((tag, value))
81    }
82
83    /// Read a constructed element (SEQUENCE, SET, or context-tagged),
84    /// passing a sub-reader scoped to its contents.
85    pub fn read_sequence<F, T>(&mut self, expected_tag: Tag, f: F) -> Result<T, BerError>
86    where
87        F: FnOnce(&mut BerReader<'_>) -> Result<T, BerError>,
88    {
89        let (tag, value) = self.read_element()?;
90        if tag != expected_tag {
91            return Err(BerError::UnexpectedTag {
92                expected: expected_tag,
93                actual: tag,
94            });
95        }
96
97        if self.depth >= self.max_depth {
98            return Err(BerError::RecursionLimit {
99                max: self.max_depth,
100            });
101        }
102
103        let mut sub = BerReader {
104            input: value,
105            depth: self.depth + 1,
106            max_depth: self.max_depth,
107            max_element_size: self.max_element_size,
108        };
109        let result = f(&mut sub)?;
110        if !sub.input.is_empty() {
111            return Err(BerError::TrailingData {
112                remaining: sub.input.len(),
113            });
114        }
115        Ok(result)
116    }
117
118    /// Like `read_sequence` but allows trailing data in the constructed element.
119    pub fn read_sequence_lax<F, T>(&mut self, expected_tag: Tag, f: F) -> Result<T, BerError>
120    where
121        F: FnOnce(&mut BerReader<'_>) -> Result<T, BerError>,
122    {
123        let (tag, value) = self.read_element()?;
124        if tag != expected_tag {
125            return Err(BerError::UnexpectedTag {
126                expected: expected_tag,
127                actual: tag,
128            });
129        }
130
131        if self.depth >= self.max_depth {
132            return Err(BerError::RecursionLimit {
133                max: self.max_depth,
134            });
135        }
136
137        let mut sub = BerReader {
138            input: value,
139            depth: self.depth + 1,
140            max_depth: self.max_depth,
141            max_element_size: self.max_element_size,
142        };
143        f(&mut sub)
144    }
145
146    pub fn read_integer(&mut self) -> Result<i64, BerError> {
147        let (tag, value) = self.read_element()?;
148        if tag.number != INTEGER || tag.class != Class::Universal || tag.constructed {
149            return Err(BerError::UnexpectedTag {
150                expected: Tag::universal(INTEGER),
151                actual: tag,
152            });
153        }
154        decode_integer(value)
155    }
156
157    pub fn read_octet_string(&mut self) -> Result<&'a [u8], BerError> {
158        let (tag, value) = self.read_element()?;
159        if tag.number != OCTET_STRING || tag.class != Class::Universal {
160            return Err(BerError::UnexpectedTag {
161                expected: Tag::universal(OCTET_STRING),
162                actual: tag,
163            });
164        }
165        if tag.constructed {
166            return Err(BerError::ConstructedPrimitive);
167        }
168        Ok(value)
169    }
170
171    pub fn read_boolean(&mut self) -> Result<bool, BerError> {
172        let (tag, value) = self.read_element()?;
173        if tag.number != BOOLEAN || tag.class != Class::Universal || tag.constructed {
174            return Err(BerError::UnexpectedTag {
175                expected: Tag::universal(BOOLEAN),
176                actual: tag,
177            });
178        }
179        if value.len() != 1 {
180            return Err(BerError::InvalidBoolean);
181        }
182        Ok(value[0] != 0)
183    }
184
185    pub fn read_enumerated(&mut self) -> Result<i64, BerError> {
186        let (tag, value) = self.read_element()?;
187        if tag.number != ENUMERATED || tag.class != Class::Universal || tag.constructed {
188            return Err(BerError::UnexpectedTag {
189                expected: Tag::universal(ENUMERATED),
190                actual: tag,
191            });
192        }
193        decode_integer(value)
194    }
195
196    /// Read an element with any tag, returning its raw bytes.
197    pub fn read_tagged_value(&mut self) -> Result<(Tag, &'a [u8]), BerError> {
198        self.read_element()
199    }
200
201    /// Read a tagged implicit octet string (context-tagged primitive).
202    pub fn read_tagged_implicit_octet_string(
203        &mut self,
204        expected_number: u32,
205    ) -> Result<&'a [u8], BerError> {
206        let (tag, value) = self.read_element()?;
207        if tag.class != Class::Context || tag.number != expected_number {
208            return Err(BerError::UnexpectedTag {
209                expected: Tag::context(expected_number),
210                actual: tag,
211            });
212        }
213        Ok(value)
214    }
215}
216
217fn parse_tag(input: &[u8]) -> Result<(Tag, usize), BerError> {
218    if input.is_empty() {
219        return Err(BerError::Truncated { need: 1, have: 0 });
220    }
221
222    let first = input[0];
223    let class = Class::from_byte(first);
224    let constructed = (first & 0x20) != 0;
225    let tag_bits = first & 0x1F;
226
227    if tag_bits < 0x1F {
228        return Ok((
229            Tag {
230                class,
231                constructed,
232                number: tag_bits as u32,
233            },
234            1,
235        ));
236    }
237
238    // High-tag-number form (at most 5 continuation bytes for u32).
239    let mut number: u32 = 0;
240    let mut i = 1;
241    loop {
242        if i >= input.len() {
243            return Err(BerError::Truncated {
244                need: i + 1,
245                have: input.len(),
246            });
247        }
248        if i > 5 {
249            return Err(BerError::TagOverflow);
250        }
251        let b = input[i];
252        number = number
253            .checked_shl(7)
254            .and_then(|n| n.checked_add((b & 0x7F) as u32))
255            .ok_or(BerError::TagOverflow)?;
256        i += 1;
257        if b & 0x80 == 0 {
258            break;
259        }
260    }
261
262    Ok((
263        Tag {
264            class,
265            constructed,
266            number,
267        },
268        i,
269    ))
270}
271
272fn decode_integer(bytes: &[u8]) -> Result<i64, BerError> {
273    if bytes.is_empty() || bytes.len() > 8 {
274        return Err(BerError::InvalidInteger);
275    }
276
277    let negative = bytes[0] & 0x80 != 0;
278    let mut result: i64 = if negative { -1 } else { 0 };
279
280    for &b in bytes {
281        result = (result << 8) | b as i64;
282    }
283    Ok(result)
284}