1use std::fmt;
4
5use bitcoin::base64::engine::general_purpose;
6use bitcoin::base64::Engine;
7use serde::{Deserialize, Serialize};
8use thiserror::Error;
9
10use super::nut21::ProtectedEndpoint;
11use crate::dhke::hash_to_curve;
12use crate::secret::Secret;
13use crate::util::hex;
14use crate::{BlindedMessage, Id, Proof, ProofDleq, PublicKey};
15
16#[derive(Debug, Error)]
18pub enum Error {
19 #[error("Invalid prefix")]
21 InvalidPrefix,
22 #[error("Dleq Proof not included for auth proof")]
24 DleqProofNotIncluded,
25 #[error(transparent)]
27 HexError(#[from] hex::Error),
28 #[error(transparent)]
30 Base64Error(#[from] bitcoin::base64::DecodeError),
31 #[error(transparent)]
33 SerdeJsonError(#[from] serde_json::Error),
34 #[error(transparent)]
36 Utf8ParseError(#[from] std::string::FromUtf8Error),
37 #[error(transparent)]
39 DHKE(#[from] crate::dhke::Error),
40}
41
42#[derive(Debug, Clone, PartialEq, Eq, Hash, Default, Serialize)]
44#[cfg_attr(feature = "swagger", derive(utoipa::ToSchema))]
45pub struct Settings {
46 pub bat_max_mint: u64,
48 pub protected_endpoints: Vec<ProtectedEndpoint>,
50}
51
52impl Settings {
53 pub fn new(bat_max_mint: u64, protected_endpoints: Vec<ProtectedEndpoint>) -> Self {
55 Self {
56 bat_max_mint,
57 protected_endpoints,
58 }
59 }
60}
61
62impl<'de> Deserialize<'de> for Settings {
64 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
65 where
66 D: serde::Deserializer<'de>,
67 {
68 use std::collections::HashSet;
69
70 use super::nut21::matching_route_paths;
71
72 #[derive(Deserialize)]
74 struct RawSettings {
75 bat_max_mint: u64,
76 protected_endpoints: Vec<RawProtectedEndpoint>,
77 }
78
79 #[derive(Deserialize)]
80 struct RawProtectedEndpoint {
81 method: super::nut21::Method,
82 path: String,
83 }
84
85 let raw = RawSettings::deserialize(deserializer)?;
87
88 let mut protected_endpoints = HashSet::new();
90
91 for raw_endpoint in raw.protected_endpoints {
92 let expanded_paths = matching_route_paths(&raw_endpoint.path).map_err(|e| {
93 serde::de::Error::custom(format!(
94 "Invalid regex pattern '{}': {}",
95 raw_endpoint.path, e
96 ))
97 })?;
98
99 for path in expanded_paths {
100 protected_endpoints.insert(super::nut21::ProtectedEndpoint::new(
101 raw_endpoint.method,
102 path,
103 ));
104 }
105 }
106
107 Ok(Settings {
109 bat_max_mint: raw.bat_max_mint,
110 protected_endpoints: protected_endpoints.into_iter().collect(),
111 })
112 }
113}
114
115#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
117pub enum AuthToken {
118 ClearAuth(String),
120 BlindAuth(BlindAuthToken),
122}
123
124impl fmt::Display for AuthToken {
125 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
126 match self {
127 Self::ClearAuth(cat) => cat.fmt(f),
128 Self::BlindAuth(bat) => bat.fmt(f),
129 }
130 }
131}
132
133impl AuthToken {
134 pub fn header_key(&self) -> String {
136 match self {
137 Self::ClearAuth(_) => "Clear-auth".to_string(),
138 Self::BlindAuth(_) => "Blind-auth".to_string(),
139 }
140 }
141}
142
143#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
145pub enum AuthRequired {
146 Clear,
148 Blind,
150}
151
152#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
154#[cfg_attr(feature = "swagger", derive(utoipa::ToSchema))]
155pub struct AuthProof {
156 #[serde(rename = "id")]
158 pub keyset_id: Id,
159 #[cfg_attr(feature = "swagger", schema(value_type = String))]
161 pub secret: Secret,
162 #[serde(rename = "C")]
164 #[cfg_attr(feature = "swagger", schema(value_type = String))]
165 pub c: PublicKey,
166 pub dleq: Option<ProofDleq>,
168}
169
170impl AuthProof {
171 pub fn y(&self) -> Result<PublicKey, Error> {
173 Ok(hash_to_curve(self.secret.as_bytes())?)
174 }
175}
176
177impl From<AuthProof> for Proof {
178 fn from(value: AuthProof) -> Self {
179 Self {
180 amount: 1.into(),
181 keyset_id: value.keyset_id,
182 secret: value.secret,
183 c: value.c,
184 witness: None,
185 dleq: value.dleq,
186 }
187 }
188}
189
190impl TryFrom<Proof> for AuthProof {
191 type Error = Error;
192 fn try_from(value: Proof) -> Result<Self, Self::Error> {
193 Ok(Self {
194 keyset_id: value.keyset_id,
195 secret: value.secret,
196 c: value.c,
197 dleq: value.dleq,
198 })
199 }
200}
201
202#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
204pub struct BlindAuthToken {
205 pub auth_proof: AuthProof,
207}
208
209impl BlindAuthToken {
210 pub fn new(auth_proof: AuthProof) -> Self {
212 BlindAuthToken { auth_proof }
213 }
214
215 pub fn without_dleq(&self) -> Self {
219 Self {
220 auth_proof: AuthProof {
221 keyset_id: self.auth_proof.keyset_id,
222 secret: self.auth_proof.secret.clone(),
223 c: self.auth_proof.c,
224 dleq: None,
225 },
226 }
227 }
228}
229
230impl fmt::Display for BlindAuthToken {
231 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
232 let json_string = serde_json::to_string(&self.auth_proof).map_err(|_| fmt::Error)?;
233 let encoded = general_purpose::URL_SAFE.encode(json_string);
234 write!(f, "authA{encoded}")
235 }
236}
237
238impl std::str::FromStr for BlindAuthToken {
239 type Err = Error;
240
241 fn from_str(s: &str) -> Result<Self, Self::Err> {
242 let encoded = s.strip_prefix("authA").ok_or(Error::InvalidPrefix)?;
244
245 let json_string = general_purpose::URL_SAFE.decode(encoded)?;
247
248 let json_str = String::from_utf8(json_string)?;
250
251 let auth_proof: AuthProof = serde_json::from_str(&json_str)?;
253
254 Ok(BlindAuthToken { auth_proof })
255 }
256}
257
258#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
260#[cfg_attr(feature = "swagger", derive(utoipa::ToSchema))]
261pub struct MintAuthRequest {
262 #[cfg_attr(feature = "swagger", schema(max_items = 1_000))]
264 pub outputs: Vec<BlindedMessage>,
265}
266
267impl MintAuthRequest {
268 pub fn amount(&self) -> u64 {
270 self.outputs.len() as u64
271 }
272}
273
274#[cfg(test)]
275mod tests {
276 use std::collections::HashSet;
277
278 use strum::IntoEnumIterator;
279
280 use super::super::nut21::{Method, RoutePath};
281 use super::*;
282
283 #[test]
284 fn test_settings_deserialize_direct_paths() {
285 let json = r#"{
286 "bat_max_mint": 10,
287 "protected_endpoints": [
288 {
289 "method": "GET",
290 "path": "/v1/mint/bolt11"
291 },
292 {
293 "method": "POST",
294 "path": "/v1/swap"
295 }
296 ]
297 }"#;
298
299 let settings: Settings = serde_json::from_str(json).unwrap();
300
301 assert_eq!(settings.bat_max_mint, 10);
302 assert_eq!(settings.protected_endpoints.len(), 2);
303
304 let paths = settings
306 .protected_endpoints
307 .iter()
308 .map(|ep| (ep.method, ep.path))
309 .collect::<Vec<_>>();
310 assert!(paths.contains(&(Method::Get, RoutePath::MintBolt11)));
311 assert!(paths.contains(&(Method::Post, RoutePath::Swap)));
312 }
313
314 #[test]
315 fn test_settings_deserialize_with_regex() {
316 let json = r#"{
317 "bat_max_mint": 5,
318 "protected_endpoints": [
319 {
320 "method": "GET",
321 "path": "^/v1/mint/.*"
322 },
323 {
324 "method": "POST",
325 "path": "/v1/swap"
326 }
327 ]
328 }"#;
329
330 let settings: Settings = serde_json::from_str(json).unwrap();
331
332 assert_eq!(settings.bat_max_mint, 5);
333 assert_eq!(settings.protected_endpoints.len(), 3); let expected_protected: HashSet<ProtectedEndpoint> = HashSet::from_iter(vec![
336 ProtectedEndpoint::new(Method::Post, RoutePath::Swap),
337 ProtectedEndpoint::new(Method::Get, RoutePath::MintBolt11),
338 ProtectedEndpoint::new(Method::Get, RoutePath::MintQuoteBolt11),
339 ]);
340
341 let deserialized_protected = settings.protected_endpoints.into_iter().collect();
342
343 assert_eq!(expected_protected, deserialized_protected);
344 }
345
346 #[test]
347 fn test_settings_deserialize_invalid_regex() {
348 let json = r#"{
349 "bat_max_mint": 5,
350 "protected_endpoints": [
351 {
352 "method": "GET",
353 "path": "(unclosed parenthesis"
354 }
355 ]
356 }"#;
357
358 let result = serde_json::from_str::<Settings>(json);
359 assert!(result.is_err());
360 }
361
362 #[test]
363 fn test_settings_deserialize_all_paths() {
364 let json = r#"{
365 "bat_max_mint": 5,
366 "protected_endpoints": [
367 {
368 "method": "GET",
369 "path": ".*"
370 }
371 ]
372 }"#;
373
374 let settings: Settings = serde_json::from_str(json).unwrap();
375 assert_eq!(
376 settings.protected_endpoints.len(),
377 RoutePath::iter().count()
378 );
379 }
380}