r402_http/client/
middleware.rs1use std::sync::Arc;
7
8use http::{Extensions, HeaderMap, StatusCode};
9use r402::hooks::{FailureRecovery, HookDecision};
10use r402::proto;
11use r402::proto::Base64Bytes;
12use r402::proto::v2;
13use r402::scheme::{
14 ClientError, FirstMatch, PaymentCandidate, PaymentPolicy, PaymentSelector, SchemeClient,
15};
16use reqwest::{Request, Response};
17use reqwest_middleware as rqm;
18#[cfg(feature = "telemetry")]
19use tracing::{debug, info, instrument, trace};
20
21use super::hooks::{ClientHooks, PaymentCreationContext};
22
23#[allow(missing_debug_implementations)] pub struct X402Client<TSelector> {
30 schemes: ClientSchemes,
31 selector: TSelector,
32 policies: Vec<Arc<dyn PaymentPolicy>>,
33 hooks: Arc<[Arc<dyn ClientHooks>]>,
34}
35
36impl X402Client<FirstMatch> {
37 #[must_use]
42 pub fn new() -> Self {
43 Self::default()
44 }
45}
46
47impl Default for X402Client<FirstMatch> {
48 fn default() -> Self {
49 Self {
50 schemes: ClientSchemes::default(),
51 selector: FirstMatch,
52 policies: Vec::new(),
53 hooks: Arc::from([]),
54 }
55 }
56}
57
58impl<TSelector> X402Client<TSelector> {
59 #[must_use]
72 pub fn register<S>(mut self, scheme: S) -> Self
73 where
74 S: SchemeClient + 'static,
75 {
76 self.schemes.push(scheme);
77 self
78 }
79
80 pub fn with_selector<P: PaymentSelector + 'static>(self, selector: P) -> X402Client<P> {
85 X402Client {
86 selector,
87 schemes: self.schemes,
88 policies: self.policies,
89 hooks: self.hooks,
90 }
91 }
92
93 #[must_use]
99 pub fn with_policy<P: PaymentPolicy + 'static>(mut self, policy: P) -> Self {
100 self.policies.push(Arc::new(policy));
101 self
102 }
103
104 #[must_use]
110 pub fn with_hook(mut self, hook: impl ClientHooks + 'static) -> Self {
111 let mut hooks = (*self.hooks).to_vec();
112 hooks.push(Arc::new(hook));
113 self.hooks = Arc::from(hooks);
114 self
115 }
116}
117
118impl<TSelector> X402Client<TSelector>
119where
120 TSelector: PaymentSelector,
121{
122 #[cfg_attr(
146 feature = "telemetry",
147 instrument(name = "x402.reqwest.make_payment_headers", skip_all, err)
148 )]
149 pub async fn make_payment_headers(&self, res: Response) -> Result<HeaderMap, ClientError> {
150 let payment_required = parse_payment_required(res)
151 .await
152 .ok_or_else(|| ClientError::ParseError("Invalid 402 response".to_string()))?;
153
154 let hook_ctx = PaymentCreationContext {
155 payment_required: payment_required.clone(),
156 };
157
158 for hook in self.hooks.iter() {
160 if let HookDecision::Abort { reason, .. } =
161 hook.before_payment_creation(&hook_ctx).await
162 {
163 return Err(ClientError::ParseError(reason));
164 }
165 }
166
167 let creation_result = self.create_payment_headers_inner(&payment_required).await;
168
169 match creation_result {
170 Ok(headers) => {
171 for hook in self.hooks.iter() {
173 hook.after_payment_creation(&hook_ctx, &headers).await;
174 }
175 Ok(headers)
176 }
177 Err(err) => {
178 let err_msg = err.to_string();
180 for hook in self.hooks.iter() {
181 if let FailureRecovery::Recovered(headers) =
182 hook.on_payment_creation_failure(&hook_ctx, &err_msg).await
183 {
184 return Ok(headers);
185 }
186 }
187 Err(err)
188 }
189 }
190 }
191
192 async fn create_payment_headers_inner(
194 &self,
195 payment_required: &proto::PaymentRequired,
196 ) -> Result<HeaderMap, ClientError> {
197 let candidates = self.schemes.candidates(payment_required);
198
199 let mut filtered: Vec<&PaymentCandidate> = candidates.iter().collect();
201 for policy in &self.policies {
202 filtered = policy.apply(filtered);
203 if filtered.is_empty() {
204 return Err(ClientError::NoMatchingPaymentOption);
205 }
206 }
207
208 let selected = self
210 .selector
211 .select(&filtered)
212 .ok_or(ClientError::NoMatchingPaymentOption)?;
213
214 #[cfg(feature = "telemetry")]
215 debug!(
216 scheme = %selected.scheme,
217 chain_id = %selected.chain_id,
218 "Selected payment scheme"
219 );
220
221 let signed_payload = selected.sign().await?;
222 let headers = {
223 let mut headers = HeaderMap::new();
224 headers.insert(
225 "Payment-Signature",
226 signed_payload
227 .parse()
228 .expect("signed payload is valid header value"),
229 );
230 headers
231 };
232
233 Ok(headers)
234 }
235}
236
237#[derive(Default)]
239#[allow(missing_debug_implementations)] pub struct ClientSchemes(Vec<Arc<dyn SchemeClient>>);
241
242impl ClientSchemes {
243 pub fn push<T: SchemeClient + 'static>(&mut self, client: T) {
245 self.0.push(Arc::new(client));
246 }
247
248 #[must_use]
250 pub fn candidates(&self, payment_required: &proto::PaymentRequired) -> Vec<PaymentCandidate> {
251 let mut candidates = vec![];
252 for client in &self.0 {
253 let accepted = client.accept(payment_required);
254 candidates.extend(accepted);
255 }
256 candidates
257 }
258}
259
260#[cfg_attr(
262 feature = "telemetry",
263 instrument(name = "x402.reqwest.next", skip_all)
264)]
265async fn run_next(
266 next: rqm::Next<'_>,
267 req: Request,
268 extensions: &mut Extensions,
269) -> rqm::Result<Response> {
270 next.run(req, extensions).await
271}
272
273#[async_trait::async_trait]
274impl<TSelector> rqm::Middleware for X402Client<TSelector>
275where
276 TSelector: PaymentSelector + Send + Sync + 'static,
277{
278 #[cfg_attr(
289 feature = "telemetry",
290 instrument(name = "x402.reqwest.handle", skip_all, err)
291 )]
292 async fn handle(
293 &self,
294 req: Request,
295 extensions: &mut Extensions,
296 next: rqm::Next<'_>,
297 ) -> rqm::Result<Response> {
298 let retry_req = req.try_clone();
299 let res = run_next(next.clone(), req, extensions).await?;
300
301 if res.status() != StatusCode::PAYMENT_REQUIRED {
302 #[cfg(feature = "telemetry")]
303 trace!(status = ?res.status(), "No payment required, returning response");
304 return Ok(res);
305 }
306
307 #[cfg(feature = "telemetry")]
308 info!(url = ?res.url(), "Received 402 Payment Required, processing payment");
309
310 let Some(mut retry) = retry_req else {
313 #[cfg(feature = "telemetry")]
314 tracing::warn!("Cannot auto-retry 402: request body not cloneable, returning raw 402");
315 return Ok(res);
316 };
317
318 let headers = self
319 .make_payment_headers(res)
320 .await
321 .map_err(|e| rqm::Error::Middleware(e.into()))?;
322
323 retry.headers_mut().extend(headers);
324
325 #[cfg(feature = "telemetry")]
326 trace!(url = ?retry.url(), "Retrying request with payment headers");
327
328 run_next(next, retry, extensions).await
329 }
330}
331
332#[cfg_attr(
339 feature = "telemetry",
340 instrument(name = "x402.reqwest.parse_payment_required", skip(response))
341)]
342pub async fn parse_payment_required(response: Response) -> Option<proto::PaymentRequired> {
343 let v2_from_header = response
344 .headers()
345 .get("Payment-Required")
346 .and_then(|h| Base64Bytes::from(h.as_bytes()).decode().ok())
347 .and_then(|b| serde_json::from_slice::<v2::PaymentRequired>(&b).ok());
348
349 if let Some(v2_payment_required) = v2_from_header {
350 #[cfg(feature = "telemetry")]
351 debug!("Parsed V2 payment required from header");
352 return Some(v2_payment_required);
353 }
354
355 if let Ok(body_bytes) = response.bytes().await
357 && let Ok(v2_from_body) = serde_json::from_slice::<v2::PaymentRequired>(&body_bytes)
358 {
359 #[cfg(feature = "telemetry")]
360 debug!("Parsed V2 payment required from response body");
361 return Some(v2_from_body);
362 }
363
364 #[cfg(feature = "telemetry")]
365 debug!("Could not parse payment required from response");
366
367 None
368}