Skip to main content

agent_pay/
server.rs

1//! Paywall server: emits 402 challenges, validates L402 Authorization.
2
3use std::collections::HashMap;
4use std::future::Future;
5use std::pin::Pin;
6use std::sync::Arc;
7use std::sync::Mutex;
8
9use chrono::{DateTime, SecondsFormat, Utc};
10use once_cell::sync::Lazy;
11use rand::RngCore;
12use regex::Regex;
13
14use crate::envelope::{sign_invoice_envelope, sign_receipt, SignInvoiceOpts, SignReceiptOpts};
15use crate::error::Error;
16use crate::lightning::{InvoiceCreateRequest, LightningNode};
17use crate::replay::ReplayCache;
18use crate::token::{issue_token, verify_token};
19
20static AUTH_RE: Lazy<Regex> =
21    Lazy::new(|| Regex::new(r"^L402\s+([^:\s]+):([0-9a-fA-F]+)$").unwrap());
22
23/// Framework-agnostic response carried back through the middleware.
24#[derive(Debug, Clone, Default)]
25pub struct PaywallResponse {
26    pub status: u16,
27    pub headers: HashMap<String, String>,
28    pub body: Option<Vec<u8>>,
29    pub json: Option<serde_json::Value>,
30}
31
32impl PaywallResponse {
33    pub fn header(&self, name: &str) -> Option<&str> {
34        let lower = name.to_ascii_lowercase();
35        self.headers
36            .iter()
37            .find(|(k, _)| k.to_ascii_lowercase() == lower)
38            .map(|(_, v)| v.as_str())
39    }
40}
41
42/// Async handler signature: receives (path, headers) and returns a response.
43pub type InnerHandler = Arc<
44    dyn (Fn(
45            String,
46            HashMap<String, String>,
47        ) -> Pin<Box<dyn Future<Output = Result<PaywallResponse, Error>> + Send>>)
48        + Send
49        + Sync,
50>;
51
52pub struct PaywallOptions {
53    pub server_did: String,
54    pub server_private_key: [u8; 32],
55    pub price_msat: u64,
56    pub resource: String,
57    pub lightning: Arc<dyn LightningNode>,
58    pub token_secret: Vec<u8>,
59    pub invoice_ttl_seconds: u64,
60    pub now: Box<dyn Fn() -> DateTime<Utc> + Send + Sync>,
61    pub replay: Option<Arc<ReplayCache>>,
62}
63
64impl PaywallOptions {
65    pub fn new(
66        server_did: impl Into<String>,
67        server_private_key: [u8; 32],
68        price_msat: u64,
69        resource: impl Into<String>,
70        lightning: Arc<dyn LightningNode>,
71        token_secret: Vec<u8>,
72    ) -> Self {
73        Self {
74            server_did: server_did.into(),
75            server_private_key,
76            price_msat,
77            resource: resource.into(),
78            lightning,
79            token_secret,
80            invoice_ttl_seconds: 300,
81            now: Box::new(Utc::now),
82            replay: None,
83        }
84    }
85}
86
87pub struct Paywall {
88    opts: PaywallOptions,
89    replay: Arc<ReplayCache>,
90    issued: Mutex<HashMap<String, String>>, // payment_hash -> bolt11
91}
92
93impl Paywall {
94    pub fn new(opts: PaywallOptions) -> Self {
95        let replay = opts
96            .replay
97            .clone()
98            .unwrap_or_else(|| Arc::new(ReplayCache::default()));
99        Self {
100            opts,
101            replay,
102            issued: Mutex::new(HashMap::new()),
103        }
104    }
105
106    pub async fn process_request(
107        &self,
108        path: &str,
109        headers: HashMap<String, String>,
110        inner: Option<InnerHandler>,
111    ) -> Result<PaywallResponse, Error> {
112        let auth = headers
113            .iter()
114            .find(|(k, _)| k.eq_ignore_ascii_case("authorization"))
115            .map(|(_, v)| v.clone());
116        let Some(auth) = auth else {
117            return self.challenge().await;
118        };
119        let Some(captures) = AUTH_RE.captures(&auth) else {
120            return self.challenge().await;
121        };
122        let token = captures.get(1).unwrap().as_str();
123        let preimage_hex = captures.get(2).unwrap().as_str();
124        let payload = match verify_token(token, &self.opts.token_secret).await {
125            Ok(p) => p,
126            Err(_) => return self.challenge().await,
127        };
128        if self.replay.is_used(&payload.payment_hash) {
129            return Ok(PaywallResponse {
130                status: 401,
131                json: Some(serde_json::json!({ "error": "preimage replayed" })),
132                ..Default::default()
133            });
134        }
135        let lookup = self
136            .opts
137            .lightning
138            .lookup_invoice(&payload.payment_hash)
139            .await?;
140        if !lookup.settled || lookup.preimage.is_none() {
141            return Ok(PaywallResponse {
142                status: 401,
143                json: Some(serde_json::json!({ "error": "invoice not settled" })),
144                ..Default::default()
145            });
146        }
147        let presented =
148            hex::decode(preimage_hex).map_err(|e| Error::Paywall(format!("preimage hex: {e}")))?;
149        let stored = lookup.preimage.unwrap();
150        if !constant_time_eq(&presented, &stored) {
151            return Ok(PaywallResponse {
152                status: 401,
153                json: Some(
154                    serde_json::json!({ "error": "preimage does not match settled invoice" }),
155                ),
156                ..Default::default()
157            });
158        }
159        let expires_ms = parse_iso_ms(&payload.expires_at)?;
160        self.replay.mark_used(&payload.payment_hash, expires_ms);
161
162        let mut inner_resp = if let Some(inner) = inner {
163            (inner)(path.to_string(), headers).await?
164        } else {
165            PaywallResponse {
166                status: 200,
167                ..Default::default()
168            }
169        };
170
171        let bolt11 = {
172            let guard = self.issued.lock().unwrap();
173            guard.get(&payload.payment_hash).cloned()
174        };
175        if let Some(bolt11) = bolt11 {
176            let paid_at = iso((self.opts.now)());
177            let receipt = sign_receipt(SignReceiptOpts {
178                bolt11: &bolt11,
179                did: &self.opts.server_did,
180                private_key: &self.opts.server_private_key,
181                preimage: &presented,
182                resource: &self.opts.resource,
183                paid_at: &paid_at,
184            })
185            .await?;
186            inner_resp
187                .headers
188                .insert("x-payment-receipt".into(), receipt);
189        }
190        Ok(inner_resp)
191    }
192
193    async fn challenge(&self) -> Result<PaywallResponse, Error> {
194        let ttl = self.opts.invoice_ttl_seconds;
195        let invoice = self
196            .opts
197            .lightning
198            .create_invoice(InvoiceCreateRequest {
199                amount_msat: self.opts.price_msat,
200                memo: None,
201                expiry_seconds: Some(ttl),
202            })
203            .await?;
204        self.issued
205            .lock()
206            .unwrap()
207            .insert(invoice.payment_hash.clone(), invoice.bolt11.clone());
208        let now = (self.opts.now)();
209        let expires_at_dt = now + chrono::Duration::seconds(ttl as i64);
210        let expires_at = iso(expires_at_dt);
211        let mut nonce = [0u8; 16];
212        rand::thread_rng().fill_bytes(&mut nonce);
213        let envelope = sign_invoice_envelope(SignInvoiceOpts {
214            bolt11: &invoice.bolt11,
215            did: &self.opts.server_did,
216            private_key: &self.opts.server_private_key,
217            price_msat: self.opts.price_msat,
218            resource: &self.opts.resource,
219            expires_at: &expires_at,
220            nonce: &nonce,
221        })
222        .await?;
223        let token =
224            issue_token(&invoice.payment_hash, &expires_at, &self.opts.token_secret).await?;
225        let mut headers = HashMap::new();
226        headers.insert(
227            "www-authenticate".into(),
228            format!("L402 macaroon=\"{token}\", invoice=\"{}\"", invoice.bolt11),
229        );
230        headers.insert("x-did-invoice".into(), envelope);
231        Ok(PaywallResponse {
232            status: 402,
233            headers,
234            body: None,
235            json: None,
236        })
237    }
238}
239
240fn iso(dt: DateTime<Utc>) -> String {
241    dt.to_rfc3339_opts(SecondsFormat::Millis, true)
242}
243
244fn parse_iso_ms(s: &str) -> Result<u64, Error> {
245    let dt = DateTime::parse_from_rfc3339(s).map_err(|e| Error::Paywall(format!("iso: {e}")))?;
246    Ok(dt.with_timezone(&Utc).timestamp_millis() as u64)
247}
248
249fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
250    if a.len() != b.len() {
251        return false;
252    }
253    let mut d = 0u8;
254    for i in 0..a.len() {
255        d |= a[i] ^ b[i];
256    }
257    d == 0
258}