1use std::fmt;
4
5use bitcoin::base64::engine::general_purpose;
6use bitcoin::base64::Engine;
7use serde::{Deserialize, Serialize};
8use thiserror::Error;
9
10use super::nut21::ProtectedEndpoint;
11use crate::dhke::hash_to_curve;
12use crate::secret::Secret;
13use crate::util::hex;
14use crate::{BlindedMessage, Id, Proof, ProofDleq, PublicKey};
15
16#[derive(Debug, Error)]
18pub enum Error {
19 #[error("Invalid prefix")]
21 InvalidPrefix,
22 #[error("Dleq Proof not included for auth proof")]
24 DleqProofNotIncluded,
25 #[error(transparent)]
27 HexError(#[from] hex::Error),
28 #[error(transparent)]
30 Base64Error(#[from] bitcoin::base64::DecodeError),
31 #[error(transparent)]
33 SerdeJsonError(#[from] serde_json::Error),
34 #[error(transparent)]
36 Utf8ParseError(#[from] std::string::FromUtf8Error),
37 #[error(transparent)]
39 DHKE(#[from] crate::dhke::Error),
40}
41
42#[derive(Debug, Clone, PartialEq, Eq, Hash, Default, Serialize)]
44#[cfg_attr(feature = "swagger", derive(utoipa::ToSchema))]
45pub struct Settings {
46 pub bat_max_mint: u64,
48 pub protected_endpoints: Vec<ProtectedEndpoint>,
50}
51
52impl Settings {
53 pub fn new(bat_max_mint: u64, protected_endpoints: Vec<ProtectedEndpoint>) -> Self {
55 Self {
56 bat_max_mint,
57 protected_endpoints,
58 }
59 }
60}
61
62impl<'de> Deserialize<'de> for Settings {
64 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
65 where
66 D: serde::Deserializer<'de>,
67 {
68 use std::collections::HashSet;
69
70 use super::nut21::matching_route_paths;
71
72 #[derive(Deserialize)]
74 struct RawSettings {
75 bat_max_mint: u64,
76 protected_endpoints: Vec<RawProtectedEndpoint>,
77 }
78
79 #[derive(Deserialize)]
80 struct RawProtectedEndpoint {
81 method: super::nut21::Method,
82 path: String,
83 }
84
85 let raw = RawSettings::deserialize(deserializer)?;
87
88 let mut protected_endpoints = HashSet::new();
90
91 for raw_endpoint in raw.protected_endpoints {
92 let expanded_paths = matching_route_paths(&raw_endpoint.path).map_err(|e| {
93 serde::de::Error::custom(format!("Invalid pattern '{}': {}", raw_endpoint.path, e))
94 })?;
95
96 for path in expanded_paths {
97 protected_endpoints.insert(super::nut21::ProtectedEndpoint::new(
98 raw_endpoint.method,
99 path,
100 ));
101 }
102 }
103
104 Ok(Settings {
106 bat_max_mint: raw.bat_max_mint,
107 protected_endpoints: protected_endpoints.into_iter().collect(),
108 })
109 }
110}
111
112#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
114pub enum AuthToken {
115 ClearAuth(String),
117 BlindAuth(BlindAuthToken),
119}
120
121impl fmt::Display for AuthToken {
122 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
123 match self {
124 Self::ClearAuth(cat) => cat.fmt(f),
125 Self::BlindAuth(bat) => bat.fmt(f),
126 }
127 }
128}
129
130impl AuthToken {
131 pub fn header_key(&self) -> String {
133 match self {
134 Self::ClearAuth(_) => "Clear-auth".to_string(),
135 Self::BlindAuth(_) => "Blind-auth".to_string(),
136 }
137 }
138}
139
140#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
142pub enum AuthRequired {
143 Clear,
145 Blind,
147}
148
149#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
151#[cfg_attr(feature = "swagger", derive(utoipa::ToSchema))]
152pub struct AuthProof {
153 #[serde(rename = "id")]
155 pub keyset_id: Id,
156 #[cfg_attr(feature = "swagger", schema(value_type = String))]
158 pub secret: Secret,
159 #[serde(rename = "C")]
161 #[cfg_attr(feature = "swagger", schema(value_type = String))]
162 pub c: PublicKey,
163 pub dleq: Option<ProofDleq>,
165}
166
167impl AuthProof {
168 pub fn y(&self) -> Result<PublicKey, Error> {
170 Ok(hash_to_curve(self.secret.as_bytes())?)
171 }
172}
173
174impl From<AuthProof> for Proof {
175 fn from(value: AuthProof) -> Self {
176 Self {
177 amount: 1.into(),
178 keyset_id: value.keyset_id,
179 secret: value.secret,
180 c: value.c,
181 witness: None,
182 dleq: value.dleq,
183 }
184 }
185}
186
187impl TryFrom<Proof> for AuthProof {
188 type Error = Error;
189 fn try_from(value: Proof) -> Result<Self, Self::Error> {
190 Ok(Self {
191 keyset_id: value.keyset_id,
192 secret: value.secret,
193 c: value.c,
194 dleq: value.dleq,
195 })
196 }
197}
198
199#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
201pub struct BlindAuthToken {
202 pub auth_proof: AuthProof,
204}
205
206impl BlindAuthToken {
207 pub fn new(auth_proof: AuthProof) -> Self {
209 BlindAuthToken { auth_proof }
210 }
211
212 pub fn without_dleq(&self) -> Self {
216 Self {
217 auth_proof: AuthProof {
218 keyset_id: self.auth_proof.keyset_id,
219 secret: self.auth_proof.secret.clone(),
220 c: self.auth_proof.c,
221 dleq: None,
222 },
223 }
224 }
225}
226
227impl fmt::Display for BlindAuthToken {
228 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
229 let json_string = serde_json::to_string(&self.auth_proof).map_err(|_| fmt::Error)?;
230 let encoded = general_purpose::URL_SAFE.encode(json_string);
231 write!(f, "authA{encoded}")
232 }
233}
234
235impl std::str::FromStr for BlindAuthToken {
236 type Err = Error;
237
238 fn from_str(s: &str) -> Result<Self, Self::Err> {
239 let encoded = s.strip_prefix("authA").ok_or(Error::InvalidPrefix)?;
241
242 let json_string = general_purpose::URL_SAFE.decode(encoded)?;
244
245 let json_str = String::from_utf8(json_string)?;
247
248 let auth_proof: AuthProof = serde_json::from_str(&json_str)?;
250
251 Ok(BlindAuthToken { auth_proof })
252 }
253}
254
255#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
257#[cfg_attr(feature = "swagger", derive(utoipa::ToSchema))]
258pub struct MintAuthRequest {
259 #[cfg_attr(feature = "swagger", schema(max_items = 1_000))]
261 pub outputs: Vec<BlindedMessage>,
262}
263
264impl MintAuthRequest {
265 pub fn amount(&self) -> u64 {
267 self.outputs.len() as u64
268 }
269}
270
271#[cfg(test)]
272mod tests {
273 use std::collections::HashSet;
274
275 use super::super::nut21::{Method, RoutePath};
276 use super::*;
277 use crate::nut00::KnownMethod;
278 use crate::PaymentMethod;
279
280 #[test]
281 fn test_settings_deserialize_direct_paths() {
282 let json = r#"{
283 "bat_max_mint": 10,
284 "protected_endpoints": [
285 {
286 "method": "GET",
287 "path": "/v1/mint/bolt11"
288 },
289 {
290 "method": "POST",
291 "path": "/v1/swap"
292 }
293 ]
294 }"#;
295
296 let settings: Settings = serde_json::from_str(json).unwrap();
297
298 assert_eq!(settings.bat_max_mint, 10);
299 assert_eq!(settings.protected_endpoints.len(), 2);
300
301 let paths = settings
303 .protected_endpoints
304 .iter()
305 .map(|ep| (ep.method, ep.path.clone()))
306 .collect::<Vec<_>>();
307 assert!(paths.contains(&(
308 Method::Get,
309 RoutePath::Mint(PaymentMethod::Known(KnownMethod::Bolt11).to_string())
310 )));
311 assert!(paths.contains(&(Method::Post, RoutePath::Swap)));
312 }
313
314 #[test]
315 fn test_settings_deserialize_with_regex() {
316 let json = r#"{
317 "bat_max_mint": 5,
318 "protected_endpoints": [
319 {
320 "method": "GET",
321 "path": "/v1/mint/*"
322 },
323 {
324 "method": "POST",
325 "path": "/v1/swap"
326 }
327 ]
328 }"#;
329
330 let settings: Settings = serde_json::from_str(json).unwrap();
331
332 assert_eq!(settings.bat_max_mint, 5);
333 assert_eq!(settings.protected_endpoints.len(), 5); let expected_protected: HashSet<ProtectedEndpoint> = HashSet::from_iter(vec![
336 ProtectedEndpoint::new(Method::Post, RoutePath::Swap),
337 ProtectedEndpoint::new(
338 Method::Get,
339 RoutePath::Mint(PaymentMethod::Known(KnownMethod::Bolt11).to_string()),
340 ),
341 ProtectedEndpoint::new(
342 Method::Get,
343 RoutePath::MintQuote(PaymentMethod::Known(KnownMethod::Bolt11).to_string()),
344 ),
345 ProtectedEndpoint::new(
346 Method::Get,
347 RoutePath::MintQuote(PaymentMethod::Known(KnownMethod::Bolt12).to_string()),
348 ),
349 ProtectedEndpoint::new(
350 Method::Get,
351 RoutePath::Mint(PaymentMethod::Known(KnownMethod::Bolt12).to_string()),
352 ),
353 ]);
354
355 let deserialized_protected = settings.protected_endpoints.into_iter().collect();
356
357 assert_eq!(expected_protected, deserialized_protected);
358 }
359
360 #[test]
361 fn test_settings_deserialize_invalid_regex() {
362 let json = r#"{
363 "bat_max_mint": 5,
364 "protected_endpoints": [
365 {
366 "method": "GET",
367 "path": "/*wildcard_start"
368 }
369 ]
370 }"#;
371
372 let result = serde_json::from_str::<Settings>(json);
373 assert!(result.is_err());
374 }
375
376 #[test]
377 fn test_settings_deserialize_all_paths() {
378 let json = r#"{
379 "bat_max_mint": 5,
380 "protected_endpoints": [
381 {
382 "method": "GET",
383 "path": "/v1/*"
384 }
385 ]
386 }"#;
387
388 let settings: Settings = serde_json::from_str(json).unwrap();
389 assert_eq!(
390 settings.protected_endpoints.len(),
391 RoutePath::all_known_paths().len()
392 );
393 }
394}