use anyhow::Context;
use async_trait::async_trait;
use reqwest::header::WWW_AUTHENTICATE;
use reqwest::{Request, Response, StatusCode};
use reqwest_middleware::{Middleware, Next};
use crate::client::provider::PaymentProvider;
use crate::protocol::core::{format_authorization, parse_www_authenticate, AUTHORIZATION_HEADER};
pub struct PaymentMiddleware<P> {
provider: P,
}
impl<P> PaymentMiddleware<P> {
pub fn new(provider: P) -> Self {
Self { provider }
}
}
#[async_trait]
impl<P> Middleware for PaymentMiddleware<P>
where
P: PaymentProvider + 'static,
{
async fn handle(
&self,
req: Request,
extensions: &mut http_types::Extensions,
next: Next<'_>,
) -> reqwest_middleware::Result<Response> {
let retry_req = req.try_clone();
let resp = next.clone().run(req, extensions).await?;
if resp.status() != StatusCode::PAYMENT_REQUIRED {
return Ok(resp);
}
let retry_req = retry_req
.context("request could not be cloned for payment retry")
.map_err(reqwest_middleware::Error::Middleware)?;
let www_auth = resp
.headers()
.get(WWW_AUTHENTICATE)
.context("402 response missing WWW-Authenticate header")
.map_err(reqwest_middleware::Error::Middleware)?
.to_str()
.context("invalid WWW-Authenticate header")
.map_err(reqwest_middleware::Error::Middleware)?;
let challenge = parse_www_authenticate(www_auth)
.context("invalid challenge")
.map_err(reqwest_middleware::Error::Middleware)?;
let credential = self
.provider
.pay(&challenge)
.await
.context("payment failed")
.map_err(reqwest_middleware::Error::Middleware)?;
let auth_header = format_authorization(&credential)
.context("failed to format credential")
.map_err(reqwest_middleware::Error::Middleware)?;
let mut retry_req = retry_req;
retry_req.headers_mut().insert(
AUTHORIZATION_HEADER,
auth_header
.parse()
.context("invalid authorization header")
.map_err(reqwest_middleware::Error::Middleware)?,
);
next.run(retry_req, extensions).await
}
}
#[cfg(test)]
mod tests {
use super::*;
#[derive(Clone)]
struct MockProvider;
impl PaymentProvider for MockProvider {
fn supports(&self, _method: &str, _intent: &str) -> bool {
true
}
async fn pay(
&self,
_challenge: &crate::protocol::core::PaymentChallenge,
) -> Result<crate::protocol::core::PaymentCredential, crate::error::MppError> {
unimplemented!("mock provider")
}
}
#[test]
fn test_middleware_new() {
let _middleware = PaymentMiddleware::new(MockProvider);
}
}