Skip to main content

dilithium/
safe_api.rs

1//! High-level safe Rust SDK for ML-DSA (FIPS 204) / CRYSTALS-Dilithium.
2//!
3//! Supports both **pure ML-DSA** (§6.1) and **HashML-DSA** (§6.2) modes.
4//!
5//! # Quick Start
6//!
7//! ```rust
8//! use dilithium::{MlDsaKeyPair, ML_DSA_44};
9//!
10//! let kp = MlDsaKeyPair::generate(ML_DSA_44).unwrap();
11//! let sig = kp.sign(b"Hello, post-quantum world!", b"").unwrap();
12//! assert!(MlDsaKeyPair::verify(
13//!     kp.public_key(), &sig, b"Hello, post-quantum world!", b"",
14//!     ML_DSA_44
15//! ));
16//! ```
17//!
18//! # Security Levels
19//!
20//! | FIPS 204 Name | NIST Level | Public Key | Secret Key | Signature |
21//! |---------------|------------|------------|------------|-----------|
22//! | ML-DSA-44     | 2          | 1312 B     | 2560 B     | 2420 B    |
23//! | ML-DSA-65     | 3          | 1952 B     | 4032 B     | 3309 B    |
24//! | ML-DSA-87     | 5          | 2592 B     | 4896 B     | 4627 B    |
25
26use alloc::{vec, vec::Vec};
27use core::fmt;
28
29use zeroize::{Zeroize, Zeroizing};
30
31pub use crate::params::DilithiumMode;
32use crate::params::*;
33use crate::sign;
34use crate::symmetric::shake256;
35
36/// Errors returned by the ML-DSA API.
37#[derive(Debug, Clone, Copy, PartialEq, Eq)]
38#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
39pub enum DilithiumError {
40    /// Random number generation failed.
41    RandomError,
42    /// Invalid key or signature format.
43    FormatError,
44    /// Signature verification failed.
45    BadSignature,
46    /// An argument was invalid (e.g. context > 255 bytes).
47    BadArgument,
48    /// Key validation failed (§7.1).
49    InvalidKey,
50}
51
52impl fmt::Display for DilithiumError {
53    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
54        match self {
55            Self::RandomError => write!(f, "random number generation failed"),
56            Self::FormatError => write!(f, "invalid format"),
57            Self::BadSignature => write!(f, "invalid signature"),
58            Self::BadArgument => write!(f, "invalid argument"),
59            Self::InvalidKey => write!(f, "key validation failed"),
60        }
61    }
62}
63
64#[cfg(feature = "std")]
65impl std::error::Error for DilithiumError {}
66
67/// An ML-DSA key pair (private key + public key).
68///
69/// The private key bytes are **automatically zeroized on drop** (FIPS 204 §7).
70///
71/// Type aliases: `MlDsaKeyPair` (FIPS 204 naming) = `DilithiumKeyPair` (legacy).
72#[derive(Debug, Clone)]
73#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
74pub struct DilithiumKeyPair {
75    #[cfg_attr(
76        feature = "serde",
77        serde(
78            serialize_with = "serde_zeroizing::serialize",
79            deserialize_with = "serde_zeroizing::deserialize"
80        )
81    )]
82    privkey: Zeroizing<Vec<u8>>,
83    pubkey: Vec<u8>,
84    mode: DilithiumMode,
85}
86
87/// Helper module for serde on `Zeroizing<Vec<u8>>`.
88#[cfg(feature = "serde")]
89mod serde_zeroizing {
90    use super::*;
91    use serde::{Deserialize, Deserializer, Serialize, Serializer};
92
93    pub fn serialize<S: Serializer>(val: &Zeroizing<Vec<u8>>, s: S) -> Result<S::Ok, S::Error> {
94        // Deref to &Vec<u8> which implements Serialize
95        let inner: &Vec<u8> = val;
96        inner.serialize(s)
97    }
98
99    pub fn deserialize<'de, D: Deserializer<'de>>(d: D) -> Result<Zeroizing<Vec<u8>>, D::Error> {
100        let v = Vec::<u8>::deserialize(d)?;
101        Ok(Zeroizing::new(v))
102    }
103}
104
105/// FIPS 204 name alias for `DilithiumKeyPair`.
106pub type MlDsaKeyPair = DilithiumKeyPair;
107
108/// An ML-DSA / Dilithium signature.
109#[derive(Debug, Clone, PartialEq, Eq)]
110#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
111pub struct DilithiumSignature {
112    data: Vec<u8>,
113}
114
115/// FIPS 204 name alias for `DilithiumSignature`.
116pub type MlDsaSignature = DilithiumSignature;
117
118impl DilithiumKeyPair {
119    /// Generate a new key pair using OS entropy (FIPS 204 §6.1 `KeyGen`).
120    pub fn generate(mode: DilithiumMode) -> Result<Self, DilithiumError> {
121        let mut seed = [0u8; SEEDBYTES];
122        getrandom(&mut seed).map_err(|()| DilithiumError::RandomError)?;
123        let result = Self::generate_deterministic(mode, &seed);
124        seed.zeroize();
125        Ok(result)
126    }
127
128    /// Generate a key pair deterministically from a seed.
129    #[must_use]
130    pub fn generate_deterministic(mode: DilithiumMode, seed: &[u8; SEEDBYTES]) -> Self {
131        let (pk, sk) = sign::keypair(mode, seed);
132        DilithiumKeyPair {
133            privkey: Zeroizing::new(sk),
134            pubkey: pk,
135            mode,
136        }
137    }
138
139    /// Sign a message using pure ML-DSA (FIPS 204 §6.1 ML-DSA.Sign).
140    ///
141    /// Context string `ctx` is optional (max 255 bytes).
142    pub fn sign(&self, msg: &[u8], ctx: &[u8]) -> Result<DilithiumSignature, DilithiumError> {
143        if ctx.len() > 255 {
144            return Err(DilithiumError::BadArgument);
145        }
146
147        let mut rnd = [0u8; RNDBYTES];
148        getrandom(&mut rnd).map_err(|()| DilithiumError::RandomError)?;
149
150        let mut sig = vec![0u8; self.mode.signature_bytes()];
151        let ret = sign::sign_signature(self.mode, &mut sig, msg, ctx, &rnd, &self.privkey);
152        rnd.zeroize();
153
154        if ret != 0 {
155            return Err(DilithiumError::BadArgument);
156        }
157
158        Ok(DilithiumSignature { data: sig })
159    }
160
161    /// Sign a message using HashML-DSA (FIPS 204 §6.2 HashML-DSA.Sign).
162    ///
163    /// The message is internally hashed with SHA-512 before signing.
164    /// Context string `ctx` is optional (max 255 bytes).
165    pub fn sign_prehash(
166        &self,
167        msg: &[u8],
168        ctx: &[u8],
169    ) -> Result<DilithiumSignature, DilithiumError> {
170        if ctx.len() > 255 {
171            return Err(DilithiumError::BadArgument);
172        }
173
174        let mut rnd = [0u8; RNDBYTES];
175        getrandom(&mut rnd).map_err(|()| DilithiumError::RandomError)?;
176
177        let mut sig = vec![0u8; self.mode.signature_bytes()];
178        let ret = sign::sign_hash(self.mode, &mut sig, msg, ctx, &rnd, &self.privkey);
179        rnd.zeroize();
180
181        if ret != 0 {
182            return Err(DilithiumError::BadArgument);
183        }
184
185        Ok(DilithiumSignature { data: sig })
186    }
187
188    /// Sign deterministically (for testing / reproducibility).
189    pub fn sign_deterministic(
190        &self,
191        msg: &[u8],
192        ctx: &[u8],
193        rnd: &[u8; RNDBYTES],
194    ) -> Result<DilithiumSignature, DilithiumError> {
195        if ctx.len() > 255 {
196            return Err(DilithiumError::BadArgument);
197        }
198
199        let mut sig = vec![0u8; self.mode.signature_bytes()];
200        sign::sign_signature(self.mode, &mut sig, msg, ctx, rnd, &self.privkey);
201        Ok(DilithiumSignature { data: sig })
202    }
203
204    /// Verify a pure ML-DSA signature (FIPS 204 §6.1 ML-DSA.Verify).
205    #[must_use]
206    pub fn verify(
207        pk: &[u8],
208        sig: &DilithiumSignature,
209        msg: &[u8],
210        ctx: &[u8],
211        mode: DilithiumMode,
212    ) -> bool {
213        if pk.len() != mode.public_key_bytes() {
214            return false;
215        }
216        if sig.data.len() != mode.signature_bytes() {
217            return false;
218        }
219        sign::verify(mode, &sig.data, msg, ctx, pk)
220    }
221
222    /// Verify a HashML-DSA signature (FIPS 204 §6.2 HashML-DSA.Verify).
223    #[must_use]
224    pub fn verify_prehash(
225        pk: &[u8],
226        sig: &DilithiumSignature,
227        msg: &[u8],
228        ctx: &[u8],
229        mode: DilithiumMode,
230    ) -> bool {
231        if pk.len() != mode.public_key_bytes() {
232            return false;
233        }
234        if sig.data.len() != mode.signature_bytes() {
235            return false;
236        }
237        sign::verify_hash(mode, &sig.data, msg, ctx, pk)
238    }
239
240    /// Get the encoded public key bytes.
241    #[must_use]
242    pub fn public_key(&self) -> &[u8] {
243        &self.pubkey
244    }
245
246    /// Get the encoded private key bytes.
247    #[must_use]
248    pub fn private_key(&self) -> &[u8] {
249        &self.privkey
250    }
251
252    /// Get the security mode.
253    #[must_use]
254    pub fn mode(&self) -> DilithiumMode {
255        self.mode
256    }
257
258    /// Reconstruct from private + public key bytes with validation (FIPS 204 §7.1).
259    ///
260    /// Validates that:
261    /// 1. Key sizes match the expected values for the given mode.
262    /// 2. The public key embedded in the secret key is consistent.
263    /// 3. The secret key's `tr = H(pk)` field is consistent.
264    pub fn from_keys(
265        privkey: &[u8],
266        pubkey: &[u8],
267        mode: DilithiumMode,
268    ) -> Result<Self, DilithiumError> {
269        // Check sizes
270        if privkey.len() != mode.secret_key_bytes() {
271            return Err(DilithiumError::FormatError);
272        }
273        if pubkey.len() != mode.public_key_bytes() {
274            return Err(DilithiumError::FormatError);
275        }
276
277        // FIPS 204 §7.1: Validate key consistency
278        // The secret key starts with rho (SEEDBYTES) which must match
279        // the public key's rho
280        let sk_rho = &privkey[..SEEDBYTES];
281        let pk_rho = &pubkey[..SEEDBYTES];
282        if sk_rho != pk_rho {
283            return Err(DilithiumError::InvalidKey);
284        }
285
286        // Validate tr = H(pk) — tr is at offset 2*SEEDBYTES in sk layout: (rho, key, tr, ...)
287        let tr_offset = 2 * SEEDBYTES;
288        let sk_tr = &privkey[tr_offset..tr_offset + TRBYTES];
289        let mut expected_tr = [0u8; TRBYTES];
290        shake256(&mut expected_tr, pubkey);
291        if sk_tr != &expected_tr[..] {
292            return Err(DilithiumError::InvalidKey);
293        }
294
295        Ok(DilithiumKeyPair {
296            privkey: Zeroizing::new(privkey.to_vec()),
297            pubkey: pubkey.to_vec(),
298            mode,
299        })
300    }
301
302    // ── Serialization ──────────────────────────────────────────────
303
304    /// Serialize the full key pair to bytes: `[mode_tag(1) | pk | sk]`.
305    ///
306    /// The mode tag encodes the security level so deserialization
307    /// can automatically select the correct parameters.
308    #[must_use]
309    pub fn to_bytes(&self) -> Vec<u8> {
310        let mut buf = Vec::with_capacity(1 + self.pubkey.len() + self.privkey.len());
311        buf.push(self.mode.mode_tag());
312        buf.extend_from_slice(&self.pubkey);
313        buf.extend_from_slice(&self.privkey);
314        buf
315    }
316
317    /// Deserialize a key pair from the format produced by [`to_bytes`](Self::to_bytes).
318    pub fn from_bytes(data: &[u8]) -> Result<Self, DilithiumError> {
319        if data.is_empty() {
320            return Err(DilithiumError::FormatError);
321        }
322        let mode = DilithiumMode::from_tag(data[0]).ok_or(DilithiumError::FormatError)?;
323        let pk_len = mode.public_key_bytes();
324        let sk_len = mode.secret_key_bytes();
325        if data.len() != 1 + pk_len + sk_len {
326            return Err(DilithiumError::FormatError);
327        }
328        let pk = &data[1..=pk_len];
329        let sk = &data[1 + pk_len..];
330        Self::from_keys(sk, pk, mode)
331    }
332
333    /// Export only the public key bytes with a mode tag: `[mode_tag(1) | pk]`.
334    #[must_use]
335    pub fn public_key_bytes(&self) -> Vec<u8> {
336        let mut buf = Vec::with_capacity(1 + self.pubkey.len());
337        buf.push(self.mode.mode_tag());
338        buf.extend_from_slice(&self.pubkey);
339        buf
340    }
341
342    /// Create a verify-only handle from tagged public key bytes.
343    pub fn from_public_key(data: &[u8]) -> Result<(DilithiumMode, Vec<u8>), DilithiumError> {
344        if data.is_empty() {
345            return Err(DilithiumError::FormatError);
346        }
347        let mode = DilithiumMode::from_tag(data[0]).ok_or(DilithiumError::FormatError)?;
348        if data.len() != 1 + mode.public_key_bytes() {
349            return Err(DilithiumError::FormatError);
350        }
351        Ok((mode, data[1..].to_vec()))
352    }
353}
354
355impl DilithiumSignature {
356    /// Get the raw signature bytes.
357    #[must_use]
358    pub fn as_bytes(&self) -> &[u8] {
359        &self.data
360    }
361
362    /// Create from raw bytes (no validation — use verify to check).
363    #[must_use]
364    pub fn from_bytes(data: Vec<u8>) -> Self {
365        Self { data }
366    }
367
368    /// Create from a byte slice (copies).
369    #[must_use]
370    pub fn from_slice(data: &[u8]) -> Self {
371        Self {
372            data: data.to_vec(),
373        }
374    }
375
376    /// Signature length in bytes.
377    #[must_use]
378    pub fn len(&self) -> usize {
379        self.data.len()
380    }
381
382    /// Returns true if the signature is empty.
383    #[must_use]
384    pub fn is_empty(&self) -> bool {
385        self.data.is_empty()
386    }
387}
388
389/// Fill buffer with random bytes via `getrandom` crate (WASM compatible).
390fn getrandom(buf: &mut [u8]) -> Result<(), ()> {
391    ::getrandom::getrandom(buf).map_err(|_| ())
392}