use std::fmt;
use actix_http::BoxedPayloadStream;
use actix_web::{Error, FromRequest, HttpRequest, dev, web::Bytes};
use derive_more::Display;
use futures_core::future::LocalBoxFuture;
use futures_util::{FutureExt as _, StreamExt as _, TryFutureExt as _};
use local_channel::mpsc;
use tokio::try_join;
use tracing::trace;
pub trait RequestSignatureScheme: Sized {
type Signature;
type Error: Into<Error>;
fn init(req: &HttpRequest) -> impl Future<Output = Result<Self, Self::Error>>;
fn consume_chunk(
&mut self,
req: &HttpRequest,
chunk: Bytes,
) -> impl Future<Output = Result<(), Self::Error>>;
fn finalize(
self,
req: &HttpRequest,
) -> impl Future<Output = Result<Self::Signature, Self::Error>>;
#[allow(unused_variables)]
#[inline]
fn verify(
signature: Self::Signature,
req: &HttpRequest,
) -> Result<Self::Signature, Self::Error> {
Ok(signature)
}
}
#[allow(missing_debug_implementations)]
#[derive(Clone)]
pub struct RequestSignature<T, S: RequestSignatureScheme> {
extractor: T,
signature: S::Signature,
}
impl<T, S: RequestSignatureScheme> RequestSignature<T, S> {
pub fn into_parts(self) -> (T, S::Signature) {
(self.extractor, self.signature)
}
}
#[derive(Display)]
#[non_exhaustive]
pub enum RequestSignatureError<T, S>
where
T: FromRequest,
T::Error: fmt::Debug + fmt::Display,
S: RequestSignatureScheme,
S::Error: fmt::Debug + fmt::Display,
{
#[display("Inner extractor error: {_0}")]
Extractor(T::Error),
#[display("Signature calculation error: {_0}")]
Signature(S::Error),
}
impl<T, S> fmt::Debug for RequestSignatureError<T, S>
where
T: FromRequest,
T::Error: fmt::Debug + fmt::Display,
S: RequestSignatureScheme,
S::Error: fmt::Debug + fmt::Display,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Extractor(err) => f
.debug_tuple("RequestSignatureError::Extractor")
.field(err)
.finish(),
Self::Signature(err) => f
.debug_tuple("RequestSignatureError::Signature")
.field(err)
.finish(),
}
}
}
impl<T, S> From<RequestSignatureError<T, S>> for actix_web::Error
where
T: FromRequest,
T::Error: fmt::Debug + fmt::Display,
S: RequestSignatureScheme,
S::Error: fmt::Debug + fmt::Display,
{
fn from(err: RequestSignatureError<T, S>) -> Self {
match err {
RequestSignatureError::Extractor(err) => err.into(),
RequestSignatureError::Signature(err) => err.into(),
}
}
}
impl<T, S> FromRequest for RequestSignature<T, S>
where
T: FromRequest + 'static,
T::Error: fmt::Debug + fmt::Display,
S: RequestSignatureScheme + 'static,
S::Error: fmt::Debug + fmt::Display,
{
type Error = RequestSignatureError<T, S>;
type Future = LocalBoxFuture<'static, Result<Self, Self::Error>>;
fn from_request(req: &HttpRequest, payload: &mut dev::Payload) -> Self::Future {
let req = req.clone();
let payload = payload.take();
Box::pin(async move {
let (tx, mut rx) = mpsc::channel();
let proxy_stream: BoxedPayloadStream = Box::pin(payload.inspect(move |res| {
if let Ok(chunk) = res {
trace!("yielding {} byte chunk", chunk.len());
tx.send(chunk.clone()).unwrap();
}
}));
trace!("creating proxy payload");
let mut proxy_payload = dev::Payload::from(proxy_stream);
let body_fut =
T::from_request(&req, &mut proxy_payload).map_err(RequestSignatureError::Extractor);
trace!("initializing signature scheme");
let mut sig_scheme = S::init(&req)
.await
.map_err(RequestSignatureError::Signature)?;
let hash_fut = actix_web::rt::spawn({
let req = req.clone();
async move {
while let Some(chunk) = rx.recv().await {
trace!("digesting chunk");
sig_scheme.consume_chunk(&req, chunk).await?;
}
trace!("finalizing signature");
sig_scheme.finalize(&req).await
}
})
.map(Result::unwrap)
.map_err(RequestSignatureError::Signature);
trace!("driving both futures");
let (body, signature) = try_join!(body_fut, hash_fut)?;
trace!("verifying signature");
let signature = S::verify(signature, &req).map_err(RequestSignatureError::Signature)?;
let out = Self {
extractor: body,
signature,
};
Ok(out)
})
}
}
#[cfg(test)]
mod tests {
use std::convert::Infallible;
use actix_web::{
App,
http::StatusCode,
test,
web::{self, Bytes},
};
use digest::{CtOutput, Digest as _};
use hex_literal::hex;
use sha2::Sha256;
use super::*;
use crate::extract::Json;
#[derive(Debug, Default)]
struct JustHash(sha2::Sha256);
impl RequestSignatureScheme for JustHash {
type Signature = CtOutput<sha2::Sha256>;
type Error = Infallible;
async fn init(head: &HttpRequest) -> Result<Self, Self::Error> {
let mut hasher = Sha256::new();
if let Some(path) = head.uri().path_and_query() {
hasher.update(path.as_str().as_bytes())
}
Ok(Self(hasher))
}
async fn consume_chunk(
&mut self,
_req: &HttpRequest,
chunk: Bytes,
) -> Result<(), Self::Error> {
self.0.update(&chunk);
Ok(())
}
async fn finalize(self, _req: &HttpRequest) -> Result<Self::Signature, Self::Error> {
let hash = self.0.finalize();
Ok(CtOutput::new(hash))
}
}
#[actix_web::test]
async fn correctly_hashes_payload() {
let app = test::init_service(App::new().route(
"/service/path",
web::get().to(|body: RequestSignature<Bytes, JustHash>| async move {
let (_, sig) = body.into_parts();
sig.into_bytes().to_vec()
}),
))
.await;
let req = test::TestRequest::with_uri("/service/path").to_request();
let body = test::call_and_read_body(&app, req).await;
assert_eq!(
body,
hex!("a5441a3d ec265f82 3758d164 1188ab1d d1093972 45012a45 fa66df70 32d02177")
.as_ref()
);
let req = test::TestRequest::with_uri("/service/path")
.set_payload("abc")
.to_request();
let body = test::call_and_read_body(&app, req).await;
assert_eq!(
body,
hex!("555290a8 9e75260d fb0afead 2d5d3d70 f058c85d 1ff98bf3 06807301 7ce4c847")
.as_ref()
);
}
#[actix_web::test]
async fn respects_inner_extractor_errors() {
let app = test::init_service(App::new().route(
"/",
web::get().to(
|body: RequestSignature<Json<u64, 4>, JustHash>| async move {
let (_, sig) = body.into_parts();
sig.into_bytes().to_vec()
},
),
))
.await;
let req = test::TestRequest::default().set_json(1234).to_request();
let body = test::call_and_read_body(&app, req).await;
assert_eq!(
body,
hex!("4f373f6c cadfaba3 1a32cf52 04cf3db9 367609ee 6a7d7113 8e4f28ef 7c1a87a9")
.as_ref()
);
let req = test::TestRequest::default().to_request();
let body = test::call_service(&app, req).await;
assert_eq!(body.status(), StatusCode::NOT_ACCEPTABLE);
let req = test::TestRequest::default().set_json(12345).to_request();
let body = test::call_service(&app, req).await;
assert_eq!(body.status(), StatusCode::PAYLOAD_TOO_LARGE);
}
}