use super::challenge::{PaymentChallenge, PaymentCredential, Receipt};
use super::types::{base64url_decode, base64url_encode, Base64UrlJson, IntentName, MethodName};
use crate::error::{MppError, Result};
use std::collections::HashMap;
const MAX_TOKEN_LEN: usize = 16 * 1024;
macro_rules! require_param {
($params:expr, $key:literal) => {
$params.get($key).ok_or_else(|| {
MppError::invalid_challenge_reason(format!("Missing '{}' field", $key))
})?
};
}
fn strip_payment_scheme(header: &str) -> Option<&str> {
let header = header.trim_start();
let scheme_len = PAYMENT_SCHEME.len();
if header.len() >= scheme_len
&& header
.get(..scheme_len)
.is_some_and(|s| s.eq_ignore_ascii_case(PAYMENT_SCHEME))
{
header.get(scheme_len..)
} else {
None
}
}
fn escape_quoted_value(s: &str) -> Result<String> {
if s.contains('\r') || s.contains('\n') {
return Err(MppError::invalid_challenge_reason(
"Header value contains invalid CRLF characters",
));
}
Ok(s.replace('\\', "\\\\").replace('"', "\\\""))
}
pub const WWW_AUTHENTICATE_HEADER: &str = "www-authenticate";
pub const AUTHORIZATION_HEADER: &str = "authorization";
pub const PAYMENT_RECEIPT_HEADER: &str = "payment-receipt";
pub const PAYMENT_SCHEME: &str = "Payment";
fn parse_auth_params(params_str: &str) -> HashMap<String, String> {
let mut params = HashMap::new();
let chars: Vec<char> = params_str.chars().collect();
let mut i = 0;
while i < chars.len() {
while i < chars.len() && (chars[i].is_whitespace() || chars[i] == ',') {
i += 1;
}
if i >= chars.len() {
break;
}
let key_start = i;
while i < chars.len() && chars[i] != '=' && !chars[i].is_whitespace() {
i += 1;
}
if i >= chars.len() || chars[i] != '=' {
while i < chars.len() && !chars[i].is_whitespace() && chars[i] != ',' {
i += 1;
}
continue;
}
let key: String = chars[key_start..i].iter().collect();
i += 1;
if i >= chars.len() {
break;
}
let value = if chars[i] == '"' {
i += 1;
let mut value = String::new();
while i < chars.len() && chars[i] != '"' {
if chars[i] == '\\' && i + 1 < chars.len() {
i += 1;
value.push(chars[i]);
} else {
value.push(chars[i]);
}
i += 1;
}
if i < chars.len() {
i += 1;
}
value
} else {
let value_start = i;
while i < chars.len() && !chars[i].is_whitespace() && chars[i] != ',' {
i += 1;
}
chars[value_start..i].iter().collect()
};
params.insert(key, value);
}
params
}
pub fn parse_www_authenticate(header: &str) -> Result<PaymentChallenge> {
let rest = strip_payment_scheme(header).ok_or_else(|| {
MppError::invalid_challenge_reason("Expected 'Payment' scheme".to_string())
})?;
let params_str = rest
.strip_prefix(' ')
.or_else(|| rest.strip_prefix('\t'))
.ok_or_else(|| {
MppError::invalid_challenge_reason("Expected space after 'Payment' scheme".to_string())
})?
.trim_start();
let params = parse_auth_params(params_str);
let id = require_param!(params, "id").clone();
let realm = require_param!(params, "realm").clone();
let method = MethodName::new(require_param!(params, "method"));
let intent = IntentName::new(require_param!(params, "intent"));
let request_b64 = require_param!(params, "request").clone();
let _ = base64url_decode(&request_b64)?;
let request = Base64UrlJson::from_raw(request_b64);
Ok(PaymentChallenge {
id,
realm,
method,
intent,
request,
expires: params.get("expires").cloned(),
description: params.get("description").cloned(),
digest: params.get("digest").cloned(),
})
}
pub fn parse_www_authenticate_all<'a>(
headers: impl IntoIterator<Item = &'a str>,
) -> Vec<Result<PaymentChallenge>> {
headers
.into_iter()
.filter(|h| {
h.trim_start()
.get(..8)
.is_some_and(|s| s.eq_ignore_ascii_case("payment "))
})
.map(parse_www_authenticate)
.collect()
}
pub fn format_www_authenticate(challenge: &PaymentChallenge) -> Result<String> {
let mut parts = vec![
format!("id=\"{}\"", escape_quoted_value(&challenge.id)?),
format!("realm=\"{}\"", escape_quoted_value(&challenge.realm)?),
format!(
"method=\"{}\"",
escape_quoted_value(challenge.method.as_str())?
),
format!(
"intent=\"{}\"",
escape_quoted_value(challenge.intent.as_str())?
),
format!(
"request=\"{}\"",
escape_quoted_value(challenge.request.raw())?
),
];
if let Some(ref expires) = challenge.expires {
parts.push(format!("expires=\"{}\"", escape_quoted_value(expires)?));
}
if let Some(ref description) = challenge.description {
parts.push(format!(
"description=\"{}\"",
escape_quoted_value(description)?
));
}
if let Some(ref digest) = challenge.digest {
parts.push(format!("digest=\"{}\"", escape_quoted_value(digest)?));
}
Ok(format!("Payment {}", parts.join(", ")))
}
pub fn format_www_authenticate_many(challenges: &[PaymentChallenge]) -> Result<Vec<String>> {
challenges.iter().map(format_www_authenticate).collect()
}
pub fn parse_authorization(header: &str) -> Result<PaymentCredential> {
let rest = strip_payment_scheme(header).ok_or_else(|| {
MppError::invalid_challenge_reason("Expected 'Payment' scheme".to_string())
})?;
let token = rest.trim();
if token.len() > MAX_TOKEN_LEN {
return Err(MppError::invalid_challenge_reason(format!(
"Token exceeds maximum length of {} bytes",
MAX_TOKEN_LEN
)));
}
let decoded = base64url_decode(token)?;
let credential: PaymentCredential = serde_json::from_slice(&decoded).map_err(|e| {
MppError::invalid_challenge_reason(format!("Invalid credential JSON: {}", e))
})?;
Ok(credential)
}
pub fn format_authorization(credential: &PaymentCredential) -> Result<String> {
let json = serde_json::to_string(credential)?;
let encoded = base64url_encode(json.as_bytes());
Ok(format!("Payment {}", encoded))
}
pub fn parse_receipt(header: &str) -> Result<Receipt> {
let token = header.trim();
if token.len() > MAX_TOKEN_LEN {
return Err(MppError::invalid_challenge_reason(format!(
"Receipt exceeds maximum length of {} bytes",
MAX_TOKEN_LEN
)));
}
let decoded = base64url_decode(token)?;
let receipt: Receipt = serde_json::from_slice(&decoded)
.map_err(|e| MppError::invalid_challenge_reason(format!("Invalid receipt JSON: {}", e)))?;
Ok(receipt)
}
pub fn format_receipt(receipt: &Receipt) -> Result<String> {
let json = serde_json::to_string(receipt)?;
Ok(base64url_encode(json.as_bytes()))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::protocol::core::types::{PayloadType, ReceiptStatus};
use crate::protocol::core::PaymentPayload;
fn test_challenge() -> PaymentChallenge {
PaymentChallenge {
id: "abc123".to_string(),
realm: "api".to_string(),
method: "tempo".into(),
intent: "charge".into(),
request: Base64UrlJson::from_value(&serde_json::json!({
"amount": "10000",
"currency": "0x123"
}))
.unwrap(),
expires: Some("2024-01-01T00:00:00Z".to_string()),
description: None,
digest: None,
}
}
#[test]
fn test_parse_www_authenticate() {
let challenge = test_challenge();
let header = format_www_authenticate(&challenge).unwrap();
let parsed = parse_www_authenticate(&header).unwrap();
assert_eq!(parsed.id, "abc123");
assert_eq!(parsed.realm, "api");
assert_eq!(parsed.method.as_str(), "tempo");
assert_eq!(parsed.intent.as_str(), "charge");
assert_eq!(parsed.expires, Some("2024-01-01T00:00:00Z".to_string()));
let request: serde_json::Value = parsed.request.decode_value().unwrap();
assert_eq!(request["amount"], "10000");
}
#[test]
fn test_parse_www_authenticate_case_insensitive() {
let header =
r#"payment id="test", realm="api", method="tempo", intent="charge", request="e30""#;
let parsed = parse_www_authenticate(header).unwrap();
assert_eq!(parsed.id, "test");
let header2 =
r#"PAYMENT id="test2", realm="api", method="tempo", intent="charge", request="e30""#;
let parsed2 = parse_www_authenticate(header2).unwrap();
assert_eq!(parsed2.id, "test2");
}
#[test]
fn test_parse_www_authenticate_leading_whitespace() {
let header =
r#" Payment id="test", realm="api", method="tempo", intent="charge", request="e30""#;
let parsed = parse_www_authenticate(header).unwrap();
assert_eq!(parsed.id, "test");
}
#[test]
fn test_parse_www_authenticate_with_description() {
let mut challenge = test_challenge();
challenge.description = Some("Pay \"here\" now".to_string());
let header = format_www_authenticate(&challenge).unwrap();
assert!(header.contains("description=\"Pay \\\"here\\\" now\""));
let parsed = parse_www_authenticate(&header).unwrap();
assert_eq!(parsed.description, Some("Pay \"here\" now".to_string()));
}
#[test]
fn test_parse_www_authenticate_all() {
let headers = vec![
"Bearer token",
r#"Payment id="a", realm="api", method="tempo", intent="charge", request="e30""#,
"Basic xyz",
r#"Payment id="b", realm="api", method="base", intent="charge", request="e30""#,
];
let results = parse_www_authenticate_all(headers);
assert_eq!(results.len(), 2);
let first = results[0].as_ref().unwrap();
assert_eq!(first.id, "a");
let second = results[1].as_ref().unwrap();
assert_eq!(second.id, "b");
}
#[test]
fn test_format_www_authenticate_many() {
let c1 = test_challenge();
let mut c2 = test_challenge();
c2.id = "def456".to_string();
c2.method = "base".into();
let headers = format_www_authenticate_many(&[c1, c2]).unwrap();
assert_eq!(headers.len(), 2);
assert!(headers[0].contains("abc123"));
assert!(headers[1].contains("def456"));
}
#[test]
fn test_parse_authorization() {
let challenge = test_challenge();
let credential = PaymentCredential::with_source(
challenge.to_echo(),
"did:pkh:eip155:42431:0x123",
PaymentPayload::transaction("0xabc"),
);
let header = format_authorization(&credential).unwrap();
let parsed = parse_authorization(&header).unwrap();
assert_eq!(parsed.challenge.id, "abc123");
assert_eq!(
parsed.source,
Some("did:pkh:eip155:42431:0x123".to_string())
);
assert_eq!(parsed.payload.signed_tx(), Some("0xabc"));
assert_eq!(parsed.payload.payload_type(), PayloadType::Transaction);
}
#[test]
fn test_parse_receipt() {
let receipt = Receipt {
status: ReceiptStatus::Success,
method: "tempo".into(),
timestamp: "2024-01-01T00:00:00Z".to_string(),
reference: "0xabc123".to_string(),
error: None,
};
let header = format_receipt(&receipt).unwrap();
let parsed = parse_receipt(&header).unwrap();
assert_eq!(parsed.status, ReceiptStatus::Success);
assert_eq!(parsed.method.as_str(), "tempo");
assert_eq!(parsed.reference, "0xabc123");
}
#[test]
fn test_parse_invalid_scheme() {
let result = parse_www_authenticate("Basic realm=\"test\"");
assert!(result.is_err());
}
#[test]
fn test_parse_missing_required_field() {
let result = parse_www_authenticate("Payment id=\"abc\", realm=\"api\"");
assert!(result.is_err());
}
#[test]
fn test_roundtrip_preserves_request() {
let original_request = serde_json::json!({
"amount": "5000",
"currency": "0xabc",
"nested": {"key": "value"}
});
let mut challenge = test_challenge();
challenge.request = Base64UrlJson::from_value(&original_request).unwrap();
let header = format_www_authenticate(&challenge).unwrap();
let parsed = parse_www_authenticate(&header).unwrap();
assert_eq!(parsed.request.raw(), challenge.request.raw());
let decoded: serde_json::Value = parsed.request.decode_value().unwrap();
assert_eq!(decoded, original_request);
}
}