rust_x402/
middleware.rs

1//! Middleware implementations for web frameworks
2
3use crate::types::{Network, *};
4use crate::{Result, X402Error};
5use axum::{
6    extract::{Request, State},
7    http::{HeaderValue, StatusCode},
8    middleware::Next,
9    response::{IntoResponse, Response},
10    Json,
11};
12use rust_decimal::Decimal;
13use std::sync::Arc;
14use tower::ServiceBuilder;
15use tower_http::trace::TraceLayer;
16
17/// Configuration for payment middleware
18#[derive(Debug, Clone)]
19pub struct PaymentMiddlewareConfig {
20    /// Payment amount in decimal units (e.g., 0.0001 for 1/10th of a cent)
21    pub amount: Decimal,
22    /// Recipient wallet address
23    pub pay_to: String,
24    /// Payment description
25    pub description: Option<String>,
26    /// MIME type of the expected response
27    pub mime_type: Option<String>,
28    /// Maximum timeout in seconds
29    pub max_timeout_seconds: u32,
30    /// JSON schema for response format
31    pub output_schema: Option<serde_json::Value>,
32    /// Facilitator configuration
33    pub facilitator_config: FacilitatorConfig,
34    /// Whether this is a testnet
35    pub testnet: bool,
36    /// Custom paywall HTML for web browsers
37    pub custom_paywall_html: Option<String>,
38    /// Resource URL (if different from request URL)
39    pub resource: Option<String>,
40    /// Resource root URL for constructing full resource URLs
41    pub resource_root_url: Option<String>,
42}
43
44impl PaymentMiddlewareConfig {
45    /// Create a new payment middleware config
46    pub fn new(amount: Decimal, pay_to: impl Into<String>) -> Self {
47        Self {
48            amount,
49            pay_to: pay_to.into(),
50            description: None,
51            mime_type: None,
52            max_timeout_seconds: 60,
53            output_schema: None,
54            facilitator_config: FacilitatorConfig::default(),
55            testnet: true,
56            custom_paywall_html: None,
57            resource: None,
58            resource_root_url: None,
59        }
60    }
61
62    /// Set the payment description
63    pub fn with_description(mut self, description: impl Into<String>) -> Self {
64        self.description = Some(description.into());
65        self
66    }
67
68    /// Set the MIME type
69    pub fn with_mime_type(mut self, mime_type: impl Into<String>) -> Self {
70        self.mime_type = Some(mime_type.into());
71        self
72    }
73
74    /// Set the maximum timeout
75    pub fn with_max_timeout_seconds(mut self, max_timeout_seconds: u32) -> Self {
76        self.max_timeout_seconds = max_timeout_seconds;
77        self
78    }
79
80    /// Set the output schema
81    pub fn with_output_schema(mut self, output_schema: serde_json::Value) -> Self {
82        self.output_schema = Some(output_schema);
83        self
84    }
85
86    /// Set the facilitator configuration
87    pub fn with_facilitator_config(mut self, facilitator_config: FacilitatorConfig) -> Self {
88        self.facilitator_config = facilitator_config;
89        self
90    }
91
92    /// Set whether this is a testnet
93    pub fn with_testnet(mut self, testnet: bool) -> Self {
94        self.testnet = testnet;
95        self
96    }
97
98    /// Set custom paywall HTML
99    pub fn with_custom_paywall_html(mut self, html: impl Into<String>) -> Self {
100        self.custom_paywall_html = Some(html.into());
101        self
102    }
103
104    /// Set the resource URL
105    pub fn with_resource(mut self, resource: impl Into<String>) -> Self {
106        self.resource = Some(resource.into());
107        self
108    }
109
110    /// Set the resource root URL
111    pub fn with_resource_root_url(mut self, url: impl Into<String>) -> Self {
112        self.resource_root_url = Some(url.into());
113        self
114    }
115
116    /// Create payment requirements from this config
117    pub fn create_payment_requirements(&self, request_uri: &str) -> Result<PaymentRequirements> {
118        let network = if self.testnet {
119            networks::BASE_SEPOLIA
120        } else {
121            networks::BASE_MAINNET
122        };
123
124        let usdc_address =
125            networks::get_usdc_address(network).ok_or_else(|| X402Error::NetworkNotSupported {
126                network: network.to_string(),
127            })?;
128
129        let resource = if let Some(ref resource_url) = self.resource {
130            resource_url.clone()
131        } else if let Some(ref root_url) = self.resource_root_url {
132            format!("{}{}", root_url, request_uri)
133        } else {
134            request_uri.to_string()
135        };
136
137        let max_amount_required = (self.amount * Decimal::from(1_000_000u64))
138            .normalize()
139            .to_string();
140
141        let mut requirements = PaymentRequirements::new(
142            schemes::EXACT,
143            network,
144            max_amount_required,
145            usdc_address,
146            &self.pay_to,
147            resource,
148            self.description.as_deref().unwrap_or("Payment required"),
149        );
150
151        requirements.mime_type = self.mime_type.clone();
152        requirements.output_schema = self.output_schema.clone();
153        requirements.max_timeout_seconds = self.max_timeout_seconds;
154
155        let network = if self.testnet {
156            Network::Testnet
157        } else {
158            Network::Mainnet
159        };
160        requirements.set_usdc_info(network)?;
161
162        Ok(requirements)
163    }
164}
165
166/// Axum middleware for x402 payments
167#[derive(Debug, Clone)]
168pub struct PaymentMiddleware {
169    pub config: Arc<PaymentMiddlewareConfig>,
170    pub facilitator: Option<crate::facilitator::FacilitatorClient>,
171    pub template_config: Option<crate::template::PaywallConfig>,
172}
173
174/// Payment processing result
175#[derive(Debug)]
176pub enum PaymentResult {
177    /// Payment verified and settled successfully
178    Success {
179        response: axum::response::Response,
180        settlement: crate::types::SettleResponse,
181    },
182    /// Payment required (402 response)
183    PaymentRequired { response: axum::response::Response },
184    /// Payment verification failed
185    VerificationFailed { response: axum::response::Response },
186    /// Payment settlement failed
187    SettlementFailed { response: axum::response::Response },
188}
189
190impl PaymentMiddleware {
191    /// Create a new payment middleware
192    pub fn new(amount: Decimal, pay_to: impl Into<String>) -> Self {
193        Self {
194            config: Arc::new(PaymentMiddlewareConfig::new(amount, pay_to)),
195            facilitator: None,
196            template_config: None,
197        }
198    }
199
200    /// Set the payment description
201    pub fn with_description(mut self, description: impl Into<String>) -> Self {
202        Arc::make_mut(&mut self.config).description = Some(description.into());
203        self
204    }
205
206    /// Set the MIME type
207    pub fn with_mime_type(mut self, mime_type: impl Into<String>) -> Self {
208        Arc::make_mut(&mut self.config).mime_type = Some(mime_type.into());
209        self
210    }
211
212    /// Set the maximum timeout
213    pub fn with_max_timeout_seconds(mut self, max_timeout_seconds: u32) -> Self {
214        Arc::make_mut(&mut self.config).max_timeout_seconds = max_timeout_seconds;
215        self
216    }
217
218    /// Set the output schema
219    pub fn with_output_schema(mut self, output_schema: serde_json::Value) -> Self {
220        Arc::make_mut(&mut self.config).output_schema = Some(output_schema);
221        self
222    }
223
224    /// Set the facilitator configuration
225    pub fn with_facilitator_config(mut self, facilitator_config: FacilitatorConfig) -> Self {
226        Arc::make_mut(&mut self.config).facilitator_config = facilitator_config;
227        self
228    }
229
230    /// Set whether this is a testnet
231    pub fn with_testnet(mut self, testnet: bool) -> Self {
232        Arc::make_mut(&mut self.config).testnet = testnet;
233        self
234    }
235
236    /// Set custom paywall HTML
237    pub fn with_custom_paywall_html(mut self, html: impl Into<String>) -> Self {
238        Arc::make_mut(&mut self.config).custom_paywall_html = Some(html.into());
239        self
240    }
241
242    /// Set the resource URL
243    pub fn with_resource(mut self, resource: impl Into<String>) -> Self {
244        Arc::make_mut(&mut self.config).resource = Some(resource.into());
245        self
246    }
247
248    /// Set the resource root URL
249    pub fn with_resource_root_url(mut self, url: impl Into<String>) -> Self {
250        Arc::make_mut(&mut self.config).resource_root_url = Some(url.into());
251        self
252    }
253
254    /// Get the middleware configuration
255    pub fn config(&self) -> &PaymentMiddlewareConfig {
256        &self.config
257    }
258
259    /// Set the facilitator client
260    pub fn with_facilitator(mut self, facilitator: crate::facilitator::FacilitatorClient) -> Self {
261        self.facilitator = Some(facilitator);
262        self
263    }
264
265    /// Set the template configuration
266    pub fn with_template_config(mut self, template_config: crate::template::PaywallConfig) -> Self {
267        self.template_config = Some(template_config);
268        self
269    }
270
271    /// Verify a payment payload
272    pub async fn verify(&self, payment_payload: &PaymentPayload) -> bool {
273        // Create facilitator if not already configured
274        let facilitator = if let Some(facilitator) = &self.facilitator {
275            facilitator.clone()
276        } else {
277            match crate::facilitator::FacilitatorClient::new(self.config.facilitator_config.clone())
278            {
279                Ok(facilitator) => facilitator,
280                Err(_) => return false,
281            }
282        };
283
284        if let Ok(requirements) = self.config.create_payment_requirements("/") {
285            if let Ok(response) = facilitator.verify(payment_payload, &requirements).await {
286                return response.is_valid;
287            }
288        }
289        false
290    }
291
292    /// Settle a payment
293    pub async fn settle(&self, payment_payload: &PaymentPayload) -> crate::Result<SettleResponse> {
294        // Create facilitator if not already configured
295        let facilitator = if let Some(facilitator) = &self.facilitator {
296            facilitator.clone()
297        } else {
298            crate::facilitator::FacilitatorClient::new(self.config.facilitator_config.clone())?
299        };
300
301        let requirements = self.config.create_payment_requirements("/")?;
302        facilitator.settle(payment_payload, &requirements).await
303    }
304
305    /// Verify payment with specific requirements
306    pub async fn verify_with_requirements(
307        &self,
308        payment_payload: &PaymentPayload,
309        requirements: &PaymentRequirements,
310    ) -> crate::Result<bool> {
311        let facilitator = if let Some(facilitator) = &self.facilitator {
312            facilitator.clone()
313        } else {
314            crate::facilitator::FacilitatorClient::new(self.config.facilitator_config.clone())?
315        };
316
317        let response = facilitator.verify(payment_payload, requirements).await?;
318        Ok(response.is_valid)
319    }
320
321    /// Settle payment with specific requirements
322    pub async fn settle_with_requirements(
323        &self,
324        payment_payload: &PaymentPayload,
325        requirements: &PaymentRequirements,
326    ) -> crate::Result<SettleResponse> {
327        let facilitator = if let Some(facilitator) = &self.facilitator {
328            facilitator.clone()
329        } else {
330            crate::facilitator::FacilitatorClient::new(self.config.facilitator_config.clone())?
331        };
332
333        facilitator.settle(payment_payload, requirements).await
334    }
335
336    /// Process payment with unified flow
337    pub async fn process_payment(
338        &self,
339        request: Request,
340        next: Next,
341    ) -> crate::Result<PaymentResult> {
342        let headers = request.headers();
343        let uri = request.uri().to_string();
344
345        // Check if this is a web browser request
346        let user_agent = headers
347            .get("User-Agent")
348            .and_then(|v| v.to_str().ok())
349            .unwrap_or("");
350        let accept = headers
351            .get("Accept")
352            .and_then(|v| v.to_str().ok())
353            .unwrap_or("");
354
355        let is_web_browser = accept.contains("text/html") && user_agent.contains("Mozilla");
356
357        // Create payment requirements
358        let payment_requirements = self.config.create_payment_requirements(&uri)?;
359
360        // Check for payment header
361        let payment_header = headers.get("X-PAYMENT").and_then(|v| v.to_str().ok());
362
363        match payment_header {
364            Some(payment_b64) => {
365                // Decode payment payload
366                let payment_payload = PaymentPayload::from_base64(payment_b64).map_err(|e| {
367                    X402Error::invalid_payment_payload(format!("Failed to decode payment: {}", e))
368                })?;
369
370                // Get facilitator client
371                let facilitator = if let Some(facilitator) = &self.facilitator {
372                    facilitator.clone()
373                } else {
374                    crate::facilitator::FacilitatorClient::new(
375                        self.config.facilitator_config.clone(),
376                    )?
377                };
378
379                // Verify payment
380                let verify_response = facilitator
381                    .verify(&payment_payload, &payment_requirements)
382                    .await
383                    .map_err(|e| {
384                        X402Error::facilitator_error(format!("Payment verification failed: {}", e))
385                    })?;
386
387                if !verify_response.is_valid {
388                    let error_response = self.create_payment_required_response(
389                        "Payment verification failed",
390                        &payment_requirements,
391                        is_web_browser,
392                    )?;
393                    return Ok(PaymentResult::VerificationFailed {
394                        response: error_response,
395                    });
396                }
397
398                // Execute the handler
399                let mut response = next.run(request).await;
400
401                // Settle the payment
402                let settle_response = facilitator
403                    .settle(&payment_payload, &payment_requirements)
404                    .await
405                    .map_err(|e| {
406                        X402Error::facilitator_error(format!("Payment settlement failed: {}", e))
407                    })?;
408
409                // Add settlement header
410                let settlement_header = settle_response.to_base64().map_err(|e| {
411                    X402Error::config(format!("Failed to encode settlement response: {}", e))
412                })?;
413
414                if let Ok(header_value) = HeaderValue::from_str(&settlement_header) {
415                    response
416                        .headers_mut()
417                        .insert("X-PAYMENT-RESPONSE", header_value);
418                }
419
420                Ok(PaymentResult::Success {
421                    response,
422                    settlement: settle_response,
423                })
424            }
425            None => {
426                // No payment provided, return 402 with requirements
427                let response = self.create_payment_required_response(
428                    "X-PAYMENT header is required",
429                    &payment_requirements,
430                    is_web_browser,
431                )?;
432                Ok(PaymentResult::PaymentRequired { response })
433            }
434        }
435    }
436
437    /// Create payment required response
438    fn create_payment_required_response(
439        &self,
440        error: &str,
441        payment_requirements: &PaymentRequirements,
442        is_web_browser: bool,
443    ) -> crate::Result<axum::response::Response> {
444        if is_web_browser {
445            let html = if let Some(custom_html) = &self.config.custom_paywall_html {
446                custom_html.clone()
447            } else {
448                // Use the template system
449                let paywall_config = self.template_config.clone().unwrap_or_else(|| {
450                    crate::template::PaywallConfig::new()
451                        .with_app_name("x402 Service")
452                        .with_app_logo("💰")
453                });
454
455                crate::template::generate_paywall_html(
456                    error,
457                    std::slice::from_ref(payment_requirements),
458                    Some(&paywall_config),
459                )
460            };
461
462            let response = Response::builder()
463                .status(StatusCode::PAYMENT_REQUIRED)
464                .header("Content-Type", "text/html")
465                .body(html.into())
466                .map_err(|e| X402Error::config(format!("Failed to create HTML response: {}", e)))?;
467
468            Ok(response)
469        } else {
470            let payment_response =
471                PaymentRequirementsResponse::new(error, vec![payment_requirements.clone()]);
472
473            Ok(Json(payment_response).into_response())
474        }
475    }
476}
477
478/// Axum middleware function for handling x402 payments
479pub async fn payment_middleware(
480    State(middleware): State<PaymentMiddleware>,
481    request: Request,
482    next: Next,
483) -> crate::Result<impl IntoResponse> {
484    match middleware.process_payment(request, next).await? {
485        PaymentResult::Success { response, .. } => Ok(response),
486        PaymentResult::PaymentRequired { response } => Ok(response),
487        PaymentResult::VerificationFailed { response } => Ok(response),
488        PaymentResult::SettlementFailed { response } => Ok(response),
489    }
490}
491
492/// Create a service builder with x402 payment middleware
493pub fn create_payment_service(
494    middleware: PaymentMiddleware,
495) -> impl tower::Layer<tower::ServiceBuilder<tower::layer::util::Identity>> + Clone {
496    ServiceBuilder::new()
497        .layer(TraceLayer::new_for_http())
498        .layer(tower::layer::util::Stack::new(
499            tower::layer::util::Identity::new(),
500            PaymentServiceLayer::new(middleware),
501        ))
502}
503
504/// Tower service layer for x402 payment middleware
505#[derive(Clone)]
506pub struct PaymentServiceLayer {
507    middleware: PaymentMiddleware,
508}
509
510impl PaymentServiceLayer {
511    pub fn new(middleware: PaymentMiddleware) -> Self {
512        Self { middleware }
513    }
514}
515
516impl<S> tower::Layer<S> for PaymentServiceLayer {
517    type Service = PaymentService<S>;
518
519    fn layer(&self, inner: S) -> Self::Service {
520        PaymentService {
521            inner,
522            middleware: self.middleware.clone(),
523        }
524    }
525}
526
527/// Tower service for x402 payment middleware
528#[derive(Clone)]
529pub struct PaymentService<S> {
530    inner: S,
531    middleware: PaymentMiddleware,
532}
533
534impl<S, ReqBody, ResBody> tower::Service<http::Request<ReqBody>> for PaymentService<S>
535where
536    S: tower::Service<
537            http::Request<ReqBody>,
538            Response = http::Response<ResBody>,
539            Error = Box<dyn std::error::Error + Send + Sync>,
540        > + Send
541        + 'static,
542    S::Future: Send + 'static,
543    ReqBody: Send + 'static,
544    ResBody: Send + 'static,
545{
546    type Response = S::Response;
547    type Error = S::Error;
548    type Future = std::pin::Pin<
549        Box<
550            dyn std::future::Future<Output = std::result::Result<Self::Response, Self::Error>>
551                + Send,
552        >,
553    >;
554
555    fn poll_ready(
556        &mut self,
557        cx: &mut std::task::Context<'_>,
558    ) -> std::task::Poll<std::result::Result<(), Self::Error>> {
559        self.inner.poll_ready(cx)
560    }
561
562    fn call(&mut self, req: http::Request<ReqBody>) -> Self::Future {
563        let middleware = self.middleware.clone();
564
565        // Extract payment header before moving the request
566        let payment_header = req
567            .headers()
568            .get("X-PAYMENT")
569            .and_then(|h| h.to_str().ok())
570            .map(|s| s.to_string());
571        let uri_path = req.uri().path().to_string();
572
573        let future = self.inner.call(req);
574
575        Box::pin(async move {
576            match payment_header {
577                Some(payment_b64) => {
578                    // Parse payment payload
579                    match crate::types::PaymentPayload::from_base64(&payment_b64) {
580                        Ok(payment_payload) => {
581                            // Create payment requirements
582                            let requirements =
583                                match middleware.config.create_payment_requirements(&uri_path) {
584                                    Ok(req) => req,
585                                    Err(e) => {
586                                        // Return 500 error if we can't create requirements
587                                        return Err(
588                                            Box::new(e) as Box<dyn std::error::Error + Send + Sync>
589                                        );
590                                    }
591                                };
592
593                            // Verify payment
594                            match middleware
595                                .verify_with_requirements(&payment_payload, &requirements)
596                                .await
597                            {
598                                Ok(true) => {
599                                    // Payment is valid, proceed with request
600                                    let response = future.await?;
601
602                                    // Settle payment after successful response
603                                    if let Ok(settlement) = middleware
604                                        .settle_with_requirements(&payment_payload, &requirements)
605                                        .await
606                                    {
607                                        // Note: In a real implementation, we would need to modify the response
608                                        // to add the X-PAYMENT-RESPONSE header, but this requires
609                                        // more complex response handling in Tower
610                                        let _ = settlement; // Acknowledge settlement
611                                    }
612
613                                    Ok(response)
614                                }
615                                Ok(false) => {
616                                    // Payment verification failed
617                                    Err(Box::new(crate::X402Error::payment_verification_failed(
618                                        "Payment verification failed",
619                                    ))
620                                        as Box<dyn std::error::Error + Send + Sync>)
621                                }
622                                Err(e) => {
623                                    // Error during verification
624                                    Err(Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
625                                }
626                            }
627                        }
628                        Err(e) => {
629                            // Invalid payment payload
630                            Err(Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
631                        }
632                    }
633                }
634                None => {
635                    // No payment header provided
636                    Err(Box::new(crate::X402Error::payment_verification_failed(
637                        "X-PAYMENT header is required",
638                    ))
639                        as Box<dyn std::error::Error + Send + Sync>)
640                }
641            }
642        })
643    }
644}
645
646#[cfg(test)]
647mod tests {
648    use super::*;
649    use std::str::FromStr;
650
651    #[test]
652    fn test_payment_middleware_config() {
653        let config = PaymentMiddlewareConfig::new(
654            Decimal::from_str("0.0001").unwrap(),
655            "0x209693Bc6afc0C5328bA36FaF03C514EF312287C",
656        )
657        .with_description("Test payment")
658        .with_testnet(true);
659
660        assert_eq!(config.amount, Decimal::from_str("0.0001").unwrap());
661        assert_eq!(config.pay_to, "0x209693Bc6afc0C5328bA36FaF03C514EF312287C");
662        assert_eq!(config.description, Some("Test payment".to_string()));
663        assert!(config.testnet);
664    }
665
666    #[test]
667    fn test_payment_middleware_creation() {
668        let middleware = PaymentMiddleware::new(
669            Decimal::from_str("0.0001").unwrap(),
670            "0x209693Bc6afc0C5328bA36FaF03C514EF312287C",
671        )
672        .with_description("Test payment");
673
674        assert_eq!(
675            middleware.config().amount,
676            Decimal::from_str("0.0001").unwrap()
677        );
678        assert_eq!(
679            middleware.config().pay_to,
680            "0x209693Bc6afc0C5328bA36FaF03C514EF312287C"
681        );
682    }
683
684    #[test]
685    fn test_payment_requirements_creation() {
686        let config = PaymentMiddlewareConfig::new(
687            Decimal::from_str("0.0001").unwrap(),
688            "0x209693Bc6afc0C5328bA36FaF03C514EF312287C",
689        )
690        .with_testnet(true);
691
692        let requirements = config.create_payment_requirements("/test").unwrap();
693
694        assert_eq!(requirements.scheme, "exact");
695        assert_eq!(requirements.network, "base-sepolia");
696        assert_eq!(requirements.max_amount_required, "100");
697        assert_eq!(
698            requirements.pay_to,
699            "0x209693Bc6afc0C5328bA36FaF03C514EF312287C"
700        );
701    }
702
703    #[test]
704    fn test_payment_middleware_config_builder() {
705        let config = PaymentMiddlewareConfig::new(
706            Decimal::from_str("0.01").unwrap(),
707            "0x209693Bc6afc0C5328bA36FaF03C514EF312287C",
708        )
709        .with_description("Test payment")
710        .with_mime_type("application/json")
711        .with_max_timeout_seconds(120)
712        .with_testnet(false)
713        .with_resource("https://example.com/test");
714
715        assert_eq!(config.amount, Decimal::from_str("0.01").unwrap());
716        assert_eq!(config.pay_to, "0x209693Bc6afc0C5328bA36FaF03C514EF312287C");
717        assert_eq!(config.description, Some("Test payment".to_string()));
718        assert_eq!(config.mime_type, Some("application/json".to_string()));
719        assert_eq!(config.max_timeout_seconds, 120);
720        assert!(!config.testnet);
721        assert_eq!(
722            config.resource,
723            Some("https://example.com/test".to_string())
724        );
725    }
726
727    #[test]
728    fn test_payment_middleware_creation_with_description() {
729        let middleware = PaymentMiddleware::new(
730            Decimal::from_str("0.001").unwrap(),
731            "0x209693Bc6afc0C5328bA36FaF03C514EF312287C",
732        )
733        .with_description("Test middleware");
734
735        assert_eq!(
736            middleware.config().amount,
737            Decimal::from_str("0.001").unwrap()
738        );
739        assert_eq!(
740            middleware.config().description,
741            Some("Test middleware".to_string())
742        );
743    }
744}