Skip to main content

pq_oid/
types.rs

1//! Type definitions for post-quantum algorithm OIDs.
2//!
3//! This module provides ergonomic, type-safe enums for all PQ algorithms.
4
5use core::fmt;
6use core::str::FromStr;
7
8use crate::error::{Error, Result};
9
10// =============================================================================
11// Algorithm Type & Family
12// =============================================================================
13
14/// The type of cryptographic algorithm.
15#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
16pub enum AlgorithmType {
17    /// Key Encapsulation Mechanism
18    Kem,
19    /// Digital Signature
20    Sign,
21}
22
23impl fmt::Display for AlgorithmType {
24    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
25        match self {
26            AlgorithmType::Kem => write!(f, "kem"),
27            AlgorithmType::Sign => write!(f, "sign"),
28        }
29    }
30}
31
32/// The algorithm family.
33#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
34pub enum AlgorithmFamily {
35    MlKem,
36    MlDsa,
37    SlhDsa,
38}
39
40impl fmt::Display for AlgorithmFamily {
41    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
42        match self {
43            AlgorithmFamily::MlKem => write!(f, "ML-KEM"),
44            AlgorithmFamily::MlDsa => write!(f, "ML-DSA"),
45            AlgorithmFamily::SlhDsa => write!(f, "SLH-DSA"),
46        }
47    }
48}
49
50/// NIST security level.
51///
52/// NIST defines security levels 1, 2, 3, and 5 for post-quantum algorithms.
53/// Level 4 is not used.
54#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
55pub enum SecurityLevel {
56    /// Level 1: At least as hard to break as AES-128
57    Level1 = 1,
58    /// Level 2: At least as hard to break as SHA-256 collision
59    Level2 = 2,
60    /// Level 3: At least as hard to break as AES-192
61    Level3 = 3,
62    /// Level 5: At least as hard to break as AES-256
63    Level5 = 5,
64}
65
66impl SecurityLevel {
67    /// Returns the numeric value (1, 2, 3, or 5).
68    #[inline]
69    pub const fn as_u8(&self) -> u8 {
70        *self as u8
71    }
72}
73
74impl fmt::Display for SecurityLevel {
75    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
76        write!(f, "Level {}", self.as_u8())
77    }
78}
79
80// =============================================================================
81// Algorithm Info
82// =============================================================================
83
84/// Type-specific sizes for algorithms.
85///
86/// KEMs have ciphertext and shared secret sizes, while signing algorithms
87/// have signature sizes. This enum makes invalid states unrepresentable.
88#[derive(Debug, Clone, Copy, PartialEq, Eq)]
89pub enum AlgorithmSizes {
90    /// Sizes specific to Key Encapsulation Mechanisms.
91    Kem {
92        /// Ciphertext size in bytes
93        ciphertext_size: usize,
94        /// Shared secret size in bytes
95        shared_secret_size: usize,
96    },
97    /// Sizes specific to Digital Signature algorithms.
98    Sign {
99        /// Signature size in bytes
100        signature_size: usize,
101    },
102}
103
104/// Information about a specific algorithm variant.
105#[derive(Debug, Clone, Copy, PartialEq, Eq)]
106pub struct AlgorithmInfo {
107    /// The algorithm name string (e.g., "ML-KEM-512")
108    pub name: &'static str,
109    /// The OID in dotted notation
110    pub oid: &'static str,
111    /// The algorithm type (KEM or Sign)
112    pub algorithm_type: AlgorithmType,
113    /// The algorithm family
114    pub family: AlgorithmFamily,
115    /// NIST security level
116    pub security_level: SecurityLevel,
117    /// Public key size in bytes
118    pub public_key_size: usize,
119    /// Private/secret key size in bytes
120    pub private_key_size: usize,
121    /// Type-specific sizes (ciphertext for KEMs, signature for signing)
122    pub sizes: AlgorithmSizes,
123}
124
125// =============================================================================
126// ML-KEM (FIPS 203)
127// =============================================================================
128
129/// ML-KEM (Module-Lattice-Based Key-Encapsulation Mechanism) variants.
130///
131/// # Example
132/// ```
133/// use pq_oid::MlKem;
134/// use std::str::FromStr;
135///
136/// let alg = MlKem::from_str("ML-KEM-512").unwrap();
137/// assert_eq!(alg.oid(), "2.16.840.1.101.3.4.4.1");
138/// assert_eq!(alg.public_key_size(), 800);
139/// ```
140#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
141pub enum MlKem {
142    /// ML-KEM-512 (NIST Level 1)
143    Kem512,
144    /// ML-KEM-768 (NIST Level 3)
145    Kem768,
146    /// ML-KEM-1024 (NIST Level 5)
147    Kem1024,
148}
149
150impl MlKem {
151    /// Number of ML-KEM variants.
152    pub const COUNT: usize = 3;
153
154    /// All ML-KEM variants.
155    pub const ALL: &'static [MlKem] = &[MlKem::Kem512, MlKem::Kem768, MlKem::Kem1024];
156
157    /// Returns the algorithm name string.
158    #[inline]
159    pub const fn as_str(&self) -> &'static str {
160        match self {
161            MlKem::Kem512 => "ML-KEM-512",
162            MlKem::Kem768 => "ML-KEM-768",
163            MlKem::Kem1024 => "ML-KEM-1024",
164        }
165    }
166
167    /// Returns the OID in dotted notation.
168    #[inline]
169    pub const fn oid(&self) -> &'static str {
170        match self {
171            MlKem::Kem512 => "2.16.840.1.101.3.4.4.1",
172            MlKem::Kem768 => "2.16.840.1.101.3.4.4.2",
173            MlKem::Kem1024 => "2.16.840.1.101.3.4.4.3",
174        }
175    }
176
177    /// Returns the NIST security level.
178    #[inline]
179    pub const fn security_level(&self) -> SecurityLevel {
180        match self {
181            MlKem::Kem512 => SecurityLevel::Level1,
182            MlKem::Kem768 => SecurityLevel::Level3,
183            MlKem::Kem1024 => SecurityLevel::Level5,
184        }
185    }
186
187    /// Returns the public key size in bytes.
188    #[inline]
189    pub const fn public_key_size(&self) -> usize {
190        match self {
191            MlKem::Kem512 => 800,
192            MlKem::Kem768 => 1184,
193            MlKem::Kem1024 => 1568,
194        }
195    }
196
197    /// Returns the private key size in bytes.
198    #[inline]
199    pub const fn private_key_size(&self) -> usize {
200        match self {
201            MlKem::Kem512 => 1632,
202            MlKem::Kem768 => 2400,
203            MlKem::Kem1024 => 3168,
204        }
205    }
206
207    /// Returns the ciphertext size in bytes.
208    #[inline]
209    pub const fn ciphertext_size(&self) -> usize {
210        match self {
211            MlKem::Kem512 => 768,
212            MlKem::Kem768 => 1088,
213            MlKem::Kem1024 => 1568,
214        }
215    }
216
217    /// Returns the shared secret size in bytes (always 32 for ML-KEM).
218    #[inline]
219    pub const fn shared_secret_size(&self) -> usize {
220        32
221    }
222
223    /// Returns the complete algorithm info.
224    pub const fn info(&self) -> AlgorithmInfo {
225        AlgorithmInfo {
226            name: self.as_str(),
227            oid: self.oid(),
228            algorithm_type: AlgorithmType::Kem,
229            family: AlgorithmFamily::MlKem,
230            security_level: self.security_level(),
231            public_key_size: self.public_key_size(),
232            private_key_size: self.private_key_size(),
233            sizes: AlgorithmSizes::Kem {
234                ciphertext_size: self.ciphertext_size(),
235                shared_secret_size: self.shared_secret_size(),
236            },
237        }
238    }
239
240    /// Parse from an OID string.
241    pub fn from_oid(oid: &str) -> Result<Self> {
242        match oid {
243            "2.16.840.1.101.3.4.4.1" => Ok(MlKem::Kem512),
244            "2.16.840.1.101.3.4.4.2" => Ok(MlKem::Kem768),
245            "2.16.840.1.101.3.4.4.3" => Ok(MlKem::Kem1024),
246            _ => Err(Error::UnknownOid),
247        }
248    }
249}
250
251impl fmt::Display for MlKem {
252    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
253        write!(f, "{}", self.as_str())
254    }
255}
256
257impl AsRef<str> for MlKem {
258    fn as_ref(&self) -> &str {
259        self.as_str()
260    }
261}
262
263impl FromStr for MlKem {
264    type Err = Error;
265
266    fn from_str(s: &str) -> Result<Self> {
267        match s {
268            "ML-KEM-512" => Ok(MlKem::Kem512),
269            "ML-KEM-768" => Ok(MlKem::Kem768),
270            "ML-KEM-1024" => Ok(MlKem::Kem1024),
271            _ => Err(Error::UnknownAlgorithm),
272        }
273    }
274}
275
276impl TryFrom<&str> for MlKem {
277    type Error = Error;
278
279    fn try_from(s: &str) -> Result<Self> {
280        s.parse()
281    }
282}
283
284// =============================================================================
285// ML-DSA (FIPS 204)
286// =============================================================================
287
288/// ML-DSA (Module-Lattice-Based Digital Signature Algorithm) variants.
289///
290/// # Example
291/// ```
292/// use pq_oid::MlDsa;
293/// use std::str::FromStr;
294///
295/// let alg = MlDsa::from_str("ML-DSA-65").unwrap();
296/// assert_eq!(alg.oid(), "2.16.840.1.101.3.4.3.18");
297/// assert_eq!(alg.jose(), "ML-DSA-65");
298/// assert_eq!(alg.cose(), -49);
299/// ```
300#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
301pub enum MlDsa {
302    /// ML-DSA-44 (NIST Level 2)
303    Dsa44,
304    /// ML-DSA-65 (NIST Level 3)
305    Dsa65,
306    /// ML-DSA-87 (NIST Level 5)
307    Dsa87,
308}
309
310impl MlDsa {
311    /// Number of ML-DSA variants.
312    pub const COUNT: usize = 3;
313
314    /// All ML-DSA variants.
315    pub const ALL: &'static [MlDsa] = &[MlDsa::Dsa44, MlDsa::Dsa65, MlDsa::Dsa87];
316
317    /// Returns the algorithm name string.
318    #[inline]
319    pub const fn as_str(&self) -> &'static str {
320        match self {
321            MlDsa::Dsa44 => "ML-DSA-44",
322            MlDsa::Dsa65 => "ML-DSA-65",
323            MlDsa::Dsa87 => "ML-DSA-87",
324        }
325    }
326
327    /// Returns the OID in dotted notation.
328    #[inline]
329    pub const fn oid(&self) -> &'static str {
330        match self {
331            MlDsa::Dsa44 => "2.16.840.1.101.3.4.3.17",
332            MlDsa::Dsa65 => "2.16.840.1.101.3.4.3.18",
333            MlDsa::Dsa87 => "2.16.840.1.101.3.4.3.19",
334        }
335    }
336
337    /// Returns the NIST security level.
338    #[inline]
339    pub const fn security_level(&self) -> SecurityLevel {
340        match self {
341            MlDsa::Dsa44 => SecurityLevel::Level2,
342            MlDsa::Dsa65 => SecurityLevel::Level3,
343            MlDsa::Dsa87 => SecurityLevel::Level5,
344        }
345    }
346
347    /// Returns the public key size in bytes.
348    #[inline]
349    pub const fn public_key_size(&self) -> usize {
350        match self {
351            MlDsa::Dsa44 => 1312,
352            MlDsa::Dsa65 => 1952,
353            MlDsa::Dsa87 => 2592,
354        }
355    }
356
357    /// Returns the private key size in bytes.
358    #[inline]
359    pub const fn private_key_size(&self) -> usize {
360        match self {
361            MlDsa::Dsa44 => 2560,
362            MlDsa::Dsa65 => 4032,
363            MlDsa::Dsa87 => 4896,
364        }
365    }
366
367    /// Returns the signature size in bytes.
368    #[inline]
369    pub const fn signature_size(&self) -> usize {
370        match self {
371            MlDsa::Dsa44 => 2420,
372            MlDsa::Dsa65 => 3309,
373            MlDsa::Dsa87 => 4627,
374        }
375    }
376
377    /// Returns the JOSE algorithm identifier.
378    #[inline]
379    pub const fn jose(&self) -> &'static str {
380        self.as_str()
381    }
382
383    /// Returns the COSE algorithm number.
384    /// Ref: https://cose-wg.github.io/draft-ietf-cose-dilithium/draft-ietf-cose-dilithium.html#name-new-cose-algorithms
385    #[inline]
386    pub const fn cose(&self) -> i32 {
387        match self {
388            MlDsa::Dsa44 => -48,
389            MlDsa::Dsa65 => -49,
390            MlDsa::Dsa87 => -50,
391        }
392    }
393
394    /// Returns the complete algorithm info.
395    pub const fn info(&self) -> AlgorithmInfo {
396        AlgorithmInfo {
397            name: self.as_str(),
398            oid: self.oid(),
399            algorithm_type: AlgorithmType::Sign,
400            family: AlgorithmFamily::MlDsa,
401            security_level: self.security_level(),
402            public_key_size: self.public_key_size(),
403            private_key_size: self.private_key_size(),
404            sizes: AlgorithmSizes::Sign {
405                signature_size: self.signature_size(),
406            },
407        }
408    }
409
410    /// Parse from an OID string.
411    pub fn from_oid(oid: &str) -> Result<Self> {
412        match oid {
413            "2.16.840.1.101.3.4.3.17" => Ok(MlDsa::Dsa44),
414            "2.16.840.1.101.3.4.3.18" => Ok(MlDsa::Dsa65),
415            "2.16.840.1.101.3.4.3.19" => Ok(MlDsa::Dsa87),
416            _ => Err(Error::UnknownOid),
417        }
418    }
419
420    /// Parse from a JOSE algorithm identifier.
421    pub fn from_jose(jose: &str) -> Result<Self> {
422        match jose {
423            "ML-DSA-44" => Ok(MlDsa::Dsa44),
424            "ML-DSA-65" => Ok(MlDsa::Dsa65),
425            "ML-DSA-87" => Ok(MlDsa::Dsa87),
426            _ => Err(Error::UnknownJoseAlgorithm),
427        }
428    }
429
430    /// Parse from a COSE algorithm number.
431    #[inline]
432    pub const fn from_cose(cose: i32) -> Option<Self> {
433        match cose {
434            -48 => Some(MlDsa::Dsa44),
435            -49 => Some(MlDsa::Dsa65),
436            -50 => Some(MlDsa::Dsa87),
437            _ => None,
438        }
439    }
440}
441
442impl fmt::Display for MlDsa {
443    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
444        write!(f, "{}", self.as_str())
445    }
446}
447
448impl AsRef<str> for MlDsa {
449    fn as_ref(&self) -> &str {
450        self.as_str()
451    }
452}
453
454impl FromStr for MlDsa {
455    type Err = Error;
456
457    fn from_str(s: &str) -> Result<Self> {
458        match s {
459            "ML-DSA-44" => Ok(MlDsa::Dsa44),
460            "ML-DSA-65" => Ok(MlDsa::Dsa65),
461            "ML-DSA-87" => Ok(MlDsa::Dsa87),
462            _ => Err(Error::UnknownAlgorithm),
463        }
464    }
465}
466
467impl TryFrom<&str> for MlDsa {
468    type Error = Error;
469
470    fn try_from(s: &str) -> Result<Self> {
471        s.parse()
472    }
473}
474
475// =============================================================================
476// SLH-DSA (FIPS 205)
477// =============================================================================
478
479/// SLH-DSA (Stateless Hash-Based Digital Signature Algorithm) variants.
480///
481/// # Example
482/// ```
483/// use pq_oid::SlhDsa;
484/// use std::str::FromStr;
485///
486/// let alg = SlhDsa::from_str("SLH-DSA-SHA2-128s").unwrap();
487/// assert_eq!(alg.oid(), "2.16.840.1.101.3.4.3.20");
488/// assert_eq!(alg.hash_function(), pq_oid::HashFunction::Sha2);
489/// ```
490#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
491pub enum SlhDsa {
492    // SHA2 variants
493    Sha2_128s,
494    Sha2_128f,
495    Sha2_192s,
496    Sha2_192f,
497    Sha2_256s,
498    Sha2_256f,
499    // SHAKE variants
500    Shake128s,
501    Shake128f,
502    Shake192s,
503    Shake192f,
504    Shake256s,
505    Shake256f,
506}
507
508/// The hash function used by SLH-DSA.
509#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
510pub enum HashFunction {
511    Sha2,
512    Shake,
513}
514
515impl fmt::Display for HashFunction {
516    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
517        match self {
518            HashFunction::Sha2 => write!(f, "SHA2"),
519            HashFunction::Shake => write!(f, "SHAKE"),
520        }
521    }
522}
523
524/// The speed/size tradeoff mode for SLH-DSA.
525#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
526pub enum SlhDsaMode {
527    /// Small signatures (slower)
528    Small,
529    /// Fast signing (larger signatures)
530    Fast,
531}
532
533impl fmt::Display for SlhDsaMode {
534    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
535        match self {
536            SlhDsaMode::Small => write!(f, "s"),
537            SlhDsaMode::Fast => write!(f, "f"),
538        }
539    }
540}
541
542impl SlhDsa {
543    /// Number of SLH-DSA variants.
544    pub const COUNT: usize = 12;
545
546    /// All SLH-DSA variants.
547    pub const ALL: &'static [SlhDsa] = &[
548        SlhDsa::Sha2_128s,
549        SlhDsa::Sha2_128f,
550        SlhDsa::Sha2_192s,
551        SlhDsa::Sha2_192f,
552        SlhDsa::Sha2_256s,
553        SlhDsa::Sha2_256f,
554        SlhDsa::Shake128s,
555        SlhDsa::Shake128f,
556        SlhDsa::Shake192s,
557        SlhDsa::Shake192f,
558        SlhDsa::Shake256s,
559        SlhDsa::Shake256f,
560    ];
561
562    /// Returns the algorithm name string.
563    #[inline]
564    pub const fn as_str(&self) -> &'static str {
565        match self {
566            SlhDsa::Sha2_128s => "SLH-DSA-SHA2-128s",
567            SlhDsa::Sha2_128f => "SLH-DSA-SHA2-128f",
568            SlhDsa::Sha2_192s => "SLH-DSA-SHA2-192s",
569            SlhDsa::Sha2_192f => "SLH-DSA-SHA2-192f",
570            SlhDsa::Sha2_256s => "SLH-DSA-SHA2-256s",
571            SlhDsa::Sha2_256f => "SLH-DSA-SHA2-256f",
572            SlhDsa::Shake128s => "SLH-DSA-SHAKE-128s",
573            SlhDsa::Shake128f => "SLH-DSA-SHAKE-128f",
574            SlhDsa::Shake192s => "SLH-DSA-SHAKE-192s",
575            SlhDsa::Shake192f => "SLH-DSA-SHAKE-192f",
576            SlhDsa::Shake256s => "SLH-DSA-SHAKE-256s",
577            SlhDsa::Shake256f => "SLH-DSA-SHAKE-256f",
578        }
579    }
580
581    /// Returns the OID in dotted notation.
582    #[inline]
583    pub const fn oid(&self) -> &'static str {
584        match self {
585            SlhDsa::Sha2_128s => "2.16.840.1.101.3.4.3.20",
586            SlhDsa::Sha2_128f => "2.16.840.1.101.3.4.3.21",
587            SlhDsa::Sha2_192s => "2.16.840.1.101.3.4.3.22",
588            SlhDsa::Sha2_192f => "2.16.840.1.101.3.4.3.23",
589            SlhDsa::Sha2_256s => "2.16.840.1.101.3.4.3.24",
590            SlhDsa::Sha2_256f => "2.16.840.1.101.3.4.3.25",
591            SlhDsa::Shake128s => "2.16.840.1.101.3.4.3.26",
592            SlhDsa::Shake128f => "2.16.840.1.101.3.4.3.27",
593            SlhDsa::Shake192s => "2.16.840.1.101.3.4.3.28",
594            SlhDsa::Shake192f => "2.16.840.1.101.3.4.3.29",
595            SlhDsa::Shake256s => "2.16.840.1.101.3.4.3.30",
596            SlhDsa::Shake256f => "2.16.840.1.101.3.4.3.31",
597        }
598    }
599
600    /// Returns the hash function used.
601    #[inline]
602    pub const fn hash_function(&self) -> HashFunction {
603        match self {
604            SlhDsa::Sha2_128s
605            | SlhDsa::Sha2_128f
606            | SlhDsa::Sha2_192s
607            | SlhDsa::Sha2_192f
608            | SlhDsa::Sha2_256s
609            | SlhDsa::Sha2_256f => HashFunction::Sha2,
610            SlhDsa::Shake128s
611            | SlhDsa::Shake128f
612            | SlhDsa::Shake192s
613            | SlhDsa::Shake192f
614            | SlhDsa::Shake256s
615            | SlhDsa::Shake256f => HashFunction::Shake,
616        }
617    }
618
619    /// Returns the mode (small or fast).
620    #[inline]
621    pub const fn mode(&self) -> SlhDsaMode {
622        match self {
623            SlhDsa::Sha2_128s
624            | SlhDsa::Sha2_192s
625            | SlhDsa::Sha2_256s
626            | SlhDsa::Shake128s
627            | SlhDsa::Shake192s
628            | SlhDsa::Shake256s => SlhDsaMode::Small,
629            SlhDsa::Sha2_128f
630            | SlhDsa::Sha2_192f
631            | SlhDsa::Sha2_256f
632            | SlhDsa::Shake128f
633            | SlhDsa::Shake192f
634            | SlhDsa::Shake256f => SlhDsaMode::Fast,
635        }
636    }
637
638    /// Returns the NIST security level.
639    #[inline]
640    pub const fn security_level(&self) -> SecurityLevel {
641        match self {
642            SlhDsa::Sha2_128s | SlhDsa::Sha2_128f | SlhDsa::Shake128s | SlhDsa::Shake128f => {
643                SecurityLevel::Level1
644            }
645            SlhDsa::Sha2_192s | SlhDsa::Sha2_192f | SlhDsa::Shake192s | SlhDsa::Shake192f => {
646                SecurityLevel::Level3
647            }
648            SlhDsa::Sha2_256s | SlhDsa::Sha2_256f | SlhDsa::Shake256s | SlhDsa::Shake256f => {
649                SecurityLevel::Level5
650            }
651        }
652    }
653
654    /// Returns the public key size in bytes.
655    #[inline]
656    pub const fn public_key_size(&self) -> usize {
657        match self {
658            SlhDsa::Sha2_128s | SlhDsa::Sha2_128f | SlhDsa::Shake128s | SlhDsa::Shake128f => 32,
659            SlhDsa::Sha2_192s | SlhDsa::Sha2_192f | SlhDsa::Shake192s | SlhDsa::Shake192f => 48,
660            SlhDsa::Sha2_256s | SlhDsa::Sha2_256f | SlhDsa::Shake256s | SlhDsa::Shake256f => 64,
661        }
662    }
663
664    /// Returns the private key size in bytes.
665    #[inline]
666    pub const fn private_key_size(&self) -> usize {
667        match self {
668            SlhDsa::Sha2_128s | SlhDsa::Sha2_128f | SlhDsa::Shake128s | SlhDsa::Shake128f => 64,
669            SlhDsa::Sha2_192s | SlhDsa::Sha2_192f | SlhDsa::Shake192s | SlhDsa::Shake192f => 96,
670            SlhDsa::Sha2_256s | SlhDsa::Sha2_256f | SlhDsa::Shake256s | SlhDsa::Shake256f => 128,
671        }
672    }
673
674    /// Returns the signature size in bytes.
675    #[inline]
676    pub const fn signature_size(&self) -> usize {
677        match self {
678            SlhDsa::Sha2_128s | SlhDsa::Shake128s => 7856,
679            SlhDsa::Sha2_128f | SlhDsa::Shake128f => 17088,
680            SlhDsa::Sha2_192s | SlhDsa::Shake192s => 16224,
681            SlhDsa::Sha2_192f | SlhDsa::Shake192f => 35664,
682            SlhDsa::Sha2_256s | SlhDsa::Shake256s => 29792,
683            SlhDsa::Sha2_256f | SlhDsa::Shake256f => 49856,
684        }
685    }
686
687    /// Returns the complete algorithm info.
688    pub const fn info(&self) -> AlgorithmInfo {
689        AlgorithmInfo {
690            name: self.as_str(),
691            oid: self.oid(),
692            algorithm_type: AlgorithmType::Sign,
693            family: AlgorithmFamily::SlhDsa,
694            security_level: self.security_level(),
695            public_key_size: self.public_key_size(),
696            private_key_size: self.private_key_size(),
697            sizes: AlgorithmSizes::Sign {
698                signature_size: self.signature_size(),
699            },
700        }
701    }
702
703    /// Parse from an OID string.
704    pub fn from_oid(oid: &str) -> Result<Self> {
705        match oid {
706            "2.16.840.1.101.3.4.3.20" => Ok(SlhDsa::Sha2_128s),
707            "2.16.840.1.101.3.4.3.21" => Ok(SlhDsa::Sha2_128f),
708            "2.16.840.1.101.3.4.3.22" => Ok(SlhDsa::Sha2_192s),
709            "2.16.840.1.101.3.4.3.23" => Ok(SlhDsa::Sha2_192f),
710            "2.16.840.1.101.3.4.3.24" => Ok(SlhDsa::Sha2_256s),
711            "2.16.840.1.101.3.4.3.25" => Ok(SlhDsa::Sha2_256f),
712            "2.16.840.1.101.3.4.3.26" => Ok(SlhDsa::Shake128s),
713            "2.16.840.1.101.3.4.3.27" => Ok(SlhDsa::Shake128f),
714            "2.16.840.1.101.3.4.3.28" => Ok(SlhDsa::Shake192s),
715            "2.16.840.1.101.3.4.3.29" => Ok(SlhDsa::Shake192f),
716            "2.16.840.1.101.3.4.3.30" => Ok(SlhDsa::Shake256s),
717            "2.16.840.1.101.3.4.3.31" => Ok(SlhDsa::Shake256f),
718            _ => Err(Error::UnknownOid),
719        }
720    }
721}
722
723impl fmt::Display for SlhDsa {
724    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
725        write!(f, "{}", self.as_str())
726    }
727}
728
729impl AsRef<str> for SlhDsa {
730    fn as_ref(&self) -> &str {
731        self.as_str()
732    }
733}
734
735impl FromStr for SlhDsa {
736    type Err = Error;
737
738    fn from_str(s: &str) -> Result<Self> {
739        match s {
740            "SLH-DSA-SHA2-128s" => Ok(SlhDsa::Sha2_128s),
741            "SLH-DSA-SHA2-128f" => Ok(SlhDsa::Sha2_128f),
742            "SLH-DSA-SHA2-192s" => Ok(SlhDsa::Sha2_192s),
743            "SLH-DSA-SHA2-192f" => Ok(SlhDsa::Sha2_192f),
744            "SLH-DSA-SHA2-256s" => Ok(SlhDsa::Sha2_256s),
745            "SLH-DSA-SHA2-256f" => Ok(SlhDsa::Sha2_256f),
746            "SLH-DSA-SHAKE-128s" => Ok(SlhDsa::Shake128s),
747            "SLH-DSA-SHAKE-128f" => Ok(SlhDsa::Shake128f),
748            "SLH-DSA-SHAKE-192s" => Ok(SlhDsa::Shake192s),
749            "SLH-DSA-SHAKE-192f" => Ok(SlhDsa::Shake192f),
750            "SLH-DSA-SHAKE-256s" => Ok(SlhDsa::Shake256s),
751            "SLH-DSA-SHAKE-256f" => Ok(SlhDsa::Shake256f),
752            _ => Err(Error::UnknownAlgorithm),
753        }
754    }
755}
756
757impl TryFrom<&str> for SlhDsa {
758    type Error = Error;
759
760    fn try_from(s: &str) -> Result<Self> {
761        s.parse()
762    }
763}
764
765// =============================================================================
766// Unified Algorithm Enum
767// =============================================================================
768
769/// A unified enum representing any supported PQ algorithm.
770///
771/// This provides a single type that can represent any algorithm from any family,
772/// useful for generic processing.
773///
774/// # Example
775/// ```
776/// use pq_oid::{Algorithm, MlKem, MlDsa};
777/// use std::str::FromStr;
778///
779/// let alg = Algorithm::from_str("ML-KEM-512").unwrap();
780/// assert_eq!(alg.oid(), "2.16.840.1.101.3.4.4.1");
781///
782/// // Convert from family-specific enums
783/// let alg: Algorithm = MlDsa::Dsa65.into();
784/// assert_eq!(alg.as_str(), "ML-DSA-65");
785/// ```
786#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
787pub enum Algorithm {
788    MlKem(MlKem),
789    MlDsa(MlDsa),
790    SlhDsa(SlhDsa),
791}
792
793impl Algorithm {
794    /// Total number of supported algorithms.
795    pub const COUNT: usize = 18;
796
797    /// Number of KEM algorithms.
798    pub const KEM_COUNT: usize = MlKem::COUNT;
799
800    /// Number of signing algorithms.
801    pub const SIGNATURE_COUNT: usize = MlDsa::COUNT + SlhDsa::COUNT;
802
803    /// All supported algorithms (18 total).
804    pub const ALL: &'static [Algorithm] = &[
805        // ML-KEM (3)
806        Algorithm::MlKem(MlKem::Kem512),
807        Algorithm::MlKem(MlKem::Kem768),
808        Algorithm::MlKem(MlKem::Kem1024),
809        // ML-DSA (3)
810        Algorithm::MlDsa(MlDsa::Dsa44),
811        Algorithm::MlDsa(MlDsa::Dsa65),
812        Algorithm::MlDsa(MlDsa::Dsa87),
813        // SLH-DSA SHA2 (6)
814        Algorithm::SlhDsa(SlhDsa::Sha2_128s),
815        Algorithm::SlhDsa(SlhDsa::Sha2_128f),
816        Algorithm::SlhDsa(SlhDsa::Sha2_192s),
817        Algorithm::SlhDsa(SlhDsa::Sha2_192f),
818        Algorithm::SlhDsa(SlhDsa::Sha2_256s),
819        Algorithm::SlhDsa(SlhDsa::Sha2_256f),
820        // SLH-DSA SHAKE (6)
821        Algorithm::SlhDsa(SlhDsa::Shake128s),
822        Algorithm::SlhDsa(SlhDsa::Shake128f),
823        Algorithm::SlhDsa(SlhDsa::Shake192s),
824        Algorithm::SlhDsa(SlhDsa::Shake192f),
825        Algorithm::SlhDsa(SlhDsa::Shake256s),
826        Algorithm::SlhDsa(SlhDsa::Shake256f),
827    ];
828
829    /// All KEM algorithms.
830    pub const ALL_KEMS: &'static [Algorithm] = &[
831        Algorithm::MlKem(MlKem::Kem512),
832        Algorithm::MlKem(MlKem::Kem768),
833        Algorithm::MlKem(MlKem::Kem1024),
834    ];
835
836    /// All signing algorithms.
837    pub const ALL_SIGNATURES: &'static [Algorithm] = &[
838        Algorithm::MlDsa(MlDsa::Dsa44),
839        Algorithm::MlDsa(MlDsa::Dsa65),
840        Algorithm::MlDsa(MlDsa::Dsa87),
841        Algorithm::SlhDsa(SlhDsa::Sha2_128s),
842        Algorithm::SlhDsa(SlhDsa::Sha2_128f),
843        Algorithm::SlhDsa(SlhDsa::Sha2_192s),
844        Algorithm::SlhDsa(SlhDsa::Sha2_192f),
845        Algorithm::SlhDsa(SlhDsa::Sha2_256s),
846        Algorithm::SlhDsa(SlhDsa::Sha2_256f),
847        Algorithm::SlhDsa(SlhDsa::Shake128s),
848        Algorithm::SlhDsa(SlhDsa::Shake128f),
849        Algorithm::SlhDsa(SlhDsa::Shake192s),
850        Algorithm::SlhDsa(SlhDsa::Shake192f),
851        Algorithm::SlhDsa(SlhDsa::Shake256s),
852        Algorithm::SlhDsa(SlhDsa::Shake256f),
853    ];
854
855    /// Returns an iterator over all supported algorithms.
856    pub fn all() -> impl Iterator<Item = Algorithm> {
857        Self::ALL.iter().copied()
858    }
859
860    /// Returns an iterator over all KEM algorithms.
861    pub fn kems() -> impl Iterator<Item = Algorithm> {
862        Self::ALL_KEMS.iter().copied()
863    }
864
865    /// Returns an iterator over all signing algorithms.
866    pub fn signatures() -> impl Iterator<Item = Algorithm> {
867        Self::ALL_SIGNATURES.iter().copied()
868    }
869
870    /// Returns the algorithm name string.
871    #[inline]
872    pub const fn as_str(&self) -> &'static str {
873        match self {
874            Algorithm::MlKem(a) => a.as_str(),
875            Algorithm::MlDsa(a) => a.as_str(),
876            Algorithm::SlhDsa(a) => a.as_str(),
877        }
878    }
879
880    /// Returns the OID in dotted notation.
881    #[inline]
882    pub const fn oid(&self) -> &'static str {
883        match self {
884            Algorithm::MlKem(a) => a.oid(),
885            Algorithm::MlDsa(a) => a.oid(),
886            Algorithm::SlhDsa(a) => a.oid(),
887        }
888    }
889
890    /// Returns the algorithm type.
891    #[inline]
892    pub const fn algorithm_type(&self) -> AlgorithmType {
893        match self {
894            Algorithm::MlKem(_) => AlgorithmType::Kem,
895            Algorithm::MlDsa(_) | Algorithm::SlhDsa(_) => AlgorithmType::Sign,
896        }
897    }
898
899    /// Returns the algorithm family.
900    #[inline]
901    pub const fn family(&self) -> AlgorithmFamily {
902        match self {
903            Algorithm::MlKem(_) => AlgorithmFamily::MlKem,
904            Algorithm::MlDsa(_) => AlgorithmFamily::MlDsa,
905            Algorithm::SlhDsa(_) => AlgorithmFamily::SlhDsa,
906        }
907    }
908
909    /// Returns the NIST security level.
910    #[inline]
911    pub const fn security_level(&self) -> SecurityLevel {
912        match self {
913            Algorithm::MlKem(a) => a.security_level(),
914            Algorithm::MlDsa(a) => a.security_level(),
915            Algorithm::SlhDsa(a) => a.security_level(),
916        }
917    }
918
919    /// Returns the public key size in bytes.
920    #[inline]
921    pub const fn public_key_size(&self) -> usize {
922        match self {
923            Algorithm::MlKem(a) => a.public_key_size(),
924            Algorithm::MlDsa(a) => a.public_key_size(),
925            Algorithm::SlhDsa(a) => a.public_key_size(),
926        }
927    }
928
929    /// Returns the private key size in bytes.
930    #[inline]
931    pub const fn private_key_size(&self) -> usize {
932        match self {
933            Algorithm::MlKem(a) => a.private_key_size(),
934            Algorithm::MlDsa(a) => a.private_key_size(),
935            Algorithm::SlhDsa(a) => a.private_key_size(),
936        }
937    }
938
939    /// Returns the complete algorithm info.
940    pub const fn info(&self) -> AlgorithmInfo {
941        match self {
942            Algorithm::MlKem(a) => a.info(),
943            Algorithm::MlDsa(a) => a.info(),
944            Algorithm::SlhDsa(a) => a.info(),
945        }
946    }
947
948    /// Parse from an OID string.
949    pub fn from_oid(oid: &str) -> Result<Self> {
950        MlKem::from_oid(oid)
951            .map(Algorithm::MlKem)
952            .or_else(|_| MlDsa::from_oid(oid).map(Algorithm::MlDsa))
953            .or_else(|_| SlhDsa::from_oid(oid).map(Algorithm::SlhDsa))
954    }
955
956    /// Returns this as an MlKem if it is one.
957    #[inline]
958    pub const fn as_ml_kem(&self) -> Option<MlKem> {
959        match self {
960            Algorithm::MlKem(a) => Some(*a),
961            _ => None,
962        }
963    }
964
965    /// Returns this as an MlDsa if it is one.
966    #[inline]
967    pub const fn as_ml_dsa(&self) -> Option<MlDsa> {
968        match self {
969            Algorithm::MlDsa(a) => Some(*a),
970            _ => None,
971        }
972    }
973
974    /// Returns this as an SlhDsa if it is one.
975    #[inline]
976    pub const fn as_slh_dsa(&self) -> Option<SlhDsa> {
977        match self {
978            Algorithm::SlhDsa(a) => Some(*a),
979            _ => None,
980        }
981    }
982}
983
984impl fmt::Display for Algorithm {
985    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
986        write!(f, "{}", self.as_str())
987    }
988}
989
990impl AsRef<str> for Algorithm {
991    fn as_ref(&self) -> &str {
992        self.as_str()
993    }
994}
995
996impl FromStr for Algorithm {
997    type Err = Error;
998
999    fn from_str(s: &str) -> Result<Self> {
1000        MlKem::from_str(s)
1001            .map(Algorithm::MlKem)
1002            .or_else(|_| MlDsa::from_str(s).map(Algorithm::MlDsa))
1003            .or_else(|_| SlhDsa::from_str(s).map(Algorithm::SlhDsa))
1004    }
1005}
1006
1007impl TryFrom<&str> for Algorithm {
1008    type Error = Error;
1009
1010    fn try_from(s: &str) -> Result<Self> {
1011        s.parse()
1012    }
1013}
1014
1015impl From<MlKem> for Algorithm {
1016    fn from(a: MlKem) -> Self {
1017        Algorithm::MlKem(a)
1018    }
1019}
1020
1021impl From<MlDsa> for Algorithm {
1022    fn from(a: MlDsa) -> Self {
1023        Algorithm::MlDsa(a)
1024    }
1025}
1026
1027impl From<SlhDsa> for Algorithm {
1028    fn from(a: SlhDsa) -> Self {
1029        Algorithm::SlhDsa(a)
1030    }
1031}
1032
1033#[cfg(test)]
1034mod tests {
1035    use super::*;
1036
1037    #[test]
1038    fn test_ml_kem_from_str() {
1039        let alg: MlKem = "ML-KEM-512".parse().unwrap();
1040        assert_eq!(alg, MlKem::Kem512);
1041        assert_eq!(alg.oid(), "2.16.840.1.101.3.4.4.1");
1042        assert_eq!(alg.public_key_size(), 800);
1043    }
1044
1045    #[test]
1046    fn test_ml_kem_try_from() {
1047        let alg: MlKem = "ML-KEM-768".try_into().unwrap();
1048        assert_eq!(alg, MlKem::Kem768);
1049    }
1050
1051    #[test]
1052    fn test_ml_kem_display() {
1053        assert_eq!(MlKem::Kem1024.to_string(), "ML-KEM-1024");
1054    }
1055
1056    #[test]
1057    fn test_ml_dsa_from_str() {
1058        let alg: MlDsa = "ML-DSA-65".parse().unwrap();
1059        assert_eq!(alg, MlDsa::Dsa65);
1060        assert_eq!(alg.jose(), "ML-DSA-65");
1061        assert_eq!(alg.cose(), -49);
1062    }
1063
1064    #[test]
1065    fn test_ml_dsa_jose_cose() {
1066        let alg = MlDsa::from_jose("ML-DSA-44").unwrap();
1067        assert_eq!(alg, MlDsa::Dsa44);
1068
1069        let alg = MlDsa::from_cose(-50).unwrap();
1070        assert_eq!(alg, MlDsa::Dsa87);
1071    }
1072
1073    #[test]
1074    fn test_slh_dsa_from_str() {
1075        let alg: SlhDsa = "SLH-DSA-SHA2-128s".parse().unwrap();
1076        assert_eq!(alg, SlhDsa::Sha2_128s);
1077        assert_eq!(alg.hash_function(), HashFunction::Sha2);
1078        assert_eq!(alg.mode(), SlhDsaMode::Small);
1079    }
1080
1081    #[test]
1082    fn test_algorithm_unified() {
1083        let alg: Algorithm = "ML-KEM-512".parse().unwrap();
1084        assert_eq!(alg.family(), AlgorithmFamily::MlKem);
1085        assert_eq!(alg.algorithm_type(), AlgorithmType::Kem);
1086
1087        let alg: Algorithm = MlDsa::Dsa65.into();
1088        assert_eq!(alg.as_str(), "ML-DSA-65");
1089    }
1090
1091    #[test]
1092    fn test_algorithm_from_oid() {
1093        let alg = Algorithm::from_oid("2.16.840.1.101.3.4.4.1").unwrap();
1094        assert_eq!(alg, Algorithm::MlKem(MlKem::Kem512));
1095
1096        let alg = MlDsa::from_oid("2.16.840.1.101.3.4.3.17").unwrap();
1097        assert_eq!(alg, MlDsa::Dsa44);
1098    }
1099
1100    #[test]
1101    fn test_algorithm_all() {
1102        assert_eq!(Algorithm::all().count(), 18);
1103        assert_eq!(Algorithm::kems().count(), 3);
1104        assert_eq!(Algorithm::signatures().count(), 15);
1105    }
1106
1107    #[test]
1108    fn test_algorithm_info() {
1109        let info = MlKem::Kem512.info();
1110        assert_eq!(info.name, "ML-KEM-512");
1111        assert_eq!(info.public_key_size, 800);
1112        assert!(matches!(
1113            info.sizes,
1114            AlgorithmSizes::Kem {
1115                ciphertext_size: 768,
1116                ..
1117            }
1118        ));
1119    }
1120
1121    #[test]
1122    fn test_as_ref_str() {
1123        fn takes_str(s: &str) -> &str {
1124            s
1125        }
1126
1127        assert_eq!(takes_str(MlKem::Kem512.as_ref()), "ML-KEM-512");
1128        assert_eq!(takes_str(MlDsa::Dsa44.as_ref()), "ML-DSA-44");
1129    }
1130
1131    #[test]
1132    fn test_algorithm_cast() {
1133        let alg = Algorithm::MlKem(MlKem::Kem512);
1134        assert!(alg.as_ml_kem().is_some());
1135        assert!(alg.as_ml_dsa().is_none());
1136    }
1137}