rustolio_utils/crypto/
encryption.rs1use aes_gcm::{
12 aead::{Aead as _, Payload},
13 Aes256Gcm, KeyInit as _,
14};
15
16use crate::bytes::Bytes;
17use crate::prelude::*;
18
19use super::rand;
20
21pub type Result<T> = std::result::Result<T, Error>;
22
23#[derive(Debug, Clone, PartialEq)]
24pub enum Error {
25 Rand(rand::Error),
26 InvalidKey,
27 Encryption,
28 Decryption,
29}
30
31impl From<rand::Error> for Error {
32 fn from(value: rand::Error) -> Self {
33 Self::Rand(value)
34 }
35}
36
37impl std::fmt::Display for Error {
38 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
39 write!(f, "{self:?}")
40 }
41}
42
43impl std::error::Error for Error {
44 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
45 match self {
46 Self::Rand(e) => Some(e),
47 _ => None,
48 }
49 }
50}
51
52#[derive(Debug, Clone, Encode, Decode)]
53#[repr(transparent)]
54pub struct Key([u8; 32]);
55
56impl Key {
57 pub fn generate() -> Result<Self> {
58 Ok(Self(rand::array()?))
59 }
60
61 pub fn to_bytes(&self) -> [u8; 32] {
62 self.0
63 }
64
65 pub fn from_bytes(bytes: &[u8]) -> Result<&Self> {
66 if bytes.len() != 32 {
67 return Err(Error::InvalidKey);
68 }
69 Ok(unsafe {
70 &*bytes.as_ptr().cast()
72 })
73 }
74
75 pub fn from_array(bytes: [u8; 32]) -> Self {
76 Self(bytes)
77 }
78}
79
80#[derive(Clone)]
81pub struct Cipher(Aes256Gcm);
82
83impl Cipher {
84 pub fn new(key: &Key) -> Self {
85 Self(Aes256Gcm::new(&key.0.into()))
86 }
87
88 pub fn encrypt(&self, msg: impl AsRef<[u8]>) -> Result<Encypted> {
89 self.encrypt_with_aad(b"", msg)
90 }
91
92 pub fn encrypt_with_aad(
93 &self,
94 aad: impl AsRef<[u8]>,
95 msg: impl AsRef<[u8]>,
96 ) -> Result<Encypted> {
97 let nonce = rand::array()?;
98
99 let payload = Payload {
100 msg: msg.as_ref(),
101 aad: aad.as_ref(),
102 };
103 let Ok(msg) = self.0.encrypt(&nonce.into(), payload) else {
104 return Err(Error::Encryption);
105 };
106
107 Ok(Encypted {
108 msg: Bytes::from(msg),
109 nonce,
110 })
111 }
112
113 pub fn decrypt(&self, encrypted: &Encypted) -> Result<Bytes> {
114 self.decrypt_with_aad(b"", encrypted)
115 }
116
117 pub fn decrypt_with_aad(&self, aad: impl AsRef<[u8]>, encrypted: &Encypted) -> Result<Bytes> {
118 let Encypted { msg, nonce } = encrypted;
119 let payload = Payload {
120 msg,
121 aad: aad.as_ref(),
122 };
123 let Ok(msg) = self.0.decrypt(nonce.into(), payload) else {
124 return Err(Error::Decryption);
125 };
126 Ok(Bytes::from_owner(msg))
127 }
128}
129
130#[derive(Debug, Clone, PartialEq, Eq, Hash, Encode, Decode)]
131pub struct Encypted {
132 nonce: [u8; 12],
133 msg: Bytes,
134}
135
136impl Encypted {
137 pub fn size(&self) -> usize {
138 12 + self.msg.len()
139 }
140}
141
142#[cfg(test)]
143mod tests {
144
145 use super::*;
146
147 #[test]
148 fn test_encryption() {
149 let msg = b"Some message";
150
151 let key = Key::generate().unwrap();
152 let cipher = Cipher::new(&key);
153
154 let encrypted = cipher.encrypt(&msg).unwrap();
155 let decryped = cipher.decrypt(&encrypted).unwrap();
156
157 assert_eq!(*decryped, *msg);
158 }
159
160 #[test]
161 fn test_encryption_encoding_decoding() {
162 let msg = b"Some message";
163
164 let key = Key::generate().unwrap();
165
166 let encoded = key.to_bytes();
167 let key = Key::from_bytes(&encoded).unwrap();
168
169 let cipher = Cipher::new(&key);
170
171 let encrypted = cipher.encrypt(&msg).unwrap();
172 let decryped = cipher.decrypt(&encrypted).unwrap();
173
174 assert_eq!(*decryped, *msg);
175 }
176}