actix_hash/body_hash.rs
1use std::{
2 future::Future,
3 mem,
4 pin::Pin,
5 task::{ready, Context, Poll},
6};
7
8use actix_web::{dev, FromRequest, HttpRequest};
9use actix_web_lab::util::fork_request_payload;
10use digest::{generic_array::GenericArray, Digest};
11use futures_core::Stream as _;
12use pin_project_lite::pin_project;
13use tracing::trace;
14
15/// Parts of the resulting body hash extractor.
16pub struct BodyHashParts<T> {
17 /// Extracted item.
18 pub inner: T,
19
20 /// Bytes of the calculated hash.
21 pub hash_bytes: Vec<u8>,
22}
23
24/// Wraps an extractor and calculates a body checksum hash alongside.
25///
26/// If your extractor would usually be `T` and you want to create a hash of type `D` then you need
27/// to use `BodyHash<T, D>`. E.g., `BodyHash<String, Sha256>`.
28///
29/// Any hasher that implements [`Digest`] can be used. Type aliases for common hashing algorithms
30/// are available at the crate root.
31///
32/// # Errors
33/// This extractor produces no errors of its own and all errors from the underlying extractor are
34/// propagated correctly; for example, if the payload limits are exceeded.
35///
36/// # When Used On The Wrong Extractor
37/// Use on a non-body extractor is tolerated unless it is used after a different extractor that
38/// _takes_ the payload. In this case, the resulting hash will be as if an empty input was given to
39/// the hasher.
40///
41/// # Example
42/// ```
43/// use actix_web::{Responder, web};
44/// use actix_hash::BodyHash;
45/// use sha2::Sha256;
46///
47/// # type T = u64;
48/// async fn hash_payload(form: BodyHash<web::Json<T>, Sha256>) -> impl Responder {
49/// if !form.verify_slice(b"correct-signature") {
50/// // return unauthorized error
51/// }
52///
53/// "Ok"
54/// }
55/// ```
56#[derive(Debug, Clone)]
57pub struct BodyHash<T, D: Digest> {
58 inner: T,
59 hash: GenericArray<u8, D::OutputSize>,
60}
61
62impl<T, D: Digest> BodyHash<T, D> {
63 /// Returns hash slice.
64 pub fn hash(&self) -> &[u8] {
65 self.hash.as_slice()
66 }
67
68 /// Returns hash output size.
69 pub fn hash_size(&self) -> usize {
70 self.hash.len()
71 }
72
73 /// Verifies HMAC hash against provided `tag` using constant-time equality.
74 pub fn verify_slice(&self, tag: &[u8]) -> bool {
75 use subtle::ConstantTimeEq as _;
76 self.hash.ct_eq(tag).into()
77 }
78
79 /// Returns body type parts, including extracted body type, raw body bytes, and hash bytes.
80 pub fn into_parts(self) -> BodyHashParts<T> {
81 let hash = self.hash().to_vec();
82
83 BodyHashParts {
84 inner: self.inner,
85 hash_bytes: hash,
86 }
87 }
88}
89
90impl<T, D> FromRequest for BodyHash<T, D>
91where
92 T: FromRequest + 'static,
93 D: Digest + 'static,
94{
95 type Error = T::Error;
96 type Future = BodyHashFut<T, D>;
97
98 fn from_request(req: &HttpRequest, payload: &mut dev::Payload) -> Self::Future {
99 if matches!(payload, dev::Payload::None) {
100 trace!("inner request payload is none");
101 BodyHashFut::PayloadNone {
102 inner_fut: T::from_request(req, payload),
103 hash: D::new().finalize(),
104 }
105 } else {
106 trace!("forking request payload");
107 let forked_payload = fork_request_payload(payload);
108
109 let inner_fut = T::from_request(req, payload);
110 let hasher = D::new();
111
112 BodyHashFut::Inner {
113 inner_fut,
114 hasher,
115 forked_payload,
116 }
117 }
118 }
119}
120
121pin_project! {
122 #[project = BodyHashFutProj]
123 pub enum BodyHashFut<T: FromRequest, D: Digest> {
124 PayloadNone {
125 #[pin]
126 inner_fut: T::Future,
127 hash: GenericArray<u8, D::OutputSize>,
128 },
129
130 Inner {
131 #[pin]
132 inner_fut: T::Future,
133 hasher: D,
134 forked_payload: dev::Payload,
135 },
136
137 InnerDone {
138 inner: Option<T>,
139 hasher: D,
140 forked_payload: dev::Payload,
141 }
142 }
143}
144
145impl<T: FromRequest, D: Digest> Future for BodyHashFut<T, D> {
146 type Output = Result<BodyHash<T, D>, T::Error>;
147
148 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
149 match self.as_mut().project() {
150 BodyHashFutProj::PayloadNone { inner_fut, hash } => {
151 let inner = ready!(inner_fut.poll(cx))?;
152 Poll::Ready(Ok(BodyHash {
153 inner,
154 hash: mem::take(hash),
155 }))
156 }
157
158 BodyHashFutProj::Inner {
159 inner_fut,
160 hasher,
161 mut forked_payload,
162 } => {
163 // poll original extractor
164 match inner_fut.poll(cx)? {
165 Poll::Ready(inner) => {
166 trace!("inner extractor complete");
167
168 let next = BodyHashFut::InnerDone {
169 inner: Some(inner),
170 hasher: mem::replace(hasher, D::new()),
171 forked_payload: mem::replace(forked_payload, dev::Payload::None),
172 };
173 self.set(next);
174
175 // re-enter poll in done state
176 self.poll(cx)
177 }
178 Poll::Pending => {
179 // drain forked payload
180 loop {
181 match Pin::new(&mut forked_payload).poll_next(cx) {
182 // update hasher with chunks
183 Poll::Ready(Some(Ok(chunk))) => hasher.update(&chunk),
184
185 Poll::Ready(None) => {
186 unreachable!(
187 "not possible to poll end of payload before inner stream \
188 completes"
189 )
190 }
191
192 // Ignore Pending because its possible the inner extractor never
193 // polls the payload stream and ignore errors because they will be
194 // propagated by original payload polls.
195 Poll::Ready(Some(Err(_))) | Poll::Pending => break,
196 }
197 }
198
199 Poll::Pending
200 }
201 }
202 }
203
204 BodyHashFutProj::InnerDone {
205 inner,
206 hasher,
207 forked_payload,
208 } => {
209 let mut pl = Pin::new(forked_payload);
210
211 // drain forked payload
212 loop {
213 match pl.as_mut().poll_next(cx) {
214 // update hasher with chunks
215 Poll::Ready(Some(Ok(chunk))) => hasher.update(&chunk),
216
217 // when drain is complete, finalize hash and return parts
218 Poll::Ready(None) => {
219 trace!("payload hashing complete");
220
221 let hasher = mem::replace(hasher, D::new());
222 let hash = hasher.finalize();
223
224 return Poll::Ready(Ok(BodyHash {
225 inner: inner.take().unwrap(),
226 hash,
227 }));
228 }
229
230 // Ignore Pending because its possible the inner extractor never polls the
231 // payload stream and ignore errors because they will be propagated by
232 // original payload polls
233 Poll::Ready(Some(Err(_))) | Poll::Pending => return Poll::Pending,
234 }
235 }
236 }
237 }
238 }
239}