1use std::path::Path;
42use std::sync::Arc;
43
44use aes_gcm::aead::{Aead, KeyInit, Payload};
45use aes_gcm::{Aes256Gcm, Key, Nonce};
46use bytes::Bytes;
47use rand::RngCore;
48use thiserror::Error;
49
50pub const SSE_MAGIC: &[u8; 4] = b"S4E1";
51pub const SSE_HEADER_BYTES: usize = 4 + 1 + 3 + 12 + 16; pub const ALGO_AES_256_GCM: u8 = 1;
53const NONCE_LEN: usize = 12;
54const TAG_LEN: usize = 16;
55const KEY_LEN: usize = 32;
56
57#[derive(Debug, Error)]
58pub enum SseError {
59 #[error("SSE key file {path:?}: {source}")]
60 KeyFileIo {
61 path: std::path::PathBuf,
62 source: std::io::Error,
63 },
64 #[error(
65 "SSE key file must be exactly 32 raw bytes (or 64-char hex / 44-char base64); got {got} bytes after parse"
66 )]
67 BadKeyLength { got: usize },
68 #[error("SSE-encrypted body too short ({got} bytes; need at least {SSE_HEADER_BYTES})")]
69 TooShort { got: usize },
70 #[error("SSE bad magic: expected S4E1, got {got:?}")]
71 BadMagic { got: [u8; 4] },
72 #[error("SSE unsupported algo tag: {tag} (this build only knows AES-256-GCM = 1)")]
73 UnsupportedAlgo { tag: u8 },
74 #[error("SSE decryption / authentication failed (key mismatch or ciphertext tampered with)")]
75 DecryptFailed,
76}
77
78#[derive(Clone)]
81pub struct SseKey(Arc<[u8; KEY_LEN]>);
82
83impl SseKey {
84 pub fn from_path(path: &Path) -> Result<Self, SseError> {
88 let raw = std::fs::read(path).map_err(|source| SseError::KeyFileIo {
89 path: path.to_path_buf(),
90 source,
91 })?;
92 Self::from_bytes(&raw)
93 }
94
95 pub fn from_bytes(bytes: &[u8]) -> Result<Self, SseError> {
96 if bytes.len() == KEY_LEN {
98 let mut k = [0u8; KEY_LEN];
99 k.copy_from_slice(bytes);
100 return Ok(Self(Arc::new(k)));
101 }
102 let s = std::str::from_utf8(bytes).unwrap_or("").trim();
104 if s.len() == KEY_LEN * 2 && s.chars().all(|c| c.is_ascii_hexdigit()) {
105 let mut k = [0u8; KEY_LEN];
106 for (i, k_byte) in k.iter_mut().enumerate() {
107 *k_byte = u8::from_str_radix(&s[i * 2..i * 2 + 2], 16)
108 .map_err(|_| SseError::BadKeyLength { got: bytes.len() })?;
109 }
110 return Ok(Self(Arc::new(k)));
111 }
112 if let Ok(decoded) =
113 base64::Engine::decode(&base64::engine::general_purpose::STANDARD, s.as_bytes())
114 && decoded.len() == KEY_LEN
115 {
116 let mut k = [0u8; KEY_LEN];
117 k.copy_from_slice(&decoded);
118 return Ok(Self(Arc::new(k)));
119 }
120 Err(SseError::BadKeyLength { got: bytes.len() })
121 }
122
123 fn as_aes_key(&self) -> &Key<Aes256Gcm> {
124 Key::<Aes256Gcm>::from_slice(self.0.as_ref())
125 }
126}
127
128impl std::fmt::Debug for SseKey {
129 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
130 f.debug_struct("SseKey")
131 .field("len", &KEY_LEN)
132 .field("key", &"<redacted>")
133 .finish()
134 }
135}
136
137pub fn encrypt(key: &SseKey, plaintext: &[u8]) -> Bytes {
140 let cipher = Aes256Gcm::new(key.as_aes_key());
141 let mut nonce_bytes = [0u8; NONCE_LEN];
142 rand::rngs::OsRng.fill_bytes(&mut nonce_bytes);
143 let nonce = Nonce::from_slice(&nonce_bytes);
144 let mut aad = [0u8; 8];
148 aad[..4].copy_from_slice(SSE_MAGIC);
149 aad[4] = ALGO_AES_256_GCM;
150 let ct_with_tag = cipher
151 .encrypt(
152 nonce,
153 Payload {
154 msg: plaintext,
155 aad: &aad,
156 },
157 )
158 .expect("aes-gcm encrypt cannot fail with a 32-byte key");
159 debug_assert!(ct_with_tag.len() >= TAG_LEN);
161 let split = ct_with_tag.len() - TAG_LEN;
162 let (ct, tag) = ct_with_tag.split_at(split);
163
164 let mut out = Vec::with_capacity(SSE_HEADER_BYTES + ct.len());
165 out.extend_from_slice(SSE_MAGIC);
166 out.push(ALGO_AES_256_GCM);
167 out.extend_from_slice(&[0u8; 3]); out.extend_from_slice(&nonce_bytes);
169 out.extend_from_slice(tag);
170 out.extend_from_slice(ct);
171 Bytes::from(out)
172}
173
174pub fn decrypt(key: &SseKey, body: &[u8]) -> Result<Bytes, SseError> {
179 if body.len() < SSE_HEADER_BYTES {
180 return Err(SseError::TooShort { got: body.len() });
181 }
182 let mut magic = [0u8; 4];
183 magic.copy_from_slice(&body[..4]);
184 if &magic != SSE_MAGIC {
185 return Err(SseError::BadMagic { got: magic });
186 }
187 let algo = body[4];
188 if algo != ALGO_AES_256_GCM {
189 return Err(SseError::UnsupportedAlgo { tag: algo });
190 }
191 let mut nonce_bytes = [0u8; NONCE_LEN];
193 nonce_bytes.copy_from_slice(&body[8..8 + NONCE_LEN]);
194 let mut tag_bytes = [0u8; TAG_LEN];
195 tag_bytes.copy_from_slice(&body[8 + NONCE_LEN..SSE_HEADER_BYTES]);
196 let ct = &body[SSE_HEADER_BYTES..];
197
198 let cipher = Aes256Gcm::new(key.as_aes_key());
199 let nonce = Nonce::from_slice(&nonce_bytes);
200 let mut aad = [0u8; 8];
201 aad[..4].copy_from_slice(SSE_MAGIC);
202 aad[4] = ALGO_AES_256_GCM;
203 let mut ct_with_tag = Vec::with_capacity(ct.len() + TAG_LEN);
204 ct_with_tag.extend_from_slice(ct);
205 ct_with_tag.extend_from_slice(&tag_bytes);
206 let plain = cipher
207 .decrypt(
208 nonce,
209 Payload {
210 msg: &ct_with_tag,
211 aad: &aad,
212 },
213 )
214 .map_err(|_| SseError::DecryptFailed)?;
215 Ok(Bytes::from(plain))
216}
217
218pub fn looks_encrypted(body: &[u8]) -> bool {
222 body.len() >= SSE_HEADER_BYTES && &body[..4] == SSE_MAGIC
223}
224
225pub type SharedSseKey = Arc<SseKey>;
226
227#[cfg(test)]
228mod tests {
229 use super::*;
230
231 fn key32() -> SseKey {
232 SseKey::from_bytes(&[7u8; 32]).unwrap()
233 }
234
235 #[test]
236 fn roundtrip_basic() {
237 let k = key32();
238 let pt = b"the quick brown fox jumps over the lazy dog";
239 let ct = encrypt(&k, pt);
240 assert!(looks_encrypted(&ct));
241 assert_eq!(&ct[..4], SSE_MAGIC);
242 assert_eq!(ct[4], ALGO_AES_256_GCM);
243 assert_eq!(ct.len(), SSE_HEADER_BYTES + pt.len());
244 let pt2 = decrypt(&k, &ct).unwrap();
245 assert_eq!(pt2.as_ref(), pt);
246 }
247
248 #[test]
249 fn wrong_key_fails() {
250 let k1 = SseKey::from_bytes(&[1u8; 32]).unwrap();
251 let k2 = SseKey::from_bytes(&[2u8; 32]).unwrap();
252 let ct = encrypt(&k1, b"secret");
253 let err = decrypt(&k2, &ct).unwrap_err();
254 assert!(matches!(err, SseError::DecryptFailed));
255 }
256
257 #[test]
258 fn tampered_ciphertext_fails() {
259 let k = key32();
260 let mut ct = encrypt(&k, b"secret message").to_vec();
261 let last = ct.len() - 1;
263 ct[last] ^= 0x01;
264 let err = decrypt(&k, &ct).unwrap_err();
265 assert!(matches!(err, SseError::DecryptFailed));
266 }
267
268 #[test]
269 fn tampered_algo_byte_fails() {
270 let k = key32();
271 let mut ct = encrypt(&k, b"secret").to_vec();
272 ct[4] = 99; let err = decrypt(&k, &ct).unwrap_err();
274 assert!(matches!(err, SseError::UnsupportedAlgo { tag: 99 }));
275 }
276
277 #[test]
278 fn rejects_short_body() {
279 let k = key32();
280 let err = decrypt(&k, b"short").unwrap_err();
281 assert!(matches!(err, SseError::TooShort { got: 5 }));
282 }
283
284 #[test]
285 fn looks_encrypted_passthrough_returns_false() {
286 assert!(!looks_encrypted(b"S4F2\x01\x00\x00\x00........"));
288 assert!(!looks_encrypted(b""));
289 }
290
291 #[test]
292 fn key_from_hex_string() {
293 let k =
294 SseKey::from_bytes(b"0102030405060708090a0b0c0d0e0f10111213141516171819202122232425")
295 .unwrap_err();
296 assert!(matches!(k, SseError::BadKeyLength { .. }));
298 let good = b"0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef";
299 let _ = SseKey::from_bytes(good).expect("64-char hex should parse");
300 }
301
302 #[test]
303 fn encrypt_uses_random_nonce() {
304 let k = key32();
307 let pt = b"deterministic input";
308 let a = encrypt(&k, pt);
309 let b = encrypt(&k, pt);
310 assert_ne!(a, b, "nonce must be random per-call");
311 }
312}