1use super::encrypt_digest_into;
2#[cfg(not(feature = "alloc"))]
3use super::Label;
4use crate::traits::{modular::ModulusParams, PublicKeyParts, UnsignedModularInt};
5use crate::{traits::RandomizedEncryptor, GenericRsaPublicKey, Result};
6#[cfg(feature = "alloc")]
7use alloc::{boxed::Box, vec::Vec};
8use core::marker::PhantomData;
9#[cfg(feature = "alloc")]
10use crypto_bigint::{modular::BoxedMontyParams, BoxedUint};
11use digest::{Digest, FixedOutputReset};
12use rand_core::{CryptoRng, TryCryptoRng};
13#[cfg(feature = "serde")]
14use serde::{Deserialize, Serialize};
15
16#[derive(Debug, Clone)]
20#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
21#[cfg_attr(
22 feature = "serde",
23 serde(bound(
24 serialize = "GenericRsaPublicKey<T, M>: Serialize",
25 deserialize = "GenericRsaPublicKey<T, M>: serde::de::DeserializeOwned"
26 ))
27)]
28pub struct GenericEncryptingKey<D, MGD, T, M>
29where
30 T: UnsignedModularInt,
31 M: ModulusParams<Modulus = T>,
32{
33 inner: GenericRsaPublicKey<T, M>,
34 #[cfg(feature = "alloc")]
35 label: Option<Box<[u8]>>,
36 #[cfg(not(feature = "alloc"))]
37 label: Option<Label>,
38 phantom: PhantomData<D>,
39 mg_phantom: PhantomData<MGD>,
40}
41
42#[cfg(feature = "alloc")]
44pub type EncryptingKey<D, MGD = D> = GenericEncryptingKey<D, MGD, BoxedUint, BoxedMontyParams>;
45
46impl<D, MGD, T, M> GenericEncryptingKey<D, MGD, T, M>
47where
48 T: UnsignedModularInt,
49 M: ModulusParams<Modulus = T>,
50{
51 pub fn new(key: GenericRsaPublicKey<T, M>) -> Self {
53 Self {
54 inner: key,
55 label: None,
56 phantom: Default::default(),
57 mg_phantom: Default::default(),
58 }
59 }
60
61 #[cfg(feature = "alloc")]
63 pub fn new_with_label<S: Into<Box<[u8]>>>(key: GenericRsaPublicKey<T, M>, label: S) -> Self {
64 Self {
65 inner: key,
66 label: Some(label.into()),
67 phantom: Default::default(),
68 mg_phantom: Default::default(),
69 }
70 }
71
72 #[cfg(not(feature = "alloc"))]
74 pub fn new_with_label(key: GenericRsaPublicKey<T, M>, label: Label) -> Self {
75 Self {
76 inner: key,
77 label: Some(label),
78 phantom: Default::default(),
79 mg_phantom: Default::default(),
80 }
81 }
82}
83
84impl<D, MGD, T, M> RandomizedEncryptor for GenericEncryptingKey<D, MGD, T, M>
85where
86 D: Digest,
87 MGD: Digest + FixedOutputReset,
88 T: UnsignedModularInt,
89 M: ModulusParams<Modulus = T>,
90{
91 fn encrypt_with_rng_into<'a, R: TryCryptoRng + ?Sized>(
92 &self,
93 rng: &mut R,
94 msg: &[u8],
95 storage: &'a mut [u8],
96 ) -> Result<&'a [u8]> {
97 let label = self.label.as_deref();
98 encrypt_digest_into::<_, D, MGD, _, T>(rng, &self.inner, msg, label, storage)
99 }
100
101 #[cfg(feature = "alloc")]
102 fn encrypt_with_rng<R: CryptoRng + ?Sized>(&self, rng: &mut R, msg: &[u8]) -> Result<Vec<u8>> {
103 let mut storage = vec![0u8; self.inner.size()];
104 let ciphertext = self.encrypt_with_rng_into(rng, msg, &mut storage)?;
105 Ok(ciphertext.to_vec())
106 }
107}
108
109#[cfg(feature = "alloc")]
110impl<D, MGD, T, M> PartialEq for GenericEncryptingKey<D, MGD, T, M>
111where
112 T: UnsignedModularInt,
113 M: ModulusParams<Modulus = T>,
114 GenericRsaPublicKey<T, M>: PartialEq,
115{
116 fn eq(&self, other: &Self) -> bool {
117 self.inner == other.inner && self.label == other.label
118 }
119}
120
121#[cfg(test)]
122mod tests {
123
124 #[test]
125 #[cfg(all(feature = "hazmat", feature = "serde", feature = "private-key"))]
126 fn test_serde() {
127 use super::*;
128 use rand::rngs::ChaCha8Rng;
129 use rand_core::SeedableRng;
130 use serde_test::{assert_tokens, Configure, Token};
131
132 let mut rng = ChaCha8Rng::from_seed([42; 32]);
133 let priv_key =
134 crate::RsaPrivateKey::new_unchecked(&mut rng, 64).expect("failed to generate key");
135 let encrypting_key = EncryptingKey::<sha2::Sha256>::new(priv_key.to_public_key());
136
137 let tokens = [
138 Token::Struct {
139 name: "GenericEncryptingKey",
140 len: 4,
141 },
142 Token::Str("inner"),
143 Token::Str(
144 "3024300d06092a864886f70d01010105000313003010020900ab240c3361d02e370203010001",
145 ),
146 Token::Str("label"),
147 Token::None,
148 Token::Str("phantom"),
149 Token::UnitStruct {
150 name: "PhantomData",
151 },
152 Token::Str("mg_phantom"),
153 Token::UnitStruct {
154 name: "PhantomData",
155 },
156 Token::StructEnd,
157 ];
158 assert_tokens(&encrypting_key.readable(), &tokens);
159 }
160}