1use crate::error::{Error, Result};
8
9fn encode_arc(value: u64, output: &mut Vec<u8>) {
12 if value == 0 {
13 output.push(0);
14 return;
15 }
16
17 let mut buf = [0u8; 10];
18 let mut len = 0;
19 let mut v = value;
20
21 while v > 0 {
23 buf[len] = (v & 0x7f) as u8;
24 v >>= 7;
25 len += 1;
26 }
27
28 for i in (1..len).rev() {
30 output.push(buf[i] | 0x80);
31 }
32 output.push(buf[0]); }
34
35fn parse_arc(part: &str) -> Result<u64> {
37 if part.len() > 1 && part.starts_with('0') {
39 return Err(Error::InvalidOid("invalid arc with leading zero"));
40 }
41
42 part.parse()
43 .map_err(|_| Error::InvalidOid("non-numeric arc"))
44}
45
46pub fn encode_oid_to(oid: &str, out: &mut Vec<u8>) -> Result<()> {
57 if oid.is_empty() || oid.trim().is_empty() {
58 return Err(Error::InvalidOid("empty string"));
59 }
60
61 let mut parts = oid.split('.');
62
63 let first_str = parts.next().ok_or(Error::InvalidOid("empty string"))?;
65 let first = parse_arc(first_str)?;
66
67 let second_str = parts
69 .next()
70 .ok_or(Error::InvalidOid("must have at least 2 arcs"))?;
71 let second = parse_arc(second_str)?;
72
73 if first > 2 {
75 return Err(Error::InvalidOid("first arc must be 0, 1, or 2"));
76 }
77
78 if first < 2 && second > 39 {
80 return Err(Error::InvalidOid(
81 "second arc must be <= 39 when first arc is 0 or 1",
82 ));
83 }
84
85 let combined = first
87 .checked_mul(40)
88 .and_then(|v| v.checked_add(second))
89 .ok_or(Error::InvalidOid("arc value overflow"))?;
90 encode_arc(combined, out);
91
92 for part in parts {
94 let arc = parse_arc(part)?;
95 encode_arc(arc, out);
96 }
97
98 Ok(())
99}
100
101pub fn encode_oid(oid: &str) -> Result<Vec<u8>> {
112 let mut result = Vec::new();
113 encode_oid_to(oid, &mut result)?;
114 Ok(result)
115}
116
117pub fn decode_oid(bytes: &[u8]) -> Result<String> {
128 if bytes.is_empty() {
129 return Err(Error::InvalidOidBytes("empty"));
130 }
131
132 let mut arcs = Vec::new();
133 let mut i = 0;
134
135 const MAX_ARC_BYTES: usize = 10;
137
138 let mut value: u64 = 0;
140 let mut arc_bytes = 0;
141 while i < bytes.len() {
142 arc_bytes += 1;
143 if arc_bytes > MAX_ARC_BYTES {
144 return Err(Error::InvalidOidBytes("arc value too large"));
145 }
146
147 let byte = bytes[i];
148 value = (value << 7) | ((byte & 0x7f) as u64);
149 i += 1;
150
151 if byte & 0x80 == 0 {
152 break;
154 }
155 }
156
157 if i > 0 && bytes[i - 1] & 0x80 != 0 {
159 return Err(Error::InvalidOidBytes("incomplete multi-byte encoding"));
160 }
161
162 let (first, second) = if value < 40 {
164 (0, value)
165 } else if value < 80 {
166 (1, value - 40)
167 } else {
168 (2, value - 80)
169 };
170
171 arcs.push(first);
172 arcs.push(second);
173
174 while i < bytes.len() {
176 value = 0;
177 arc_bytes = 0;
178 let start_index = i;
179
180 while i < bytes.len() {
181 arc_bytes += 1;
182 if arc_bytes > MAX_ARC_BYTES {
183 return Err(Error::InvalidOidBytes("arc value too large"));
184 }
185
186 let byte = bytes[i];
187 value = (value << 7) | ((byte & 0x7f) as u64);
188 i += 1;
189
190 if byte & 0x80 == 0 {
191 break;
193 }
194 }
195
196 if i > start_index && bytes[i - 1] & 0x80 != 0 {
198 return Err(Error::InvalidOidBytes("incomplete multi-byte encoding"));
199 }
200
201 arcs.push(value);
202 }
203
204 Ok(arcs
205 .iter()
206 .map(|a| a.to_string())
207 .collect::<Vec<_>>()
208 .join("."))
209}
210
211#[cfg(test)]
212mod tests {
213 use super::*;
214 use crate::{Algorithm, MlDsa, MlKem};
215
216 #[test]
223 fn test_encode_ml_kem_512_exact_bytes() {
224 let bytes = encode_oid(MlKem::Kem512.oid()).unwrap();
225 assert_eq!(
226 bytes,
227 vec![0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x04, 0x01]
228 );
229 }
230
231 #[test]
232 fn test_encode_ml_kem_768_exact_bytes() {
233 let bytes = encode_oid(MlKem::Kem768.oid()).unwrap();
234 assert_eq!(
235 bytes,
236 vec![0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x04, 0x02]
237 );
238 }
239
240 #[test]
241 fn test_encode_ml_kem_1024_exact_bytes() {
242 let bytes = encode_oid(MlKem::Kem1024.oid()).unwrap();
243 assert_eq!(
244 bytes,
245 vec![0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x04, 0x03]
246 );
247 }
248
249 #[test]
250 fn test_encode_ml_dsa_44_exact_bytes() {
251 let bytes = encode_oid(MlDsa::Dsa44.oid()).unwrap();
252 assert_eq!(
253 bytes,
254 vec![0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x03, 0x11]
255 );
256 }
257
258 #[test]
259 fn test_encode_ml_dsa_65_exact_bytes() {
260 let bytes = encode_oid(MlDsa::Dsa65.oid()).unwrap();
261 assert_eq!(
262 bytes,
263 vec![0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x03, 0x12]
264 );
265 }
266
267 #[test]
268 fn test_encode_ml_dsa_87_exact_bytes() {
269 let bytes = encode_oid(MlDsa::Dsa87.oid()).unwrap();
270 assert_eq!(
271 bytes,
272 vec![0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x03, 0x13]
273 );
274 }
275
276 #[test]
277 fn test_decode_ml_kem_512_exact_bytes() {
278 let bytes = [0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x04, 0x01];
279 assert_eq!(decode_oid(&bytes).unwrap(), MlKem::Kem512.oid());
280 }
281
282 #[test]
283 fn test_decode_ml_dsa_44_exact_bytes() {
284 let bytes = [0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x03, 0x11];
285 assert_eq!(decode_oid(&bytes).unwrap(), MlDsa::Dsa44.oid());
286 }
287
288 #[test]
289 fn test_encode_decode_roundtrip() {
290 let test_oids = [
291 "2.16.840.1.101.3.4.4.1",
292 "2.16.840.1.101.3.4.3.17",
293 "1.2.3.4.5",
294 "0.9.2342",
295 ];
296
297 for oid in &test_oids {
298 let encoded = encode_oid(oid).unwrap();
299 let decoded = decode_oid(&encoded).unwrap();
300 assert_eq!(*oid, decoded);
301 }
302 }
303
304 #[test]
305 fn test_roundtrip_all_algorithm_oids() {
306 for alg in Algorithm::all() {
307 let oid_str = alg.oid();
308 let encoded = encode_oid(oid_str).unwrap();
309 let decoded = decode_oid(&encoded).unwrap();
310 assert_eq!(oid_str, decoded);
311 }
312 }
313
314 #[test]
315 fn test_encode_invalid_empty() {
316 assert!(matches!(encode_oid(""), Err(Error::InvalidOid(_))));
317 }
318
319 #[test]
320 fn test_encode_invalid_single_arc() {
321 assert!(matches!(encode_oid("2"), Err(Error::InvalidOid(_))));
322 }
323
324 #[test]
325 fn test_encode_invalid_first_arc() {
326 assert!(matches!(encode_oid("3.5.6"), Err(Error::InvalidOid(_))));
327 }
328
329 #[test]
330 fn test_encode_invalid_second_arc() {
331 assert!(matches!(encode_oid("1.50.6"), Err(Error::InvalidOid(_))));
332 }
333
334 #[test]
335 fn test_encode_invalid_non_numeric() {
336 assert!(matches!(
337 encode_oid("2.16.abc.1"),
338 Err(Error::InvalidOid(_))
339 ));
340 }
341
342 #[test]
343 fn test_decode_empty() {
344 assert!(matches!(decode_oid(&[]), Err(Error::InvalidOidBytes(_))));
345 }
346
347 #[test]
348 fn test_decode_incomplete() {
349 assert!(matches!(
350 decode_oid(&[0x86, 0x48, 0x80]),
351 Err(Error::InvalidOidBytes(_))
352 ));
353 }
354
355 #[test]
356 fn test_decode_incomplete_at_start() {
357 assert!(matches!(
358 decode_oid(&[0x60, 0x86]),
359 Err(Error::InvalidOidBytes(_))
360 ));
361 }
362}