1use std::future::Future;
48use std::pin::Pin;
49use std::sync::Arc;
50use std::task::{Context, Poll};
51
52use axum_core::body::Body;
53use http::{Request, Response, StatusCode};
54use r402::config::ResourceConfig;
55use r402::proto::{PaymentPayload, PaymentRequirements, ResourceInfo};
56use r402::server::X402ResourceServer;
57use tower::{Layer, Service};
58
59use crate::constants::{PAYMENT_REQUIRED_HEADER, PAYMENT_SIGNATURE_HEADER};
60use crate::headers::{decode_payment_payload, encode_payment_required, encode_payment_response};
61use crate::types::RouteConfig;
62
63#[derive(Clone, Debug)]
68pub struct PaymentGate {
69 server: Arc<X402ResourceServer>,
70}
71
72impl PaymentGate {
73 #[must_use]
75 pub fn new(server: Arc<X402ResourceServer>) -> Self {
76 Self { server }
77 }
78
79 #[must_use]
84 pub fn route(&self, config: RouteConfig) -> PaymentRouteLayer {
85 PaymentRouteLayer {
86 shared: Arc::new(PaymentRouteShared {
87 server: Arc::clone(&self.server),
88 config,
89 }),
90 }
91 }
92}
93
94struct PaymentRouteShared {
96 server: Arc<X402ResourceServer>,
97 config: RouteConfig,
98}
99
100impl std::fmt::Debug for PaymentRouteShared {
101 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
102 f.debug_struct("PaymentRouteShared")
103 .field("server", &self.server)
104 .field("accepts_count", &self.config.accepts.len())
105 .finish_non_exhaustive()
106 }
107}
108
109#[derive(Clone, Debug)]
114pub struct PaymentRouteLayer {
115 shared: Arc<PaymentRouteShared>,
116}
117
118impl PaymentRouteLayer {
119 #[must_use]
121 pub fn with_description(self, desc: impl Into<String>) -> Self {
122 let shared = (*self.shared).clone_with_description(Some(desc.into()));
123 Self {
124 shared: Arc::new(shared),
125 }
126 }
127
128 #[must_use]
130 pub fn with_mime_type(self, mime: impl Into<String>) -> Self {
131 let shared = (*self.shared).clone_with_mime_type(Some(mime.into()));
132 Self {
133 shared: Arc::new(shared),
134 }
135 }
136
137 #[must_use]
139 pub fn with_resource(self, url: impl Into<String>) -> Self {
140 let shared = (*self.shared).clone_with_resource(Some(url.into()));
141 Self {
142 shared: Arc::new(shared),
143 }
144 }
145}
146
147impl PaymentRouteShared {
148 fn clone_with_description(&self, desc: Option<String>) -> Self {
149 let mut config = self.config.clone();
150 config.description = desc;
151 Self {
152 server: Arc::clone(&self.server),
153 config,
154 }
155 }
156
157 fn clone_with_mime_type(&self, mime: Option<String>) -> Self {
158 let mut config = self.config.clone();
159 config.mime_type = mime;
160 Self {
161 server: Arc::clone(&self.server),
162 config,
163 }
164 }
165
166 fn clone_with_resource(&self, url: Option<String>) -> Self {
167 let mut config = self.config.clone();
168 config.resource = url;
169 Self {
170 server: Arc::clone(&self.server),
171 config,
172 }
173 }
174}
175
176impl<S> Layer<S> for PaymentRouteLayer {
177 type Service = PaymentRouteService<S>;
178
179 fn layer(&self, inner: S) -> Self::Service {
180 PaymentRouteService {
181 inner,
182 shared: Arc::clone(&self.shared),
183 }
184 }
185}
186
187#[derive(Clone)]
191pub struct PaymentRouteService<S> {
192 inner: S,
193 shared: Arc<PaymentRouteShared>,
194}
195
196impl<S> std::fmt::Debug for PaymentRouteService<S> {
197 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
198 f.debug_struct("PaymentRouteService")
199 .field("shared", &self.shared)
200 .finish_non_exhaustive()
201 }
202}
203
204impl<S> Service<Request<Body>> for PaymentRouteService<S>
205where
206 S: Service<Request<Body>, Response = Response<Body>> + Clone + Send + 'static,
207 S::Future: Send + 'static,
208 S::Error: Into<Box<dyn std::error::Error + Send + Sync>> + Send,
209{
210 type Response = Response<Body>;
211 type Error = Box<dyn std::error::Error + Send + Sync>;
212 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
213
214 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
215 self.inner.poll_ready(cx).map_err(Into::into)
216 }
217
218 fn call(&mut self, req: Request<Body>) -> Self::Future {
219 let shared = Arc::clone(&self.shared);
220 let mut inner = self.inner.clone();
221
222 Box::pin(async move {
223 let path = req.uri().path().to_owned();
224
225 let payment_payload = extract_payment_payload(&req);
227
228 let requirements = match build_requirements(&shared.server, &shared.config) {
230 Ok(reqs) => reqs,
231 Err(e) => {
232 return Ok(error_response(
233 StatusCode::INTERNAL_SERVER_ERROR,
234 &format!("Failed to build payment requirements: {e}"),
235 ));
236 }
237 };
238
239 let resource_info = ResourceInfo {
241 url: shared
242 .config
243 .resource
244 .clone()
245 .unwrap_or_else(|| path.clone()),
246 description: shared.config.description.clone(),
247 mime_type: shared.config.mime_type.clone(),
248 };
249
250 let payload = match payment_payload {
252 Some(p) => p,
253 None => {
254 let payment_required = shared.server.create_payment_required(
255 requirements,
256 Some(resource_info),
257 Some("Payment required".to_owned()),
258 None,
259 );
260 return Ok(payment_required_response(&payment_required));
261 }
262 };
263
264 let matching_reqs = match shared
266 .server
267 .find_matching_requirements(&requirements, &payload)
268 {
269 Some(reqs) => reqs.clone(),
270 None => {
271 let payment_required = shared.server.create_payment_required(
272 requirements,
273 Some(resource_info),
274 Some("No matching payment requirements".to_owned()),
275 None,
276 );
277 return Ok(payment_required_response(&payment_required));
278 }
279 };
280
281 let verify_result = shared.server.verify_payment(&payload, &matching_reqs).await;
283
284 match verify_result {
285 Ok(ref vr) if vr.is_valid => {
286 let mut response = inner.call(req).await.map_err(Into::into)?;
287 settle_and_add_headers(&shared.server, &payload, &matching_reqs, &mut response)
288 .await;
289 Ok(response)
290 }
291 Ok(vr) => {
292 let payment_required = shared.server.create_payment_required(
293 requirements,
294 Some(resource_info),
295 vr.invalid_reason.clone(),
296 None,
297 );
298 Ok(payment_required_response(&payment_required))
299 }
300 Err(e) => {
301 let payment_required = shared.server.create_payment_required(
302 requirements,
303 Some(resource_info),
304 Some(e.to_string()),
305 None,
306 );
307 Ok(payment_required_response(&payment_required))
308 }
309 }
310 })
311 }
312}
313
314fn extract_payment_payload(req: &Request<Body>) -> Option<PaymentPayload> {
316 let header_value = req.headers().get(PAYMENT_SIGNATURE_HEADER).or_else(|| {
317 req.headers()
318 .get(PAYMENT_SIGNATURE_HEADER.to_lowercase().as_str())
319 })?;
320 let value_str = header_value.to_str().ok()?;
321 let parsed = decode_payment_payload(value_str).ok()?;
322 match parsed {
323 r402::proto::helpers::PaymentPayloadEnum::V2(p) => Some(*p),
324 r402::proto::helpers::PaymentPayloadEnum::V1(_) => None,
325 }
326}
327
328fn build_requirements(
330 server: &X402ResourceServer,
331 route_config: &RouteConfig,
332) -> Result<Vec<PaymentRequirements>, r402::scheme::SchemeError> {
333 let mut all = Vec::new();
334 for option in &route_config.accepts {
335 let config = ResourceConfig {
336 scheme: option.scheme.clone(),
337 pay_to: option.pay_to.clone(),
338 price: option.price.clone(),
339 network: option.network.clone(),
340 max_timeout_seconds: option.max_timeout_seconds,
341 };
342 let reqs = server.build_payment_requirements(&config)?;
343 all.extend(reqs);
344 }
345 Ok(all)
346}
347
348fn payment_required_response(payment_required: &r402::proto::PaymentRequired) -> Response<Body> {
350 let encoded = encode_payment_required(payment_required).unwrap_or_default();
351 let body_json = serde_json::to_string(payment_required).unwrap_or_default();
352
353 Response::builder()
354 .status(StatusCode::PAYMENT_REQUIRED)
355 .header(PAYMENT_REQUIRED_HEADER, &encoded)
356 .header(http::header::CONTENT_TYPE, "application/json")
357 .header(
358 http::header::ACCESS_CONTROL_EXPOSE_HEADERS,
359 PAYMENT_REQUIRED_HEADER,
360 )
361 .body(Body::from(body_json))
362 .expect("valid 402 response")
363}
364
365fn error_response(status: StatusCode, message: &str) -> Response<Body> {
367 let body = serde_json::json!({ "error": message });
368 Response::builder()
369 .status(status)
370 .header(http::header::CONTENT_TYPE, "application/json")
371 .body(Body::from(body.to_string()))
372 .expect("valid error response")
373}
374
375async fn settle_and_add_headers(
377 server: &X402ResourceServer,
378 payload: &PaymentPayload,
379 requirements: &PaymentRequirements,
380 response: &mut Response<Body>,
381) {
382 match server.settle_payment(payload, requirements).await {
383 Ok(settle_response) if settle_response.success => {
384 if let Ok(encoded) = encode_payment_response(&settle_response) {
385 response.headers_mut().insert(
386 http::header::HeaderName::from_static("payment-response"),
387 http::header::HeaderValue::from_str(&encoded)
388 .unwrap_or_else(|_| http::header::HeaderValue::from_static("")),
389 );
390 response.headers_mut().insert(
391 http::header::HeaderName::from_static("access-control-expose-headers"),
392 http::header::HeaderValue::from_static("PAYMENT-RESPONSE"),
393 );
394 }
395 }
396 Ok(_) | Err(_) => {}
397 }
398}