Skip to main content

rustolio_utils/crypto/
encryption.rs

1//
2// SPDX-License-Identifier: MPL-2.0
3//
4// Copyright (c) 2026 Tobias Binnewies. All rights reserved.
5//
6// This Source Code Form is subject to the terms of the Mozilla Public
7// License, v. 2.0. If a copy of the MPL was not distributed with this
8// file, You can obtain one at http://mozilla.org/MPL/2.0/.
9//
10
11use aes_gcm::{
12    aead::{Aead as _, Payload},
13    Aes256Gcm, KeyInit as _,
14};
15
16use crate::bytes::Bytes;
17use crate::prelude::*;
18
19use super::rand;
20
21pub type Result<T> = std::result::Result<T, Error>;
22
23#[derive(Debug, Clone, PartialEq)]
24pub enum Error {
25    Rand(rand::Error),
26    InvalidKey,
27    Encryption,
28    Decryption,
29}
30
31impl From<rand::Error> for Error {
32    fn from(value: rand::Error) -> Self {
33        Self::Rand(value)
34    }
35}
36
37impl std::fmt::Display for Error {
38    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
39        write!(f, "{self:?}")
40    }
41}
42
43impl std::error::Error for Error {
44    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
45        match self {
46            Self::Rand(e) => Some(e),
47            _ => None,
48        }
49    }
50}
51
52#[derive(Debug, Clone, Encode, Decode)]
53#[repr(transparent)]
54pub struct Key([u8; 32]);
55
56impl Key {
57    pub fn generate() -> Result<Self> {
58        Ok(Self(rand::array()?))
59    }
60
61    pub fn to_bytes(&self) -> [u8; 32] {
62        self.0
63    }
64
65    pub fn from_bytes(bytes: &[u8]) -> Result<&Self> {
66        if bytes.len() != 32 {
67            return Err(Error::InvalidKey);
68        }
69        Ok(unsafe {
70            // SAFETY: #[repr(transparent)] & length checked
71            &*bytes.as_ptr().cast()
72        })
73    }
74
75    pub fn from_array(bytes: [u8; 32]) -> Self {
76        Self(bytes)
77    }
78}
79
80#[derive(Clone)]
81pub struct Cipher(Aes256Gcm);
82
83impl Cipher {
84    pub fn new(key: &Key) -> Self {
85        Self(Aes256Gcm::new(&key.0.into()))
86    }
87
88    pub fn encrypt(&self, msg: impl AsRef<[u8]>) -> Result<Encypted> {
89        self.encrypt_with_aad(b"", msg)
90    }
91
92    pub fn encrypt_with_aad(
93        &self,
94        aad: impl AsRef<[u8]>,
95        msg: impl AsRef<[u8]>,
96    ) -> Result<Encypted> {
97        let nonce = rand::array()?;
98
99        let payload = Payload {
100            msg: msg.as_ref(),
101            aad: aad.as_ref(),
102        };
103        let Ok(msg) = self.0.encrypt(&nonce.into(), payload) else {
104            return Err(Error::Encryption);
105        };
106
107        Ok(Encypted {
108            msg: Bytes::from(msg),
109            nonce,
110        })
111    }
112
113    pub fn decrypt(&self, encrypted: &Encypted) -> Result<Bytes> {
114        self.decrypt_with_aad(b"", encrypted)
115    }
116
117    pub fn decrypt_with_aad(&self, aad: impl AsRef<[u8]>, encrypted: &Encypted) -> Result<Bytes> {
118        let Encypted { msg, nonce } = encrypted;
119        let payload = Payload {
120            msg,
121            aad: aad.as_ref(),
122        };
123        let Ok(msg) = self.0.decrypt(nonce.into(), payload) else {
124            return Err(Error::Decryption);
125        };
126        Ok(Bytes::from_owner(msg))
127    }
128}
129
130#[derive(Debug, Clone, PartialEq, Eq, Hash, Encode, Decode)]
131pub struct Encypted {
132    nonce: [u8; 12],
133    msg: Bytes,
134}
135
136impl Encypted {
137    pub fn size(&self) -> usize {
138        12 + self.msg.len()
139    }
140}
141
142#[cfg(test)]
143mod tests {
144
145    use super::*;
146
147    #[test]
148    fn test_encryption() {
149        let msg = b"Some message";
150
151        let key = Key::generate().unwrap();
152        let cipher = Cipher::new(&key);
153
154        let encrypted = cipher.encrypt(&msg).unwrap();
155        let decryped = cipher.decrypt(&encrypted).unwrap();
156
157        assert_eq!(*decryped, *msg);
158    }
159
160    #[test]
161    fn test_encryption_encoding_decoding() {
162        let msg = b"Some message";
163
164        let key = Key::generate().unwrap();
165
166        let encoded = key.to_bytes();
167        let key = Key::from_bytes(&encoded).unwrap();
168
169        let cipher = Cipher::new(&key);
170
171        let encrypted = cipher.encrypt(&msg).unwrap();
172        let decryped = cipher.decrypt(&encrypted).unwrap();
173
174        assert_eq!(*decryped, *msg);
175    }
176}