Skip to main content

facet_asn1/
serializer.rs

1//! ASN.1 DER serializer implementing FormatSerializer.
2
3extern crate alloc;
4
5use alloc::{string::String, vec::Vec};
6
7use facet_format::{FormatSerializer, ScalarValue, SerializeError};
8
9// ASN.1 Universal Tags
10const TAG_BOOLEAN: u8 = 0x01;
11const TAG_INTEGER: u8 = 0x02;
12const TAG_OCTET_STRING: u8 = 0x04;
13const TAG_NULL: u8 = 0x05;
14const TAG_REAL: u8 = 0x09;
15const TAG_UTF8STRING: u8 = 0x0C;
16const TAG_SEQUENCE: u8 = 0x10;
17
18const CONSTRUCTED_BIT: u8 = 0x20;
19
20// Real format special values
21const REAL_INFINITY: u8 = 0b01000000;
22const REAL_NEG_INFINITY: u8 = 0b01000001;
23const REAL_NAN: u8 = 0b01000010;
24const REAL_NEG_ZERO: u8 = 0b01000011;
25
26const F64_MANTISSA_MASK: u64 = 0b1111111111111111111111111111111111111111111111111111;
27
28/// ASN.1 serialization error.
29#[derive(Debug)]
30pub struct Asn1SerializeError {
31    message: String,
32}
33
34impl core::fmt::Display for Asn1SerializeError {
35    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
36        f.write_str(&self.message)
37    }
38}
39
40#[cfg(feature = "std")]
41impl std::error::Error for Asn1SerializeError {}
42
43/// ASN.1 DER serializer.
44pub struct Asn1Serializer {
45    out: Vec<u8>,
46    /// Stack of positions where container lengths need to be patched
47    stack: Vec<ContainerState>,
48}
49
50#[derive(Debug)]
51struct ContainerState {
52    /// Position of the length placeholder
53    len_pos: usize,
54}
55
56impl Asn1Serializer {
57    /// Create a new ASN.1 DER serializer.
58    pub const fn new() -> Self {
59        Self {
60            out: Vec::new(),
61            stack: Vec::new(),
62        }
63    }
64
65    /// Consume the serializer and return the output bytes.
66    pub fn finish(self) -> Vec<u8> {
67        self.out
68    }
69
70    /// Write a TLV with the given tag and value bytes.
71    fn write_tlv(&mut self, tag: u8, value: &[u8]) {
72        self.out.push(tag);
73        self.write_length(value.len());
74        self.out.extend_from_slice(value);
75    }
76
77    /// Write a length in DER format.
78    fn write_length(&mut self, len: usize) {
79        if len < 128 {
80            self.out.push(len as u8);
81        } else {
82            // Count how many bytes we need for the length
83            let mut temp = len;
84            let mut bytes_needed = 0;
85            while temp > 0 {
86                bytes_needed += 1;
87                temp >>= 8;
88            }
89            self.out.push(0x80 | bytes_needed);
90            let len_bytes = len.to_be_bytes();
91            self.out
92                .extend_from_slice(&len_bytes[8 - bytes_needed as usize..]);
93        }
94    }
95
96    /// Write a boolean value.
97    fn write_bool(&mut self, value: bool) {
98        let byte = if value { 0xFF } else { 0x00 };
99        self.write_tlv(TAG_BOOLEAN, &[byte]);
100    }
101
102    /// Write an integer value.
103    fn write_i64(&mut self, value: i64) {
104        let bytes = value.to_be_bytes();
105        // Find the minimal representation
106        let mut leading_redundant = 0;
107        for window in bytes.windows(2) {
108            let byte = window[0] as i8;
109            let bit = window[1] as i8 >> 7;
110            if byte ^ bit == 0 {
111                leading_redundant += 1;
112            } else {
113                break;
114            }
115        }
116        self.write_tlv(TAG_INTEGER, &bytes[leading_redundant..]);
117    }
118
119    /// Write an unsigned integer value.
120    fn write_u64(&mut self, value: u64) {
121        let bytes = value.to_be_bytes();
122        // Find leading zeros, but ensure we don't remove the sign bit
123        let mut start = 0;
124        while start < 7 && bytes[start] == 0 && (bytes[start + 1] & 0x80) == 0 {
125            start += 1;
126        }
127        // If high bit is set, need to add a leading zero to keep it positive
128        if bytes[start] & 0x80 != 0 {
129            self.out.push(TAG_INTEGER);
130            self.write_length(bytes.len() - start + 1);
131            self.out.push(0x00);
132            self.out.extend_from_slice(&bytes[start..]);
133        } else {
134            self.write_tlv(TAG_INTEGER, &bytes[start..]);
135        }
136    }
137
138    /// Write a real (f64) value.
139    fn write_f64(&mut self, value: f64) {
140        use core::num::FpCategory;
141        match value.classify() {
142            FpCategory::Nan => self.write_tlv(TAG_REAL, &[REAL_NAN]),
143            FpCategory::Infinite => {
144                if value.is_sign_positive() {
145                    self.write_tlv(TAG_REAL, &[REAL_INFINITY]);
146                } else {
147                    self.write_tlv(TAG_REAL, &[REAL_NEG_INFINITY]);
148                }
149            }
150            FpCategory::Zero | FpCategory::Subnormal => {
151                // Subnormals are rounded to zero in DER
152                if value.is_sign_positive() {
153                    self.write_tlv(TAG_REAL, &[]); // Positive zero is empty content
154                } else {
155                    self.write_tlv(TAG_REAL, &[REAL_NEG_ZERO]);
156                }
157            }
158            FpCategory::Normal => {
159                let sign_negative = value.is_sign_negative();
160                let bits = value.to_bits();
161                let mut exponent = ((bits >> 52) & 0b11111111111) as i16 - 1023;
162                let mut mantissa = bits & F64_MANTISSA_MASK | (0b1 << 52);
163                let mut normalization_factor = 52;
164                while mantissa & 0b1 == 0 {
165                    mantissa >>= 1;
166                    normalization_factor -= 1;
167                }
168                exponent -= normalization_factor;
169
170                let mantissa_bytes = mantissa.to_be_bytes();
171                let mut leading_zero_bytes = 0;
172                for byte in mantissa_bytes {
173                    if byte == 0 {
174                        leading_zero_bytes += 1;
175                    } else {
176                        break;
177                    }
178                }
179
180                let exponent_bytes = exponent.to_be_bytes();
181                let short_exp = exponent_bytes[0] == 0 || exponent_bytes[0] == 0xFF;
182                let content_len =
183                    2 + (!short_exp as usize) + mantissa_bytes.len() - leading_zero_bytes;
184
185                let structure_byte = 0b10000000 | ((sign_negative as u8) << 6) | (!short_exp as u8);
186
187                self.out.push(TAG_REAL);
188                self.write_length(content_len);
189                self.out.push(structure_byte);
190
191                if short_exp {
192                    self.out.push(exponent_bytes[1]);
193                } else {
194                    self.out.extend_from_slice(&exponent_bytes);
195                }
196                self.out
197                    .extend_from_slice(&mantissa_bytes[leading_zero_bytes..]);
198            }
199        }
200    }
201
202    /// Write a UTF-8 string.
203    fn write_str(&mut self, s: &str) {
204        self.write_tlv(TAG_UTF8STRING, s.as_bytes());
205    }
206
207    /// Write binary data as OCTET STRING.
208    fn write_bytes(&mut self, bytes: &[u8]) {
209        self.write_tlv(TAG_OCTET_STRING, bytes);
210    }
211
212    /// Write a NULL value.
213    fn write_null(&mut self) {
214        self.write_tlv(TAG_NULL, &[]);
215    }
216
217    /// Begin a SEQUENCE (for struct).
218    fn begin_sequence(&mut self) {
219        self.out.push(TAG_SEQUENCE | CONSTRUCTED_BIT);
220        let len_pos = self.out.len();
221        // Placeholder for length - we'll use long form to avoid resizing
222        self.out.extend_from_slice(&[0x84, 0, 0, 0, 0]); // Long form with 4 bytes
223        self.stack.push(ContainerState { len_pos });
224    }
225
226    /// End a SEQUENCE and patch the length.
227    fn end_sequence(&mut self) {
228        if let Some(state) = self.stack.pop() {
229            let content_len = self.out.len() - state.len_pos - 5; // Subtract the 5-byte placeholder
230
231            // Patch the length (we used 4-byte long form)
232            let len_bytes = (content_len as u32).to_be_bytes();
233            self.out[state.len_pos] = 0x84; // Long form, 4 bytes
234            self.out[state.len_pos + 1..state.len_pos + 5].copy_from_slice(&len_bytes);
235        }
236    }
237}
238
239impl Default for Asn1Serializer {
240    fn default() -> Self {
241        Self::new()
242    }
243}
244
245impl FormatSerializer for Asn1Serializer {
246    type Error = Asn1SerializeError;
247
248    fn begin_struct(&mut self) -> Result<(), Self::Error> {
249        self.begin_sequence();
250        Ok(())
251    }
252
253    fn field_key(&mut self, _key: &str) -> Result<(), Self::Error> {
254        // ASN.1 DER doesn't encode field names - fields are positional
255        Ok(())
256    }
257
258    fn end_struct(&mut self) -> Result<(), Self::Error> {
259        self.end_sequence();
260        Ok(())
261    }
262
263    fn begin_seq(&mut self) -> Result<(), Self::Error> {
264        self.begin_sequence();
265        Ok(())
266    }
267
268    fn end_seq(&mut self) -> Result<(), Self::Error> {
269        self.end_sequence();
270        Ok(())
271    }
272
273    fn is_self_describing(&self) -> bool {
274        false
275    }
276
277    fn scalar(&mut self, scalar: ScalarValue<'_>) -> Result<(), Self::Error> {
278        match scalar {
279            ScalarValue::Null | ScalarValue::Unit => self.write_null(),
280            ScalarValue::Bool(v) => self.write_bool(v),
281            ScalarValue::Char(c) => {
282                let mut buf = [0u8; 4];
283                self.write_str(c.encode_utf8(&mut buf));
284            }
285            ScalarValue::U64(n) => self.write_u64(n),
286            ScalarValue::I64(n) => self.write_i64(n),
287            ScalarValue::U128(n) => {
288                // ASN.1 supports arbitrary-precision integers
289                // For simplicity, convert to bytes
290                if n <= u64::MAX as u128 {
291                    self.write_u64(n as u64);
292                } else {
293                    let bytes = n.to_be_bytes();
294                    let mut start = 0;
295                    while start < 15 && bytes[start] == 0 {
296                        start += 1;
297                    }
298                    // Ensure positive by checking high bit
299                    if bytes[start] & 0x80 != 0 {
300                        self.out.push(TAG_INTEGER);
301                        self.write_length(bytes.len() - start + 1);
302                        self.out.push(0x00);
303                        self.out.extend_from_slice(&bytes[start..]);
304                    } else {
305                        self.write_tlv(TAG_INTEGER, &bytes[start..]);
306                    }
307                }
308            }
309            ScalarValue::I128(n) => {
310                if n >= i64::MIN as i128 && n <= i64::MAX as i128 {
311                    self.write_i64(n as i64);
312                } else {
313                    let bytes = n.to_be_bytes();
314                    let mut leading_redundant = 0;
315                    for window in bytes.windows(2) {
316                        let byte = window[0] as i8;
317                        let bit = window[1] as i8 >> 7;
318                        if byte ^ bit == 0 {
319                            leading_redundant += 1;
320                        } else {
321                            break;
322                        }
323                    }
324                    self.write_tlv(TAG_INTEGER, &bytes[leading_redundant..]);
325                }
326            }
327            ScalarValue::F64(n) => self.write_f64(n),
328            ScalarValue::Str(s) => self.write_str(&s),
329            ScalarValue::Bytes(bytes) => self.write_bytes(&bytes),
330        }
331        Ok(())
332    }
333
334    fn typed_scalar(
335        &mut self,
336        scalar_type: facet_core::ScalarType,
337        value: facet_reflect::Peek<'_, '_>,
338    ) -> Result<(), Self::Error> {
339        use facet_core::ScalarType;
340
341        // Handle unit type as an empty SEQUENCE (not NULL)
342        // This allows roundtrip since the deserializer expects tuples to be sequences
343        if matches!(scalar_type, ScalarType::Unit) {
344            self.write_tlv(TAG_SEQUENCE | CONSTRUCTED_BIT, &[]);
345            return Ok(());
346        }
347
348        // For other types, use the default implementation which calls scalar()
349        let scalar = match scalar_type {
350            ScalarType::Unit => unreachable!(), // Handled above
351            ScalarType::Bool => ScalarValue::Bool(*value.get::<bool>().unwrap()),
352            ScalarType::Char => {
353                let c = *value.get::<char>().unwrap();
354                let mut buf = [0u8; 4];
355                ScalarValue::Str(alloc::borrow::Cow::Owned(
356                    c.encode_utf8(&mut buf).to_string(),
357                ))
358            }
359            ScalarType::Str | ScalarType::String | ScalarType::CowStr => {
360                ScalarValue::Str(alloc::borrow::Cow::Borrowed(value.as_str().unwrap()))
361            }
362            ScalarType::F32 => ScalarValue::F64(*value.get::<f32>().unwrap() as f64),
363            ScalarType::F64 => ScalarValue::F64(*value.get::<f64>().unwrap()),
364            ScalarType::U8 => ScalarValue::U64(*value.get::<u8>().unwrap() as u64),
365            ScalarType::U16 => ScalarValue::U64(*value.get::<u16>().unwrap() as u64),
366            ScalarType::U32 => ScalarValue::U64(*value.get::<u32>().unwrap() as u64),
367            ScalarType::U64 => ScalarValue::U64(*value.get::<u64>().unwrap()),
368            ScalarType::U128 => ScalarValue::U128(*value.get::<u128>().unwrap()),
369            ScalarType::USize => ScalarValue::U64(*value.get::<usize>().unwrap() as u64),
370            ScalarType::I8 => ScalarValue::I64(*value.get::<i8>().unwrap() as i64),
371            ScalarType::I16 => ScalarValue::I64(*value.get::<i16>().unwrap() as i64),
372            ScalarType::I32 => ScalarValue::I64(*value.get::<i32>().unwrap() as i64),
373            ScalarType::I64 => ScalarValue::I64(*value.get::<i64>().unwrap()),
374            ScalarType::I128 => ScalarValue::I128(*value.get::<i128>().unwrap()),
375            ScalarType::ISize => ScalarValue::I64(*value.get::<isize>().unwrap() as i64),
376            _ => {
377                // For unknown scalar types, try to get a string representation
378                if let Some(s) = value.as_str() {
379                    ScalarValue::Str(alloc::borrow::Cow::Borrowed(s))
380                } else {
381                    ScalarValue::Null
382                }
383            }
384        };
385        self.scalar(scalar)
386    }
387}
388
389/// Serialize a value to ASN.1 DER bytes.
390pub fn to_vec<'facet, T>(value: &T) -> Result<Vec<u8>, SerializeError<Asn1SerializeError>>
391where
392    T: facet_core::Facet<'facet>,
393{
394    let mut ser = Asn1Serializer::new();
395    facet_format::serialize_root(&mut ser, facet_reflect::Peek::new(value))?;
396    Ok(ser.finish())
397}