use http::{Extensions, HeaderMap, StatusCode};
use reqwest::{Request, Response};
use reqwest_middleware as rqm;
use std::sync::Arc;
use x402_types::proto;
use x402_types::proto::{OriginalJson, v1, v2};
use x402_types::scheme::client::{
FirstMatch, PaymentCandidate, PaymentSelector, X402Error, X402SchemeClient,
};
use x402_types::util::Base64Bytes;
#[cfg(feature = "telemetry")]
use tracing::{debug, info, instrument, trace};
pub struct X402Client<TSelector> {
schemes: ClientSchemes,
selector: TSelector,
}
impl X402Client<FirstMatch> {
pub fn new() -> Self {
Self::default()
}
}
impl Default for X402Client<FirstMatch> {
fn default() -> Self {
Self {
schemes: ClientSchemes::default(),
selector: FirstMatch,
}
}
}
impl<TSelector> X402Client<TSelector> {
pub fn register<S>(mut self, scheme: S) -> Self
where
S: X402SchemeClient + 'static,
{
self.schemes.push(scheme);
self
}
pub fn with_selector<P: PaymentSelector + 'static>(self, selector: P) -> X402Client<P> {
X402Client {
selector,
schemes: self.schemes,
}
}
}
impl<TSelector> X402Client<TSelector>
where
TSelector: PaymentSelector,
{
#[cfg_attr(
feature = "telemetry",
instrument(name = "x402.reqwest.make_payment_headers", skip_all, err)
)]
pub async fn make_payment_headers(&self, res: Response) -> Result<HeaderMap, X402Error> {
let payment_required = parse_payment_required(res)
.await
.ok_or(X402Error::ParseError("Invalid 402 response".to_string()))?;
let candidates = self.schemes.candidates(&payment_required);
let selected = self
.selector
.select(&candidates)
.ok_or(X402Error::NoMatchingPaymentOption)?;
#[cfg(feature = "telemetry")]
debug!(
scheme = %selected.scheme,
chain_id = %selected.chain_id,
"Selected payment scheme"
);
let signed_payload = selected.sign().await?;
let header_name = match &payment_required {
proto::PaymentRequired::V1(_) => "X-Payment",
proto::PaymentRequired::V2(_) => "Payment-Signature",
};
let headers = {
let mut headers = HeaderMap::new();
headers.insert(header_name, signed_payload.parse().unwrap());
headers
};
Ok(headers)
}
}
#[derive(Default)]
pub struct ClientSchemes(Vec<Arc<dyn X402SchemeClient>>);
impl ClientSchemes {
pub fn push<T: X402SchemeClient + 'static>(&mut self, client: T) {
self.0.push(Arc::new(client));
}
pub fn candidates(&self, payment_required: &proto::PaymentRequired) -> Vec<PaymentCandidate> {
let mut candidates = vec![];
for client in self.0.iter() {
let accepted = client.accept(payment_required);
candidates.extend(accepted);
}
candidates
}
}
#[cfg_attr(
feature = "telemetry",
instrument(name = "x402.reqwest.next", skip_all)
)]
async fn run_next(
next: rqm::Next<'_>,
req: Request,
extensions: &mut Extensions,
) -> rqm::Result<Response> {
next.run(req, extensions).await
}
#[async_trait::async_trait]
impl<TSelector> rqm::Middleware for X402Client<TSelector>
where
TSelector: PaymentSelector + Send + Sync + 'static,
{
#[cfg_attr(
feature = "telemetry",
instrument(name = "x402.reqwest.handle", skip_all, err)
)]
async fn handle(
&self,
req: Request,
extensions: &mut Extensions,
next: rqm::Next<'_>,
) -> rqm::Result<Response> {
let retry_req = req.try_clone();
let res = run_next(next.clone(), req, extensions).await?;
if res.status() != StatusCode::PAYMENT_REQUIRED {
#[cfg(feature = "telemetry")]
trace!(status = ?res.status(), "No payment required, returning response");
return Ok(res);
}
#[cfg(feature = "telemetry")]
info!(url = ?res.url(), "Received 402 Payment Required, processing payment");
let headers = self
.make_payment_headers(res)
.await
.map_err(|e| rqm::Error::Middleware(e.into()))?;
let mut retry = retry_req.ok_or(rqm::Error::Middleware(
X402Error::RequestNotCloneable.into(),
))?;
retry.headers_mut().extend(headers);
#[cfg(feature = "telemetry")]
trace!(url = ?retry.url(), "Retrying request with payment headers");
run_next(next, retry, extensions).await
}
}
#[cfg_attr(
feature = "telemetry",
instrument(name = "x402.reqwest.parse_payment_required", skip(response))
)]
pub async fn parse_payment_required(response: Response) -> Option<proto::PaymentRequired> {
let headers = response.headers();
let v2_payment_required = headers
.get("Payment-Required")
.and_then(|h| Base64Bytes::from(h.as_bytes()).decode().ok())
.and_then(|b| serde_json::from_slice::<v2::PaymentRequired<OriginalJson>>(&b).ok());
if let Some(v2_payment_required) = v2_payment_required {
#[cfg(feature = "telemetry")]
debug!("Parsed V2 payment required from header");
return Some(proto::PaymentRequired::V2(v2_payment_required));
}
let v1_payment_required = response
.bytes()
.await
.ok()
.and_then(|b| serde_json::from_slice::<v1::PaymentRequired<OriginalJson>>(&b).ok());
if let Some(v1_payment_required) = v1_payment_required {
#[cfg(feature = "telemetry")]
debug!("Parsed V1 payment required from body");
return Some(proto::PaymentRequired::V1(v1_payment_required));
}
#[cfg(feature = "telemetry")]
debug!("Could not parse payment required from response");
None
}