1use alloc::{vec, vec::Vec};
27use core::fmt;
28
29use zeroize::{Zeroize, Zeroizing};
30
31pub use crate::params::DilithiumMode;
32use crate::params::*;
33use crate::sign;
34use crate::symmetric::shake256;
35
36#[derive(Debug, Clone, Copy, PartialEq, Eq)]
38#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
39pub enum DilithiumError {
40 RandomError,
42 FormatError,
44 BadSignature,
46 BadArgument,
48 InvalidKey,
50}
51
52impl fmt::Display for DilithiumError {
53 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
54 match self {
55 Self::RandomError => write!(f, "random number generation failed"),
56 Self::FormatError => write!(f, "invalid format"),
57 Self::BadSignature => write!(f, "invalid signature"),
58 Self::BadArgument => write!(f, "invalid argument"),
59 Self::InvalidKey => write!(f, "key validation failed"),
60 }
61 }
62}
63
64#[cfg(feature = "std")]
65impl std::error::Error for DilithiumError {}
66
67#[derive(Debug, Clone)]
73#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
74pub struct DilithiumKeyPair {
75 #[cfg_attr(
76 feature = "serde",
77 serde(
78 serialize_with = "serde_zeroizing::serialize",
79 deserialize_with = "serde_zeroizing::deserialize"
80 )
81 )]
82 privkey: Zeroizing<Vec<u8>>,
83 pubkey: Vec<u8>,
84 mode: DilithiumMode,
85}
86
87#[cfg(feature = "serde")]
89mod serde_zeroizing {
90 use super::*;
91 use serde::{Deserialize, Deserializer, Serialize, Serializer};
92
93 pub fn serialize<S: Serializer>(val: &Zeroizing<Vec<u8>>, s: S) -> Result<S::Ok, S::Error> {
94 let inner: &Vec<u8> = val;
96 inner.serialize(s)
97 }
98
99 pub fn deserialize<'de, D: Deserializer<'de>>(d: D) -> Result<Zeroizing<Vec<u8>>, D::Error> {
100 let v = Vec::<u8>::deserialize(d)?;
101 Ok(Zeroizing::new(v))
102 }
103}
104
105pub type MlDsaKeyPair = DilithiumKeyPair;
107
108#[derive(Debug, Clone, PartialEq, Eq)]
110#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
111pub struct DilithiumSignature {
112 data: Vec<u8>,
113}
114
115pub type MlDsaSignature = DilithiumSignature;
117
118impl DilithiumKeyPair {
119 pub fn generate(mode: DilithiumMode) -> Result<Self, DilithiumError> {
121 let mut seed = [0u8; SEEDBYTES];
122 getrandom(&mut seed).map_err(|()| DilithiumError::RandomError)?;
123 let result = Self::generate_deterministic(mode, &seed);
124 seed.zeroize();
125 Ok(result)
126 }
127
128 #[must_use]
130 pub fn generate_deterministic(mode: DilithiumMode, seed: &[u8; SEEDBYTES]) -> Self {
131 let (pk, sk) = sign::keypair(mode, seed);
132 DilithiumKeyPair {
133 privkey: Zeroizing::new(sk),
134 pubkey: pk,
135 mode,
136 }
137 }
138
139 pub fn sign(&self, msg: &[u8], ctx: &[u8]) -> Result<DilithiumSignature, DilithiumError> {
143 if ctx.len() > 255 {
144 return Err(DilithiumError::BadArgument);
145 }
146
147 let mut rnd = [0u8; RNDBYTES];
148 getrandom(&mut rnd).map_err(|()| DilithiumError::RandomError)?;
149
150 let mut sig = vec![0u8; self.mode.signature_bytes()];
151 let ret = sign::sign_signature(self.mode, &mut sig, msg, ctx, &rnd, &self.privkey);
152 rnd.zeroize();
153
154 if ret != 0 {
155 return Err(DilithiumError::BadArgument);
156 }
157
158 Ok(DilithiumSignature { data: sig })
159 }
160
161 pub fn sign_prehash(
166 &self,
167 msg: &[u8],
168 ctx: &[u8],
169 ) -> Result<DilithiumSignature, DilithiumError> {
170 if ctx.len() > 255 {
171 return Err(DilithiumError::BadArgument);
172 }
173
174 let mut rnd = [0u8; RNDBYTES];
175 getrandom(&mut rnd).map_err(|()| DilithiumError::RandomError)?;
176
177 let mut sig = vec![0u8; self.mode.signature_bytes()];
178 let ret = sign::sign_hash(self.mode, &mut sig, msg, ctx, &rnd, &self.privkey);
179 rnd.zeroize();
180
181 if ret != 0 {
182 return Err(DilithiumError::BadArgument);
183 }
184
185 Ok(DilithiumSignature { data: sig })
186 }
187
188 pub fn sign_deterministic(
190 &self,
191 msg: &[u8],
192 ctx: &[u8],
193 rnd: &[u8; RNDBYTES],
194 ) -> Result<DilithiumSignature, DilithiumError> {
195 if ctx.len() > 255 {
196 return Err(DilithiumError::BadArgument);
197 }
198
199 let mut sig = vec![0u8; self.mode.signature_bytes()];
200 sign::sign_signature(self.mode, &mut sig, msg, ctx, rnd, &self.privkey);
201 Ok(DilithiumSignature { data: sig })
202 }
203
204 #[must_use]
206 pub fn verify(
207 pk: &[u8],
208 sig: &DilithiumSignature,
209 msg: &[u8],
210 ctx: &[u8],
211 mode: DilithiumMode,
212 ) -> bool {
213 if pk.len() != mode.public_key_bytes() {
214 return false;
215 }
216 if sig.data.len() != mode.signature_bytes() {
217 return false;
218 }
219 sign::verify(mode, &sig.data, msg, ctx, pk)
220 }
221
222 #[must_use]
224 pub fn verify_prehash(
225 pk: &[u8],
226 sig: &DilithiumSignature,
227 msg: &[u8],
228 ctx: &[u8],
229 mode: DilithiumMode,
230 ) -> bool {
231 if pk.len() != mode.public_key_bytes() {
232 return false;
233 }
234 if sig.data.len() != mode.signature_bytes() {
235 return false;
236 }
237 sign::verify_hash(mode, &sig.data, msg, ctx, pk)
238 }
239
240 #[must_use]
242 pub fn public_key(&self) -> &[u8] {
243 &self.pubkey
244 }
245
246 #[must_use]
248 pub fn private_key(&self) -> &[u8] {
249 &self.privkey
250 }
251
252 #[must_use]
254 pub fn mode(&self) -> DilithiumMode {
255 self.mode
256 }
257
258 pub fn from_keys(
265 privkey: &[u8],
266 pubkey: &[u8],
267 mode: DilithiumMode,
268 ) -> Result<Self, DilithiumError> {
269 if privkey.len() != mode.secret_key_bytes() {
271 return Err(DilithiumError::FormatError);
272 }
273 if pubkey.len() != mode.public_key_bytes() {
274 return Err(DilithiumError::FormatError);
275 }
276
277 let sk_rho = &privkey[..SEEDBYTES];
281 let pk_rho = &pubkey[..SEEDBYTES];
282 if sk_rho != pk_rho {
283 return Err(DilithiumError::InvalidKey);
284 }
285
286 let tr_offset = 2 * SEEDBYTES;
288 let sk_tr = &privkey[tr_offset..tr_offset + TRBYTES];
289 let mut expected_tr = [0u8; TRBYTES];
290 shake256(&mut expected_tr, pubkey);
291 if sk_tr != &expected_tr[..] {
292 return Err(DilithiumError::InvalidKey);
293 }
294
295 Ok(DilithiumKeyPair {
296 privkey: Zeroizing::new(privkey.to_vec()),
297 pubkey: pubkey.to_vec(),
298 mode,
299 })
300 }
301
302 #[must_use]
309 pub fn to_bytes(&self) -> Vec<u8> {
310 let mut buf = Vec::with_capacity(1 + self.pubkey.len() + self.privkey.len());
311 buf.push(self.mode.mode_tag());
312 buf.extend_from_slice(&self.pubkey);
313 buf.extend_from_slice(&self.privkey);
314 buf
315 }
316
317 pub fn from_bytes(data: &[u8]) -> Result<Self, DilithiumError> {
319 if data.is_empty() {
320 return Err(DilithiumError::FormatError);
321 }
322 let mode = DilithiumMode::from_tag(data[0]).ok_or(DilithiumError::FormatError)?;
323 let pk_len = mode.public_key_bytes();
324 let sk_len = mode.secret_key_bytes();
325 if data.len() != 1 + pk_len + sk_len {
326 return Err(DilithiumError::FormatError);
327 }
328 let pk = &data[1..=pk_len];
329 let sk = &data[1 + pk_len..];
330 Self::from_keys(sk, pk, mode)
331 }
332
333 #[must_use]
335 pub fn public_key_bytes(&self) -> Vec<u8> {
336 let mut buf = Vec::with_capacity(1 + self.pubkey.len());
337 buf.push(self.mode.mode_tag());
338 buf.extend_from_slice(&self.pubkey);
339 buf
340 }
341
342 pub fn from_public_key(data: &[u8]) -> Result<(DilithiumMode, Vec<u8>), DilithiumError> {
344 if data.is_empty() {
345 return Err(DilithiumError::FormatError);
346 }
347 let mode = DilithiumMode::from_tag(data[0]).ok_or(DilithiumError::FormatError)?;
348 if data.len() != 1 + mode.public_key_bytes() {
349 return Err(DilithiumError::FormatError);
350 }
351 Ok((mode, data[1..].to_vec()))
352 }
353}
354
355impl DilithiumSignature {
356 #[must_use]
358 pub fn as_bytes(&self) -> &[u8] {
359 &self.data
360 }
361
362 #[must_use]
364 pub fn from_bytes(data: Vec<u8>) -> Self {
365 Self { data }
366 }
367
368 #[must_use]
370 pub fn from_slice(data: &[u8]) -> Self {
371 Self {
372 data: data.to_vec(),
373 }
374 }
375
376 #[must_use]
378 pub fn len(&self) -> usize {
379 self.data.len()
380 }
381
382 #[must_use]
384 pub fn is_empty(&self) -> bool {
385 self.data.is_empty()
386 }
387}
388
389fn getrandom(buf: &mut [u8]) -> Result<(), ()> {
391 ::getrandom::getrandom(buf).map_err(|_| ())
392}