1#[cfg(test)]
5mod tests;
6
7use aes_gcm::{
8 aead::{consts::U12, generic_array::typenum::Unsigned, Tag},
9 AeadInPlace, Aes128Gcm, KeyInit, Nonce,
10};
11use hkdf::Hkdf;
12use sha2::Sha256;
13
14#[derive(Debug)]
16pub enum Error {
17 HeaderLengthInvalid,
19 KeyIdLengthInvalid,
21 RecordLengthInvalid,
23 PaddingInvalid,
25 Aes128Gcm,
27}
28
29impl std::error::Error for Error {}
30
31impl std::fmt::Display for Error {
32 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
33 std::fmt::Debug::fmt(self, f)
34 }
35}
36
37fn derive_key<IKM: AsRef<[u8]>>(salt: [u8; 16], ikm: IKM) -> aes_gcm::Key<Aes128Gcm> {
38 let info = b"Content-Encoding: aes128gcm\0";
39 let mut okm = [0u8; 16];
40 let hk = Hkdf::<Sha256>::new(Some(&salt), ikm.as_ref());
41 hk.expand(info, &mut okm)
42 .expect("okm length is always 16, impossile for it to be too large");
43
44 aes_gcm::Key::<Aes128Gcm>::from(okm)
45}
46
47fn derive_nonce<IKM: AsRef<[u8]>>(salt: [u8; 16], ikm: IKM, seq: [u8; 12]) -> Nonce<U12> {
48 let info = b"Content-Encoding: nonce\0";
49 let mut okm = [0u8; 12];
50 let hk = Hkdf::<Sha256>::new(Some(salt.as_ref()), ikm.as_ref());
51 hk.expand(info, &mut okm)
52 .expect("okm length is always 12, impossile for it to be too large");
53
54 for i in 0..12 {
55 okm[i] ^= seq[i]
56 }
57
58 Nonce::from(okm)
59}
60
61fn generate_encryption_header<KI: AsRef<[u8]>>(
62 salt: [u8; 16],
63 record_size: u32,
64 keyid: KI,
65) -> Result<Vec<u8>, Error> {
66 let mut header = Vec::new();
67 header.extend_from_slice(&salt[..]);
68 header.extend_from_slice(&record_size.to_be_bytes());
69 let keyid = keyid.as_ref();
70 header.push(
71 keyid
72 .len()
73 .try_into()
74 .map_err(|_| Error::KeyIdLengthInvalid)?,
75 );
76 header.extend_from_slice(keyid);
77
78 Ok(header)
79}
80
81fn encrypt_record<B: aes_gcm::aead::Buffer>(
82 key: &aes_gcm::Key<Aes128Gcm>,
83 nonce: &Nonce<U12>,
84 mut record: B,
85 encrypted_record_size: u32,
86 is_last: bool,
87) -> Result<B, Error> {
88 let plain_record_size: u32 = record
89 .len()
90 .try_into()
91 .map_err(|_| Error::RecordLengthInvalid)?;
92
93 if plain_record_size >= encrypted_record_size - 16 {
94 return Err(Error::RecordLengthInvalid);
95 }
96
97 if is_last {
98 record
99 .extend_from_slice(b"\x02")
100 .map_err(|_| Error::Aes128Gcm)?;
101 } else {
102 let pad_len = encrypted_record_size - plain_record_size - 16;
103 record
104 .extend_from_slice(b"\x01")
105 .map_err(|_| Error::Aes128Gcm)?;
106 record
107 .extend_from_slice(
108 &b"\x00".repeat(
109 (pad_len - 1).try_into().expect(
110 "padding length is between 0 and 15 which will always fit into usize",
111 ),
112 ),
113 )
114 .map_err(|_| Error::Aes128Gcm)?;
115 }
116
117 Aes128Gcm::new(key)
118 .encrypt_in_place(nonce, b"", &mut record)
119 .map_err(|_| Error::Aes128Gcm)?;
120
121 Ok(record)
122}
123
124pub fn encrypt<IKM: AsRef<[u8]>, KI: AsRef<[u8]>, R: Iterator<Item = Vec<u8>>>(
126 ikm: IKM,
127 salt: [u8; 16],
128 keyid: KI,
129 records: R,
130 encrypted_record_size: u32,
131) -> Result<Vec<u8>, Error> {
132 let header = generate_encryption_header(salt, encrypted_record_size, keyid.as_ref())?;
133
134 let records = records.enumerate().map(|(n, record)| {
135 let mut seq = [0u8; 12];
136 seq[4..].copy_from_slice(&n.to_be_bytes());
137 let key = derive_key(salt, ikm.as_ref());
138 let nonce = derive_nonce(salt, ikm.as_ref(), seq);
139 (key, nonce, record)
140 });
141
142 let mut output = Vec::new();
143 output.extend_from_slice(&header);
144
145 let mut peekable = records.peekable();
146 while let Some((key, nonce, record)) = peekable.next() {
147 let is_last_record = peekable.peek().is_none();
148 let record = encrypt_record(&key, &nonce, record, encrypted_record_size, is_last_record)?;
149 output.extend_from_slice(&record);
150 }
151
152 Ok(output)
153}
154
155fn decrypt_record<'a>(
156 key: &aes_gcm::Key<Aes128Gcm>,
157 nonce: &Nonce<U12>,
158 record: &'a mut [u8],
159 is_last: bool,
160) -> Result<&'a [u8], Error> {
161 if record.len() < <Aes128Gcm as aes_gcm::AeadCore>::TagSize::to_usize() {
162 return Err(Error::RecordLengthInvalid);
163 }
164 let tag_pos = record.len() - <Aes128Gcm as aes_gcm::AeadCore>::TagSize::to_usize();
165 let (msg, tag) = record.as_mut().split_at_mut(tag_pos);
166
167 Aes128Gcm::new(key)
168 .decrypt_in_place_detached(nonce, b"", msg, Tag::<Aes128Gcm>::from_slice(tag))
169 .map_err(|_| Error::Aes128Gcm)?;
170
171 let pad_index = msg
172 .as_ref()
173 .iter()
174 .rposition(|it| *it != 0)
175 .ok_or(Error::PaddingInvalid)?;
176 match msg[pad_index] {
177 2 if !is_last => Err(Error::PaddingInvalid),
178 1 if is_last => Err(Error::PaddingInvalid),
179 _ => Ok(&msg[..pad_index]),
180 }
181}
182
183pub fn decrypt<IKM: AsRef<[u8]>>(
185 ikm: IKM,
186 mut encrypted_message: Vec<u8>,
187) -> Result<Vec<u8>, Error> {
188 if encrypted_message.len() < 21 {
189 return Err(Error::HeaderLengthInvalid);
190 }
191
192 let (header, keyid_and_records) = encrypted_message.split_at_mut(21);
193 let salt = header[..16].try_into().expect(
194 "casting a slice of fixed length to an array of the same length will always succeed",
195 );
196 let encrypted_record_size = u32::from_be_bytes(header[16..16 + 4].try_into().expect(
197 "casting a slice of fixed length to an array of the same length will always succeed",
198 ));
199 let idlen = header[20].into();
200
201 if keyid_and_records.len() < idlen {
202 return Err(Error::KeyIdLengthInvalid);
203 }
204
205 let (_, records) = keyid_and_records.split_at_mut(idlen);
206 let all_records_len = records.len();
207 let records = records
208 .chunks_mut(
209 encrypted_record_size
210 .try_into()
211 .map_err(|_| Error::RecordLengthInvalid)?,
212 )
213 .enumerate()
214 .map(|(n, record)| {
215 let mut seq = [0u8; 12];
216 seq[4..].copy_from_slice(&n.to_be_bytes());
217 let key = derive_key(salt, ikm.as_ref());
218 let nonce = derive_nonce(salt, ikm.as_ref(), seq);
219 (key, nonce, record)
220 });
221
222 let mut output = Vec::with_capacity(all_records_len);
223
224 let mut peekable = records.peekable();
225 while let Some((key, nonce, record)) = peekable.next() {
226 let is_last_record = peekable.peek().is_none();
227 let plaintext = decrypt_record(&key, &nonce, record, is_last_record)?;
228 output.extend_from_slice(plaintext)
229 }
230
231 Ok(output)
232}