http_signature_normalization_actix/digest/
middleware.rs1use crate::{Canceled, DefaultSpawner, Spawn};
4
5use super::{DigestPart, DigestVerify};
6use actix_web::{
7 body::MessageBody,
8 dev::{Payload, Service, ServiceRequest, ServiceResponse, Transform},
9 error::PayloadError,
10 http::{header::HeaderValue, StatusCode},
11 web, FromRequest, HttpMessage, HttpRequest, HttpResponse, ResponseError,
12};
13use futures_core::{future::LocalBoxFuture, Stream};
14use std::{
15 future::{ready, Ready},
16 pin::Pin,
17 task::{Context, Poll},
18};
19use streem::{from_fn::Yielder, IntoStreamer};
20use tokio::sync::{mpsc, oneshot};
21use tracing::{debug, Span};
22use tracing_error::SpanTrace;
23
24#[derive(Copy, Clone, Debug)]
25pub struct DigestVerified;
30
31#[derive(Clone, Debug)]
32pub struct VerifyDigest<T, Spawner = DefaultSpawner>(Spawner, bool, T);
46
47#[doc(hidden)]
48pub struct VerifyMiddleware<T, Spawner, S>(S, Spawner, bool, T);
49
50#[derive(Debug, thiserror::Error)]
51#[error("Error verifying digest")]
52#[doc(hidden)]
53pub struct VerifyError {
54 context: String,
55 kind: VerifyErrorKind,
56}
57
58impl VerifyError {
59 fn new(span: &Span, kind: VerifyErrorKind) -> Self {
60 span.in_scope(|| VerifyError {
61 context: SpanTrace::capture().to_string(),
62 kind,
63 })
64 }
65}
66
67#[derive(Debug, thiserror::Error)]
68enum VerifyErrorKind {
69 #[error("Missing request extension")]
70 Extension,
71
72 #[error("Digest header missing")]
73 MissingDigest,
74
75 #[error("Digest header is empty")]
76 Empty,
77
78 #[error("Failed to verify digest")]
79 Verify,
80
81 #[error("Payload dropped. If this was unexpected, it could be that the payload isn't required in the route this middleware is guarding")]
82 Dropped,
83}
84
85struct RxStream<T>(mpsc::Receiver<T>);
86
87impl<T> Stream for RxStream<T> {
88 type Item = Result<T, PayloadError>;
89
90 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
91 Pin::new(&mut self.0).poll_recv(cx).map(|opt| opt.map(Ok))
92 }
93}
94
95impl<T> VerifyDigest<T>
96where
97 T: DigestVerify + Clone,
98{
99 pub fn new(verify_digest: T) -> Self {
101 VerifyDigest(DefaultSpawner, true, verify_digest)
102 }
103}
104
105impl<T, Spawner> VerifyDigest<T, Spawner>
106where
107 T: DigestVerify + Clone,
108{
109 pub fn spawner<NewSpawner>(self, spawner: NewSpawner) -> VerifyDigest<T, NewSpawner>
113 where
114 NewSpawner: Spawn,
115 {
116 VerifyDigest(spawner, self.1, self.2)
117 }
118
119 pub fn optional(self) -> Self {
124 VerifyDigest(self.0, false, self.2)
125 }
126}
127
128struct VerifiedReceiver {
129 rx: Option<oneshot::Receiver<()>>,
130}
131
132impl FromRequest for DigestVerified {
133 type Error = VerifyError;
134 type Future = LocalBoxFuture<'static, Result<Self, Self::Error>>;
135
136 fn from_request(req: &HttpRequest, _: &mut Payload) -> Self::Future {
137 let res = req
138 .extensions_mut()
139 .get_mut::<VerifiedReceiver>()
140 .and_then(|r| r.rx.take())
141 .ok_or_else(|| VerifyError::new(&Span::current(), VerifyErrorKind::Extension));
142
143 if res.is_err() {
144 debug!("Failed to fetch DigestVerified from request");
145 }
146
147 Box::pin(async move {
148 res?.await
149 .map_err(|_| VerifyError::new(&Span::current(), VerifyErrorKind::Dropped))
150 .map(|()| DigestVerified)
151 })
152 }
153}
154
155impl<T, Spawner, S, B> Transform<S, ServiceRequest> for VerifyDigest<T, Spawner>
156where
157 T: DigestVerify + Clone + Send + 'static,
158 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = actix_web::Error> + 'static,
159 S::Error: 'static,
160 B: MessageBody + 'static,
161 Spawner: Spawn + Clone + 'static,
162{
163 type Response = ServiceResponse<B>;
164 type Error = actix_web::Error;
165 type Transform = VerifyMiddleware<T, Spawner, S>;
166 type InitError = ();
167 type Future = Ready<Result<Self::Transform, Self::InitError>>;
168
169 fn new_transform(&self, service: S) -> Self::Future {
170 ready(Ok(VerifyMiddleware(
171 service,
172 self.0.clone(),
173 self.1,
174 self.2.clone(),
175 )))
176 }
177}
178
179impl<T, Spawner, S, B> Service<ServiceRequest> for VerifyMiddleware<T, Spawner, S>
180where
181 T: DigestVerify + Clone + Send + 'static,
182 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = actix_web::Error> + 'static,
183 S::Error: 'static,
184 B: MessageBody + 'static,
185 Spawner: Spawn + Clone + 'static,
186{
187 type Response = ServiceResponse<B>;
188 type Error = actix_web::Error;
189 type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
190
191 fn poll_ready(&self, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
192 self.0.poll_ready(cx)
193 }
194
195 fn call(&self, mut req: ServiceRequest) -> Self::Future {
196 let span = tracing::info_span!(
197 "Verify digest",
198 digest.required = tracing::field::display(&self.2),
199 );
200
201 if let Some(digest) = req.headers().get("Digest") {
202 let vec = match parse_digest(digest) {
203 Some(vec) => vec,
204 None => {
205 return Box::pin(ready(Err(
206 VerifyError::new(&span, VerifyErrorKind::Empty).into()
207 )));
208 }
209 };
210
211 let spawner = self.1.clone();
212 let digest = self.3.clone();
213 let (verify_tx, verify_rx) = oneshot::channel();
214
215 let payload = req.take_payload();
216 let payload: Pin<Box<dyn Stream<Item = Result<web::Bytes, PayloadError>> + 'static>> =
217 Box::pin(streem::try_from_fn(|yielder| async move {
218 verify_payload(yielder, spawner, vec, digest, payload, verify_tx).await
219 }));
220 req.set_payload(payload.into());
221
222 req.extensions_mut().insert(VerifiedReceiver {
223 rx: Some(verify_rx),
224 });
225
226 Box::pin(self.0.call(req))
227 } else if self.2 {
228 Box::pin(ready(Err(VerifyError::new(
229 &span,
230 VerifyErrorKind::MissingDigest,
231 )
232 .into())))
233 } else {
234 Box::pin(self.0.call(req))
235 }
236 }
237}
238
239fn canceled_error(error: Canceled) -> PayloadError {
240 PayloadError::Io(std::io::Error::new(std::io::ErrorKind::Other, error))
241}
242
243fn verified_error(error: VerifyError) -> PayloadError {
244 PayloadError::Io(std::io::Error::new(std::io::ErrorKind::Other, error))
245}
246
247async fn verify_payload<T, Spawner>(
248 yielder: Yielder<Result<web::Bytes, PayloadError>>,
249 spawner: Spawner,
250 vec: Vec<DigestPart>,
251 mut verify_digest: T,
252 payload: Payload,
253 verify_tx: oneshot::Sender<()>,
254) -> Result<(), PayloadError>
255where
256 T: DigestVerify + Clone + Send + 'static,
257 Spawner: Spawn,
258{
259 let mut payload = payload.into_streamer();
260
261 let mut error = None;
262
263 while let Some(bytes) = payload.try_next().await? {
264 if error.is_none() {
265 let bytes2 = bytes.clone();
266 let mut verify_digest2 = verify_digest.clone();
267
268 let task = spawner.spawn_blocking(move || {
269 verify_digest2.update(bytes2.as_ref());
270 Ok(verify_digest2) as Result<T, VerifyError>
271 });
272
273 yielder.yield_ok(bytes).await;
274
275 match task.await {
276 Ok(Ok(digest)) => {
277 verify_digest = digest;
278 }
279 Ok(Err(e)) => {
280 error = Some(verified_error(e));
281 }
282 Err(e) => {
283 error = Some(canceled_error(e));
284 }
285 }
286 } else {
287 yielder.yield_ok(bytes).await;
288 }
289 }
290
291 if let Some(error) = error {
292 return Err(error);
293 }
294
295 let verified = spawner
296 .spawn_blocking(move || Ok(verify_digest.verify(&vec)) as Result<_, VerifyError>)
297 .await
298 .map_err(canceled_error)?
299 .map_err(verified_error)?;
300
301 if verified {
302 if verify_tx.send(()).is_err() {
303 debug!("handler dropped");
304 }
305
306 Ok(())
307 } else {
308 Err(verified_error(VerifyError::new(
309 &Span::current(),
310 VerifyErrorKind::Verify,
311 )))
312 }
313}
314
315fn parse_digest(h: &HeaderValue) -> Option<Vec<DigestPart>> {
316 let h = h.to_str().ok()?.split(';').next()?;
317 let v: Vec<_> = h
318 .split(',')
319 .filter_map(|p| {
320 let mut iter = p.splitn(2, '=');
321 iter.next()
322 .and_then(|alg| iter.next().map(|value| (alg, value)))
323 })
324 .map(|(alg, value)| DigestPart {
325 algorithm: alg.to_owned(),
326 digest: value.to_owned(),
327 })
328 .collect();
329
330 if v.is_empty() {
331 None
332 } else {
333 Some(v)
334 }
335}
336
337impl ResponseError for VerifyError {
338 fn status_code(&self) -> StatusCode {
339 StatusCode::BAD_REQUEST
340 }
341
342 fn error_response(&self) -> HttpResponse {
343 HttpResponse::new(self.status_code())
344 }
345}