Skip to main content

pq_oid/
encoding.rs

1//! OID DER encoding and decoding functions.
2//!
3//! DER encoding for OIDs:
4//! - First two arcs are combined: (first * 40) + second
5//! - Each subsequent arc is encoded in base-128 with high bit set on continuation bytes
6
7use crate::error::{Error, Result};
8
9/// Encode a single arc value directly into the output buffer.
10/// Max u64 in base-128 requires ceil(64/7) = 10 bytes.
11fn encode_arc(value: u64, output: &mut Vec<u8>) {
12    if value == 0 {
13        output.push(0);
14        return;
15    }
16
17    let mut buf = [0u8; 10];
18    let mut len = 0;
19    let mut v = value;
20
21    // Encode in reverse order (LSB first)
22    while v > 0 {
23        buf[len] = (v & 0x7f) as u8;
24        v >>= 7;
25        len += 1;
26    }
27
28    // Write in correct order (MSB first), setting high bit on all but last byte
29    for i in (1..len).rev() {
30        output.push(buf[i] | 0x80);
31    }
32    output.push(buf[0]); // Last byte without high bit
33}
34
35/// Parse and validate an arc string, returning the numeric value.
36fn parse_arc(part: &str) -> Result<u64> {
37    // Verify no leading zeros (e.g., "01" should fail)
38    if part.len() > 1 && part.starts_with('0') {
39        return Err(Error::InvalidOid("invalid arc with leading zero"));
40    }
41
42    part.parse()
43        .map_err(|_| Error::InvalidOid("non-numeric arc"))
44}
45
46/// Encode an OID string to DER bytes, writing to the provided buffer.
47///
48/// This is the low-allocation version that writes directly to `out`.
49///
50/// # Arguments
51/// * `oid` - OID string in dotted notation (e.g., "2.16.840.1.101.3.4.4.1")
52/// * `out` - Output buffer to write encoded bytes to
53///
54/// # Errors
55/// Returns an error if the OID format is invalid
56pub fn encode_oid_to(oid: &str, out: &mut Vec<u8>) -> Result<()> {
57    if oid.is_empty() || oid.trim().is_empty() {
58        return Err(Error::InvalidOid("empty string"));
59    }
60
61    let mut parts = oid.split('.');
62
63    // Parse first arc
64    let first_str = parts.next().ok_or(Error::InvalidOid("empty string"))?;
65    let first = parse_arc(first_str)?;
66
67    // Parse second arc
68    let second_str = parts
69        .next()
70        .ok_or(Error::InvalidOid("must have at least 2 arcs"))?;
71    let second = parse_arc(second_str)?;
72
73    // First arc must be 0, 1, or 2
74    if first > 2 {
75        return Err(Error::InvalidOid("first arc must be 0, 1, or 2"));
76    }
77
78    // When first arc is 0 or 1, second arc must be < 40
79    if first < 2 && second > 39 {
80        return Err(Error::InvalidOid(
81            "second arc must be <= 39 when first arc is 0 or 1",
82        ));
83    }
84
85    // Encode combined first two arcs (use checked arithmetic to prevent overflow)
86    let combined = first
87        .checked_mul(40)
88        .and_then(|v| v.checked_add(second))
89        .ok_or(Error::InvalidOid("arc value overflow"))?;
90    encode_arc(combined, out);
91
92    // Encode remaining arcs
93    for part in parts {
94        let arc = parse_arc(part)?;
95        encode_arc(arc, out);
96    }
97
98    Ok(())
99}
100
101/// Encode an OID string to DER bytes (without the tag and length).
102///
103/// # Arguments
104/// * `oid` - OID string in dotted notation (e.g., "2.16.840.1.101.3.4.4.1")
105///
106/// # Returns
107/// DER-encoded OID bytes
108///
109/// # Errors
110/// Returns an error if the OID format is invalid
111pub fn encode_oid(oid: &str) -> Result<Vec<u8>> {
112    let mut result = Vec::new();
113    encode_oid_to(oid, &mut result)?;
114    Ok(result)
115}
116
117/// Decode DER bytes to an OID string.
118///
119/// # Arguments
120/// * `bytes` - DER-encoded OID bytes (without tag and length)
121///
122/// # Returns
123/// OID string in dotted notation
124///
125/// # Errors
126/// Returns an error if the bytes are invalid
127pub fn decode_oid(bytes: &[u8]) -> Result<String> {
128    if bytes.is_empty() {
129        return Err(Error::InvalidOidBytes("empty"));
130    }
131
132    let mut arcs = Vec::new();
133    let mut i = 0;
134
135    // Max continuation bytes for u64: ceil(64/7) = 10
136    const MAX_ARC_BYTES: usize = 10;
137
138    // Decode first byte(s) (combined first two arcs)
139    let mut value: u64 = 0;
140    let mut arc_bytes = 0;
141    while i < bytes.len() {
142        arc_bytes += 1;
143        if arc_bytes > MAX_ARC_BYTES {
144            return Err(Error::InvalidOidBytes("arc value too large"));
145        }
146
147        let byte = bytes[i];
148        value = (value << 7) | ((byte & 0x7f) as u64);
149        i += 1;
150
151        if byte & 0x80 == 0 {
152            // End of this arc
153            break;
154        }
155    }
156
157    // Check if we ended in the middle of a multi-byte value
158    if i > 0 && bytes[i - 1] & 0x80 != 0 {
159        return Err(Error::InvalidOidBytes("incomplete multi-byte encoding"));
160    }
161
162    // Split combined value into first two arcs
163    let (first, second) = if value < 40 {
164        (0, value)
165    } else if value < 80 {
166        (1, value - 40)
167    } else {
168        (2, value - 80)
169    };
170
171    arcs.push(first);
172    arcs.push(second);
173
174    // Decode remaining arcs
175    while i < bytes.len() {
176        value = 0;
177        arc_bytes = 0;
178        let start_index = i;
179
180        while i < bytes.len() {
181            arc_bytes += 1;
182            if arc_bytes > MAX_ARC_BYTES {
183                return Err(Error::InvalidOidBytes("arc value too large"));
184            }
185
186            let byte = bytes[i];
187            value = (value << 7) | ((byte & 0x7f) as u64);
188            i += 1;
189
190            if byte & 0x80 == 0 {
191                // End of this arc
192                break;
193            }
194        }
195
196        // Check if we ended in the middle of a multi-byte value
197        if i > start_index && bytes[i - 1] & 0x80 != 0 {
198            return Err(Error::InvalidOidBytes("incomplete multi-byte encoding"));
199        }
200
201        arcs.push(value);
202    }
203
204    Ok(arcs
205        .iter()
206        .map(|a| a.to_string())
207        .collect::<Vec<_>>()
208        .join("."))
209}
210
211#[cfg(test)]
212mod tests {
213    use super::*;
214    use crate::{Algorithm, MlDsa, MlKem};
215
216    // DER encoding reference:
217    // OID 2.16.840.1.101.3.4.4.1 encodes as: [0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x04, 0x01]
218    // - First two arcs combined: 2*40 + 16 = 96 = 0x60
219    // - 840 in base-128: 0x86, 0x48
220    // - Remaining arcs: 1=0x01, 101=0x65, 3=0x03, 4=0x04, 4=0x04, 1=0x01
221
222    #[test]
223    fn test_encode_ml_kem_512_exact_bytes() {
224        let bytes = encode_oid(MlKem::Kem512.oid()).unwrap();
225        assert_eq!(
226            bytes,
227            vec![0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x04, 0x01]
228        );
229    }
230
231    #[test]
232    fn test_encode_ml_kem_768_exact_bytes() {
233        let bytes = encode_oid(MlKem::Kem768.oid()).unwrap();
234        assert_eq!(
235            bytes,
236            vec![0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x04, 0x02]
237        );
238    }
239
240    #[test]
241    fn test_encode_ml_kem_1024_exact_bytes() {
242        let bytes = encode_oid(MlKem::Kem1024.oid()).unwrap();
243        assert_eq!(
244            bytes,
245            vec![0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x04, 0x03]
246        );
247    }
248
249    #[test]
250    fn test_encode_ml_dsa_44_exact_bytes() {
251        let bytes = encode_oid(MlDsa::Dsa44.oid()).unwrap();
252        assert_eq!(
253            bytes,
254            vec![0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x03, 0x11]
255        );
256    }
257
258    #[test]
259    fn test_encode_ml_dsa_65_exact_bytes() {
260        let bytes = encode_oid(MlDsa::Dsa65.oid()).unwrap();
261        assert_eq!(
262            bytes,
263            vec![0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x03, 0x12]
264        );
265    }
266
267    #[test]
268    fn test_encode_ml_dsa_87_exact_bytes() {
269        let bytes = encode_oid(MlDsa::Dsa87.oid()).unwrap();
270        assert_eq!(
271            bytes,
272            vec![0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x03, 0x13]
273        );
274    }
275
276    #[test]
277    fn test_decode_ml_kem_512_exact_bytes() {
278        let bytes = [0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x04, 0x01];
279        assert_eq!(decode_oid(&bytes).unwrap(), MlKem::Kem512.oid());
280    }
281
282    #[test]
283    fn test_decode_ml_dsa_44_exact_bytes() {
284        let bytes = [0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x03, 0x11];
285        assert_eq!(decode_oid(&bytes).unwrap(), MlDsa::Dsa44.oid());
286    }
287
288    #[test]
289    fn test_encode_decode_roundtrip() {
290        let test_oids = [
291            "2.16.840.1.101.3.4.4.1",
292            "2.16.840.1.101.3.4.3.17",
293            "1.2.3.4.5",
294            "0.9.2342",
295        ];
296
297        for oid in &test_oids {
298            let encoded = encode_oid(oid).unwrap();
299            let decoded = decode_oid(&encoded).unwrap();
300            assert_eq!(*oid, decoded);
301        }
302    }
303
304    #[test]
305    fn test_roundtrip_all_algorithm_oids() {
306        for alg in Algorithm::all() {
307            let oid_str = alg.oid();
308            let encoded = encode_oid(oid_str).unwrap();
309            let decoded = decode_oid(&encoded).unwrap();
310            assert_eq!(oid_str, decoded);
311        }
312    }
313
314    #[test]
315    fn test_encode_invalid_empty() {
316        assert!(matches!(encode_oid(""), Err(Error::InvalidOid(_))));
317    }
318
319    #[test]
320    fn test_encode_invalid_single_arc() {
321        assert!(matches!(encode_oid("2"), Err(Error::InvalidOid(_))));
322    }
323
324    #[test]
325    fn test_encode_invalid_first_arc() {
326        assert!(matches!(encode_oid("3.5.6"), Err(Error::InvalidOid(_))));
327    }
328
329    #[test]
330    fn test_encode_invalid_second_arc() {
331        assert!(matches!(encode_oid("1.50.6"), Err(Error::InvalidOid(_))));
332    }
333
334    #[test]
335    fn test_encode_invalid_non_numeric() {
336        assert!(matches!(
337            encode_oid("2.16.abc.1"),
338            Err(Error::InvalidOid(_))
339        ));
340    }
341
342    #[test]
343    fn test_decode_empty() {
344        assert!(matches!(decode_oid(&[]), Err(Error::InvalidOidBytes(_))));
345    }
346
347    #[test]
348    fn test_decode_incomplete() {
349        assert!(matches!(
350            decode_oid(&[0x86, 0x48, 0x80]),
351            Err(Error::InvalidOidBytes(_))
352        ));
353    }
354
355    #[test]
356    fn test_decode_incomplete_at_start() {
357        assert!(matches!(
358            decode_oid(&[0x60, 0x86]),
359            Err(Error::InvalidOidBytes(_))
360        ));
361    }
362}