Skip to main content

s2_common/
encryption.rs

1//! Encryption spec parsing, header parsing, and key parsing.
2
3use core::str::FromStr;
4use std::sync::Arc;
5
6use base64ct::Encoding;
7use http::{HeaderName, HeaderValue};
8use secrecy::{ExposeSecret, SecretBox, zeroize::Zeroizing};
9use strum::{Display, EnumString};
10
11use crate::http::ParseableHeader;
12
13pub static S2_ENCRYPTION_HEADER: HeaderName = HeaderName::from_static("s2-encryption");
14
15type EncryptionKey<const N: usize> = Arc<SecretBox<[u8; N]>>;
16
17/// Encryption algorithm.
18#[derive(Debug, Clone, Copy, PartialEq, Eq, Display, EnumString)]
19#[strum(ascii_case_insensitive)]
20pub enum EncryptionAlgorithm {
21    /// AEGIS-256
22    #[strum(serialize = "aegis-256")]
23    Aegis256,
24    /// AES-256-GCM
25    #[strum(serialize = "aes-256-gcm")]
26    Aes256Gcm,
27}
28
29/// Encryption mode, including plaintext.
30#[derive(
31    Debug,
32    Clone,
33    Copy,
34    PartialEq,
35    Eq,
36    Hash,
37    serde::Serialize,
38    serde::Deserialize,
39    Display,
40    EnumString,
41    enumset::EnumSetType,
42)]
43#[strum(ascii_case_insensitive)]
44#[cfg_attr(feature = "clap", derive(clap::ValueEnum))]
45#[enumset(no_super_impls)]
46pub enum EncryptionMode {
47    #[strum(serialize = "plain")]
48    #[serde(rename = "plain")]
49    Plain,
50    #[strum(serialize = "aegis-256")]
51    #[serde(rename = "aegis-256")]
52    #[cfg_attr(feature = "clap", value(name = "aegis-256"))]
53    Aegis256,
54    #[strum(serialize = "aes-256-gcm")]
55    #[serde(rename = "aes-256-gcm")]
56    #[cfg_attr(feature = "clap", value(name = "aes-256-gcm"))]
57    Aes256Gcm,
58}
59
60impl From<EncryptionAlgorithm> for EncryptionMode {
61    fn from(value: EncryptionAlgorithm) -> Self {
62        match value {
63            EncryptionAlgorithm::Aegis256 => Self::Aegis256,
64            EncryptionAlgorithm::Aes256Gcm => Self::Aes256Gcm,
65        }
66    }
67}
68
69#[derive(Debug, Clone)]
70pub struct Aegis256Key(EncryptionKey<32>);
71
72impl Aegis256Key {
73    pub fn new(key: [u8; 32]) -> Self {
74        Self(Arc::new(SecretBox::new(Box::new(key))))
75    }
76
77    pub fn from_base64(key_b64: &str) -> Result<Self, EncryptionSpecError> {
78        parse_encryption_key::<32>(key_b64).map(Self)
79    }
80
81    pub(crate) fn secret(&self) -> &[u8; 32] {
82        self.0.as_ref().expose_secret()
83    }
84}
85
86#[derive(Debug, Clone)]
87pub struct Aes256GcmKey(EncryptionKey<32>);
88
89impl Aes256GcmKey {
90    pub fn new(key: [u8; 32]) -> Self {
91        Self(Arc::new(SecretBox::new(Box::new(key))))
92    }
93
94    pub fn from_base64(key_b64: &str) -> Result<Self, EncryptionSpecError> {
95        parse_encryption_key::<32>(key_b64).map(Self)
96    }
97
98    pub(crate) fn secret(&self) -> &[u8; 32] {
99        self.0.as_ref().expose_secret()
100    }
101}
102
103#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)]
104pub enum EncryptionSpecError {
105    #[error("Invalid encryption spec: expected '<mode>; <key>' or 'plain'")]
106    InvalidSyntax,
107    #[error("Invalid encryption spec: missing encryption mode")]
108    MissingMode,
109    #[error(
110        "Invalid encryption spec: unknown encryption mode {mode:?}; expected 'plain', 'aegis-256', or 'aes-256-gcm'"
111    )]
112    UnknownMode { mode: String },
113    #[error("Invalid encryption spec: key is not allowed when mode is 'plain'")]
114    UnexpectedKeyForPlain,
115    #[error("Invalid encryption spec: missing key for '{mode}'")]
116    MissingKey { mode: EncryptionMode },
117    #[error("Invalid encryption spec: key is not valid base64")]
118    InvalidKeyBase64,
119    #[error("Invalid encryption spec: key must be exactly {expected} bytes, got {actual} bytes")]
120    InvalidKeyLength { expected: usize, actual: usize },
121}
122
123#[derive(Debug, Clone, Default)]
124pub enum EncryptionSpec {
125    #[default]
126    Plain,
127    Aegis256(Aegis256Key),
128    Aes256Gcm(Aes256GcmKey),
129}
130
131impl EncryptionSpec {
132    pub fn aegis256(key: [u8; 32]) -> Self {
133        Self::Aegis256(Aegis256Key::new(key))
134    }
135
136    pub fn aes256_gcm(key: [u8; 32]) -> Self {
137        Self::Aes256Gcm(Aes256GcmKey::new(key))
138    }
139
140    pub fn mode(&self) -> EncryptionMode {
141        match self {
142            Self::Plain => EncryptionMode::Plain,
143            Self::Aegis256(_) => EncryptionMode::Aegis256,
144            Self::Aes256Gcm(_) => EncryptionMode::Aes256Gcm,
145        }
146    }
147
148    pub fn to_header_value(&self) -> HeaderValue {
149        let mut value = match self {
150            Self::Plain => HeaderValue::from_static("plain"),
151            Self::Aegis256(key) => {
152                header_value_for_key(EncryptionAlgorithm::Aegis256, key.secret())
153            }
154            Self::Aes256Gcm(key) => {
155                header_value_for_key(EncryptionAlgorithm::Aes256Gcm, key.secret())
156            }
157        };
158        value.set_sensitive(true);
159        value
160    }
161}
162
163impl FromStr for EncryptionSpec {
164    type Err = EncryptionSpecError;
165
166    fn from_str(s: &str) -> Result<Self, Self::Err> {
167        let s = s.trim();
168        let mut parts = s.splitn(3, ';');
169        let mode_str = parts.next().unwrap_or_default().trim();
170        let key_b64 = parts.next().map(str::trim);
171        if parts.next().is_some() {
172            return Err(EncryptionSpecError::InvalidSyntax);
173        }
174
175        if mode_str.is_empty() {
176            return Err(EncryptionSpecError::MissingMode);
177        }
178
179        let key_b64 = key_b64.filter(|key| !key.is_empty());
180        match (parse_mode(mode_str)?, key_b64) {
181            (EncryptionMode::Plain, None) => Ok(Self::Plain),
182            (EncryptionMode::Plain, Some(_)) => Err(EncryptionSpecError::UnexpectedKeyForPlain),
183            (EncryptionMode::Aegis256, Some(key_b64)) => {
184                Ok(Self::Aegis256(Aegis256Key::from_base64(key_b64)?))
185            }
186            (EncryptionMode::Aegis256, None) => Err(EncryptionSpecError::MissingKey {
187                mode: EncryptionMode::Aegis256,
188            }),
189            (EncryptionMode::Aes256Gcm, Some(key_b64)) => {
190                Ok(Self::Aes256Gcm(Aes256GcmKey::from_base64(key_b64)?))
191            }
192            (EncryptionMode::Aes256Gcm, None) => Err(EncryptionSpecError::MissingKey {
193                mode: EncryptionMode::Aes256Gcm,
194            }),
195        }
196    }
197}
198
199impl ParseableHeader for EncryptionSpec {
200    fn name() -> &'static HeaderName {
201        &S2_ENCRYPTION_HEADER
202    }
203}
204
205fn parse_encryption_key<const N: usize>(
206    key_b64: &str,
207) -> Result<EncryptionKey<N>, EncryptionSpecError> {
208    use base64ct::{Base64, Encoding};
209    use secrecy::zeroize::Zeroize;
210
211    let mut key = Box::new([0u8; N]);
212    let decoded = match Base64::decode(key_b64, key.as_mut()) {
213        Ok(decoded) => decoded,
214        Err(_) => {
215            key.as_mut().zeroize();
216            return Err(EncryptionSpecError::InvalidKeyBase64);
217        }
218    };
219
220    if decoded.len() != N {
221        let len = decoded.len();
222        key.as_mut().zeroize();
223        return Err(EncryptionSpecError::InvalidKeyLength {
224            expected: N,
225            actual: len,
226        });
227    }
228
229    Ok(Arc::new(SecretBox::new(key)))
230}
231
232fn header_value_for_key(algorithm: EncryptionAlgorithm, key: &[u8; 32]) -> HeaderValue {
233    let algorithm = algorithm.to_string();
234    let encoded_len = base64ct::Base64::encoded_len(key);
235    let mut value = Zeroizing::new(vec![0u8; algorithm.len() + 2 + encoded_len]);
236    value[..algorithm.len()].copy_from_slice(algorithm.as_bytes());
237    value[algorithm.len()..algorithm.len() + 2].copy_from_slice(b"; ");
238    base64ct::Base64::encode(key, &mut value[algorithm.len() + 2..])
239        .expect("base64 output length should match buffer");
240
241    HeaderValue::from_bytes(&value).expect("encryption header value should be ASCII")
242}
243
244fn parse_mode(mode_str: &str) -> Result<EncryptionMode, EncryptionSpecError> {
245    mode_str
246        .parse::<EncryptionMode>()
247        .map_err(|_| EncryptionSpecError::UnknownMode {
248            mode: mode_str.to_owned(),
249        })
250}
251
252#[cfg(test)]
253mod tests {
254    use http::header::HeaderValue;
255    use rstest::rstest;
256
257    use super::*;
258
259    const KEY_B64: &str = "AQIDBAUGBwgJCgsMDQ4PEBESExQVFhcYGRobHB0eHyA=";
260    const KEY_BYTES: [u8; 32] = [
261        1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25,
262        26, 27, 28, 29, 30, 31, 32,
263    ];
264
265    fn assert_encrypted_spec(
266        spec: EncryptionSpec,
267        algorithm: EncryptionAlgorithm,
268        expected: &[u8; 32],
269    ) {
270        match (algorithm, spec) {
271            (EncryptionAlgorithm::Aegis256, EncryptionSpec::Aegis256(key)) => {
272                assert_eq!(key.secret(), expected)
273            }
274            (EncryptionAlgorithm::Aes256Gcm, EncryptionSpec::Aes256Gcm(key)) => {
275                assert_eq!(key.secret(), expected)
276            }
277            (_, EncryptionSpec::Plain) => panic!("expected encrypted spec"),
278            (expected_algorithm, actual_spec) => {
279                panic!("expected {expected_algorithm:?}, got {actual_spec:?}")
280            }
281        }
282    }
283
284    fn assert_invalid_parse(header: &str, expected: EncryptionSpecError) {
285        let result = header.parse::<EncryptionSpec>();
286        match result {
287            Err(actual) => assert_eq!(actual, expected),
288            Ok(actual) => panic!("expected invalid spec for {header:?}, got {actual:?}"),
289        }
290    }
291
292    #[rstest]
293    #[case("aegis-256", EncryptionAlgorithm::Aegis256)]
294    #[case("aes-256-gcm", EncryptionAlgorithm::Aes256Gcm)]
295    #[case("AEGIS-256", EncryptionAlgorithm::Aegis256)]
296    #[case("AES-256-GCM", EncryptionAlgorithm::Aes256Gcm)]
297    fn parse_header_valid_encrypted(
298        #[case] algorithm: &str,
299        #[case] expected: EncryptionAlgorithm,
300    ) {
301        let spec = format!("{algorithm}; {KEY_B64}")
302            .parse::<EncryptionSpec>()
303            .unwrap();
304        assert_encrypted_spec(spec, expected, &KEY_BYTES);
305    }
306
307    #[test]
308    fn parse_header_aes_with_whitespace() {
309        let spec = format!(" aes-256-gcm ; {KEY_B64} ")
310            .parse::<EncryptionSpec>()
311            .unwrap();
312        assert_encrypted_spec(spec, EncryptionAlgorithm::Aes256Gcm, &KEY_BYTES);
313    }
314
315    #[rstest]
316    #[case("plain")]
317    #[case("PLAIN")]
318    #[case("plain; ")]
319    fn parse_header_plain_variants(#[case] header: &str) {
320        let spec = header.parse::<EncryptionSpec>().unwrap();
321        assert!(matches!(spec, EncryptionSpec::Plain));
322    }
323
324    #[test]
325    fn spec_mode_matches_variant() {
326        assert_eq!(EncryptionSpec::Plain.mode(), EncryptionMode::Plain);
327        assert_eq!(
328            EncryptionSpec::aegis256(KEY_BYTES).mode(),
329            EncryptionMode::Aegis256
330        );
331        assert_eq!(
332            EncryptionSpec::aes256_gcm(KEY_BYTES).mode(),
333            EncryptionMode::Aes256Gcm
334        );
335    }
336
337    #[rstest]
338    #[case(EncryptionMode::Plain, "\"plain\"")]
339    #[case(EncryptionMode::Aegis256, "\"aegis-256\"")]
340    #[case(EncryptionMode::Aes256Gcm, "\"aes-256-gcm\"")]
341    fn mode_serde_roundtrip(#[case] mode: EncryptionMode, #[case] expected: &str) {
342        let serialized = serde_json::to_string(&mode).unwrap();
343        assert_eq!(serialized, expected);
344        let deserialized: EncryptionMode = serde_json::from_str(expected).unwrap();
345        assert_eq!(deserialized, mode);
346    }
347
348    #[rstest]
349    #[case(EncryptionMode::Plain, "plain")]
350    #[case(EncryptionMode::Aegis256, "aegis-256")]
351    #[case(EncryptionMode::Aes256Gcm, "aes-256-gcm")]
352    fn mode_display_matches_spec(#[case] mode: EncryptionMode, #[case] expected: &str) {
353        assert_eq!(mode.to_string(), expected);
354    }
355
356    #[rstest]
357    #[case("", EncryptionSpecError::MissingMode)]
358    #[case(
359        "aegis-256",
360        EncryptionSpecError::MissingKey {
361            mode: EncryptionMode::Aegis256
362        }
363    )]
364    #[case(
365        "aegis-256; AQIDBAUGBwgJCgsMDQ4PEBESExQVFhcYGRobHB0eHyA=; extra",
366        EncryptionSpecError::InvalidSyntax
367    )]
368    #[case(
369        "aegis-256; 3q2+7w==",
370        EncryptionSpecError::InvalidKeyLength {
371            expected: 32,
372            actual: 4
373        }
374    )]
375    #[case(
376        "aegis-256; not-valid-base64!!!",
377        EncryptionSpecError::InvalidKeyBase64
378    )]
379    #[case(
380        "bogus; AQIDBAUGBwgJCgsMDQ4PEBESExQVFhcYGRobHB0eHyA=",
381        EncryptionSpecError::UnknownMode {
382            mode: "bogus".to_owned()
383        }
384    )]
385    #[case(
386        "plain; AQIDBAUGBwgJCgsMDQ4PEBESExQVFhcYGRobHB0eHyA=",
387        EncryptionSpecError::UnexpectedKeyForPlain
388    )]
389    fn parse_header_invalid_cases(#[case] header: &str, #[case] expected: EncryptionSpecError) {
390        assert_invalid_parse(header, expected);
391    }
392
393    #[test]
394    fn header_value_is_sensitive() {
395        let value = EncryptionSpec::aegis256([7; 32]).to_header_value();
396        assert!(value.is_sensitive());
397        assert_ne!(value, HeaderValue::from_static("plain"));
398    }
399
400    #[test]
401    fn plain_header_value_roundtrips() {
402        let value = EncryptionSpec::Plain.to_header_value();
403        assert_eq!(value.to_str().unwrap(), "plain");
404        assert!(value.is_sensitive());
405
406        let parsed = value.to_str().unwrap().parse::<EncryptionSpec>().unwrap();
407        assert!(matches!(parsed, EncryptionSpec::Plain));
408    }
409
410    #[rstest]
411    #[case(EncryptionAlgorithm::Aegis256)]
412    #[case(EncryptionAlgorithm::Aes256Gcm)]
413    fn encrypted_header_value_roundtrips(#[case] algorithm: EncryptionAlgorithm) {
414        let value = match algorithm {
415            EncryptionAlgorithm::Aegis256 => EncryptionSpec::aegis256(KEY_BYTES),
416            EncryptionAlgorithm::Aes256Gcm => EncryptionSpec::aes256_gcm(KEY_BYTES),
417        }
418        .to_header_value();
419        assert_eq!(value.to_str().unwrap(), format!("{algorithm}; {KEY_B64}"));
420        assert!(value.is_sensitive());
421
422        let parsed = value.to_str().unwrap().parse::<EncryptionSpec>().unwrap();
423        assert_encrypted_spec(parsed, algorithm, &KEY_BYTES);
424    }
425}