Skip to main content

ember_plus/ber/
tag.rs

1//! BER tag encoding and decoding.
2
3use crate::error::{BerError, Result};
4
5/// Tag class in BER encoding.
6#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7#[repr(u8)]
8pub enum TagClass {
9    /// Universal class (0b00)
10    Universal = 0,
11    /// Application class (0b01)
12    Application = 1,
13    /// Context-specific class (0b10)
14    Context = 2,
15    /// Private class (0b11)
16    Private = 3,
17}
18
19impl TagClass {
20    /// Create a tag class from the class bits.
21    pub fn from_bits(bits: u8) -> Result<Self> {
22        match bits & 0b11 {
23            0 => Ok(TagClass::Universal),
24            1 => Ok(TagClass::Application),
25            2 => Ok(TagClass::Context),
26            3 => Ok(TagClass::Private),
27            _ => Err(BerError::UnsupportedTagClass(bits).into()),
28        }
29    }
30
31    /// Get the bits for this tag class.
32    pub fn to_bits(self) -> u8 {
33        self as u8
34    }
35}
36
37/// Tag type (primitive or constructed).
38#[derive(Debug, Clone, Copy, PartialEq, Eq)]
39pub enum TagType {
40    /// Primitive encoding
41    Primitive,
42    /// Constructed encoding (contains other TLV elements)
43    Constructed,
44}
45
46impl TagType {
47    /// Create from the P/C bit.
48    pub fn from_bit(bit: bool) -> Self {
49        if bit {
50            TagType::Constructed
51        } else {
52            TagType::Primitive
53        }
54    }
55
56    /// Convert to the P/C bit.
57    pub fn to_bit(self) -> bool {
58        matches!(self, TagType::Constructed)
59    }
60}
61
62/// A BER tag identifier.
63#[derive(Debug, Clone, Copy, PartialEq, Eq)]
64pub struct Tag {
65    /// The tag class
66    pub class: TagClass,
67    /// Whether this is a constructed type
68    pub tag_type: TagType,
69    /// The tag number
70    pub number: u32,
71}
72
73impl Tag {
74    /// Create a new tag.
75    pub const fn new(class: TagClass, tag_type: TagType, number: u32) -> Self {
76        Tag { class, tag_type, number }
77    }
78
79    /// Create a universal primitive tag.
80    pub const fn universal_primitive(number: u32) -> Self {
81        Tag::new(TagClass::Universal, TagType::Primitive, number)
82    }
83
84    /// Create a universal constructed tag.
85    pub const fn universal_constructed(number: u32) -> Self {
86        Tag::new(TagClass::Universal, TagType::Constructed, number)
87    }
88
89    /// Create a context-specific primitive tag.
90    pub const fn context_primitive(number: u32) -> Self {
91        Tag::new(TagClass::Context, TagType::Primitive, number)
92    }
93
94    /// Create a context-specific constructed tag.
95    pub const fn context_constructed(number: u32) -> Self {
96        Tag::new(TagClass::Context, TagType::Constructed, number)
97    }
98
99    /// Create an application primitive tag.
100    pub const fn application_primitive(number: u32) -> Self {
101        Tag::new(TagClass::Application, TagType::Primitive, number)
102    }
103
104    /// Create an application constructed tag.
105    pub const fn application_constructed(number: u32) -> Self {
106        Tag::new(TagClass::Application, TagType::Constructed, number)
107    }
108
109    /// Check if this tag is primitive.
110    pub fn is_primitive(&self) -> bool {
111        self.tag_type == TagType::Primitive
112    }
113
114    /// Check if this tag is constructed.
115    pub fn is_constructed(&self) -> bool {
116        self.tag_type == TagType::Constructed
117    }
118
119    /// Check if this tag is universal class.
120    pub fn is_universal(&self) -> bool {
121        self.class == TagClass::Universal
122    }
123
124    /// Check if this tag is context-specific class.
125    pub fn is_context(&self) -> bool {
126        self.class == TagClass::Context
127    }
128
129    /// Check if this tag is application class.
130    pub fn is_application(&self) -> bool {
131        self.class == TagClass::Application
132    }
133
134    /// Encode this tag to bytes.
135    pub fn encode(&self) -> Vec<u8> {
136        let mut result = Vec::new();
137        
138        // First byte: class (2 bits) | P/C (1 bit) | number or 0x1F (5 bits)
139        let first_byte = (self.class.to_bits() << 6)
140            | ((self.tag_type.to_bit() as u8) << 5)
141            | if self.number < 31 { self.number as u8 } else { 0x1F };
142        
143        result.push(first_byte);
144        
145        // Long form tag number if >= 31
146        if self.number >= 31 {
147            let mut number = self.number;
148            let mut bytes = Vec::new();
149            
150            // Encode in base-128, most significant byte first
151            loop {
152                bytes.push((number & 0x7F) as u8);
153                number >>= 7;
154                if number == 0 {
155                    break;
156                }
157            }
158            
159            // Reverse and set continuation bits
160            for (i, byte) in bytes.iter().rev().enumerate() {
161                if i < bytes.len() - 1 {
162                    result.push(byte | 0x80); // Set continuation bit
163                } else {
164                    result.push(*byte); // Last byte, no continuation
165                }
166            }
167        }
168        
169        result
170    }
171
172    /// Decode a tag from bytes.
173    pub fn decode(data: &[u8]) -> Result<(Self, usize)> {
174        if data.is_empty() {
175            return Err(BerError::UnexpectedEof.into());
176        }
177
178        let first_byte = data[0];
179        let class = TagClass::from_bits(first_byte >> 6)?;
180        let tag_type = TagType::from_bit((first_byte & 0x20) != 0);
181        let short_number = first_byte & 0x1F;
182
183        if short_number < 31 {
184            // Short form
185            Ok((Tag::new(class, tag_type, short_number as u32), 1))
186        } else {
187            // Long form
188            let mut number: u32 = 0;
189            let mut pos = 1;
190
191            loop {
192                if pos >= data.len() {
193                    return Err(BerError::UnexpectedEof.into());
194                }
195
196                let byte = data[pos];
197                number = number
198                    .checked_shl(7)
199                    .ok_or(BerError::LengthOverflow)?
200                    .checked_add((byte & 0x7F) as u32)
201                    .ok_or(BerError::LengthOverflow)?;
202                pos += 1;
203
204                if (byte & 0x80) == 0 {
205                    break;
206                }
207            }
208
209            Ok((Tag::new(class, tag_type, number), pos))
210        }
211    }
212}
213
214// Common tag constants
215impl Tag {
216    /// BOOLEAN tag
217    pub const BOOLEAN: Tag = Tag::universal_primitive(1);
218    /// INTEGER tag
219    pub const INTEGER: Tag = Tag::universal_primitive(2);
220    /// BIT STRING tag
221    pub const BIT_STRING: Tag = Tag::universal_primitive(3);
222    /// OCTET STRING tag
223    pub const OCTET_STRING: Tag = Tag::universal_primitive(4);
224    /// NULL tag
225    pub const NULL: Tag = Tag::universal_primitive(5);
226    /// OBJECT IDENTIFIER tag
227    pub const OID: Tag = Tag::universal_primitive(6);
228    /// REAL tag
229    pub const REAL: Tag = Tag::universal_primitive(9);
230    /// ENUMERATED tag
231    pub const ENUMERATED: Tag = Tag::universal_primitive(10);
232    /// UTF8String tag
233    pub const UTF8_STRING: Tag = Tag::universal_primitive(12);
234    /// RELATIVE-OID tag
235    pub const RELATIVE_OID: Tag = Tag::universal_primitive(13);
236    /// SEQUENCE tag
237    pub const SEQUENCE: Tag = Tag::universal_constructed(16);
238    /// SET tag
239    pub const SET: Tag = Tag::universal_constructed(17);
240}
241
242#[cfg(test)]
243mod tests {
244    use super::*;
245
246    #[test]
247    fn test_short_tag_roundtrip() {
248        for number in 0..31u32 {
249            let tag = Tag::universal_primitive(number);
250            let encoded = tag.encode();
251            let (decoded, len) = Tag::decode(&encoded).unwrap();
252            assert_eq!(len, 1);
253            assert_eq!(tag, decoded);
254        }
255    }
256
257    #[test]
258    fn test_long_tag_roundtrip() {
259        for number in [31u32, 127, 128, 255, 256, 16383, 16384, 0x1FFFFF] {
260            let tag = Tag::context_constructed(number);
261            let encoded = tag.encode();
262            let (decoded, _) = Tag::decode(&encoded).unwrap();
263            assert_eq!(tag, decoded);
264        }
265    }
266
267    #[test]
268    fn test_tag_classes() {
269        let tag = Tag::new(TagClass::Application, TagType::Constructed, 5);
270        assert!(tag.is_application());
271        assert!(tag.is_constructed());
272        
273        let tag = Tag::context_primitive(10);
274        assert!(tag.is_context());
275        assert!(tag.is_primitive());
276    }
277}