1use std::collections::HashMap;
11use std::future::Future;
12use std::pin::Pin;
13use std::sync::Arc;
14use std::task::{Context, Poll};
15
16use axum_core::body::Body;
17use http::{Request, Response, StatusCode};
18use r402::config::ResourceConfig;
19use r402::proto::{PaymentPayload, PaymentRequirements, ResourceInfo};
20use r402::server::X402ResourceServer;
21use tower::{Layer, Service};
22
23use crate::constants::{PAYMENT_REQUIRED_HEADER, PAYMENT_SIGNATURE_HEADER};
24use crate::headers::{decode_payment_payload, encode_payment_required, encode_payment_response};
25use crate::types::{
26 CompiledRoute, PaywallConfig, RouteConfig, RouteValidationError, parse_route_pattern,
27};
28
29pub type RoutesConfig = HashMap<String, RouteConfig>;
33
34#[derive(Clone)]
62pub struct PaymentGateLayer {
63 shared: Arc<PaymentGateShared>,
64}
65
66struct PaymentGateShared {
68 server: Arc<X402ResourceServer>,
69 compiled_routes: Vec<CompiledRoute>,
70 #[allow(dead_code)]
71 paywall_config: Option<PaywallConfig>,
72}
73
74impl std::fmt::Debug for PaymentGateShared {
75 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
76 f.debug_struct("PaymentGateShared")
77 .field("server", &self.server)
78 .field("routes_count", &self.compiled_routes.len())
79 .finish_non_exhaustive()
80 }
81}
82
83impl std::fmt::Debug for PaymentGateLayer {
84 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
85 f.debug_struct("PaymentGateLayer")
86 .field("shared", &self.shared)
87 .finish()
88 }
89}
90
91impl PaymentGateLayer {
92 #[must_use]
94 pub fn new(server: Arc<X402ResourceServer>, routes: RoutesConfig) -> Self {
95 let compiled_routes = routes
96 .into_iter()
97 .map(|(pattern, config)| {
98 let (method, path) = parse_route_pattern(&pattern);
99 CompiledRoute {
100 method,
101 path_pattern: path,
102 config,
103 }
104 })
105 .collect();
106
107 Self {
108 shared: Arc::new(PaymentGateShared {
109 server,
110 compiled_routes,
111 paywall_config: None,
112 }),
113 }
114 }
115
116 #[must_use]
118 pub fn with_paywall(
119 server: Arc<X402ResourceServer>,
120 routes: RoutesConfig,
121 paywall_config: PaywallConfig,
122 ) -> Self {
123 let compiled_routes = routes
124 .into_iter()
125 .map(|(pattern, config)| {
126 let (method, path) = parse_route_pattern(&pattern);
127 CompiledRoute {
128 method,
129 path_pattern: path,
130 config,
131 }
132 })
133 .collect();
134
135 Self {
136 shared: Arc::new(PaymentGateShared {
137 server,
138 compiled_routes,
139 paywall_config: Some(paywall_config),
140 }),
141 }
142 }
143}
144
145impl PaymentGateLayer {
146 #[must_use]
157 pub fn validate_routes(&self) -> Vec<RouteValidationError> {
158 let server = &self.shared.server;
159 let mut errors = Vec::new();
160
161 for route in &self.shared.compiled_routes {
162 let pattern = format!("{} {}", route.method, route.path_pattern);
163
164 for option in &route.config.accepts {
165 if !server.has_registered_scheme(&option.network, &option.scheme) {
166 errors.push(RouteValidationError {
167 route_pattern: pattern.clone(),
168 scheme: option.scheme.clone(),
169 network: option.network.clone(),
170 reason: "missing_scheme".to_owned(),
171 message: format!(
172 "Route \"{pattern}\": No scheme for \"{}\" on \"{}\"",
173 option.scheme, option.network,
174 ),
175 });
176 continue;
177 }
178
179 if server
180 .get_supported_kind(2, &option.network, &option.scheme)
181 .is_none()
182 {
183 errors.push(RouteValidationError {
184 route_pattern: pattern.clone(),
185 scheme: option.scheme.clone(),
186 network: option.network.clone(),
187 reason: "missing_facilitator".to_owned(),
188 message: format!(
189 "Route \"{pattern}\": Facilitator doesn't support \"{}\" on \"{}\"",
190 option.scheme, option.network,
191 ),
192 });
193 }
194 }
195 }
196
197 errors
198 }
199}
200
201impl<S> Layer<S> for PaymentGateLayer {
202 type Service = PaymentGateService<S>;
203
204 fn layer(&self, inner: S) -> Self::Service {
205 PaymentGateService {
206 inner,
207 shared: Arc::clone(&self.shared),
208 }
209 }
210}
211
212#[derive(Clone)]
216pub struct PaymentGateService<S> {
217 inner: S,
218 shared: Arc<PaymentGateShared>,
219}
220
221impl<S> std::fmt::Debug for PaymentGateService<S> {
222 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
223 f.debug_struct("PaymentGateService")
224 .field("shared", &self.shared)
225 .finish_non_exhaustive()
226 }
227}
228
229impl<S> Service<Request<Body>> for PaymentGateService<S>
230where
231 S: Service<Request<Body>, Response = Response<Body>> + Clone + Send + 'static,
232 S::Future: Send + 'static,
233 S::Error: Into<Box<dyn std::error::Error + Send + Sync>> + Send,
234{
235 type Response = Response<Body>;
236 type Error = Box<dyn std::error::Error + Send + Sync>;
237 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
238
239 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
240 self.inner.poll_ready(cx).map_err(Into::into)
241 }
242
243 fn call(&mut self, req: Request<Body>) -> Self::Future {
244 let shared = Arc::clone(&self.shared);
245 let mut inner = self.inner.clone();
246
247 Box::pin(async move {
248 let method = req.method().as_str().to_uppercase();
249 let path = req.uri().path().to_owned();
250
251 let route = shared
253 .compiled_routes
254 .iter()
255 .find(|r| r.matches(&method, &path));
256
257 let route_config = match route {
258 Some(r) => &r.config,
259 None => {
260 return inner.call(req).await.map_err(Into::into);
262 }
263 };
264
265 let payment_payload = extract_payment_payload(&req);
267
268 let requirements = match build_requirements(&shared.server, route_config, &path) {
270 Ok(reqs) => reqs,
271 Err(e) => {
272 return Ok(error_response(
273 StatusCode::INTERNAL_SERVER_ERROR,
274 &format!("Failed to build payment requirements: {e}"),
275 ));
276 }
277 };
278
279 let resource_info = ResourceInfo {
281 url: route_config
282 .resource
283 .clone()
284 .unwrap_or_else(|| path.clone()),
285 description: route_config.description.clone(),
286 mime_type: route_config.mime_type.clone(),
287 };
288
289 let payload = match payment_payload {
291 Some(p) => p,
292 None => {
293 let payment_required = shared.server.create_payment_required(
294 requirements,
295 Some(resource_info),
296 Some("Payment required".to_owned()),
297 None,
298 );
299
300 return Ok(payment_required_response(&payment_required));
301 }
302 };
303
304 let matching_reqs = match shared
306 .server
307 .find_matching_requirements(&requirements, &payload)
308 {
309 Some(reqs) => reqs.clone(),
310 None => {
311 let payment_required = shared.server.create_payment_required(
312 requirements,
313 Some(resource_info),
314 Some("No matching payment requirements".to_owned()),
315 None,
316 );
317
318 return Ok(payment_required_response(&payment_required));
319 }
320 };
321
322 let verify_result = shared.server.verify_payment(&payload, &matching_reqs).await;
324
325 match verify_result {
326 Ok(ref vr) if vr.is_valid => {
327 let mut response = inner.call(req).await.map_err(Into::into)?;
329
330 settle_and_add_headers(&shared.server, &payload, &matching_reqs, &mut response)
332 .await;
333
334 Ok(response)
335 }
336 Ok(vr) => {
337 let payment_required = shared.server.create_payment_required(
339 requirements,
340 Some(resource_info),
341 vr.invalid_reason.clone(),
342 None,
343 );
344
345 Ok(payment_required_response(&payment_required))
346 }
347 Err(e) => {
348 let payment_required = shared.server.create_payment_required(
349 requirements,
350 Some(resource_info),
351 Some(e.to_string()),
352 None,
353 );
354
355 Ok(payment_required_response(&payment_required))
356 }
357 }
358 })
359 }
360}
361
362fn extract_payment_payload(req: &Request<Body>) -> Option<PaymentPayload> {
364 let header_value = req.headers().get(PAYMENT_SIGNATURE_HEADER).or_else(|| {
365 req.headers()
366 .get(PAYMENT_SIGNATURE_HEADER.to_lowercase().as_str())
367 })?;
368
369 let value_str = header_value.to_str().ok()?;
370 let parsed = decode_payment_payload(value_str).ok()?;
371
372 match parsed {
373 r402::proto::helpers::PaymentPayloadEnum::V2(p) => Some(*p),
374 r402::proto::helpers::PaymentPayloadEnum::V1(_) => None,
375 }
376}
377
378fn build_requirements(
380 server: &X402ResourceServer,
381 route_config: &RouteConfig,
382 _path: &str,
383) -> Result<Vec<PaymentRequirements>, r402::scheme::SchemeError> {
384 let mut all_requirements = Vec::new();
385
386 for option in &route_config.accepts {
387 let config = ResourceConfig {
388 scheme: option.scheme.clone(),
389 pay_to: option.pay_to.clone(),
390 price: option.price.clone(),
391 network: option.network.clone(),
392 max_timeout_seconds: option.max_timeout_seconds,
393 };
394
395 let reqs = server.build_payment_requirements(&config)?;
396 all_requirements.extend(reqs);
397 }
398
399 Ok(all_requirements)
400}
401
402fn payment_required_response(payment_required: &r402::proto::PaymentRequired) -> Response<Body> {
404 let encoded = encode_payment_required(payment_required).unwrap_or_default();
405
406 let body_json = serde_json::to_string(payment_required).unwrap_or_default();
407
408 Response::builder()
409 .status(StatusCode::PAYMENT_REQUIRED)
410 .header(PAYMENT_REQUIRED_HEADER, &encoded)
411 .header(http::header::CONTENT_TYPE, "application/json")
412 .header(
413 http::header::ACCESS_CONTROL_EXPOSE_HEADERS,
414 PAYMENT_REQUIRED_HEADER,
415 )
416 .body(Body::from(body_json))
417 .expect("valid 402 response")
418}
419
420fn error_response(status: StatusCode, message: &str) -> Response<Body> {
422 let body = serde_json::json!({ "error": message });
423
424 Response::builder()
425 .status(status)
426 .header(http::header::CONTENT_TYPE, "application/json")
427 .body(Body::from(body.to_string()))
428 .expect("valid error response")
429}
430
431async fn settle_and_add_headers(
433 server: &X402ResourceServer,
434 payload: &PaymentPayload,
435 requirements: &PaymentRequirements,
436 response: &mut Response<Body>,
437) {
438 match server.settle_payment(payload, requirements).await {
439 Ok(settle_response) if settle_response.success => {
440 if let Ok(encoded) = encode_payment_response(&settle_response) {
441 response.headers_mut().insert(
442 http::header::HeaderName::from_static("payment-response"),
443 http::header::HeaderValue::from_str(&encoded)
444 .unwrap_or_else(|_| http::header::HeaderValue::from_static("")),
445 );
446 response.headers_mut().insert(
447 http::header::HeaderName::from_static("access-control-expose-headers"),
448 http::header::HeaderValue::from_static("PAYMENT-RESPONSE"),
449 );
450 }
451 }
452 Ok(_) | Err(_) => {
453 }
456 }
457}