1use std::fmt;
4
5use bitcoin::base64::engine::general_purpose::{self, GeneralPurposeConfig};
6use bitcoin::base64::engine::GeneralPurpose;
7use bitcoin::base64::{alphabet, Engine};
8use serde::{Deserialize, Serialize};
9use thiserror::Error;
10
11use super::nut21::ProtectedEndpoint;
12use crate::dhke::hash_to_curve;
13use crate::secret::Secret;
14use crate::util::hex;
15use crate::{BlindedMessage, Id, Proof, ProofDleq, PublicKey};
16
17#[derive(Debug, Error)]
19pub enum Error {
20 #[error("Invalid prefix")]
22 InvalidPrefix,
23 #[error("Dleq Proof not included for auth proof")]
25 DleqProofNotIncluded,
26 #[error(transparent)]
28 HexError(#[from] hex::Error),
29 #[error(transparent)]
31 Base64Error(#[from] bitcoin::base64::DecodeError),
32 #[error(transparent)]
34 SerdeJsonError(#[from] serde_json::Error),
35 #[error(transparent)]
37 Utf8ParseError(#[from] std::string::FromUtf8Error),
38 #[error(transparent)]
40 DHKE(#[from] crate::dhke::Error),
41}
42
43#[derive(Debug, Clone, PartialEq, Eq, Hash, Default, Serialize)]
45pub struct Settings {
46 pub bat_max_mint: u64,
48 pub protected_endpoints: Vec<ProtectedEndpoint>,
50}
51
52impl Settings {
53 pub fn new(bat_max_mint: u64, protected_endpoints: Vec<ProtectedEndpoint>) -> Self {
55 Self {
56 bat_max_mint,
57 protected_endpoints,
58 }
59 }
60}
61
62impl<'de> Deserialize<'de> for Settings {
64 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
65 where
66 D: serde::Deserializer<'de>,
67 {
68 use std::collections::HashSet;
69
70 use super::nut21::matching_route_paths;
71
72 #[derive(Deserialize)]
74 struct RawSettings {
75 bat_max_mint: u64,
76 protected_endpoints: Vec<RawProtectedEndpoint>,
77 }
78
79 #[derive(Deserialize)]
80 struct RawProtectedEndpoint {
81 method: super::nut21::Method,
82 path: String,
83 }
84
85 let raw = RawSettings::deserialize(deserializer)?;
87
88 let mut protected_endpoints = HashSet::new();
90
91 for raw_endpoint in raw.protected_endpoints {
92 let expanded_paths = matching_route_paths(&raw_endpoint.path).map_err(|e| {
93 serde::de::Error::custom(format!("Invalid pattern '{}': {}", raw_endpoint.path, e))
94 })?;
95
96 for path in expanded_paths {
97 protected_endpoints.insert(super::nut21::ProtectedEndpoint::new(
98 raw_endpoint.method,
99 path,
100 ));
101 }
102 }
103
104 Ok(Settings {
106 bat_max_mint: raw.bat_max_mint,
107 protected_endpoints: protected_endpoints.into_iter().collect(),
108 })
109 }
110}
111
112#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
114pub enum AuthToken {
115 ClearAuth(String),
117 BlindAuth(BlindAuthToken),
119}
120
121impl fmt::Display for AuthToken {
122 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
123 match self {
124 Self::ClearAuth(cat) => cat.fmt(f),
125 Self::BlindAuth(bat) => bat.fmt(f),
126 }
127 }
128}
129
130impl AuthToken {
131 pub fn header_key(&self) -> String {
133 match self {
134 Self::ClearAuth(_) => "Clear-auth".to_string(),
135 Self::BlindAuth(_) => "Blind-auth".to_string(),
136 }
137 }
138}
139
140#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
142pub enum AuthRequired {
143 Clear,
145 Blind,
147}
148
149#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
151pub struct AuthProof {
152 #[serde(rename = "id")]
154 pub keyset_id: Id,
155 pub secret: Secret,
157 #[serde(rename = "C")]
159 pub c: PublicKey,
160 pub dleq: Option<ProofDleq>,
162}
163
164impl AuthProof {
165 pub fn y(&self) -> Result<PublicKey, Error> {
167 Ok(hash_to_curve(self.secret.as_bytes())?)
168 }
169}
170
171impl From<AuthProof> for Proof {
172 fn from(value: AuthProof) -> Self {
173 Self {
174 amount: 1.into(),
175 keyset_id: value.keyset_id,
176 secret: value.secret,
177 c: value.c,
178 witness: None,
179 dleq: value.dleq,
180 p2pk_e: None,
181 }
182 }
183}
184
185impl TryFrom<Proof> for AuthProof {
186 type Error = Error;
187 fn try_from(value: Proof) -> Result<Self, Self::Error> {
188 Ok(Self {
189 keyset_id: value.keyset_id,
190 secret: value.secret,
191 c: value.c,
192 dleq: value.dleq,
193 })
194 }
195}
196
197#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
199pub struct BlindAuthToken {
200 pub auth_proof: AuthProof,
202}
203
204impl BlindAuthToken {
205 pub fn new(auth_proof: AuthProof) -> Self {
207 BlindAuthToken { auth_proof }
208 }
209
210 pub fn without_dleq(&self) -> Self {
214 Self {
215 auth_proof: AuthProof {
216 keyset_id: self.auth_proof.keyset_id,
217 secret: self.auth_proof.secret.clone(),
218 c: self.auth_proof.c,
219 dleq: None,
220 },
221 }
222 }
223}
224
225impl fmt::Display for BlindAuthToken {
226 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
227 let json_string = serde_json::to_string(&self.auth_proof).map_err(|_| fmt::Error)?;
228 let encoded = general_purpose::URL_SAFE.encode(json_string);
229 write!(f, "authA{encoded}")
230 }
231}
232
233impl std::str::FromStr for BlindAuthToken {
234 type Err = Error;
235
236 fn from_str(s: &str) -> Result<Self, Self::Err> {
237 let encoded = s.strip_prefix("authA").ok_or(Error::InvalidPrefix)?;
239
240 let decode_config = GeneralPurposeConfig::new()
242 .with_decode_padding_mode(bitcoin::base64::engine::DecodePaddingMode::Indifferent);
243 let json_string =
244 GeneralPurpose::new(&alphabet::URL_SAFE, decode_config).decode(encoded)?;
245
246 let json_str = String::from_utf8(json_string)?;
248
249 let auth_proof: AuthProof = serde_json::from_str(&json_str)?;
251
252 Ok(BlindAuthToken { auth_proof })
253 }
254}
255
256#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
258pub struct MintAuthRequest {
259 pub outputs: Vec<BlindedMessage>,
261}
262
263impl MintAuthRequest {
264 pub fn amount(&self) -> u64 {
266 self.outputs.len() as u64
267 }
268}
269
270#[cfg(test)]
271mod tests {
272 use std::collections::HashSet;
273
274 use super::super::nut21::{Method, RoutePath};
275 use super::*;
276 use crate::nut00::KnownMethod;
277 use crate::PaymentMethod;
278
279 #[test]
280 fn test_blind_auth_token_padding() {
281 use std::str::FromStr;
282
283 use crate::SecretKey;
284
285 let secret_key = SecretKey::generate();
287 let public_key = secret_key.public_key();
288 let secret = Secret::generate();
289 let auth_proof = AuthProof {
290 keyset_id: Id::from_bytes(&[0, 1, 2, 3, 4, 5, 6, 7]).expect("valid id"),
291 secret,
292 c: public_key,
293 dleq: None,
294 };
295 let token = BlindAuthToken::new(auth_proof);
296
297 let token_str = token.to_string();
299 assert!(token_str.starts_with("authA"));
300
301 let parsed =
303 BlindAuthToken::from_str(&token_str).expect("Failed to parse token with padding");
304 assert_eq!(token, parsed);
305
306 let token_no_pad = token_str.trim_end_matches('=');
308 let parsed_no_pad =
309 BlindAuthToken::from_str(token_no_pad).expect("Failed to parse token without padding");
310 assert_eq!(token, parsed_no_pad);
311 }
312
313 #[test]
314 fn test_settings_deserialize_direct_paths() {
315 let json = r#"{
316 "bat_max_mint": 10,
317 "protected_endpoints": [
318 {
319 "method": "GET",
320 "path": "/v1/mint/bolt11"
321 },
322 {
323 "method": "POST",
324 "path": "/v1/swap"
325 }
326 ]
327 }"#;
328
329 let settings: Settings = serde_json::from_str(json).unwrap();
330
331 assert_eq!(settings.bat_max_mint, 10);
332 assert_eq!(settings.protected_endpoints.len(), 2);
333
334 let paths = settings
336 .protected_endpoints
337 .iter()
338 .map(|ep| (ep.method, ep.path.clone()))
339 .collect::<Vec<_>>();
340 assert!(paths.contains(&(
341 Method::Get,
342 RoutePath::Mint(PaymentMethod::Known(KnownMethod::Bolt11).to_string())
343 )));
344 assert!(paths.contains(&(Method::Post, RoutePath::Swap)));
345 }
346
347 #[test]
348 fn test_settings_deserialize_with_regex() {
349 let json = r#"{
350 "bat_max_mint": 5,
351 "protected_endpoints": [
352 {
353 "method": "GET",
354 "path": "/v1/mint/*"
355 },
356 {
357 "method": "POST",
358 "path": "/v1/swap"
359 }
360 ]
361 }"#;
362
363 let settings: Settings = serde_json::from_str(json).unwrap();
364
365 assert_eq!(settings.bat_max_mint, 5);
366 assert_eq!(settings.protected_endpoints.len(), 6); let expected_protected: HashSet<ProtectedEndpoint> = HashSet::from_iter(vec![
369 ProtectedEndpoint::new(Method::Post, RoutePath::Swap),
370 ProtectedEndpoint::new(
371 Method::Get,
372 RoutePath::Mint(PaymentMethod::Known(KnownMethod::Bolt11).to_string()),
373 ),
374 ProtectedEndpoint::new(
375 Method::Get,
376 RoutePath::MintQuote(PaymentMethod::Known(KnownMethod::Bolt11).to_string()),
377 ),
378 ProtectedEndpoint::new(
379 Method::Get,
380 RoutePath::MintQuote(PaymentMethod::Known(KnownMethod::Bolt12).to_string()),
381 ),
382 ProtectedEndpoint::new(
383 Method::Get,
384 RoutePath::Mint(PaymentMethod::Known(KnownMethod::Bolt12).to_string()),
385 ),
386 ProtectedEndpoint::new(Method::Get, RoutePath::Wildcard("/v1/mint/".to_string())),
387 ]);
388
389 let deserialized_protected = settings.protected_endpoints.into_iter().collect();
390
391 assert_eq!(expected_protected, deserialized_protected);
392 }
393
394 #[test]
395 fn test_settings_deserialize_invalid_regex() {
396 let json = r#"{
397 "bat_max_mint": 5,
398 "protected_endpoints": [
399 {
400 "method": "GET",
401 "path": "/*wildcard_start"
402 }
403 ]
404 }"#;
405
406 let result = serde_json::from_str::<Settings>(json);
407 assert!(result.is_err());
408 }
409
410 #[test]
411 fn test_settings_deserialize_unknown_exact_path() {
412 let json = r#"{
413 "bat_max_mint": 5,
414 "protected_endpoints": [
415 {
416 "method": "POST",
417 "path": "/v1/swp"
418 }
419 ]
420 }"#;
421
422 let result = serde_json::from_str::<Settings>(json);
423 assert!(result.is_err());
424 }
425
426 #[test]
427 fn test_settings_deserialize_all_paths() {
428 let json = r#"{
429 "bat_max_mint": 5,
430 "protected_endpoints": [
431 {
432 "method": "GET",
433 "path": "/v1/*"
434 }
435 ]
436 }"#;
437
438 let settings: Settings = serde_json::from_str(json).unwrap();
439 assert_eq!(
440 settings.protected_endpoints.len(),
441 RoutePath::all_known_paths().len() + 1
442 );
443 }
444}