mountpoint_s3_fs/fs/
sse.rs

1use mountpoint_s3_client::checksums::crc32c::{self, Crc32c};
2use thiserror::Error;
3
4/// Server-side encryption configuration for newly created objects
5#[derive(Debug, Clone)]
6pub struct ServerSideEncryption {
7    sse_type: Option<String>,
8    sse_kms_key_id: Option<String>,
9    checksum: Crc32c,
10}
11
12#[derive(Debug, Error)]
13pub enum SseCorruptedError {
14    #[error("Checksum mismatch. expected: {0:?}, actual: {1:?}")]
15    ChecksumMismatch(Crc32c, Crc32c),
16    #[error("SSE type mismatch. expected: {0:?}, actual: {1:?}")]
17    TypeMismatch(String, Option<String>),
18    #[error("SSE KMS key ID mismatch. expected: {0:?}, actual: {1:?}")]
19    KeyMismatch(String, Option<String>),
20}
21
22impl Default for ServerSideEncryption {
23    fn default() -> Self {
24        Self {
25            sse_type: Default::default(),
26            sse_kms_key_id: Default::default(),
27            checksum: Crc32c::new(0),
28        }
29    }
30}
31
32impl ServerSideEncryption {
33    /// Construct SSE settings from raw values provided via CLI
34    pub fn new(sse_type: Option<String>, sse_kms_key_id: Option<String>) -> Self {
35        let checksum = Self::compute_checksum(sse_type.as_deref(), sse_kms_key_id.as_deref());
36        Self {
37            sse_type,
38            sse_kms_key_id,
39            checksum,
40        }
41    }
42
43    /// Computes the checksum of SSE settings by combining two strings containing the type and the key
44    /// Note, that this implementation yields the same result for Some("") and None, but we may safely
45    /// assume that it will never be called with an empty string as one of its parameters.
46    fn compute_checksum(sse_type: Option<&str>, sse_kms_key_id: Option<&str>) -> Crc32c {
47        let mut hasher = crc32c::Hasher::new();
48        if let Some(maybe_sse_type) = sse_type {
49            hasher.update(maybe_sse_type.as_bytes());
50        }
51        if let Some(maybe_sse_kms_key_id) = sse_kms_key_id {
52            hasher.update(maybe_sse_kms_key_id.as_bytes());
53        }
54        hasher.finalize()
55    }
56
57    fn validate(&self) -> Result<(), SseCorruptedError> {
58        let computed = Self::compute_checksum(self.sse_type.as_deref(), self.sse_kms_key_id.as_deref());
59        if computed == self.checksum {
60            Ok(())
61        } else {
62            Err(SseCorruptedError::ChecksumMismatch(self.checksum, computed))
63        }
64    }
65
66    /// Checks that SSE settings still match the checksum and returns the string representations of:
67    /// 1. the SSE type as it is expected by S3 API;
68    /// 2. and AWS KMS Key ID, if provided.
69    pub fn into_inner(self) -> Result<(Option<String>, Option<String>), SseCorruptedError> {
70        self.validate()?;
71        Ok((self.sse_type, self.sse_kms_key_id))
72    }
73
74    /// Checks that values provided as arguments to this function match the values stored in the object.
75    /// S3 will return some values for sse type and key even if they were not set on our side.
76    /// We want to check only the values which we set.
77    pub fn verify_response(
78        &self,
79        sse_type: Option<&str>,
80        sse_kms_key_id: Option<&str>,
81    ) -> Result<(), SseCorruptedError> {
82        self.validate()?; // validate in-memory values, as we are using them to decide whether to skip the response check or not
83        if self.sse_type.is_some() && self.sse_type.as_deref() != sse_type {
84            return Err(SseCorruptedError::TypeMismatch(
85                self.sse_type.as_ref().unwrap().clone(),
86                sse_type.map(str::to_string),
87            ));
88        }
89        if self.sse_kms_key_id.is_some() && self.sse_kms_key_id.as_deref() != sse_kms_key_id {
90            return Err(SseCorruptedError::KeyMismatch(
91                self.sse_kms_key_id.as_ref().unwrap().clone(),
92                sse_kms_key_id.map(str::to_string),
93            ));
94        }
95        Ok(())
96    }
97
98    #[cfg(test)]
99    pub fn corrupt_data(&mut self, sse_type: Option<String>, sse_kms_key_id: Option<String>) {
100        self.sse_type = sse_type;
101        self.sse_kms_key_id = sse_kms_key_id;
102    }
103}
104
105#[cfg(test)]
106mod tests {
107    use super::*;
108    use test_case::test_case;
109
110    #[test_case(Some("aws:kms"), Some("some_key_alias"), Some("aws:kmr"), Some("some_key_alias"))]
111    #[test_case(Some("aws:kms"), Some("some_key_alias"), Some("aws:kms"), Some("some_key_ali`s"))]
112    #[test_case(Some("aws:kms"), Some("some_key_alias"), None, Some("some_key_alias"))]
113    #[test_case(Some("aws:kms"), Some("some_key_alias"), Some("aws:kms"), None)]
114    #[test_case(Some("aws:kms"), None, Some("aws:kmr"), None)]
115    #[test_case(None, None, Some("garbage"), None)]
116    fn test_sse_corrupted_on_into_inner(
117        sse_type: Option<&str>,
118        key_id: Option<&str>,
119        sse_type_corrupted: Option<&str>,
120        key_id_corrupted: Option<&str>,
121    ) {
122        let mut sse = ServerSideEncryption::new(sse_type.map(String::from), key_id.map(String::from));
123        sse.sse_type = sse_type_corrupted.map(String::from);
124        sse.sse_kms_key_id = key_id_corrupted.map(String::from);
125        sse.into_inner()
126            .expect_err("into_inner() should produce an error when values do no match the checksum");
127    }
128
129    #[test_case(Some("aws:kms"), Some("some_key_alias"))]
130    #[test_case(Some("aws:kms"), None)]
131    #[test_case(None, None)]
132    fn test_sse_into_inner_ok(sse_type: Option<&str>, key_id: Option<&str>) {
133        let sse = ServerSideEncryption::new(sse_type.map(String::from), key_id.map(String::from));
134        let (returned_sse_type, returned_key_id) = sse
135            .into_inner()
136            .expect("into_inner() should return values when they match the checksum");
137        assert_eq!(sse_type, returned_sse_type.as_deref());
138        assert_eq!(key_id, returned_key_id.as_deref());
139    }
140
141    #[test_case(Some("aws:kms"), Some("some_key_alias"), Some("aws:kmr"), Some("some_key_alias"))]
142    #[test_case(Some("aws:kms"), Some("some_key_alias"), Some("aws:kms"), Some("some_key_ali`s"))]
143    #[test_case(Some("aws:kms"), Some("some_key_alias"), None, Some("some_key_alias"))]
144    #[test_case(Some("aws:kms"), Some("some_key_alias"), Some("aws:kms"), None)]
145    fn test_sse_response_corrupted_on_verify_response(
146        sse_type: Option<&str>,
147        key_id: Option<&str>,
148        sse_type_corrupted: Option<&str>,
149        key_id_corrupted: Option<&str>,
150    ) {
151        let sse = ServerSideEncryption::new(sse_type.map(String::from), key_id.map(String::from));
152        sse.verify_response(sse_type_corrupted, key_id_corrupted)
153            .expect_err("verify_response() should produce an error when response values do no match the checksum");
154    }
155
156    #[test_case(Some("aws:kms"), Some("some_key_alias"), Some("aws:kms"), Some("some_key_alias"))]
157    #[test_case(Some("aws:kms"), None, Some("aws:kms"), Some("some_key_alias"))]
158    #[test_case(None, None, Some("aws:kms"), Some("some_key_alias"))]
159    fn test_sse_verify_response_ok(
160        sse_type: Option<&str>,
161        key_id: Option<&str>,
162        sse_type_response: Option<&str>,
163        key_id_response: Option<&str>,
164    ) {
165        let sse = ServerSideEncryption::new(sse_type.map(String::from), key_id.map(String::from));
166        sse.verify_response(sse_type_response, key_id_response)
167            .expect("verify_response() should return Ok(()) when values match the checksum")
168    }
169}