http-signature-normalization-actix 0.3.0

An HTTP Signatures library that leaves the signing to you
Documentation
//! Types for setting up Digest middleware verification

use super::{DigestPart, DigestVerify};
use actix_web::{
    dev::{MessageBody, Payload, Service, ServiceRequest, ServiceResponse, Transform},
    error::PayloadError,
    http::{header::HeaderValue, StatusCode},
    FromRequest, HttpMessage, HttpRequest, HttpResponse, ResponseError,
};
use bytes::{Bytes, BytesMut};
use futures::{
    channel::mpsc,
    future::{err, ok, ready, Ready},
    Stream, StreamExt,
};
use log::{debug, warn};
use std::{
    future::Future,
    pin::Pin,
    task::{Context, Poll},
};

#[derive(Copy, Clone, Debug)]
/// A type implementing FromRequest that can be used in route handler to guard for verified
/// digests
///
/// This is only required when the [`VerifyDigest`] middleware is set to optional
pub struct DigestVerified;

/// The VerifyDigest middleware
///
/// ```rust,ignore
/// let middleware = VerifyDigest::new(MyVerify::new())
///     .optional();
///
/// HttpServer::new(move || {
///     App::new()
///         .wrap(middleware.clone())
///         .route("/protected", web::post().to(|_: DigestVerified| "Verified Digest Header"))
///         .route("/unprotected", web::post().to(|| "No verification required"))
/// })
/// ```
pub struct VerifyDigest<T>(bool, T);

#[doc(hidden)]
pub struct VerifyMiddleware<T, S>(S, bool, T);

#[derive(Debug, thiserror::Error)]
#[error("Error verifying digest")]
#[doc(hidden)]
pub struct VerifyError;

impl<T> VerifyDigest<T>
where
    T: DigestVerify + Clone,
{
    /// Produce a new VerifyDigest with a user-provided [`Digestverify`] type
    pub fn new(verify_digest: T) -> Self {
        VerifyDigest(true, verify_digest)
    }

    /// Mark verifying the Digest as optional
    ///
    /// If a digest is present in the request, it will be verified, but it is not required to be
    /// present
    pub fn optional(self) -> Self {
        VerifyDigest(false, self.1)
    }
}

impl FromRequest for DigestVerified {
    type Error = VerifyError;
    type Future = Ready<Result<Self, Self::Error>>;
    type Config = ();

    fn from_request(req: &HttpRequest, _: &mut Payload) -> Self::Future {
        let res = req.extensions().get::<Self>().copied().ok_or(VerifyError);

        if res.is_err() {
            debug!("Failed to fetch DigestVerified from request");
        }

        ready(res)
    }
}

impl<T, S, B> Transform<S> for VerifyDigest<T>
where
    T: DigestVerify + Clone + 'static,
    S: Service<Request = ServiceRequest, Response = ServiceResponse<B>, Error = actix_web::Error>
        + 'static,
    S::Error: 'static,
    B: MessageBody + 'static,
{
    type Request = ServiceRequest;
    type Response = ServiceResponse<B>;
    type Error = actix_web::Error;
    type Transform = VerifyMiddleware<T, S>;
    type InitError = ();
    type Future = Ready<Result<Self::Transform, Self::InitError>>;

    fn new_transform(&self, service: S) -> Self::Future {
        ok(VerifyMiddleware(service, self.0, self.1.clone()))
    }
}

type FutResult<T, E> = dyn Future<Output = Result<T, E>>;

impl<T, S, B> Service for VerifyMiddleware<T, S>
where
    T: DigestVerify + Clone + 'static,
    S: Service<Request = ServiceRequest, Response = ServiceResponse<B>, Error = actix_web::Error>
        + 'static,
    S::Error: 'static,
    B: MessageBody + 'static,
{
    type Request = ServiceRequest;
    type Response = ServiceResponse<B>;
    type Error = actix_web::Error;
    type Future = Pin<Box<FutResult<Self::Response, Self::Error>>>;

    fn poll_ready(&mut self, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
        self.0.poll_ready(cx)
    }

    fn call(&mut self, mut req: ServiceRequest) -> Self::Future {
        if let Some(digest) = req.headers().get("Digest") {
            let vec = match parse_digest(digest) {
                Some(vec) => vec,
                None => {
                    warn!("Digest header could not be parsed");
                    return Box::pin(err(VerifyError.into()));
                }
            };
            let payload = req.take_payload();

            let (tx, rx) = mpsc::channel(1);
            let f1 = verify_payload(vec, self.2.clone(), payload, tx);

            let payload: Pin<Box<dyn Stream<Item = Result<Bytes, PayloadError>> + 'static>> =
                Box::pin(rx.map(Ok));
            req.set_payload(payload.into());
            req.extensions_mut().insert(DigestVerified);

            let f2 = self.0.call(req);

            Box::pin(async move {
                f1.await?;
                f2.await
            })
        } else if self.1 {
            Box::pin(err(VerifyError.into()))
        } else {
            Box::pin(self.0.call(req))
        }
    }
}

async fn verify_payload<T>(
    vec: Vec<DigestPart>,
    mut verify_digest: T,
    mut payload: Payload,
    mut tx: mpsc::Sender<Bytes>,
) -> Result<(), actix_web::Error>
where
    T: DigestVerify + Clone + 'static,
{
    let mut output_bytes = BytesMut::new();

    while let Some(res) = payload.next().await {
        let bytes = res?;
        output_bytes.extend(bytes);

        if tx.is_closed() {
            warn!("Payload dropped. If this was unexpected, it could be that the payload isn't required in the route this middleware is guarding");
            return Err(VerifyError.into());
        }
    }

    let bytes = output_bytes.freeze();

    if verify_digest.verify(&vec, &bytes.as_ref()) {
        tx.try_send(bytes).map_err(|_| VerifyError.into())
    } else {
        warn!("Digest could not be verified");
        Err(VerifyError.into())
    }
}

fn parse_digest(h: &HeaderValue) -> Option<Vec<DigestPart>> {
    let h = h.to_str().ok()?.split(';').next()?;
    let v: Vec<_> = h
        .split(',')
        .filter_map(|p| {
            let mut iter = p.splitn(2, '=');
            iter.next()
                .and_then(|alg| iter.next().map(|value| (alg, value)))
        })
        .map(|(alg, value)| DigestPart {
            algorithm: alg.to_owned(),
            digest: value.to_owned(),
        })
        .collect();

    if v.is_empty() {
        None
    } else {
        Some(v)
    }
}

impl ResponseError for VerifyError {
    fn status_code(&self) -> StatusCode {
        StatusCode::BAD_REQUEST
    }

    fn error_response(&self) -> HttpResponse {
        HttpResponse::BadRequest().finish()
    }
}