Skip to main content

dilithium/
safe_api.rs

1//! High-level safe Rust SDK for ML-DSA (FIPS 204) / CRYSTALS-Dilithium.
2//!
3//! Supports both **pure ML-DSA** (§6.1) and **HashML-DSA** (§6.2) modes.
4//!
5//! # Quick Start
6//!
7//! ```rust
8//! use dilithium::{MlDsaKeyPair, ML_DSA_44};
9//!
10//! let kp = MlDsaKeyPair::generate(ML_DSA_44).unwrap();
11//! let sig = kp.sign(b"Hello, post-quantum world!", b"").unwrap();
12//! assert!(MlDsaKeyPair::verify(
13//!     kp.public_key(), &sig, b"Hello, post-quantum world!", b"",
14//!     ML_DSA_44
15//! ));
16//! ```
17//!
18//! # Security Levels
19//!
20//! | FIPS 204 Name | NIST Level | Public Key | Secret Key | Signature |
21//! |---------------|------------|------------|------------|-----------|
22//! | ML-DSA-44     | 2          | 1312 B     | 2560 B     | 2420 B    |
23//! | ML-DSA-65     | 3          | 1952 B     | 4032 B     | 3309 B    |
24//! | ML-DSA-87     | 5          | 2592 B     | 4896 B     | 4627 B    |
25
26use alloc::{vec, vec::Vec};
27use core::fmt;
28
29use zeroize::Zeroizing;
30#[cfg(feature = "getrandom")]
31use zeroize::Zeroize;
32
33pub use crate::params::DilithiumMode;
34use crate::params::*;
35use crate::sign;
36use crate::symmetric::shake256;
37
38/// Errors returned by the ML-DSA API.
39#[derive(Debug, Clone, Copy, PartialEq, Eq)]
40#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
41pub enum DilithiumError {
42    /// Random number generation failed.
43    RandomError,
44    /// Invalid key or signature format.
45    FormatError,
46    /// Signature verification failed.
47    BadSignature,
48    /// An argument was invalid (e.g. context > 255 bytes).
49    BadArgument,
50    /// Key validation failed (§7.1).
51    InvalidKey,
52}
53
54impl fmt::Display for DilithiumError {
55    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
56        match self {
57            Self::RandomError => write!(f, "random number generation failed"),
58            Self::FormatError => write!(f, "invalid format"),
59            Self::BadSignature => write!(f, "invalid signature"),
60            Self::BadArgument => write!(f, "invalid argument"),
61            Self::InvalidKey => write!(f, "key validation failed"),
62        }
63    }
64}
65
66#[cfg(feature = "std")]
67impl std::error::Error for DilithiumError {}
68
69/// An ML-DSA key pair (private key + public key).
70///
71/// The private key bytes are **automatically zeroized on drop** (FIPS 204 §7).
72///
73/// Type aliases: `MlDsaKeyPair` (FIPS 204 naming) = `DilithiumKeyPair` (legacy).
74#[derive(Debug, Clone)]
75#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
76pub struct DilithiumKeyPair {
77    #[cfg_attr(
78        feature = "serde",
79        serde(
80            serialize_with = "serde_zeroizing::serialize",
81            deserialize_with = "serde_zeroizing::deserialize"
82        )
83    )]
84    privkey: Zeroizing<Vec<u8>>,
85    pubkey: Vec<u8>,
86    mode: DilithiumMode,
87}
88
89/// Helper module for serde on `Zeroizing<Vec<u8>>`.
90#[cfg(feature = "serde")]
91mod serde_zeroizing {
92    use super::*;
93    use serde::{Deserialize, Deserializer, Serialize, Serializer};
94
95    pub fn serialize<S: Serializer>(val: &Zeroizing<Vec<u8>>, s: S) -> Result<S::Ok, S::Error> {
96        // Deref to &Vec<u8> which implements Serialize
97        let inner: &Vec<u8> = val;
98        inner.serialize(s)
99    }
100
101    pub fn deserialize<'de, D: Deserializer<'de>>(d: D) -> Result<Zeroizing<Vec<u8>>, D::Error> {
102        let v = Vec::<u8>::deserialize(d)?;
103        Ok(Zeroizing::new(v))
104    }
105}
106
107/// FIPS 204 name alias for `DilithiumKeyPair`.
108pub type MlDsaKeyPair = DilithiumKeyPair;
109
110/// An ML-DSA / Dilithium signature.
111#[derive(Debug, Clone, PartialEq, Eq)]
112#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
113pub struct DilithiumSignature {
114    data: Vec<u8>,
115}
116
117/// FIPS 204 name alias for `DilithiumSignature`.
118pub type MlDsaSignature = DilithiumSignature;
119
120impl DilithiumKeyPair {
121    /// Generate a new key pair using OS entropy (FIPS 204 §6.1 `KeyGen`).
122    ///
123    /// Requires the `std` or `getrandom` feature (enabled by default).
124    #[cfg(feature = "getrandom")]
125    pub fn generate(mode: DilithiumMode) -> Result<Self, DilithiumError> {
126        let mut seed = [0u8; SEEDBYTES];
127        getrandom(&mut seed).map_err(|()| DilithiumError::RandomError)?;
128        let result = Self::generate_deterministic(mode, &seed);
129        seed.zeroize();
130        Ok(result)
131    }
132
133    /// Generate a key pair deterministically from a seed.
134    #[must_use]
135    pub fn generate_deterministic(mode: DilithiumMode, seed: &[u8; SEEDBYTES]) -> Self {
136        let (pk, sk) = sign::keypair(mode, seed);
137        DilithiumKeyPair {
138            privkey: Zeroizing::new(sk),
139            pubkey: pk,
140            mode,
141        }
142    }
143
144    /// Sign a message using pure ML-DSA (FIPS 204 §6.1 ML-DSA.Sign).
145    ///
146    /// Context string `ctx` is optional (max 255 bytes).
147    /// Requires the `std` or `getrandom` feature for randomized signing.
148    #[cfg(feature = "getrandom")]
149    pub fn sign(&self, msg: &[u8], ctx: &[u8]) -> Result<DilithiumSignature, DilithiumError> {
150        if ctx.len() > 255 {
151            return Err(DilithiumError::BadArgument);
152        }
153
154        let mut rnd = [0u8; RNDBYTES];
155        getrandom(&mut rnd).map_err(|()| DilithiumError::RandomError)?;
156
157        let mut sig = vec![0u8; self.mode.signature_bytes()];
158        let ret = sign::sign_signature(self.mode, &mut sig, msg, ctx, &rnd, &self.privkey);
159        rnd.zeroize();
160
161        if ret != 0 {
162            return Err(DilithiumError::BadArgument);
163        }
164
165        Ok(DilithiumSignature { data: sig })
166    }
167
168    /// Sign a message using HashML-DSA (FIPS 204 §6.2 HashML-DSA.Sign).
169    ///
170    /// The message is internally hashed with SHA-512 before signing.
171    /// Context string `ctx` is optional (max 255 bytes).
172    /// Requires the `std` or `getrandom` feature for randomized signing.
173    #[cfg(feature = "getrandom")]
174    pub fn sign_prehash(
175        &self,
176        msg: &[u8],
177        ctx: &[u8],
178    ) -> Result<DilithiumSignature, DilithiumError> {
179        if ctx.len() > 255 {
180            return Err(DilithiumError::BadArgument);
181        }
182
183        let mut rnd = [0u8; RNDBYTES];
184        getrandom(&mut rnd).map_err(|()| DilithiumError::RandomError)?;
185
186        let mut sig = vec![0u8; self.mode.signature_bytes()];
187        let ret = sign::sign_hash(self.mode, &mut sig, msg, ctx, &rnd, &self.privkey);
188        rnd.zeroize();
189
190        if ret != 0 {
191            return Err(DilithiumError::BadArgument);
192        }
193
194        Ok(DilithiumSignature { data: sig })
195    }
196
197    /// Sign deterministically (for testing / reproducibility).
198    pub fn sign_deterministic(
199        &self,
200        msg: &[u8],
201        ctx: &[u8],
202        rnd: &[u8; RNDBYTES],
203    ) -> Result<DilithiumSignature, DilithiumError> {
204        if ctx.len() > 255 {
205            return Err(DilithiumError::BadArgument);
206        }
207
208        let mut sig = vec![0u8; self.mode.signature_bytes()];
209        sign::sign_signature(self.mode, &mut sig, msg, ctx, rnd, &self.privkey);
210        Ok(DilithiumSignature { data: sig })
211    }
212
213    /// Verify a pure ML-DSA signature (FIPS 204 §6.1 ML-DSA.Verify).
214    #[must_use]
215    pub fn verify(
216        pk: &[u8],
217        sig: &DilithiumSignature,
218        msg: &[u8],
219        ctx: &[u8],
220        mode: DilithiumMode,
221    ) -> bool {
222        if pk.len() != mode.public_key_bytes() {
223            return false;
224        }
225        if sig.data.len() != mode.signature_bytes() {
226            return false;
227        }
228        sign::verify(mode, &sig.data, msg, ctx, pk)
229    }
230
231    /// Verify a HashML-DSA signature (FIPS 204 §6.2 HashML-DSA.Verify).
232    #[must_use]
233    pub fn verify_prehash(
234        pk: &[u8],
235        sig: &DilithiumSignature,
236        msg: &[u8],
237        ctx: &[u8],
238        mode: DilithiumMode,
239    ) -> bool {
240        if pk.len() != mode.public_key_bytes() {
241            return false;
242        }
243        if sig.data.len() != mode.signature_bytes() {
244            return false;
245        }
246        sign::verify_hash(mode, &sig.data, msg, ctx, pk)
247    }
248
249    /// Get the encoded public key bytes.
250    #[must_use]
251    pub fn public_key(&self) -> &[u8] {
252        &self.pubkey
253    }
254
255    /// Get the encoded private key bytes.
256    #[must_use]
257    pub fn private_key(&self) -> &[u8] {
258        &self.privkey
259    }
260
261    /// Get the security mode.
262    #[must_use]
263    pub fn mode(&self) -> DilithiumMode {
264        self.mode
265    }
266
267    /// Reconstruct from private + public key bytes with validation (FIPS 204 §7.1).
268    ///
269    /// Validates that:
270    /// 1. Key sizes match the expected values for the given mode.
271    /// 2. The public key embedded in the secret key is consistent.
272    /// 3. The secret key's `tr = H(pk)` field is consistent.
273    pub fn from_keys(
274        privkey: &[u8],
275        pubkey: &[u8],
276        mode: DilithiumMode,
277    ) -> Result<Self, DilithiumError> {
278        // Check sizes
279        if privkey.len() != mode.secret_key_bytes() {
280            return Err(DilithiumError::FormatError);
281        }
282        if pubkey.len() != mode.public_key_bytes() {
283            return Err(DilithiumError::FormatError);
284        }
285
286        // FIPS 204 §7.1: Validate key consistency
287        // The secret key starts with rho (SEEDBYTES) which must match
288        // the public key's rho
289        let sk_rho = &privkey[..SEEDBYTES];
290        let pk_rho = &pubkey[..SEEDBYTES];
291        if sk_rho != pk_rho {
292            return Err(DilithiumError::InvalidKey);
293        }
294
295        // Validate tr = H(pk) — tr is at offset 2*SEEDBYTES in sk layout: (rho, key, tr, ...)
296        let tr_offset = 2 * SEEDBYTES;
297        let sk_tr = &privkey[tr_offset..tr_offset + TRBYTES];
298        let mut expected_tr = [0u8; TRBYTES];
299        shake256(&mut expected_tr, pubkey);
300        if sk_tr != &expected_tr[..] {
301            return Err(DilithiumError::InvalidKey);
302        }
303
304        Ok(DilithiumKeyPair {
305            privkey: Zeroizing::new(privkey.to_vec()),
306            pubkey: pubkey.to_vec(),
307            mode,
308        })
309    }
310
311    // ── Serialization ──────────────────────────────────────────────
312
313    /// Serialize the full key pair to bytes: `[mode_tag(1) | pk | sk]`.
314    ///
315    /// The mode tag encodes the security level so deserialization
316    /// can automatically select the correct parameters.
317    #[must_use]
318    pub fn to_bytes(&self) -> Vec<u8> {
319        let mut buf = Vec::with_capacity(1 + self.pubkey.len() + self.privkey.len());
320        buf.push(self.mode.mode_tag());
321        buf.extend_from_slice(&self.pubkey);
322        buf.extend_from_slice(&self.privkey);
323        buf
324    }
325
326    /// Deserialize a key pair from the format produced by [`to_bytes`](Self::to_bytes).
327    pub fn from_bytes(data: &[u8]) -> Result<Self, DilithiumError> {
328        if data.is_empty() {
329            return Err(DilithiumError::FormatError);
330        }
331        let mode = DilithiumMode::from_tag(data[0]).ok_or(DilithiumError::FormatError)?;
332        let pk_len = mode.public_key_bytes();
333        let sk_len = mode.secret_key_bytes();
334        if data.len() != 1 + pk_len + sk_len {
335            return Err(DilithiumError::FormatError);
336        }
337        let pk = &data[1..=pk_len];
338        let sk = &data[1 + pk_len..];
339        Self::from_keys(sk, pk, mode)
340    }
341
342    /// Export only the public key bytes with a mode tag: `[mode_tag(1) | pk]`.
343    #[must_use]
344    pub fn public_key_bytes(&self) -> Vec<u8> {
345        let mut buf = Vec::with_capacity(1 + self.pubkey.len());
346        buf.push(self.mode.mode_tag());
347        buf.extend_from_slice(&self.pubkey);
348        buf
349    }
350
351    /// Create a verify-only handle from tagged public key bytes.
352    pub fn from_public_key(data: &[u8]) -> Result<(DilithiumMode, Vec<u8>), DilithiumError> {
353        if data.is_empty() {
354            return Err(DilithiumError::FormatError);
355        }
356        let mode = DilithiumMode::from_tag(data[0]).ok_or(DilithiumError::FormatError)?;
357        if data.len() != 1 + mode.public_key_bytes() {
358            return Err(DilithiumError::FormatError);
359        }
360        Ok((mode, data[1..].to_vec()))
361    }
362}
363
364impl DilithiumSignature {
365    /// Get the raw signature bytes.
366    #[must_use]
367    pub fn as_bytes(&self) -> &[u8] {
368        &self.data
369    }
370
371    /// Create from raw bytes (no validation — use verify to check).
372    #[must_use]
373    pub fn from_bytes(data: Vec<u8>) -> Self {
374        Self { data }
375    }
376
377    /// Create from a byte slice (copies).
378    #[must_use]
379    pub fn from_slice(data: &[u8]) -> Self {
380        Self {
381            data: data.to_vec(),
382        }
383    }
384
385    /// Signature length in bytes.
386    #[must_use]
387    pub fn len(&self) -> usize {
388        self.data.len()
389    }
390
391    /// Returns true if the signature is empty.
392    #[must_use]
393    pub fn is_empty(&self) -> bool {
394        self.data.is_empty()
395    }
396}
397
398/// Fill buffer with random bytes via `getrandom` crate (WASM compatible).
399#[cfg(feature = "getrandom")]
400fn getrandom(buf: &mut [u8]) -> Result<(), ()> {
401    ::getrandom::getrandom(buf).map_err(|_| ())
402}