httpsig_hyper/
hyper_content_digest.rs

1use super::{ContentDigestType, CONTENT_DIGEST_HEADER};
2use crate::error::{HyperDigestError, HyperDigestResult};
3use base64::{engine::general_purpose, Engine as _};
4use bytes::Bytes;
5use http::{Request, Response};
6use http_body::Body;
7use http_body_util::{combinators::BoxBody, BodyExt, Full};
8use sfv::FromStr;
9use sha2::Digest;
10use std::future::Future;
11
12// hyper's http specific extension to generate and verify http signature
13
14/* --------------------------------------- */
15pub trait ContentDigest: http_body::Body {
16  /// Returns the bytes object of the body
17  fn into_bytes(self) -> impl Future<Output = Result<Bytes, Self::Error>> + Send
18  where
19    Self: Sized + Send,
20    Self::Data: Send,
21  {
22    async { Ok(self.collect().await?.to_bytes()) }
23  }
24
25  /// Returns the content digest in base64
26  fn into_bytes_with_digest(
27    self,
28    cd_type: &ContentDigestType,
29  ) -> impl Future<Output = Result<(Bytes, String), Self::Error>> + Send
30  where
31    Self: Sized + Send,
32    Self::Data: Send,
33  {
34    async move {
35      let body_bytes = self.into_bytes().await?;
36      let digest = derive_digest(&body_bytes, cd_type);
37
38      Ok((body_bytes, general_purpose::STANDARD.encode(digest)))
39    }
40  }
41}
42
43/// Returns the digest of the given body in Vec<u8>
44fn derive_digest(body_bytes: &Bytes, cd_type: &ContentDigestType) -> Vec<u8> {
45  match cd_type {
46    ContentDigestType::Sha256 => {
47      let mut hasher = sha2::Sha256::new();
48      hasher.update(body_bytes);
49      hasher.finalize().to_vec()
50    }
51
52    ContentDigestType::Sha512 => {
53      let mut hasher = sha2::Sha512::new();
54      hasher.update(body_bytes);
55      hasher.finalize().to_vec()
56    }
57  }
58}
59
60impl<T: ?Sized> ContentDigest for T where T: http_body::Body {}
61
62/* --------------------------------------- */
63/// A trait to set the http content digest in request in base64
64pub trait RequestContentDigest {
65  type Error;
66  type PassthroughRequest;
67
68  /// Set the content digest in the request
69  fn set_content_digest(
70    self,
71    cd_type: &ContentDigestType,
72  ) -> impl Future<Output = Result<Self::PassthroughRequest, Self::Error>> + Send
73  where
74    Self: Sized;
75
76  /// Verify the content digest in the request and returns self if it's valid otherwise returns error
77  fn verify_content_digest(self) -> impl Future<Output = Result<Self::PassthroughRequest, Self::Error>> + Send
78  where
79    Self: Sized;
80}
81
82/// A trait to set the http content digest in response in base64
83pub trait ResponseContentDigest {
84  type Error;
85  type PassthroughResponse;
86
87  /// Set the content digest in the response
88  fn set_content_digest(
89    self,
90    cd_type: &ContentDigestType,
91  ) -> impl Future<Output = Result<Self::PassthroughResponse, Self::Error>> + Send
92  where
93    Self: Sized;
94
95  /// Verify the content digest in the response and returns self if it's valid otherwise returns error
96  fn verify_content_digest(self) -> impl Future<Output = Result<Self::PassthroughResponse, Self::Error>> + Send
97  where
98    Self: Sized;
99}
100
101impl<B> RequestContentDigest for Request<B>
102where
103  B: Body + Send,
104  <B as Body>::Data: Send,
105{
106  type Error = HyperDigestError;
107  type PassthroughRequest = Request<BoxBody<Bytes, Self::Error>>;
108
109  /// Set the content digest in the request
110  async fn set_content_digest(self, cd_type: &ContentDigestType) -> HyperDigestResult<Self::PassthroughRequest>
111  where
112    Self: Sized,
113  {
114    let (mut parts, body) = self.into_parts();
115    let (body_bytes, digest) = body
116      .into_bytes_with_digest(cd_type)
117      .await
118      .map_err(|_e| HyperDigestError::HttpBodyError("Failed to generate digest".to_string()))?;
119    let new_body = Full::new(body_bytes).map_err(|never| match never {}).boxed();
120
121    parts
122      .headers
123      .insert(CONTENT_DIGEST_HEADER, format!("{cd_type}=:{digest}:").parse().unwrap());
124
125    let new_req = Request::from_parts(parts, new_body);
126    Ok(new_req)
127  }
128
129  /// Verifies the consistency between self and given content-digest in &[u8]
130  /// Returns self in Bytes if it's valid otherwise returns error
131  async fn verify_content_digest(self) -> Result<Self::PassthroughRequest, Self::Error>
132  where
133    Self: Sized,
134  {
135    let header_map = self.headers();
136    let (cd_type, _expected_digest) = extract_content_digest(header_map).await?;
137    let (header, body) = self.into_parts();
138    let body_bytes = body
139      .into_bytes()
140      .await
141      .map_err(|_e| HyperDigestError::HttpBodyError("Failed to get body bytes".to_string()))?;
142    let digest = derive_digest(&body_bytes, &cd_type);
143
144    if matches!(digest, _expected_digest) {
145      let new_body = Full::new(body_bytes).map_err(|never| match never {}).boxed();
146      let res = Request::from_parts(header, new_body);
147      Ok(res)
148    } else {
149      Err(HyperDigestError::InvalidContentDigest(
150        "Content-Digest verification failed".to_string(),
151      ))
152    }
153  }
154}
155
156impl<B> ResponseContentDigest for Response<B>
157where
158  B: Body + Send,
159  <B as Body>::Data: Send,
160{
161  type Error = HyperDigestError;
162  type PassthroughResponse = Response<BoxBody<Bytes, Self::Error>>;
163
164  async fn set_content_digest(self, cd_type: &ContentDigestType) -> HyperDigestResult<Self::PassthroughResponse>
165  where
166    Self: Sized,
167  {
168    let (mut parts, body) = self.into_parts();
169    let (body_bytes, digest) = body
170      .into_bytes_with_digest(cd_type)
171      .await
172      .map_err(|_e| HyperDigestError::HttpBodyError("Failed to generate digest".to_string()))?;
173    let new_body = Full::new(body_bytes).map_err(|never| match never {}).boxed();
174
175    parts
176      .headers
177      .insert(CONTENT_DIGEST_HEADER, format!("{cd_type}=:{digest}:").parse().unwrap());
178
179    let new_req = Response::from_parts(parts, new_body);
180    Ok(new_req)
181  }
182  async fn verify_content_digest(self) -> HyperDigestResult<Self::PassthroughResponse>
183  where
184    Self: Sized,
185  {
186    let header_map = self.headers();
187    let (cd_type, _expected_digest) = extract_content_digest(header_map).await?;
188    let (header, body) = self.into_parts();
189    let body_bytes = body
190      .into_bytes()
191      .await
192      .map_err(|_e| HyperDigestError::HttpBodyError("Failed to get body bytes".to_string()))?;
193    let digest = derive_digest(&body_bytes, &cd_type);
194
195    if matches!(digest, _expected_digest) {
196      let new_body = Full::new(body_bytes).map_err(|never| match never {}).boxed();
197      let res = Response::from_parts(header, new_body);
198      Ok(res)
199    } else {
200      Err(HyperDigestError::InvalidContentDigest(
201        "Content-Digest verification failed".to_string(),
202      ))
203    }
204  }
205}
206
207async fn extract_content_digest(header_map: &http::HeaderMap) -> HyperDigestResult<(ContentDigestType, Vec<u8>)> {
208  let content_digest_header = header_map
209    .get(CONTENT_DIGEST_HEADER)
210    .ok_or(HyperDigestError::NoDigestHeader("No content-digest header".to_string()))?
211    .to_str()?;
212  let indexmap = sfv::Parser::parse_dictionary(content_digest_header.as_bytes())
213    .map_err(|e| HyperDigestError::InvalidHeaderValue(e.to_string()))?;
214  if indexmap.len() != 1 {
215    return Err(HyperDigestError::InvalidHeaderValue(
216      "Content-Digest header should have only one value".to_string(),
217    ));
218  };
219  let (cd_type, cd) = indexmap.iter().next().unwrap();
220  let cd_type = ContentDigestType::from_str(cd_type)
221    .map_err(|e| HyperDigestError::InvalidHeaderValue(format!("Invalid Content-Digest type: {e}")))?;
222  if !matches!(
223    cd,
224    sfv::ListEntry::Item(sfv::Item {
225      bare_item: sfv::BareItem::ByteSeq(_),
226      ..
227    })
228  ) {
229    return Err(HyperDigestError::InvalidHeaderValue(
230      "Invalid Content-Digest value".to_string(),
231    ));
232  }
233
234  let cd = match cd {
235    sfv::ListEntry::Item(sfv::Item {
236      bare_item: sfv::BareItem::ByteSeq(cd),
237      ..
238    }) => cd,
239    _ => unreachable!(),
240  };
241  Ok((cd_type, cd.to_owned()))
242}
243
244/* --------------------------------------- */
245#[cfg(test)]
246mod tests {
247  use super::*;
248
249  #[tokio::test]
250  async fn content_digest() {
251    let body = Full::new(&b"{\"hello\": \"world\"}"[..]);
252    let (_body_bytes, digest) = body.into_bytes_with_digest(&ContentDigestType::Sha256).await.unwrap();
253
254    assert_eq!(digest, "X48E9qOokqqrvdts8nOJRJN3OWDUoyWxBf7kbu9DBPE=");
255
256    let (_body_bytes, digest) = body.into_bytes_with_digest(&ContentDigestType::Sha512).await.unwrap();
257    assert_eq!(
258      digest,
259      "WZDPaVn/7XgHaAy8pmojAkGWoRx2UFChF41A2svX+TaPm+AbwAgBWnrIiYllu7BNNyealdVLvRwEmTHWXvJwew=="
260    );
261  }
262
263  #[tokio::test]
264  async fn hyper_request_test() {
265    let body = Full::new(&b"{\"hello\": \"world\"}"[..]);
266
267    let req = Request::builder()
268      .method("GET")
269      .uri("https://example.com/")
270      .header("date", "Sun, 09 May 2021 18:30:00 GMT")
271      .header("content-type", "application/json")
272      .body(body)
273      .unwrap();
274    let req = req.set_content_digest(&ContentDigestType::Sha256).await.unwrap();
275
276    assert!(req.headers().contains_key(CONTENT_DIGEST_HEADER));
277    let digest = req.headers().get(CONTENT_DIGEST_HEADER).unwrap().to_str().unwrap();
278    assert_eq!(digest, format!("sha-256=:X48E9qOokqqrvdts8nOJRJN3OWDUoyWxBf7kbu9DBPE=:"));
279
280    let verified = req.verify_content_digest().await;
281    assert!(verified.is_ok());
282  }
283
284  #[tokio::test]
285  async fn hyper_response_test() {
286    let body = Full::new(&b"{\"hello\": \"world\"}"[..]);
287
288    let res = Response::builder()
289      .status(200)
290      .header("date", "Sun, 09 May 2021 18:30:00 GMT")
291      .header("content-type", "application/json")
292      .body(body)
293      .unwrap();
294    let res = res.set_content_digest(&ContentDigestType::Sha256).await.unwrap();
295
296    assert!(res.headers().contains_key(CONTENT_DIGEST_HEADER));
297    let digest = res.headers().get(CONTENT_DIGEST_HEADER).unwrap().to_str().unwrap();
298    assert_eq!(digest, format!("sha-256=:X48E9qOokqqrvdts8nOJRJN3OWDUoyWxBf7kbu9DBPE=:"));
299
300    let verified = res.verify_content_digest().await;
301    assert!(verified.is_ok());
302  }
303}