use std::fmt;
use bitcoin::base64::engine::general_purpose::{self, GeneralPurposeConfig};
use bitcoin::base64::engine::GeneralPurpose;
use bitcoin::base64::{alphabet, Engine};
use serde::{Deserialize, Serialize};
use thiserror::Error;
use super::nut21::ProtectedEndpoint;
use crate::dhke::hash_to_curve;
use crate::secret::Secret;
use crate::util::hex;
use crate::{BlindedMessage, Id, Proof, ProofDleq, PublicKey};
#[derive(Debug, Error)]
pub enum Error {
#[error("Invalid prefix")]
InvalidPrefix,
#[error("Dleq Proof not included for auth proof")]
DleqProofNotIncluded,
#[error(transparent)]
HexError(#[from] hex::Error),
#[error(transparent)]
Base64Error(#[from] bitcoin::base64::DecodeError),
#[error(transparent)]
SerdeJsonError(#[from] serde_json::Error),
#[error(transparent)]
Utf8ParseError(#[from] std::string::FromUtf8Error),
#[error(transparent)]
DHKE(#[from] crate::dhke::Error),
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Default, Serialize)]
#[cfg_attr(feature = "swagger", derive(utoipa::ToSchema))]
pub struct Settings {
pub bat_max_mint: u64,
pub protected_endpoints: Vec<ProtectedEndpoint>,
}
impl Settings {
pub fn new(bat_max_mint: u64, protected_endpoints: Vec<ProtectedEndpoint>) -> Self {
Self {
bat_max_mint,
protected_endpoints,
}
}
}
impl<'de> Deserialize<'de> for Settings {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
use std::collections::HashSet;
use super::nut21::matching_route_paths;
#[derive(Deserialize)]
struct RawSettings {
bat_max_mint: u64,
protected_endpoints: Vec<RawProtectedEndpoint>,
}
#[derive(Deserialize)]
struct RawProtectedEndpoint {
method: super::nut21::Method,
path: String,
}
let raw = RawSettings::deserialize(deserializer)?;
let mut protected_endpoints = HashSet::new();
for raw_endpoint in raw.protected_endpoints {
let expanded_paths = matching_route_paths(&raw_endpoint.path).map_err(|e| {
serde::de::Error::custom(format!("Invalid pattern '{}': {}", raw_endpoint.path, e))
})?;
for path in expanded_paths {
protected_endpoints.insert(super::nut21::ProtectedEndpoint::new(
raw_endpoint.method,
path,
));
}
}
Ok(Settings {
bat_max_mint: raw.bat_max_mint,
protected_endpoints: protected_endpoints.into_iter().collect(),
})
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum AuthToken {
ClearAuth(String),
BlindAuth(BlindAuthToken),
}
impl fmt::Display for AuthToken {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::ClearAuth(cat) => cat.fmt(f),
Self::BlindAuth(bat) => bat.fmt(f),
}
}
}
impl AuthToken {
pub fn header_key(&self) -> String {
match self {
Self::ClearAuth(_) => "Clear-auth".to_string(),
Self::BlindAuth(_) => "Blind-auth".to_string(),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum AuthRequired {
Clear,
Blind,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[cfg_attr(feature = "swagger", derive(utoipa::ToSchema))]
pub struct AuthProof {
#[serde(rename = "id")]
pub keyset_id: Id,
#[cfg_attr(feature = "swagger", schema(value_type = String))]
pub secret: Secret,
#[serde(rename = "C")]
#[cfg_attr(feature = "swagger", schema(value_type = String))]
pub c: PublicKey,
pub dleq: Option<ProofDleq>,
}
impl AuthProof {
pub fn y(&self) -> Result<PublicKey, Error> {
Ok(hash_to_curve(self.secret.as_bytes())?)
}
}
impl From<AuthProof> for Proof {
fn from(value: AuthProof) -> Self {
Self {
amount: 1.into(),
keyset_id: value.keyset_id,
secret: value.secret,
c: value.c,
witness: None,
dleq: value.dleq,
p2pk_e: None,
}
}
}
impl TryFrom<Proof> for AuthProof {
type Error = Error;
fn try_from(value: Proof) -> Result<Self, Self::Error> {
Ok(Self {
keyset_id: value.keyset_id,
secret: value.secret,
c: value.c,
dleq: value.dleq,
})
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct BlindAuthToken {
pub auth_proof: AuthProof,
}
impl BlindAuthToken {
pub fn new(auth_proof: AuthProof) -> Self {
BlindAuthToken { auth_proof }
}
pub fn without_dleq(&self) -> Self {
Self {
auth_proof: AuthProof {
keyset_id: self.auth_proof.keyset_id,
secret: self.auth_proof.secret.clone(),
c: self.auth_proof.c,
dleq: None,
},
}
}
}
impl fmt::Display for BlindAuthToken {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let json_string = serde_json::to_string(&self.auth_proof).map_err(|_| fmt::Error)?;
let encoded = general_purpose::URL_SAFE.encode(json_string);
write!(f, "authA{encoded}")
}
}
impl std::str::FromStr for BlindAuthToken {
type Err = Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let encoded = s.strip_prefix("authA").ok_or(Error::InvalidPrefix)?;
let decode_config = GeneralPurposeConfig::new()
.with_decode_padding_mode(bitcoin::base64::engine::DecodePaddingMode::Indifferent);
let json_string =
GeneralPurpose::new(&alphabet::URL_SAFE, decode_config).decode(encoded)?;
let json_str = String::from_utf8(json_string)?;
let auth_proof: AuthProof = serde_json::from_str(&json_str)?;
Ok(BlindAuthToken { auth_proof })
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[cfg_attr(feature = "swagger", derive(utoipa::ToSchema))]
pub struct MintAuthRequest {
#[cfg_attr(feature = "swagger", schema(max_items = 1_000))]
pub outputs: Vec<BlindedMessage>,
}
impl MintAuthRequest {
pub fn amount(&self) -> u64 {
self.outputs.len() as u64
}
}
#[cfg(test)]
mod tests {
use std::collections::HashSet;
use super::super::nut21::{Method, RoutePath};
use super::*;
use crate::nut00::KnownMethod;
use crate::PaymentMethod;
#[test]
fn test_blind_auth_token_padding() {
use std::str::FromStr;
use crate::SecretKey;
let secret_key = SecretKey::generate();
let public_key = secret_key.public_key();
let secret = Secret::generate();
let auth_proof = AuthProof {
keyset_id: Id::from_bytes(&[0, 1, 2, 3, 4, 5, 6, 7]).expect("valid id"),
secret,
c: public_key,
dleq: None,
};
let token = BlindAuthToken::new(auth_proof);
let token_str = token.to_string();
assert!(token_str.starts_with("authA"));
let parsed =
BlindAuthToken::from_str(&token_str).expect("Failed to parse token with padding");
assert_eq!(token, parsed);
let token_no_pad = token_str.trim_end_matches('=');
let parsed_no_pad =
BlindAuthToken::from_str(token_no_pad).expect("Failed to parse token without padding");
assert_eq!(token, parsed_no_pad);
}
#[test]
fn test_settings_deserialize_direct_paths() {
let json = r#"{
"bat_max_mint": 10,
"protected_endpoints": [
{
"method": "GET",
"path": "/v1/mint/bolt11"
},
{
"method": "POST",
"path": "/v1/swap"
}
]
}"#;
let settings: Settings = serde_json::from_str(json).unwrap();
assert_eq!(settings.bat_max_mint, 10);
assert_eq!(settings.protected_endpoints.len(), 2);
let paths = settings
.protected_endpoints
.iter()
.map(|ep| (ep.method, ep.path.clone()))
.collect::<Vec<_>>();
assert!(paths.contains(&(
Method::Get,
RoutePath::Mint(PaymentMethod::Known(KnownMethod::Bolt11).to_string())
)));
assert!(paths.contains(&(Method::Post, RoutePath::Swap)));
}
#[test]
fn test_settings_deserialize_with_regex() {
let json = r#"{
"bat_max_mint": 5,
"protected_endpoints": [
{
"method": "GET",
"path": "/v1/mint/*"
},
{
"method": "POST",
"path": "/v1/swap"
}
]
}"#;
let settings: Settings = serde_json::from_str(json).unwrap();
assert_eq!(settings.bat_max_mint, 5);
assert_eq!(settings.protected_endpoints.len(), 5);
let expected_protected: HashSet<ProtectedEndpoint> = HashSet::from_iter(vec![
ProtectedEndpoint::new(Method::Post, RoutePath::Swap),
ProtectedEndpoint::new(
Method::Get,
RoutePath::Mint(PaymentMethod::Known(KnownMethod::Bolt11).to_string()),
),
ProtectedEndpoint::new(
Method::Get,
RoutePath::MintQuote(PaymentMethod::Known(KnownMethod::Bolt11).to_string()),
),
ProtectedEndpoint::new(
Method::Get,
RoutePath::MintQuote(PaymentMethod::Known(KnownMethod::Bolt12).to_string()),
),
ProtectedEndpoint::new(
Method::Get,
RoutePath::Mint(PaymentMethod::Known(KnownMethod::Bolt12).to_string()),
),
]);
let deserialized_protected = settings.protected_endpoints.into_iter().collect();
assert_eq!(expected_protected, deserialized_protected);
}
#[test]
fn test_settings_deserialize_invalid_regex() {
let json = r#"{
"bat_max_mint": 5,
"protected_endpoints": [
{
"method": "GET",
"path": "/*wildcard_start"
}
]
}"#;
let result = serde_json::from_str::<Settings>(json);
assert!(result.is_err());
}
#[test]
fn test_settings_deserialize_all_paths() {
let json = r#"{
"bat_max_mint": 5,
"protected_endpoints": [
{
"method": "GET",
"path": "/v1/*"
}
]
}"#;
let settings: Settings = serde_json::from_str(json).unwrap();
assert_eq!(
settings.protected_endpoints.len(),
RoutePath::all_known_paths().len()
);
}
}