1use mountpoint_s3_client::checksums::crc32c::{self, Crc32c};
2use thiserror::Error;
3
4#[derive(Debug, Clone)]
6pub struct ServerSideEncryption {
7 sse_type: Option<String>,
8 sse_kms_key_id: Option<String>,
9 checksum: Crc32c,
10}
11
12#[derive(Debug, Error)]
13pub enum SseCorruptedError {
14 #[error("Checksum mismatch. expected: {0:?}, actual: {1:?}")]
15 ChecksumMismatch(Crc32c, Crc32c),
16 #[error("SSE type mismatch. expected: {0:?}, actual: {1:?}")]
17 TypeMismatch(String, Option<String>),
18 #[error("SSE KMS key ID mismatch. expected: {0:?}, actual: {1:?}")]
19 KeyMismatch(String, Option<String>),
20}
21
22impl Default for ServerSideEncryption {
23 fn default() -> Self {
24 Self {
25 sse_type: Default::default(),
26 sse_kms_key_id: Default::default(),
27 checksum: Crc32c::new(0),
28 }
29 }
30}
31
32impl ServerSideEncryption {
33 pub fn new(sse_type: Option<String>, sse_kms_key_id: Option<String>) -> Self {
35 let checksum = Self::compute_checksum(sse_type.as_deref(), sse_kms_key_id.as_deref());
36 Self {
37 sse_type,
38 sse_kms_key_id,
39 checksum,
40 }
41 }
42
43 fn compute_checksum(sse_type: Option<&str>, sse_kms_key_id: Option<&str>) -> Crc32c {
47 let mut hasher = crc32c::Hasher::new();
48 if let Some(maybe_sse_type) = sse_type {
49 hasher.update(maybe_sse_type.as_bytes());
50 }
51 if let Some(maybe_sse_kms_key_id) = sse_kms_key_id {
52 hasher.update(maybe_sse_kms_key_id.as_bytes());
53 }
54 hasher.finalize()
55 }
56
57 fn validate(&self) -> Result<(), SseCorruptedError> {
58 let computed = Self::compute_checksum(self.sse_type.as_deref(), self.sse_kms_key_id.as_deref());
59 if computed == self.checksum {
60 Ok(())
61 } else {
62 Err(SseCorruptedError::ChecksumMismatch(self.checksum, computed))
63 }
64 }
65
66 pub fn into_inner(self) -> Result<(Option<String>, Option<String>), SseCorruptedError> {
70 self.validate()?;
71 Ok((self.sse_type, self.sse_kms_key_id))
72 }
73
74 pub fn verify_response(
78 &self,
79 sse_type: Option<&str>,
80 sse_kms_key_id: Option<&str>,
81 ) -> Result<(), SseCorruptedError> {
82 self.validate()?; if self.sse_type.is_some() && self.sse_type.as_deref() != sse_type {
84 return Err(SseCorruptedError::TypeMismatch(
85 self.sse_type.as_ref().unwrap().clone(),
86 sse_type.map(str::to_string),
87 ));
88 }
89 if self.sse_kms_key_id.is_some() && self.sse_kms_key_id.as_deref() != sse_kms_key_id {
90 return Err(SseCorruptedError::KeyMismatch(
91 self.sse_kms_key_id.as_ref().unwrap().clone(),
92 sse_kms_key_id.map(str::to_string),
93 ));
94 }
95 Ok(())
96 }
97
98 #[cfg(test)]
99 pub fn corrupt_data(&mut self, sse_type: Option<String>, sse_kms_key_id: Option<String>) {
100 self.sse_type = sse_type;
101 self.sse_kms_key_id = sse_kms_key_id;
102 }
103}
104
105#[cfg(test)]
106mod tests {
107 use super::*;
108 use test_case::test_case;
109
110 #[test_case(Some("aws:kms"), Some("some_key_alias"), Some("aws:kmr"), Some("some_key_alias"))]
111 #[test_case(Some("aws:kms"), Some("some_key_alias"), Some("aws:kms"), Some("some_key_ali`s"))]
112 #[test_case(Some("aws:kms"), Some("some_key_alias"), None, Some("some_key_alias"))]
113 #[test_case(Some("aws:kms"), Some("some_key_alias"), Some("aws:kms"), None)]
114 #[test_case(Some("aws:kms"), None, Some("aws:kmr"), None)]
115 #[test_case(None, None, Some("garbage"), None)]
116 fn test_sse_corrupted_on_into_inner(
117 sse_type: Option<&str>,
118 key_id: Option<&str>,
119 sse_type_corrupted: Option<&str>,
120 key_id_corrupted: Option<&str>,
121 ) {
122 let mut sse = ServerSideEncryption::new(sse_type.map(String::from), key_id.map(String::from));
123 sse.sse_type = sse_type_corrupted.map(String::from);
124 sse.sse_kms_key_id = key_id_corrupted.map(String::from);
125 sse.into_inner()
126 .expect_err("into_inner() should produce an error when values do no match the checksum");
127 }
128
129 #[test_case(Some("aws:kms"), Some("some_key_alias"))]
130 #[test_case(Some("aws:kms"), None)]
131 #[test_case(None, None)]
132 fn test_sse_into_inner_ok(sse_type: Option<&str>, key_id: Option<&str>) {
133 let sse = ServerSideEncryption::new(sse_type.map(String::from), key_id.map(String::from));
134 let (returned_sse_type, returned_key_id) = sse
135 .into_inner()
136 .expect("into_inner() should return values when they match the checksum");
137 assert_eq!(sse_type, returned_sse_type.as_deref());
138 assert_eq!(key_id, returned_key_id.as_deref());
139 }
140
141 #[test_case(Some("aws:kms"), Some("some_key_alias"), Some("aws:kmr"), Some("some_key_alias"))]
142 #[test_case(Some("aws:kms"), Some("some_key_alias"), Some("aws:kms"), Some("some_key_ali`s"))]
143 #[test_case(Some("aws:kms"), Some("some_key_alias"), None, Some("some_key_alias"))]
144 #[test_case(Some("aws:kms"), Some("some_key_alias"), Some("aws:kms"), None)]
145 fn test_sse_response_corrupted_on_verify_response(
146 sse_type: Option<&str>,
147 key_id: Option<&str>,
148 sse_type_corrupted: Option<&str>,
149 key_id_corrupted: Option<&str>,
150 ) {
151 let sse = ServerSideEncryption::new(sse_type.map(String::from), key_id.map(String::from));
152 sse.verify_response(sse_type_corrupted, key_id_corrupted)
153 .expect_err("verify_response() should produce an error when response values do no match the checksum");
154 }
155
156 #[test_case(Some("aws:kms"), Some("some_key_alias"), Some("aws:kms"), Some("some_key_alias"))]
157 #[test_case(Some("aws:kms"), None, Some("aws:kms"), Some("some_key_alias"))]
158 #[test_case(None, None, Some("aws:kms"), Some("some_key_alias"))]
159 fn test_sse_verify_response_ok(
160 sse_type: Option<&str>,
161 key_id: Option<&str>,
162 sse_type_response: Option<&str>,
163 key_id_response: Option<&str>,
164 ) {
165 let sse = ServerSideEncryption::new(sse_type.map(String::from), key_id.map(String::from));
166 sse.verify_response(sse_type_response, key_id_response)
167 .expect("verify_response() should return Ok(()) when values match the checksum")
168 }
169}