http_signature_normalization_actix/digest/
middleware.rs

1//! Types for setting up Digest middleware verification
2
3use crate::{Canceled, DefaultSpawner, Spawn};
4
5use super::{DigestPart, DigestVerify};
6use actix_web::{
7    body::MessageBody,
8    dev::{Payload, Service, ServiceRequest, ServiceResponse, Transform},
9    error::PayloadError,
10    http::{header::HeaderValue, StatusCode},
11    web, FromRequest, HttpMessage, HttpRequest, HttpResponse, ResponseError,
12};
13use futures_core::{future::LocalBoxFuture, Stream};
14use std::{
15    future::{ready, Ready},
16    pin::Pin,
17    task::{Context, Poll},
18};
19use streem::{from_fn::Yielder, IntoStreamer};
20use tokio::sync::{mpsc, oneshot};
21use tracing::{debug, Span};
22use tracing_error::SpanTrace;
23
24#[derive(Copy, Clone, Debug)]
25/// A type implementing FromRequest that can be used in route handler to guard for verified
26/// digests
27///
28/// This is only required when the [`VerifyDigest`] middleware is set to optional
29pub struct DigestVerified;
30
31#[derive(Clone, Debug)]
32/// The VerifyDigest middleware
33///
34/// ```rust,ignore
35/// let middleware = VerifyDigest::new(MyVerify::new())
36///     .optional();
37///
38/// HttpServer::new(move || {
39///     App::new()
40///         .wrap(middleware.clone())
41///         .route("/protected", web::post().to(|_: DigestVerified| "Verified Digest Header"))
42///         .route("/unprotected", web::post().to(|| "No verification required"))
43/// })
44/// ```
45pub struct VerifyDigest<T, Spawner = DefaultSpawner>(Spawner, bool, T);
46
47#[doc(hidden)]
48pub struct VerifyMiddleware<T, Spawner, S>(S, Spawner, bool, T);
49
50#[derive(Debug, thiserror::Error)]
51#[error("Error verifying digest")]
52#[doc(hidden)]
53pub struct VerifyError {
54    context: String,
55    kind: VerifyErrorKind,
56}
57
58impl VerifyError {
59    fn new(span: &Span, kind: VerifyErrorKind) -> Self {
60        span.in_scope(|| VerifyError {
61            context: SpanTrace::capture().to_string(),
62            kind,
63        })
64    }
65}
66
67#[derive(Debug, thiserror::Error)]
68enum VerifyErrorKind {
69    #[error("Missing request extension")]
70    Extension,
71
72    #[error("Digest header missing")]
73    MissingDigest,
74
75    #[error("Digest header is empty")]
76    Empty,
77
78    #[error("Failed to verify digest")]
79    Verify,
80
81    #[error("Payload dropped. If this was unexpected, it could be that the payload isn't required in the route this middleware is guarding")]
82    Dropped,
83}
84
85struct RxStream<T>(mpsc::Receiver<T>);
86
87impl<T> Stream for RxStream<T> {
88    type Item = Result<T, PayloadError>;
89
90    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
91        Pin::new(&mut self.0).poll_recv(cx).map(|opt| opt.map(Ok))
92    }
93}
94
95impl<T> VerifyDigest<T>
96where
97    T: DigestVerify + Clone,
98{
99    /// Produce a new VerifyDigest with a user-provided [`Digestverify`] type
100    pub fn new(verify_digest: T) -> Self {
101        VerifyDigest(DefaultSpawner, true, verify_digest)
102    }
103}
104
105impl<T, Spawner> VerifyDigest<T, Spawner>
106where
107    T: DigestVerify + Clone,
108{
109    /// Set the spawner used for verifying bytes in the request
110    ///
111    /// By default this value uses `actix_web::web::block`
112    pub fn spawner<NewSpawner>(self, spawner: NewSpawner) -> VerifyDigest<T, NewSpawner>
113    where
114        NewSpawner: Spawn,
115    {
116        VerifyDigest(spawner, self.1, self.2)
117    }
118
119    /// Mark verifying the Digest as optional
120    ///
121    /// If a digest is present in the request, it will be verified, but it is not required to be
122    /// present
123    pub fn optional(self) -> Self {
124        VerifyDigest(self.0, false, self.2)
125    }
126}
127
128struct VerifiedReceiver {
129    rx: Option<oneshot::Receiver<()>>,
130}
131
132impl FromRequest for DigestVerified {
133    type Error = VerifyError;
134    type Future = LocalBoxFuture<'static, Result<Self, Self::Error>>;
135
136    fn from_request(req: &HttpRequest, _: &mut Payload) -> Self::Future {
137        let res = req
138            .extensions_mut()
139            .get_mut::<VerifiedReceiver>()
140            .and_then(|r| r.rx.take())
141            .ok_or_else(|| VerifyError::new(&Span::current(), VerifyErrorKind::Extension));
142
143        if res.is_err() {
144            debug!("Failed to fetch DigestVerified from request");
145        }
146
147        Box::pin(async move {
148            res?.await
149                .map_err(|_| VerifyError::new(&Span::current(), VerifyErrorKind::Dropped))
150                .map(|()| DigestVerified)
151        })
152    }
153}
154
155impl<T, Spawner, S, B> Transform<S, ServiceRequest> for VerifyDigest<T, Spawner>
156where
157    T: DigestVerify + Clone + Send + 'static,
158    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = actix_web::Error> + 'static,
159    S::Error: 'static,
160    B: MessageBody + 'static,
161    Spawner: Spawn + Clone + 'static,
162{
163    type Response = ServiceResponse<B>;
164    type Error = actix_web::Error;
165    type Transform = VerifyMiddleware<T, Spawner, S>;
166    type InitError = ();
167    type Future = Ready<Result<Self::Transform, Self::InitError>>;
168
169    fn new_transform(&self, service: S) -> Self::Future {
170        ready(Ok(VerifyMiddleware(
171            service,
172            self.0.clone(),
173            self.1,
174            self.2.clone(),
175        )))
176    }
177}
178
179impl<T, Spawner, S, B> Service<ServiceRequest> for VerifyMiddleware<T, Spawner, S>
180where
181    T: DigestVerify + Clone + Send + 'static,
182    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = actix_web::Error> + 'static,
183    S::Error: 'static,
184    B: MessageBody + 'static,
185    Spawner: Spawn + Clone + 'static,
186{
187    type Response = ServiceResponse<B>;
188    type Error = actix_web::Error;
189    type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
190
191    fn poll_ready(&self, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
192        self.0.poll_ready(cx)
193    }
194
195    fn call(&self, mut req: ServiceRequest) -> Self::Future {
196        let span = tracing::info_span!(
197            "Verify digest",
198            digest.required = tracing::field::display(&self.2),
199        );
200
201        if let Some(digest) = req.headers().get("Digest") {
202            let vec = match parse_digest(digest) {
203                Some(vec) => vec,
204                None => {
205                    return Box::pin(ready(Err(
206                        VerifyError::new(&span, VerifyErrorKind::Empty).into()
207                    )));
208                }
209            };
210
211            let spawner = self.1.clone();
212            let digest = self.3.clone();
213            let (verify_tx, verify_rx) = oneshot::channel();
214
215            let payload = req.take_payload();
216            let payload: Pin<Box<dyn Stream<Item = Result<web::Bytes, PayloadError>> + 'static>> =
217                Box::pin(streem::try_from_fn(|yielder| async move {
218                    verify_payload(yielder, spawner, vec, digest, payload, verify_tx).await
219                }));
220            req.set_payload(payload.into());
221
222            req.extensions_mut().insert(VerifiedReceiver {
223                rx: Some(verify_rx),
224            });
225
226            Box::pin(self.0.call(req))
227        } else if self.2 {
228            Box::pin(ready(Err(VerifyError::new(
229                &span,
230                VerifyErrorKind::MissingDigest,
231            )
232            .into())))
233        } else {
234            Box::pin(self.0.call(req))
235        }
236    }
237}
238
239fn canceled_error(error: Canceled) -> PayloadError {
240    PayloadError::Io(std::io::Error::new(std::io::ErrorKind::Other, error))
241}
242
243fn verified_error(error: VerifyError) -> PayloadError {
244    PayloadError::Io(std::io::Error::new(std::io::ErrorKind::Other, error))
245}
246
247async fn verify_payload<T, Spawner>(
248    yielder: Yielder<Result<web::Bytes, PayloadError>>,
249    spawner: Spawner,
250    vec: Vec<DigestPart>,
251    mut verify_digest: T,
252    payload: Payload,
253    verify_tx: oneshot::Sender<()>,
254) -> Result<(), PayloadError>
255where
256    T: DigestVerify + Clone + Send + 'static,
257    Spawner: Spawn,
258{
259    let mut payload = payload.into_streamer();
260
261    let mut error = None;
262
263    while let Some(bytes) = payload.try_next().await? {
264        if error.is_none() {
265            let bytes2 = bytes.clone();
266            let mut verify_digest2 = verify_digest.clone();
267
268            let task = spawner.spawn_blocking(move || {
269                verify_digest2.update(bytes2.as_ref());
270                Ok(verify_digest2) as Result<T, VerifyError>
271            });
272
273            yielder.yield_ok(bytes).await;
274
275            match task.await {
276                Ok(Ok(digest)) => {
277                    verify_digest = digest;
278                }
279                Ok(Err(e)) => {
280                    error = Some(verified_error(e));
281                }
282                Err(e) => {
283                    error = Some(canceled_error(e));
284                }
285            }
286        } else {
287            yielder.yield_ok(bytes).await;
288        }
289    }
290
291    if let Some(error) = error {
292        return Err(error);
293    }
294
295    let verified = spawner
296        .spawn_blocking(move || Ok(verify_digest.verify(&vec)) as Result<_, VerifyError>)
297        .await
298        .map_err(canceled_error)?
299        .map_err(verified_error)?;
300
301    if verified {
302        if verify_tx.send(()).is_err() {
303            debug!("handler dropped");
304        }
305
306        Ok(())
307    } else {
308        Err(verified_error(VerifyError::new(
309            &Span::current(),
310            VerifyErrorKind::Verify,
311        )))
312    }
313}
314
315fn parse_digest(h: &HeaderValue) -> Option<Vec<DigestPart>> {
316    let h = h.to_str().ok()?.split(';').next()?;
317    let v: Vec<_> = h
318        .split(',')
319        .filter_map(|p| {
320            let mut iter = p.splitn(2, '=');
321            iter.next()
322                .and_then(|alg| iter.next().map(|value| (alg, value)))
323        })
324        .map(|(alg, value)| DigestPart {
325            algorithm: alg.to_owned(),
326            digest: value.to_owned(),
327        })
328        .collect();
329
330    if v.is_empty() {
331        None
332    } else {
333        Some(v)
334    }
335}
336
337impl ResponseError for VerifyError {
338    fn status_code(&self) -> StatusCode {
339        StatusCode::BAD_REQUEST
340    }
341
342    fn error_response(&self) -> HttpResponse {
343        HttpResponse::new(self.status_code())
344    }
345}