use std::sync::Arc;
use http::{Extensions, HeaderMap, StatusCode};
use r402::hooks::{FailureRecovery, HookDecision};
use r402::proto;
use r402::proto::Base64Bytes;
use r402::proto::v2;
use r402::scheme::{
ClientError, FirstMatch, PaymentCandidate, PaymentPolicy, PaymentSelector, SchemeClient,
};
use reqwest::{Request, Response};
use reqwest_middleware as rqm;
#[cfg(feature = "telemetry")]
use tracing::{debug, info, instrument, trace};
use super::hooks::{ClientHooks, PaymentCreationContext};
#[allow(
missing_debug_implementations,
reason = "ClientSchemes contains dyn trait objects"
)]
pub struct X402Client<TSelector> {
schemes: ClientSchemes,
selector: TSelector,
policies: Vec<Arc<dyn PaymentPolicy>>,
hooks: Arc<[Arc<dyn ClientHooks>]>,
}
impl X402Client<FirstMatch> {
#[must_use]
pub fn new() -> Self {
Self::default()
}
}
impl Default for X402Client<FirstMatch> {
fn default() -> Self {
Self {
schemes: ClientSchemes::default(),
selector: FirstMatch,
policies: Vec::new(),
hooks: Arc::from([]),
}
}
}
impl<TSelector> X402Client<TSelector> {
#[must_use]
pub fn register<S>(mut self, scheme: S) -> Self
where
S: SchemeClient + 'static,
{
self.schemes.push(scheme);
self
}
pub fn with_selector<P: PaymentSelector + 'static>(self, selector: P) -> X402Client<P> {
X402Client {
selector,
schemes: self.schemes,
policies: self.policies,
hooks: self.hooks,
}
}
#[must_use]
pub fn with_policy<P: PaymentPolicy + 'static>(mut self, policy: P) -> Self {
self.policies.push(Arc::new(policy));
self
}
#[must_use]
pub fn with_hook(mut self, hook: impl ClientHooks + 'static) -> Self {
let mut hooks = (*self.hooks).to_vec();
hooks.push(Arc::new(hook));
self.hooks = Arc::from(hooks);
self
}
}
impl<TSelector> X402Client<TSelector>
where
TSelector: PaymentSelector,
{
#[cfg_attr(
feature = "telemetry",
instrument(name = "x402.reqwest.make_payment_headers", skip_all, err)
)]
pub async fn make_payment_headers(&self, res: Response) -> Result<HeaderMap, ClientError> {
let payment_required = parse_payment_required(res)
.await
.ok_or_else(|| ClientError::ParseError("Invalid 402 response".to_owned()))?;
let hook_ctx = PaymentCreationContext {
payment_required: payment_required.clone(),
};
for hook in self.hooks.iter() {
if let HookDecision::Abort { reason, .. } =
hook.before_payment_creation(&hook_ctx).await
{
return Err(ClientError::ParseError(reason));
}
}
let creation_result = self.create_payment_headers_inner(&payment_required).await;
match creation_result {
Ok(headers) => {
for hook in self.hooks.iter() {
hook.after_payment_creation(&hook_ctx, &headers).await;
}
Ok(headers)
}
Err(err) => {
let err_msg = err.to_string();
for hook in self.hooks.iter() {
if let FailureRecovery::Recovered(headers) =
hook.on_payment_creation_failure(&hook_ctx, &err_msg).await
{
return Ok(headers);
}
}
Err(err)
}
}
}
async fn create_payment_headers_inner(
&self,
payment_required: &proto::PaymentRequired,
) -> Result<HeaderMap, ClientError> {
let candidates = self.schemes.candidates(payment_required);
let mut filtered: Vec<&PaymentCandidate> = candidates.iter().collect();
for policy in &self.policies {
filtered = policy.apply(filtered);
if filtered.is_empty() {
return Err(ClientError::NoMatchingPaymentOption);
}
}
let selected = self
.selector
.select(&filtered)
.ok_or(ClientError::NoMatchingPaymentOption)?;
#[cfg(feature = "telemetry")]
debug!(
scheme = %selected.scheme,
chain_id = %selected.chain_id,
"Selected payment scheme"
);
let signed_payload = selected.sign().await?;
let headers = {
let mut headers = HeaderMap::new();
#[allow(
clippy::expect_used,
reason = "base64-encoded payload is always valid ASCII header"
)]
headers.insert(
"Payment-Signature",
signed_payload
.parse()
.expect("signed payload is valid header value"),
);
headers
};
Ok(headers)
}
}
#[derive(Default)]
#[allow(
missing_debug_implementations,
reason = "dyn trait objects do not impl Debug"
)]
pub(super) struct ClientSchemes(Vec<Arc<dyn SchemeClient>>);
impl ClientSchemes {
pub(super) fn push<T: SchemeClient + 'static>(&mut self, client: T) {
self.0.push(Arc::new(client));
}
#[must_use]
pub(super) fn candidates(
&self,
payment_required: &proto::PaymentRequired,
) -> Vec<PaymentCandidate> {
let mut candidates = vec![];
for client in &self.0 {
let accepted = client.accept(payment_required);
candidates.extend(accepted);
}
candidates
}
}
#[cfg_attr(
feature = "telemetry",
instrument(name = "x402.reqwest.next", skip_all)
)]
async fn run_next(
next: rqm::Next<'_>,
req: Request,
extensions: &mut Extensions,
) -> rqm::Result<Response> {
next.run(req, extensions).await
}
#[async_trait::async_trait]
impl<TSelector> rqm::Middleware for X402Client<TSelector>
where
TSelector: PaymentSelector + Send + Sync + 'static,
{
#[cfg_attr(
feature = "telemetry",
instrument(name = "x402.reqwest.handle", skip_all, err)
)]
async fn handle(
&self,
req: Request,
extensions: &mut Extensions,
next: rqm::Next<'_>,
) -> rqm::Result<Response> {
let retry_req = req.try_clone();
let res = run_next(next.clone(), req, extensions).await?;
if res.status() != StatusCode::PAYMENT_REQUIRED {
#[cfg(feature = "telemetry")]
trace!(status = ?res.status(), "No payment required, returning response");
return Ok(res);
}
#[cfg(feature = "telemetry")]
info!(url = ?res.url(), "Received 402 Payment Required, processing payment");
let Some(mut retry) = retry_req else {
#[cfg(feature = "telemetry")]
tracing::warn!("Cannot auto-retry 402: request body not cloneable, returning raw 402");
return Ok(res);
};
let headers = self
.make_payment_headers(res)
.await
.map_err(|e| rqm::Error::Middleware(e.into()))?;
retry.headers_mut().extend(headers);
#[cfg(feature = "telemetry")]
trace!(url = ?retry.url(), "Retrying request with payment headers");
run_next(next, retry, extensions).await
}
}
#[cfg_attr(
feature = "telemetry",
instrument(name = "x402.reqwest.parse_payment_required", skip(response))
)]
pub async fn parse_payment_required(response: Response) -> Option<proto::PaymentRequired> {
let v2_from_header = response
.headers()
.get("Payment-Required")
.and_then(|h| Base64Bytes::from(h.as_bytes()).decode().ok())
.and_then(|b| serde_json::from_slice::<v2::PaymentRequired>(&b).ok());
if let Some(v2_payment_required) = v2_from_header {
#[cfg(feature = "telemetry")]
debug!("Parsed V2 payment required from header");
return Some(v2_payment_required);
}
if let Ok(body_bytes) = response.bytes().await
&& let Ok(v2_from_body) = serde_json::from_slice::<v2::PaymentRequired>(&body_bytes)
{
#[cfg(feature = "telemetry")]
debug!("Parsed V2 payment required from response body");
return Some(v2_from_body);
}
#[cfg(feature = "telemetry")]
debug!("Could not parse payment required from response");
None
}