Skip to main content

agent_pay/
client.rs

1//! L402-aware HTTP client. Mirrors fetchWithL402 from the TS reference.
2
3use std::collections::HashMap;
4use std::future::Future;
5use std::pin::Pin;
6use std::sync::Arc;
7
8use chrono::{DateTime, Utc};
9use once_cell::sync::Lazy;
10use regex::Regex;
11
12use crate::bolt11::parse_invoice;
13use crate::envelope::{verify_invoice_envelope, verify_receipt};
14use crate::error::Error;
15use crate::jws::ResolveKey;
16use crate::keys::public_key_from_did_key;
17use crate::lightning::LightningNode;
18
19static CHALLENGE_RE: Lazy<Regex> =
20    Lazy::new(|| Regex::new(r#"macaroon="([^"]+)",\s*invoice="([^"]+)""#).unwrap());
21
22/// Minimal response carrier (mirrors PaywallResponse intentionally to allow
23/// trivial adapter between server and client for tests).
24#[derive(Debug, Clone, Default)]
25pub struct FetchResponse {
26    pub status: u16,
27    pub headers: HashMap<String, String>,
28    pub body: Option<Vec<u8>>,
29    pub json: Option<serde_json::Value>,
30}
31
32impl FetchResponse {
33    pub fn header(&self, name: &str) -> Option<&str> {
34        let lower = name.to_ascii_lowercase();
35        self.headers
36            .iter()
37            .find(|(k, _)| k.to_ascii_lowercase() == lower)
38            .map(|(_, v)| v.as_str())
39    }
40}
41
42pub type FetchFn = Arc<
43    dyn (Fn(
44            String,
45            HashMap<String, String>,
46        ) -> Pin<Box<dyn Future<Output = Result<FetchResponse, Error>> + Send>>)
47        + Send
48        + Sync,
49>;
50
51pub struct FetchOptions {
52    pub wallet: Arc<dyn LightningNode>,
53    pub max_price_msat: u64,
54    pub fetch: FetchFn,
55    pub expected_did: Option<String>,
56    pub verify_receipt_flag: bool,
57    pub now: Box<dyn Fn() -> DateTime<Utc> + Send + Sync>,
58    pub request_headers: HashMap<String, String>,
59    pub method: String,
60}
61
62impl FetchOptions {
63    pub fn new(wallet: Arc<dyn LightningNode>, max_price_msat: u64, fetch: FetchFn) -> Self {
64        Self {
65            wallet,
66            max_price_msat,
67            fetch,
68            expected_did: None,
69            verify_receipt_flag: true,
70            now: Box::new(Utc::now),
71            request_headers: HashMap::new(),
72            method: "GET".into(),
73        }
74    }
75}
76
77pub async fn fetch_with_l402(url: &str, opts: FetchOptions) -> Result<FetchResponse, Error> {
78    let first = (opts.fetch)(url.to_string(), opts.request_headers.clone()).await?;
79    if first.status != 402 {
80        return Ok(first);
81    }
82    let www_auth = first.header("www-authenticate").unwrap_or("").to_string();
83    let captures = CHALLENGE_RE
84        .captures(&www_auth)
85        .ok_or_else(|| Error::fetch("no L402 challenge", "missing-challenge"))?;
86    let token = captures.get(1).unwrap().as_str().to_string();
87    let bolt11 = captures.get(2).unwrap().as_str().to_string();
88    let envelope_jws = first
89        .header("x-did-invoice")
90        .ok_or_else(|| Error::fetch("missing X-Did-Invoice", "missing-x-did-invoice"))?
91        .to_string();
92
93    let resolver = make_did_key_resolver(opts.expected_did.clone());
94    let env = verify_invoice_envelope(&envelope_jws, &bolt11, &resolver)
95        .await
96        .map_err(|e| {
97            Error::fetch(
98                format!("X-Did-Invoice verification failed: {e}"),
99                "jws-invalid",
100            )
101        })?;
102
103    let price: u64 = env
104        .price_msat
105        .parse()
106        .map_err(|e| Error::fetch(format!("price_msat: {e}"), "jws-invalid"))?;
107    if price > opts.max_price_msat {
108        return Err(Error::fetch(
109            format!("price {price} exceeds cap {}", opts.max_price_msat),
110            "price-cap",
111        ));
112    }
113    let expires_ms = parse_iso_ms(&env.expires_at)?;
114    let now_ms = (opts.now)().timestamp_millis() as u64;
115    if expires_ms <= now_ms {
116        return Err(Error::fetch(
117            format!("invoice expired ({})", env.expires_at),
118            "expired",
119        ));
120    }
121    let parsed = parse_invoice(&bolt11)
122        .map_err(|e| Error::fetch(format!("bolt11 parse: {e}"), "jws-invalid"))?;
123    if parsed.amount_msat != price {
124        return Err(Error::fetch(
125            format!(
126                "BOLT11 amount {} mismatches envelope price {}",
127                parsed.amount_msat, price
128            ),
129            "amount-mismatch",
130        ));
131    }
132
133    let pay = opts.wallet.pay_invoice(&bolt11).await?;
134    let preimage_hex = hex::encode(&pay.preimage);
135    let mut second_headers = opts.request_headers.clone();
136    second_headers.insert(
137        "authorization".into(),
138        format!("L402 {token}:{preimage_hex}"),
139    );
140    let second = (opts.fetch)(url.to_string(), second_headers).await?;
141    if second.status != 200 {
142        return Ok(second);
143    }
144    if opts.verify_receipt_flag {
145        if let Some(receipt) = second.header("x-payment-receipt") {
146            verify_receipt(receipt, &bolt11, &resolver)
147                .await
148                .map_err(|e| {
149                    Error::fetch(
150                        format!("receipt verification failed: {e}"),
151                        "receipt-invalid",
152                    )
153                })?;
154        }
155    }
156    Ok(second)
157}
158
159fn make_did_key_resolver(pinned: Option<String>) -> ResolveKey {
160    Arc::new(move |kid: String| {
161        let pinned = pinned.clone();
162        Box::pin(async move {
163            let did = kid.split('#').next().unwrap_or(&kid).to_string();
164            if let Some(p) = pinned.as_ref() {
165                if &did != p {
166                    return Err(Error::fetch(format!("unexpected DID {did}"), "jws-invalid"));
167                }
168            }
169            public_key_from_did_key(&did)
170        })
171    })
172}
173
174fn parse_iso_ms(s: &str) -> Result<u64, Error> {
175    let dt = DateTime::parse_from_rfc3339(s)
176        .map_err(|e| Error::fetch(format!("iso: {e}"), "jws-invalid"))?;
177    Ok(dt.with_timezone(&Utc).timestamp_millis() as u64)
178}