r402_http/client/
middleware.rs1use std::sync::Arc;
7
8use http::{Extensions, HeaderMap, StatusCode};
9use r402::hooks::{FailureRecovery, HookDecision};
10use r402::proto;
11use r402::proto::Base64Bytes;
12use r402::proto::v2;
13use r402::scheme::{
14 ClientError, FirstMatch, PaymentCandidate, PaymentPolicy, PaymentSelector, SchemeClient,
15};
16use reqwest::{Request, Response};
17use reqwest_middleware as rqm;
18#[cfg(feature = "telemetry")]
19use tracing::{debug, info, instrument, trace};
20
21use super::hooks::{ClientHooks, PaymentCreationContext};
22
23#[allow(
29 missing_debug_implementations,
30 reason = "ClientSchemes contains dyn trait objects"
31)]
32pub struct X402Client<TSelector> {
33 schemes: ClientSchemes,
34 selector: TSelector,
35 policies: Vec<Arc<dyn PaymentPolicy>>,
36 hooks: Arc<[Arc<dyn ClientHooks>]>,
37}
38
39impl X402Client<FirstMatch> {
40 #[must_use]
45 pub fn new() -> Self {
46 Self::default()
47 }
48}
49
50impl Default for X402Client<FirstMatch> {
51 fn default() -> Self {
52 Self {
53 schemes: ClientSchemes::default(),
54 selector: FirstMatch,
55 policies: Vec::new(),
56 hooks: Arc::from([]),
57 }
58 }
59}
60
61impl<TSelector> X402Client<TSelector> {
62 #[must_use]
75 pub fn register<S>(mut self, scheme: S) -> Self
76 where
77 S: SchemeClient + 'static,
78 {
79 self.schemes.push(scheme);
80 self
81 }
82
83 pub fn with_selector<P: PaymentSelector + 'static>(self, selector: P) -> X402Client<P> {
88 X402Client {
89 selector,
90 schemes: self.schemes,
91 policies: self.policies,
92 hooks: self.hooks,
93 }
94 }
95
96 #[must_use]
102 pub fn with_policy<P: PaymentPolicy + 'static>(mut self, policy: P) -> Self {
103 self.policies.push(Arc::new(policy));
104 self
105 }
106
107 #[must_use]
113 pub fn with_hook(mut self, hook: impl ClientHooks + 'static) -> Self {
114 let mut hooks = (*self.hooks).to_vec();
115 hooks.push(Arc::new(hook));
116 self.hooks = Arc::from(hooks);
117 self
118 }
119}
120
121impl<TSelector> X402Client<TSelector>
122where
123 TSelector: PaymentSelector,
124{
125 #[cfg_attr(
149 feature = "telemetry",
150 instrument(name = "x402.reqwest.make_payment_headers", skip_all, err)
151 )]
152 pub async fn make_payment_headers(&self, res: Response) -> Result<HeaderMap, ClientError> {
153 let payment_required = parse_payment_required(res)
154 .await
155 .ok_or_else(|| ClientError::ParseError("Invalid 402 response".to_owned()))?;
156
157 let hook_ctx = PaymentCreationContext {
158 payment_required: payment_required.clone(),
159 };
160
161 for hook in self.hooks.iter() {
163 if let HookDecision::Abort { reason, .. } =
164 hook.before_payment_creation(&hook_ctx).await
165 {
166 return Err(ClientError::ParseError(reason));
167 }
168 }
169
170 let creation_result = self.create_payment_headers_inner(&payment_required).await;
171
172 match creation_result {
173 Ok(headers) => {
174 for hook in self.hooks.iter() {
176 hook.after_payment_creation(&hook_ctx, &headers).await;
177 }
178 Ok(headers)
179 }
180 Err(err) => {
181 let err_msg = err.to_string();
183 for hook in self.hooks.iter() {
184 if let FailureRecovery::Recovered(headers) =
185 hook.on_payment_creation_failure(&hook_ctx, &err_msg).await
186 {
187 return Ok(headers);
188 }
189 }
190 Err(err)
191 }
192 }
193 }
194
195 async fn create_payment_headers_inner(
197 &self,
198 payment_required: &proto::PaymentRequired,
199 ) -> Result<HeaderMap, ClientError> {
200 let candidates = self.schemes.candidates(payment_required);
201
202 let mut filtered: Vec<&PaymentCandidate> = candidates.iter().collect();
204 for policy in &self.policies {
205 filtered = policy.apply(filtered);
206 if filtered.is_empty() {
207 return Err(ClientError::NoMatchingPaymentOption);
208 }
209 }
210
211 let selected = self
213 .selector
214 .select(&filtered)
215 .ok_or(ClientError::NoMatchingPaymentOption)?;
216
217 #[cfg(feature = "telemetry")]
218 debug!(
219 scheme = %selected.scheme,
220 chain_id = %selected.chain_id,
221 "Selected payment scheme"
222 );
223
224 let signed_payload = selected.sign().await?;
225 let headers = {
226 let mut headers = HeaderMap::new();
227 #[allow(
228 clippy::expect_used,
229 reason = "base64-encoded payload is always valid ASCII header"
230 )]
231 headers.insert(
232 "Payment-Signature",
233 signed_payload
234 .parse()
235 .expect("signed payload is valid header value"),
236 );
237 headers
238 };
239
240 Ok(headers)
241 }
242}
243
244#[derive(Default)]
246#[allow(
247 missing_debug_implementations,
248 reason = "dyn trait objects do not impl Debug"
249)]
250pub(super) struct ClientSchemes(Vec<Arc<dyn SchemeClient>>);
251
252impl ClientSchemes {
253 pub(super) fn push<T: SchemeClient + 'static>(&mut self, client: T) {
255 self.0.push(Arc::new(client));
256 }
257
258 #[must_use]
260 pub(super) fn candidates(
261 &self,
262 payment_required: &proto::PaymentRequired,
263 ) -> Vec<PaymentCandidate> {
264 let mut candidates = vec![];
265 for client in &self.0 {
266 let accepted = client.accept(payment_required);
267 candidates.extend(accepted);
268 }
269 candidates
270 }
271}
272
273#[cfg_attr(
275 feature = "telemetry",
276 instrument(name = "x402.reqwest.next", skip_all)
277)]
278async fn run_next(
279 next: rqm::Next<'_>,
280 req: Request,
281 extensions: &mut Extensions,
282) -> rqm::Result<Response> {
283 next.run(req, extensions).await
284}
285
286#[async_trait::async_trait]
287impl<TSelector> rqm::Middleware for X402Client<TSelector>
288where
289 TSelector: PaymentSelector + Send + Sync + 'static,
290{
291 #[cfg_attr(
302 feature = "telemetry",
303 instrument(name = "x402.reqwest.handle", skip_all, err)
304 )]
305 async fn handle(
306 &self,
307 req: Request,
308 extensions: &mut Extensions,
309 next: rqm::Next<'_>,
310 ) -> rqm::Result<Response> {
311 let retry_req = req.try_clone();
312 let res = run_next(next.clone(), req, extensions).await?;
313
314 if res.status() != StatusCode::PAYMENT_REQUIRED {
315 #[cfg(feature = "telemetry")]
316 trace!(status = ?res.status(), "No payment required, returning response");
317 return Ok(res);
318 }
319
320 #[cfg(feature = "telemetry")]
321 info!(url = ?res.url(), "Received 402 Payment Required, processing payment");
322
323 let Some(mut retry) = retry_req else {
326 #[cfg(feature = "telemetry")]
327 tracing::warn!("Cannot auto-retry 402: request body not cloneable, returning raw 402");
328 return Ok(res);
329 };
330
331 let headers = self
332 .make_payment_headers(res)
333 .await
334 .map_err(|e| rqm::Error::Middleware(e.into()))?;
335
336 retry.headers_mut().extend(headers);
337
338 #[cfg(feature = "telemetry")]
339 trace!(url = ?retry.url(), "Retrying request with payment headers");
340
341 run_next(next, retry, extensions).await
342 }
343}
344
345#[cfg_attr(
352 feature = "telemetry",
353 instrument(name = "x402.reqwest.parse_payment_required", skip(response))
354)]
355pub async fn parse_payment_required(response: Response) -> Option<proto::PaymentRequired> {
356 let v2_from_header = response
357 .headers()
358 .get("Payment-Required")
359 .and_then(|h| Base64Bytes::from(h.as_bytes()).decode().ok())
360 .and_then(|b| serde_json::from_slice::<v2::PaymentRequired>(&b).ok());
361
362 if let Some(v2_payment_required) = v2_from_header {
363 #[cfg(feature = "telemetry")]
364 debug!("Parsed V2 payment required from header");
365 return Some(v2_payment_required);
366 }
367
368 if let Ok(body_bytes) = response.bytes().await
370 && let Ok(v2_from_body) = serde_json::from_slice::<v2::PaymentRequired>(&body_bytes)
371 {
372 #[cfg(feature = "telemetry")]
373 debug!("Parsed V2 payment required from response body");
374 return Some(v2_from_body);
375 }
376
377 #[cfg(feature = "telemetry")]
378 debug!("Could not parse payment required from response");
379
380 None
381}