httpsig_hyper/
hyper_content_digest.rs

1use super::{ContentDigestType, CONTENT_DIGEST_HEADER};
2use crate::error::{HyperDigestError, HyperDigestResult};
3use base64::{engine::general_purpose, Engine as _};
4use bytes::Bytes;
5use http::{Request, Response};
6use http_body::Body;
7use http_body_util::{combinators::BoxBody, BodyExt, Full};
8use sha2::Digest;
9use std::future::Future;
10use std::str::FromStr;
11
12// hyper's http specific extension to generate and verify http signature
13
14/* --------------------------------------- */
15pub trait ContentDigest: http_body::Body {
16  /// Returns the bytes object of the body
17  fn into_bytes(self) -> impl Future<Output = Result<Bytes, Self::Error>> + Send
18  where
19    Self: Sized + Send,
20    Self::Data: Send,
21  {
22    async { Ok(self.collect().await?.to_bytes()) }
23  }
24
25  /// Returns the content digest in base64
26  fn into_bytes_with_digest(
27    self,
28    cd_type: &ContentDigestType,
29  ) -> impl Future<Output = Result<(Bytes, String), Self::Error>> + Send
30  where
31    Self: Sized + Send,
32    Self::Data: Send,
33  {
34    async move {
35      let body_bytes = self.into_bytes().await?;
36      let digest = derive_digest(&body_bytes, cd_type);
37
38      Ok((body_bytes, general_purpose::STANDARD.encode(digest)))
39    }
40  }
41}
42
43/// Returns the digest of the given body in Vec<u8>
44fn derive_digest(body_bytes: &Bytes, cd_type: &ContentDigestType) -> Vec<u8> {
45  match cd_type {
46    ContentDigestType::Sha256 => {
47      let mut hasher = sha2::Sha256::new();
48      hasher.update(body_bytes);
49      hasher.finalize().to_vec()
50    }
51
52    ContentDigestType::Sha512 => {
53      let mut hasher = sha2::Sha512::new();
54      hasher.update(body_bytes);
55      hasher.finalize().to_vec()
56    }
57  }
58}
59
60impl<T: ?Sized> ContentDigest for T where T: http_body::Body {}
61
62/* --------------------------------------- */
63/// A trait to set the http content digest in request in base64
64pub trait RequestContentDigest {
65  type Error;
66  type PassthroughRequest;
67
68  /// Set the content digest in the request
69  fn set_content_digest(
70    self,
71    cd_type: &ContentDigestType,
72  ) -> impl Future<Output = Result<Self::PassthroughRequest, Self::Error>> + Send
73  where
74    Self: Sized;
75
76  /// Verify the content digest in the request and returns self if it's valid otherwise returns error
77  fn verify_content_digest(self) -> impl Future<Output = Result<Self::PassthroughRequest, Self::Error>> + Send
78  where
79    Self: Sized;
80}
81
82/// A trait to set the http content digest in response in base64
83pub trait ResponseContentDigest {
84  type Error;
85  type PassthroughResponse;
86
87  /// Set the content digest in the response
88  fn set_content_digest(
89    self,
90    cd_type: &ContentDigestType,
91  ) -> impl Future<Output = Result<Self::PassthroughResponse, Self::Error>> + Send
92  where
93    Self: Sized;
94
95  /// Verify the content digest in the response and returns self if it's valid otherwise returns error
96  fn verify_content_digest(self) -> impl Future<Output = Result<Self::PassthroughResponse, Self::Error>> + Send
97  where
98    Self: Sized;
99}
100
101impl<B> RequestContentDigest for Request<B>
102where
103  B: Body + Send,
104  <B as Body>::Data: Send,
105{
106  type Error = HyperDigestError;
107  type PassthroughRequest = Request<BoxBody<Bytes, Self::Error>>;
108
109  /// Set the content digest in the request
110  async fn set_content_digest(self, cd_type: &ContentDigestType) -> HyperDigestResult<Self::PassthroughRequest>
111  where
112    Self: Sized,
113  {
114    let (mut parts, body) = self.into_parts();
115    let (body_bytes, digest) = body
116      .into_bytes_with_digest(cd_type)
117      .await
118      .map_err(|_e| HyperDigestError::HttpBodyError("Failed to generate digest".to_string()))?;
119    let new_body = Full::new(body_bytes).map_err(|never| match never {}).boxed();
120
121    parts
122      .headers
123      .insert(CONTENT_DIGEST_HEADER, format!("{cd_type}=:{digest}:").parse().unwrap());
124
125    let new_req = Request::from_parts(parts, new_body);
126    Ok(new_req)
127  }
128
129  /// Verifies the consistency between self and given content-digest in &[u8]
130  /// Returns self in Bytes if it's valid otherwise returns error
131  async fn verify_content_digest(self) -> Result<Self::PassthroughRequest, Self::Error>
132  where
133    Self: Sized,
134  {
135    let header_map = self.headers();
136    let (cd_type, _expected_digest) = extract_content_digest(header_map).await?;
137    let (header, body) = self.into_parts();
138    let body_bytes = body
139      .into_bytes()
140      .await
141      .map_err(|_e| HyperDigestError::HttpBodyError("Failed to get body bytes".to_string()))?;
142    let digest = derive_digest(&body_bytes, &cd_type);
143
144    if matches!(digest, _expected_digest) {
145      let new_body = Full::new(body_bytes).map_err(|never| match never {}).boxed();
146      let res = Request::from_parts(header, new_body);
147      Ok(res)
148    } else {
149      Err(HyperDigestError::InvalidContentDigest(
150        "Content-Digest verification failed".to_string(),
151      ))
152    }
153  }
154}
155
156impl<B> ResponseContentDigest for Response<B>
157where
158  B: Body + Send,
159  <B as Body>::Data: Send,
160{
161  type Error = HyperDigestError;
162  type PassthroughResponse = Response<BoxBody<Bytes, Self::Error>>;
163
164  async fn set_content_digest(self, cd_type: &ContentDigestType) -> HyperDigestResult<Self::PassthroughResponse>
165  where
166    Self: Sized,
167  {
168    let (mut parts, body) = self.into_parts();
169    let (body_bytes, digest) = body
170      .into_bytes_with_digest(cd_type)
171      .await
172      .map_err(|_e| HyperDigestError::HttpBodyError("Failed to generate digest".to_string()))?;
173    let new_body = Full::new(body_bytes).map_err(|never| match never {}).boxed();
174
175    parts
176      .headers
177      .insert(CONTENT_DIGEST_HEADER, format!("{cd_type}=:{digest}:").parse().unwrap());
178
179    let new_req = Response::from_parts(parts, new_body);
180    Ok(new_req)
181  }
182  async fn verify_content_digest(self) -> HyperDigestResult<Self::PassthroughResponse>
183  where
184    Self: Sized,
185  {
186    let header_map = self.headers();
187    let (cd_type, _expected_digest) = extract_content_digest(header_map).await?;
188    let (header, body) = self.into_parts();
189    let body_bytes = body
190      .into_bytes()
191      .await
192      .map_err(|_e| HyperDigestError::HttpBodyError("Failed to get body bytes".to_string()))?;
193    let digest = derive_digest(&body_bytes, &cd_type);
194
195    if matches!(digest, _expected_digest) {
196      let new_body = Full::new(body_bytes).map_err(|never| match never {}).boxed();
197      let res = Response::from_parts(header, new_body);
198      Ok(res)
199    } else {
200      Err(HyperDigestError::InvalidContentDigest(
201        "Content-Digest verification failed".to_string(),
202      ))
203    }
204  }
205}
206
207async fn extract_content_digest(header_map: &http::HeaderMap) -> HyperDigestResult<(ContentDigestType, Vec<u8>)> {
208  let content_digest_header = header_map
209    .get(CONTENT_DIGEST_HEADER)
210    .ok_or(HyperDigestError::NoDigestHeader("No content-digest header".to_string()))?
211    .to_str()?;
212  let indexmap = sfv::Parser::new(content_digest_header)
213    .parse::<sfv::Dictionary>()
214    .map_err(|e| HyperDigestError::InvalidHeaderValue(e.to_string()))?;
215  if indexmap.len() != 1 {
216    return Err(HyperDigestError::InvalidHeaderValue(
217      "Content-Digest header should have only one value".to_string(),
218    ));
219  };
220  let (cd_type, cd) = indexmap.iter().next().unwrap();
221  let cd_type = ContentDigestType::from_str(cd_type.as_str())
222    .map_err(|e| HyperDigestError::InvalidHeaderValue(format!("Invalid Content-Digest type: {e}")))?;
223  if !matches!(
224    cd,
225    sfv::ListEntry::Item(sfv::Item {
226      bare_item: sfv::BareItem::ByteSequence(_),
227      ..
228    })
229  ) {
230    return Err(HyperDigestError::InvalidHeaderValue(
231      "Invalid Content-Digest value".to_string(),
232    ));
233  }
234
235  let cd = match cd {
236    sfv::ListEntry::Item(sfv::Item {
237      bare_item: sfv::BareItem::ByteSequence(cd),
238      ..
239    }) => cd,
240    _ => unreachable!(),
241  };
242  Ok((cd_type, cd.to_owned()))
243}
244
245/* --------------------------------------- */
246#[cfg(test)]
247mod tests {
248  use super::*;
249
250  #[tokio::test]
251  async fn content_digest() {
252    let body = Full::new(&b"{\"hello\": \"world\"}"[..]);
253    let (_body_bytes, digest) = body.into_bytes_with_digest(&ContentDigestType::Sha256).await.unwrap();
254
255    assert_eq!(digest, "X48E9qOokqqrvdts8nOJRJN3OWDUoyWxBf7kbu9DBPE=");
256
257    let (_body_bytes, digest) = body.into_bytes_with_digest(&ContentDigestType::Sha512).await.unwrap();
258    assert_eq!(
259      digest,
260      "WZDPaVn/7XgHaAy8pmojAkGWoRx2UFChF41A2svX+TaPm+AbwAgBWnrIiYllu7BNNyealdVLvRwEmTHWXvJwew=="
261    );
262  }
263
264  #[tokio::test]
265  async fn hyper_request_test() {
266    let body = Full::new(&b"{\"hello\": \"world\"}"[..]);
267
268    let req = Request::builder()
269      .method("GET")
270      .uri("https://example.com/")
271      .header("date", "Sun, 09 May 2021 18:30:00 GMT")
272      .header("content-type", "application/json")
273      .body(body)
274      .unwrap();
275    let req = req.set_content_digest(&ContentDigestType::Sha256).await.unwrap();
276
277    assert!(req.headers().contains_key(CONTENT_DIGEST_HEADER));
278    let digest = req.headers().get(CONTENT_DIGEST_HEADER).unwrap().to_str().unwrap();
279    assert_eq!(digest, format!("sha-256=:X48E9qOokqqrvdts8nOJRJN3OWDUoyWxBf7kbu9DBPE=:"));
280
281    let verified = req.verify_content_digest().await;
282    assert!(verified.is_ok());
283  }
284
285  #[tokio::test]
286  async fn hyper_response_test() {
287    let body = Full::new(&b"{\"hello\": \"world\"}"[..]);
288
289    let res = Response::builder()
290      .status(200)
291      .header("date", "Sun, 09 May 2021 18:30:00 GMT")
292      .header("content-type", "application/json")
293      .body(body)
294      .unwrap();
295    let res = res.set_content_digest(&ContentDigestType::Sha256).await.unwrap();
296
297    assert!(res.headers().contains_key(CONTENT_DIGEST_HEADER));
298    let digest = res.headers().get(CONTENT_DIGEST_HEADER).unwrap().to_str().unwrap();
299    assert_eq!(digest, format!("sha-256=:X48E9qOokqqrvdts8nOJRJN3OWDUoyWxBf7kbu9DBPE=:"));
300
301    let verified = res.verify_content_digest().await;
302    assert!(verified.is_ok());
303  }
304}