1use super::{ContentDigestType, CONTENT_DIGEST_HEADER};
2use crate::error::{HyperDigestError, HyperDigestResult};
3use base64::{engine::general_purpose, Engine as _};
4use bytes::Bytes;
5use http::{Request, Response};
6use http_body::Body;
7use http_body_util::{combinators::BoxBody, BodyExt, Full};
8use sha2::Digest;
9use std::future::Future;
10use std::str::FromStr;
11use subtle::ConstantTimeEq;
12
13pub trait ContentDigest: http_body::Body {
17 fn into_bytes(self) -> impl Future<Output = Result<Bytes, Self::Error>> + Send
19 where
20 Self: Sized + Send,
21 Self::Data: Send,
22 {
23 async { Ok(self.collect().await?.to_bytes()) }
24 }
25
26 fn into_bytes_with_digest(
28 self,
29 cd_type: &ContentDigestType,
30 ) -> impl Future<Output = Result<(Bytes, String), Self::Error>> + Send
31 where
32 Self: Sized + Send,
33 Self::Data: Send,
34 {
35 async move {
36 let body_bytes = self.into_bytes().await?;
37 let digest = derive_digest(&body_bytes, cd_type);
38
39 Ok((body_bytes, general_purpose::STANDARD.encode(digest)))
40 }
41 }
42}
43
44fn derive_digest(body_bytes: &Bytes, cd_type: &ContentDigestType) -> Vec<u8> {
46 match cd_type {
47 ContentDigestType::Sha256 => {
48 let mut hasher = sha2::Sha256::new();
49 hasher.update(body_bytes);
50 hasher.finalize().to_vec()
51 }
52
53 ContentDigestType::Sha512 => {
54 let mut hasher = sha2::Sha512::new();
55 hasher.update(body_bytes);
56 hasher.finalize().to_vec()
57 }
58 }
59}
60
61impl<T: ?Sized> ContentDigest for T where T: http_body::Body {}
62
63pub trait RequestContentDigest {
66 type Error;
67 type PassthroughRequest;
68
69 fn set_content_digest(
71 self,
72 cd_type: &ContentDigestType,
73 ) -> impl Future<Output = Result<Self::PassthroughRequest, Self::Error>> + Send
74 where
75 Self: Sized;
76
77 fn verify_content_digest(self) -> impl Future<Output = Result<Self::PassthroughRequest, Self::Error>> + Send
79 where
80 Self: Sized;
81}
82
83pub trait ResponseContentDigest {
85 type Error;
86 type PassthroughResponse;
87
88 fn set_content_digest(
90 self,
91 cd_type: &ContentDigestType,
92 ) -> impl Future<Output = Result<Self::PassthroughResponse, Self::Error>> + Send
93 where
94 Self: Sized;
95
96 fn verify_content_digest(self) -> impl Future<Output = Result<Self::PassthroughResponse, Self::Error>> + Send
98 where
99 Self: Sized;
100}
101
102impl<B> RequestContentDigest for Request<B>
103where
104 B: Body + Send,
105 <B as Body>::Data: Send,
106{
107 type Error = HyperDigestError;
108 type PassthroughRequest = Request<BoxBody<Bytes, Self::Error>>;
109
110 async fn set_content_digest(self, cd_type: &ContentDigestType) -> HyperDigestResult<Self::PassthroughRequest>
112 where
113 Self: Sized,
114 {
115 let (mut parts, body) = self.into_parts();
116 let (body_bytes, digest) = body
117 .into_bytes_with_digest(cd_type)
118 .await
119 .map_err(|_e| HyperDigestError::HttpBodyError("Failed to generate digest".to_string()))?;
120 let new_body = Full::new(body_bytes).map_err(|never| match never {}).boxed();
121
122 parts
123 .headers
124 .insert(CONTENT_DIGEST_HEADER, format!("{cd_type}=:{digest}:").parse().unwrap());
125
126 let new_req = Request::from_parts(parts, new_body);
127 Ok(new_req)
128 }
129
130 async fn verify_content_digest(self) -> Result<Self::PassthroughRequest, Self::Error>
133 where
134 Self: Sized,
135 {
136 let header_map = self.headers();
137 let (cd_type, expected_digest) = extract_content_digest(header_map).await?;
138 let (header, body) = self.into_parts();
139 let body_bytes = body
140 .into_bytes()
141 .await
142 .map_err(|_e| HyperDigestError::HttpBodyError("Failed to get body bytes".to_string()))?;
143 let digest = derive_digest(&body_bytes, &cd_type);
144
145 if is_equal_digest(&digest, &expected_digest) {
147 let new_body = Full::new(body_bytes).map_err(|never| match never {}).boxed();
148 let res = Request::from_parts(header, new_body);
149 Ok(res)
150 } else {
151 Err(HyperDigestError::InvalidContentDigest(
152 "Content-Digest verification failed".to_string(),
153 ))
154 }
155 }
156}
157
158impl<B> ResponseContentDigest for Response<B>
159where
160 B: Body + Send,
161 <B as Body>::Data: Send,
162{
163 type Error = HyperDigestError;
164 type PassthroughResponse = Response<BoxBody<Bytes, Self::Error>>;
165
166 async fn set_content_digest(self, cd_type: &ContentDigestType) -> HyperDigestResult<Self::PassthroughResponse>
167 where
168 Self: Sized,
169 {
170 let (mut parts, body) = self.into_parts();
171 let (body_bytes, digest) = body
172 .into_bytes_with_digest(cd_type)
173 .await
174 .map_err(|_e| HyperDigestError::HttpBodyError("Failed to generate digest".to_string()))?;
175 let new_body = Full::new(body_bytes).map_err(|never| match never {}).boxed();
176
177 parts
178 .headers
179 .insert(CONTENT_DIGEST_HEADER, format!("{cd_type}=:{digest}:").parse().unwrap());
180
181 let new_req = Response::from_parts(parts, new_body);
182 Ok(new_req)
183 }
184 async fn verify_content_digest(self) -> HyperDigestResult<Self::PassthroughResponse>
185 where
186 Self: Sized,
187 {
188 let header_map = self.headers();
189 let (cd_type, expected_digest) = extract_content_digest(header_map).await?;
190 let (header, body) = self.into_parts();
191 let body_bytes = body
192 .into_bytes()
193 .await
194 .map_err(|_e| HyperDigestError::HttpBodyError("Failed to get body bytes".to_string()))?;
195 let digest = derive_digest(&body_bytes, &cd_type);
196
197 if is_equal_digest(&digest, &expected_digest) {
199 let new_body = Full::new(body_bytes).map_err(|never| match never {}).boxed();
200 let res = Response::from_parts(header, new_body);
201 Ok(res)
202 } else {
203 Err(HyperDigestError::InvalidContentDigest(
204 "Content-Digest verification failed".to_string(),
205 ))
206 }
207 }
208}
209
210fn is_equal_digest(digest1: &[u8], digest2: &[u8]) -> bool {
212 if digest1.len() != digest2.len() {
215 return false;
216 }
217 digest1.ct_eq(digest2).into()
218}
219
220async fn extract_content_digest(header_map: &http::HeaderMap) -> HyperDigestResult<(ContentDigestType, Vec<u8>)> {
221 let content_digest_header = header_map
222 .get(CONTENT_DIGEST_HEADER)
223 .ok_or(HyperDigestError::NoDigestHeader("No content-digest header".to_string()))?
224 .to_str()?;
225 let indexmap = sfv::Parser::new(content_digest_header)
226 .parse::<sfv::Dictionary>()
227 .map_err(|e| HyperDigestError::InvalidHeaderValue(e.to_string()))?;
228 if indexmap.len() != 1 {
229 return Err(HyperDigestError::InvalidHeaderValue(
230 "Content-Digest header should have only one value".to_string(),
231 ));
232 };
233 let (cd_type, cd) = indexmap.iter().next().unwrap();
234 let cd_type = ContentDigestType::from_str(cd_type.as_str())
235 .map_err(|e| HyperDigestError::InvalidHeaderValue(format!("Invalid Content-Digest type: {e}")))?;
236 if !matches!(
237 cd,
238 sfv::ListEntry::Item(sfv::Item {
239 bare_item: sfv::BareItem::ByteSequence(_),
240 ..
241 })
242 ) {
243 return Err(HyperDigestError::InvalidHeaderValue(
244 "Invalid Content-Digest value".to_string(),
245 ));
246 }
247
248 let cd = match cd {
249 sfv::ListEntry::Item(sfv::Item {
250 bare_item: sfv::BareItem::ByteSequence(cd),
251 ..
252 }) => cd,
253 _ => unreachable!(),
254 };
255 Ok((cd_type, cd.to_owned()))
256}
257
258#[cfg(test)]
260mod tests {
261 use super::*;
262
263 #[tokio::test]
264 async fn content_digest() {
265 let body = Full::new(&b"{\"hello\": \"world\"}"[..]);
266 let (_body_bytes, digest) = body.into_bytes_with_digest(&ContentDigestType::Sha256).await.unwrap();
267
268 assert_eq!(digest, "X48E9qOokqqrvdts8nOJRJN3OWDUoyWxBf7kbu9DBPE=");
269
270 let (_body_bytes, digest) = body.into_bytes_with_digest(&ContentDigestType::Sha512).await.unwrap();
271 assert_eq!(
272 digest,
273 "WZDPaVn/7XgHaAy8pmojAkGWoRx2UFChF41A2svX+TaPm+AbwAgBWnrIiYllu7BNNyealdVLvRwEmTHWXvJwew=="
274 );
275 }
276
277 #[tokio::test]
278 async fn hyper_request_test() {
279 let body = Full::new(&b"{\"hello\": \"world\"}"[..]);
280
281 let req = Request::builder()
282 .method("GET")
283 .uri("https://example.com/")
284 .header("date", "Sun, 09 May 2021 18:30:00 GMT")
285 .header("content-type", "application/json")
286 .body(body)
287 .unwrap();
288 let req = req.set_content_digest(&ContentDigestType::Sha256).await.unwrap();
289
290 assert!(req.headers().contains_key(CONTENT_DIGEST_HEADER));
291 let digest = req.headers().get(CONTENT_DIGEST_HEADER).unwrap().to_str().unwrap();
292 assert_eq!(digest, format!("sha-256=:X48E9qOokqqrvdts8nOJRJN3OWDUoyWxBf7kbu9DBPE=:"));
293
294 let verified = req.verify_content_digest().await;
295 assert!(verified.is_ok());
296 }
297
298 #[tokio::test]
299 async fn hyper_response_test() {
300 let body = Full::new(&b"{\"hello\": \"world\"}"[..]);
301
302 let res = Response::builder()
303 .status(200)
304 .header("date", "Sun, 09 May 2021 18:30:00 GMT")
305 .header("content-type", "application/json")
306 .body(body)
307 .unwrap();
308 let res = res.set_content_digest(&ContentDigestType::Sha256).await.unwrap();
309
310 assert!(res.headers().contains_key(CONTENT_DIGEST_HEADER));
311 let digest = res.headers().get(CONTENT_DIGEST_HEADER).unwrap().to_str().unwrap();
312 assert_eq!(digest, format!("sha-256=:X48E9qOokqqrvdts8nOJRJN3OWDUoyWxBf7kbu9DBPE=:"));
313
314 let verified = res.verify_content_digest().await;
315 assert!(verified.is_ok());
316 }
317
318 #[tokio::test]
319 async fn hyper_request_digest_mismatch_by_body_tamper_should_fail() {
320 let body = Full::new(&b"{\"hello\": \"world\"}"[..]);
322 let req = Request::builder()
323 .method("GET")
324 .uri("https://example.com/")
325 .header("date", "Sun, 09 May 2021 18:30:00 GMT")
326 .header("content-type", "application/json")
327 .body(body)
328 .unwrap();
329
330 let req = req.set_content_digest(&ContentDigestType::Sha256).await.unwrap();
331 assert!(req.headers().contains_key(CONTENT_DIGEST_HEADER));
332
333 let (parts, _old_body) = req.into_parts();
335 let tampered_body = Full::new(&b"{\"hello\": \"pwned\"}"[..]).boxed();
336 let tampered_req = Request::from_parts(parts, tampered_body);
337
338 let verified = tampered_req.verify_content_digest().await;
340 assert!(verified.is_err());
341 match verified.err().unwrap() {
342 HyperDigestError::InvalidContentDigest(_) => {}
343 e => panic!("unexpected error: {e:?}"),
344 }
345 }
346
347 #[tokio::test]
348 async fn hyper_response_digest_mismatch_by_header_tamper_should_fail() {
349 let body = Full::new(&b"{\"hello\": \"world\"}"[..]);
351 let res = Response::builder()
352 .status(200)
353 .header("date", "Sun, 09 May 2021 18:30:00 GMT")
354 .header("content-type", "application/json")
355 .body(body)
356 .unwrap();
357
358 let res = res.set_content_digest(&ContentDigestType::Sha256).await.unwrap();
359 let (mut parts, body) = res.into_parts();
360
361 parts.headers.insert(
365 CONTENT_DIGEST_HEADER,
366 "sha-256=:Y48E9qOokqqrvdts8nOJRJN3OWDUoyWxBf7kbu9DBPE=:".parse().unwrap(),
367 );
368
369 let tampered_res = Response::from_parts(parts, body);
370
371 let verified = tampered_res.verify_content_digest().await;
373 assert!(verified.is_err());
374 match verified.err().unwrap() {
375 HyperDigestError::InvalidContentDigest(_) => {}
376 e => panic!("unexpected error: {e:?}"),
377 }
378 }
379
380 #[tokio::test]
381 async fn hyper_request_missing_content_digest_header_should_fail() {
382 let body = Full::new(&b"{\"hello\": \"world\"}"[..]);
383 let req = Request::builder()
384 .method("GET")
385 .uri("https://example.com/")
386 .header("date", "Sun, 09 May 2021 18:30:00 GMT")
387 .header("content-type", "application/json")
388 .body(body)
389 .unwrap();
390
391 let verified = req.verify_content_digest().await;
393 assert!(verified.is_err());
394 match verified.err().unwrap() {
395 HyperDigestError::NoDigestHeader(_) => {}
396 e => panic!("unexpected error: {e:?}"),
397 }
398 }
399
400 #[tokio::test]
401 async fn hyper_request_digest_length_mismatch_should_fail() {
402 let body = Full::new(&b"{\"hello\": \"world\"}"[..]);
404 let req = Request::builder()
405 .method("GET")
406 .uri("https://example.com/")
407 .header("date", "Sun, 09 May 2021 18:30:00 GMT")
408 .header("content-type", "application/json")
409 .body(body)
410 .unwrap();
411
412 let req = req.set_content_digest(&ContentDigestType::Sha256).await.unwrap();
413
414 let (mut parts, body) = req.into_parts();
418
419 parts
420 .headers
421 .insert(CONTENT_DIGEST_HEADER, "sha-256=:AAAA=:".parse().unwrap());
422
423 let tampered_req = Request::from_parts(parts, body);
424
425 let verified = tampered_req.verify_content_digest().await;
427 assert!(verified.is_err());
428 }
429}