1use alloc::{vec, vec::Vec};
27use core::fmt;
28
29use zeroize::Zeroizing;
30#[cfg(feature = "getrandom")]
31use zeroize::Zeroize;
32
33pub use crate::params::DilithiumMode;
34use crate::params::*;
35use crate::sign;
36use crate::symmetric::shake256;
37
38#[derive(Debug, Clone, Copy, PartialEq, Eq)]
40#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
41pub enum DilithiumError {
42 RandomError,
44 FormatError,
46 BadSignature,
48 BadArgument,
50 InvalidKey,
52}
53
54impl fmt::Display for DilithiumError {
55 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
56 match self {
57 Self::RandomError => write!(f, "random number generation failed"),
58 Self::FormatError => write!(f, "invalid format"),
59 Self::BadSignature => write!(f, "invalid signature"),
60 Self::BadArgument => write!(f, "invalid argument"),
61 Self::InvalidKey => write!(f, "key validation failed"),
62 }
63 }
64}
65
66#[cfg(feature = "std")]
67impl std::error::Error for DilithiumError {}
68
69#[derive(Debug, Clone)]
75#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
76pub struct DilithiumKeyPair {
77 #[cfg_attr(
78 feature = "serde",
79 serde(
80 serialize_with = "serde_zeroizing::serialize",
81 deserialize_with = "serde_zeroizing::deserialize"
82 )
83 )]
84 privkey: Zeroizing<Vec<u8>>,
85 pubkey: Vec<u8>,
86 mode: DilithiumMode,
87}
88
89#[cfg(feature = "serde")]
91mod serde_zeroizing {
92 use super::*;
93 use serde::{Deserialize, Deserializer, Serialize, Serializer};
94
95 pub fn serialize<S: Serializer>(val: &Zeroizing<Vec<u8>>, s: S) -> Result<S::Ok, S::Error> {
96 let inner: &Vec<u8> = val;
98 inner.serialize(s)
99 }
100
101 pub fn deserialize<'de, D: Deserializer<'de>>(d: D) -> Result<Zeroizing<Vec<u8>>, D::Error> {
102 let v = Vec::<u8>::deserialize(d)?;
103 Ok(Zeroizing::new(v))
104 }
105}
106
107pub type MlDsaKeyPair = DilithiumKeyPair;
109
110#[derive(Debug, Clone, PartialEq, Eq)]
112#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
113pub struct DilithiumSignature {
114 data: Vec<u8>,
115}
116
117pub type MlDsaSignature = DilithiumSignature;
119
120impl DilithiumKeyPair {
121 #[cfg(feature = "getrandom")]
125 pub fn generate(mode: DilithiumMode) -> Result<Self, DilithiumError> {
126 let mut seed = [0u8; SEEDBYTES];
127 getrandom(&mut seed).map_err(|()| DilithiumError::RandomError)?;
128 let result = Self::generate_deterministic(mode, &seed);
129 seed.zeroize();
130 Ok(result)
131 }
132
133 #[must_use]
135 pub fn generate_deterministic(mode: DilithiumMode, seed: &[u8; SEEDBYTES]) -> Self {
136 let (pk, sk) = sign::keypair(mode, seed);
137 DilithiumKeyPair {
138 privkey: Zeroizing::new(sk),
139 pubkey: pk,
140 mode,
141 }
142 }
143
144 #[cfg(feature = "getrandom")]
149 pub fn sign(&self, msg: &[u8], ctx: &[u8]) -> Result<DilithiumSignature, DilithiumError> {
150 if ctx.len() > 255 {
151 return Err(DilithiumError::BadArgument);
152 }
153
154 let mut rnd = [0u8; RNDBYTES];
155 getrandom(&mut rnd).map_err(|()| DilithiumError::RandomError)?;
156
157 let mut sig = vec![0u8; self.mode.signature_bytes()];
158 let ret = sign::sign_signature(self.mode, &mut sig, msg, ctx, &rnd, &self.privkey);
159 rnd.zeroize();
160
161 if ret != 0 {
162 return Err(DilithiumError::BadArgument);
163 }
164
165 Ok(DilithiumSignature { data: sig })
166 }
167
168 #[cfg(feature = "getrandom")]
174 pub fn sign_prehash(
175 &self,
176 msg: &[u8],
177 ctx: &[u8],
178 ) -> Result<DilithiumSignature, DilithiumError> {
179 if ctx.len() > 255 {
180 return Err(DilithiumError::BadArgument);
181 }
182
183 let mut rnd = [0u8; RNDBYTES];
184 getrandom(&mut rnd).map_err(|()| DilithiumError::RandomError)?;
185
186 let mut sig = vec![0u8; self.mode.signature_bytes()];
187 let ret = sign::sign_hash(self.mode, &mut sig, msg, ctx, &rnd, &self.privkey);
188 rnd.zeroize();
189
190 if ret != 0 {
191 return Err(DilithiumError::BadArgument);
192 }
193
194 Ok(DilithiumSignature { data: sig })
195 }
196
197 pub fn sign_deterministic(
199 &self,
200 msg: &[u8],
201 ctx: &[u8],
202 rnd: &[u8; RNDBYTES],
203 ) -> Result<DilithiumSignature, DilithiumError> {
204 if ctx.len() > 255 {
205 return Err(DilithiumError::BadArgument);
206 }
207
208 let mut sig = vec![0u8; self.mode.signature_bytes()];
209 sign::sign_signature(self.mode, &mut sig, msg, ctx, rnd, &self.privkey);
210 Ok(DilithiumSignature { data: sig })
211 }
212
213 #[must_use]
215 pub fn verify(
216 pk: &[u8],
217 sig: &DilithiumSignature,
218 msg: &[u8],
219 ctx: &[u8],
220 mode: DilithiumMode,
221 ) -> bool {
222 if pk.len() != mode.public_key_bytes() {
223 return false;
224 }
225 if sig.data.len() != mode.signature_bytes() {
226 return false;
227 }
228 sign::verify(mode, &sig.data, msg, ctx, pk)
229 }
230
231 #[must_use]
233 pub fn verify_prehash(
234 pk: &[u8],
235 sig: &DilithiumSignature,
236 msg: &[u8],
237 ctx: &[u8],
238 mode: DilithiumMode,
239 ) -> bool {
240 if pk.len() != mode.public_key_bytes() {
241 return false;
242 }
243 if sig.data.len() != mode.signature_bytes() {
244 return false;
245 }
246 sign::verify_hash(mode, &sig.data, msg, ctx, pk)
247 }
248
249 #[must_use]
251 pub fn public_key(&self) -> &[u8] {
252 &self.pubkey
253 }
254
255 #[must_use]
257 pub fn private_key(&self) -> &[u8] {
258 &self.privkey
259 }
260
261 #[must_use]
263 pub fn mode(&self) -> DilithiumMode {
264 self.mode
265 }
266
267 pub fn from_keys(
274 privkey: &[u8],
275 pubkey: &[u8],
276 mode: DilithiumMode,
277 ) -> Result<Self, DilithiumError> {
278 if privkey.len() != mode.secret_key_bytes() {
280 return Err(DilithiumError::FormatError);
281 }
282 if pubkey.len() != mode.public_key_bytes() {
283 return Err(DilithiumError::FormatError);
284 }
285
286 let sk_rho = &privkey[..SEEDBYTES];
290 let pk_rho = &pubkey[..SEEDBYTES];
291 if sk_rho != pk_rho {
292 return Err(DilithiumError::InvalidKey);
293 }
294
295 let tr_offset = 2 * SEEDBYTES;
297 let sk_tr = &privkey[tr_offset..tr_offset + TRBYTES];
298 let mut expected_tr = [0u8; TRBYTES];
299 shake256(&mut expected_tr, pubkey);
300 if sk_tr != &expected_tr[..] {
301 return Err(DilithiumError::InvalidKey);
302 }
303
304 Ok(DilithiumKeyPair {
305 privkey: Zeroizing::new(privkey.to_vec()),
306 pubkey: pubkey.to_vec(),
307 mode,
308 })
309 }
310
311 #[must_use]
318 pub fn to_bytes(&self) -> Vec<u8> {
319 let mut buf = Vec::with_capacity(1 + self.pubkey.len() + self.privkey.len());
320 buf.push(self.mode.mode_tag());
321 buf.extend_from_slice(&self.pubkey);
322 buf.extend_from_slice(&self.privkey);
323 buf
324 }
325
326 pub fn from_bytes(data: &[u8]) -> Result<Self, DilithiumError> {
328 if data.is_empty() {
329 return Err(DilithiumError::FormatError);
330 }
331 let mode = DilithiumMode::from_tag(data[0]).ok_or(DilithiumError::FormatError)?;
332 let pk_len = mode.public_key_bytes();
333 let sk_len = mode.secret_key_bytes();
334 if data.len() != 1 + pk_len + sk_len {
335 return Err(DilithiumError::FormatError);
336 }
337 let pk = &data[1..=pk_len];
338 let sk = &data[1 + pk_len..];
339 Self::from_keys(sk, pk, mode)
340 }
341
342 #[must_use]
344 pub fn public_key_bytes(&self) -> Vec<u8> {
345 let mut buf = Vec::with_capacity(1 + self.pubkey.len());
346 buf.push(self.mode.mode_tag());
347 buf.extend_from_slice(&self.pubkey);
348 buf
349 }
350
351 pub fn from_public_key(data: &[u8]) -> Result<(DilithiumMode, Vec<u8>), DilithiumError> {
353 if data.is_empty() {
354 return Err(DilithiumError::FormatError);
355 }
356 let mode = DilithiumMode::from_tag(data[0]).ok_or(DilithiumError::FormatError)?;
357 if data.len() != 1 + mode.public_key_bytes() {
358 return Err(DilithiumError::FormatError);
359 }
360 Ok((mode, data[1..].to_vec()))
361 }
362}
363
364impl DilithiumSignature {
365 #[must_use]
367 pub fn as_bytes(&self) -> &[u8] {
368 &self.data
369 }
370
371 #[must_use]
373 pub fn from_bytes(data: Vec<u8>) -> Self {
374 Self { data }
375 }
376
377 #[must_use]
379 pub fn from_slice(data: &[u8]) -> Self {
380 Self {
381 data: data.to_vec(),
382 }
383 }
384
385 #[must_use]
387 pub fn len(&self) -> usize {
388 self.data.len()
389 }
390
391 #[must_use]
393 pub fn is_empty(&self) -> bool {
394 self.data.is_empty()
395 }
396}
397
398#[cfg(feature = "getrandom")]
400fn getrandom(buf: &mut [u8]) -> Result<(), ()> {
401 ::getrandom::getrandom(buf).map_err(|_| ())
402}