mod harness;
use std::collections::HashMap;
use std::sync::Arc;
use agent_pay::{
did_key_from_public_key, fetch_with_l402, generate_key_pair, issue_token,
sign_invoice_envelope, FetchOptions, FetchResponse, InvoiceCreateRequest, LightningNode,
MemoryLedger, MemoryNode, Paywall, PaywallOptions, SignInvoiceOpts,
};
use chrono::{Duration, Utc};
use harness::{make_fetch, ok_handler, run_client, SECRET};
fn setup(price_msat: u64, ttl: u64) -> (Arc<Paywall>, Arc<MemoryLedger>, String, [u8; 32]) {
let kp = generate_key_pair();
let did = did_key_from_public_key(&kp.public_key).unwrap();
let ledger = MemoryLedger::new();
let lightning: Arc<dyn LightningNode> = Arc::new(MemoryNode::new(ledger.clone(), "server"));
let mut opts = PaywallOptions::new(
did.clone(),
kp.private_key,
price_msat,
"/report",
lightning,
SECRET.to_vec(),
);
opts.invoice_ttl_seconds = ttl;
(Arc::new(Paywall::new(opts)), ledger, did, kp.private_key)
}
#[tokio::test]
async fn fetch_with_l402_pays_via_fake_node_retries_and_parses_200() {
let (paywall, ledger, _did, _) = setup(1000, 300);
let client: Arc<dyn LightningNode> = Arc::new(MemoryNode::new(ledger.clone(), "client"));
let fetch = make_fetch(paywall, harness::echo_handler());
let res = run_client(client, 5000, fetch, None).await.unwrap();
assert_eq!(res.status, 200);
assert!(res.header("x-payment-receipt").is_some());
assert_eq!(res.json.unwrap(), serde_json::json!({"data": "hello"}));
}
#[tokio::test]
async fn overcharging_bolt11_must_equal_envelope_price() {
let kp = generate_key_pair();
let did = did_key_from_public_key(&kp.public_key).unwrap();
let ledger = MemoryLedger::new();
let node: Arc<dyn LightningNode> = Arc::new(MemoryNode::new(ledger.clone(), "server"));
let wallet: Arc<dyn LightningNode> = Arc::new(MemoryNode::new(ledger.clone(), "client"));
let did_for_fetch = did.clone();
let private_key = kp.private_key;
let lying: agent_pay::FetchFn = Arc::new(move |_url, _headers| {
let node = node.clone();
let did = did_for_fetch.clone();
Box::pin(async move {
let inv = node
.create_invoice(InvoiceCreateRequest {
amount_msat: 9999,
memo: None,
expiry_seconds: None,
})
.await
.unwrap();
let envelope = sign_invoice_envelope(SignInvoiceOpts {
bolt11: &inv.bolt11,
did: &did,
private_key: &private_key,
price_msat: 1000,
resource: "/lying",
expires_at: "2030-01-01T00:00:00Z",
nonce: &[0u8; 16],
})
.await
.unwrap();
let tok = issue_token(&inv.payment_hash, "2030-01-01T00:00:00Z", SECRET)
.await
.unwrap();
let mut headers = HashMap::new();
headers.insert(
"www-authenticate".into(),
format!("L402 macaroon=\"{tok}\", invoice=\"{}\"", inv.bolt11),
);
headers.insert("x-did-invoice".into(), envelope);
Ok(FetchResponse {
status: 402,
headers,
body: None,
json: None,
})
})
});
let opts = FetchOptions::new(wallet, 50_000, lying);
let err = fetch_with_l402("http://x/lying", opts).await.unwrap_err();
let msg = format!("{err}").to_lowercase();
assert!(
msg.contains("amount") || msg.contains("mismatch"),
"msg={msg}"
);
}
#[tokio::test]
async fn throws_when_receipt_jws_is_tampered() {
let (paywall, ledger, _did, _) = setup(1000, 300);
let wallet: Arc<dyn LightningNode> = Arc::new(MemoryNode::new(ledger.clone(), "client"));
let inner = make_fetch(paywall, ok_handler());
let tamperer: agent_pay::FetchFn = Arc::new(move |url, headers| {
let inner = inner.clone();
Box::pin(async move {
let res = (inner)(url, headers).await?;
let receipt = match res.header("x-payment-receipt") {
Some(r) => r.to_string(),
None => return Ok(res),
};
let mut parts: Vec<String> = receipt.split('.').map(String::from).collect();
if parts.len() != 3 {
return Ok(res);
}
let sig = parts[2].clone();
let first = sig.chars().next().unwrap_or('A');
let new_first = if first == 'A' { 'B' } else { 'A' };
parts[2] = format!("{}{}", new_first, &sig[1..]);
let mut new_headers = res.headers.clone();
new_headers.insert("x-payment-receipt".into(), parts.join("."));
Ok(FetchResponse {
status: res.status,
headers: new_headers,
body: res.body,
json: res.json,
})
})
});
let opts = FetchOptions::new(wallet, 5000, tamperer);
let err = fetch_with_l402("http://x/tamper", opts).await.unwrap_err();
let msg = format!("{err}").to_lowercase();
assert!(msg.contains("receipt"), "msg={msg}");
}
#[tokio::test]
async fn client_enforces_max_price_msat_cap() {
let (paywall, ledger, _did, _) = setup(10_000, 300);
let wallet: Arc<dyn LightningNode> = Arc::new(MemoryNode::new(ledger.clone(), "client"));
let fetch = make_fetch(paywall, ok_handler());
let opts = FetchOptions::new(wallet, 5000, fetch);
let err = fetch_with_l402("http://x/r", opts).await.unwrap_err();
let msg = format!("{err}").to_lowercase();
assert!(msg.contains("cap") || msg.contains("exceeds"), "msg={msg}");
}
#[tokio::test]
async fn client_rejects_envelope_past_expires_at() {
let (paywall, ledger, _did, _) = setup(1000, 1);
let wallet: Arc<dyn LightningNode> = Arc::new(MemoryNode::new(ledger.clone(), "client"));
let fetch = make_fetch(paywall, ok_handler());
let mut opts = FetchOptions::new(wallet, 5000, fetch);
opts.now = Box::new(|| Utc::now() + Duration::seconds(10));
let err = fetch_with_l402("http://x/r", opts).await.unwrap_err();
let msg = format!("{err}").to_lowercase();
assert!(msg.contains("expired"), "msg={msg}");
}