1use std::future::Future;
11use std::sync::Arc;
12
13use r402::client::X402Client;
14use r402::proto::{PaymentPayload, PaymentPayloadV1, PaymentRequired, PaymentRequiredV1};
15use reqwest::{Request, Response};
16use reqwest_middleware::{Middleware, Next};
17
18use crate::constants::{PAYMENT_REQUIRED_HEADER, PAYMENT_SIGNATURE_HEADER, X_PAYMENT_HEADER};
19use crate::error::HttpError;
20use crate::headers::{decode_payment_required, encode_payment_signature, encode_x_payment};
21
22#[derive(Debug, Clone)]
47pub struct X402HttpClient {
48 client: Arc<X402Client>,
49}
50
51impl X402HttpClient {
52 #[must_use]
54 pub fn new(client: Arc<X402Client>) -> Self {
55 Self { client }
56 }
57
58 async fn extract_payment_required(response: &Response) -> Option<PaymentRequiredVersion> {
62 if let Some(header_value) = response.headers().get(PAYMENT_REQUIRED_HEADER) {
64 if let Ok(s) = header_value.to_str() {
65 if let Ok(parsed) = decode_payment_required(s) {
66 return match parsed {
67 r402::proto::helpers::PaymentRequiredEnum::V2(pr) => {
68 Some(PaymentRequiredVersion::V2(*pr))
69 }
70 r402::proto::helpers::PaymentRequiredEnum::V1(pr) => {
71 Some(PaymentRequiredVersion::V1(*pr))
72 }
73 };
74 }
75 }
76 }
77
78 None
79 }
80
81 fn encode_payment_header(
83 payload: &PaymentPayloadVersion,
84 ) -> Result<(String, String), HttpError> {
85 match payload {
86 PaymentPayloadVersion::V2(p) => {
87 let encoded = encode_payment_signature(p)?;
88 Ok((PAYMENT_SIGNATURE_HEADER.to_owned(), encoded))
89 }
90 PaymentPayloadVersion::V1(p) => {
91 let encoded = encode_x_payment(p)?;
92 Ok((X_PAYMENT_HEADER.to_owned(), encoded))
93 }
94 }
95 }
96}
97
98enum PaymentRequiredVersion {
100 V2(PaymentRequired),
101 V1(PaymentRequiredV1),
102}
103
104enum PaymentPayloadVersion {
106 V2(PaymentPayload),
107 V1(PaymentPayloadV1),
108}
109
110impl Middleware for X402HttpClient {
111 fn handle<'life0, 'life1, 'life2, 'async_trait>(
112 &'life0 self,
113 req: Request,
114 extensions: &'life1 mut http::Extensions,
115 next: Next<'life2>,
116 ) -> core::pin::Pin<
117 Box<dyn Future<Output = Result<Response, reqwest_middleware::Error>> + Send + 'async_trait>,
118 >
119 where
120 'life0: 'async_trait,
121 'life1: 'async_trait,
122 'life2: 'async_trait,
123 Self: 'async_trait,
124 {
125 Box::pin(async move {
126 let method = req.method().clone();
128 let url = req.url().clone();
129 let original_headers = req.headers().clone();
130
131 let response = next.clone().run(req, extensions).await?;
133
134 if response.status().as_u16() != 402 {
136 return Ok(response);
137 }
138
139 let payment_required = match Self::extract_payment_required(&response).await {
141 Some(pr) => pr,
142 None => return Ok(response),
143 };
144
145 let payment_payload = match &payment_required {
147 PaymentRequiredVersion::V2(pr) => {
148 match self.client.create_payment_payload(pr).await {
149 Ok(p) => PaymentPayloadVersion::V2(p),
150 Err(_) => return Ok(response),
151 }
152 }
153 PaymentRequiredVersion::V1(pr) => {
154 match self.client.create_payment_payload_v1(pr).await {
155 Ok(p) => PaymentPayloadVersion::V1(p),
156 Err(_) => return Ok(response),
157 }
158 }
159 };
160
161 let (header_name, header_value) = match Self::encode_payment_header(&payment_payload) {
163 Ok(h) => h,
164 Err(_) => return Ok(response),
165 };
166
167 let mut retry_req = Request::new(method, url);
169 *retry_req.headers_mut() = original_headers;
170 retry_req.headers_mut().insert(
171 reqwest::header::HeaderName::from_bytes(header_name.as_bytes())
172 .expect("valid header name"),
173 reqwest::header::HeaderValue::from_str(&header_value).expect("valid header value"),
174 );
175
176 next.run(retry_req, extensions).await
178 })
179 }
180}