Skip to main content

cashu/nuts/auth/
nut22.rs

1//! 22 Blind Auth
2
3use std::fmt;
4
5use bitcoin::base64::engine::general_purpose;
6use bitcoin::base64::Engine;
7use serde::{Deserialize, Serialize};
8use thiserror::Error;
9
10use super::nut21::ProtectedEndpoint;
11use crate::dhke::hash_to_curve;
12use crate::secret::Secret;
13use crate::util::hex;
14use crate::{BlindedMessage, Id, Proof, ProofDleq, PublicKey};
15
16/// NUT22 Error
17#[derive(Debug, Error)]
18pub enum Error {
19    /// Invalid Prefix
20    #[error("Invalid prefix")]
21    InvalidPrefix,
22    /// Dleq proof not included
23    #[error("Dleq Proof not included for auth proof")]
24    DleqProofNotIncluded,
25    /// Hex Error
26    #[error(transparent)]
27    HexError(#[from] hex::Error),
28    /// Base64 error
29    #[error(transparent)]
30    Base64Error(#[from] bitcoin::base64::DecodeError),
31    /// Serde Json error
32    #[error(transparent)]
33    SerdeJsonError(#[from] serde_json::Error),
34    /// Utf8 parse error
35    #[error(transparent)]
36    Utf8ParseError(#[from] std::string::FromUtf8Error),
37    /// DHKE error
38    #[error(transparent)]
39    DHKE(#[from] crate::dhke::Error),
40}
41
42/// Blind auth settings
43#[derive(Debug, Clone, PartialEq, Eq, Hash, Default, Serialize)]
44#[cfg_attr(feature = "swagger", derive(utoipa::ToSchema))]
45pub struct Settings {
46    /// Max number of blind auth tokens that can be minted per request
47    pub bat_max_mint: u64,
48    /// Protected endpoints
49    pub protected_endpoints: Vec<ProtectedEndpoint>,
50}
51
52impl Settings {
53    /// Create new [`Settings`]
54    pub fn new(bat_max_mint: u64, protected_endpoints: Vec<ProtectedEndpoint>) -> Self {
55        Self {
56            bat_max_mint,
57            protected_endpoints,
58        }
59    }
60}
61
62// Custom deserializer for Settings to expand patterns in protected endpoints
63impl<'de> Deserialize<'de> for Settings {
64    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
65    where
66        D: serde::Deserializer<'de>,
67    {
68        use std::collections::HashSet;
69
70        use super::nut21::matching_route_paths;
71
72        // Define a temporary struct to deserialize the raw data
73        #[derive(Deserialize)]
74        struct RawSettings {
75            bat_max_mint: u64,
76            protected_endpoints: Vec<RawProtectedEndpoint>,
77        }
78
79        #[derive(Deserialize)]
80        struct RawProtectedEndpoint {
81            method: super::nut21::Method,
82            path: String,
83        }
84
85        // Deserialize into the temporary struct
86        let raw = RawSettings::deserialize(deserializer)?;
87
88        // Process protected endpoints, expanding patterns if present
89        let mut protected_endpoints = HashSet::new();
90
91        for raw_endpoint in raw.protected_endpoints {
92            let expanded_paths = matching_route_paths(&raw_endpoint.path).map_err(|e| {
93                serde::de::Error::custom(format!("Invalid pattern '{}': {}", raw_endpoint.path, e))
94            })?;
95
96            for path in expanded_paths {
97                protected_endpoints.insert(super::nut21::ProtectedEndpoint::new(
98                    raw_endpoint.method,
99                    path,
100                ));
101            }
102        }
103
104        // Create the final Settings struct
105        Ok(Settings {
106            bat_max_mint: raw.bat_max_mint,
107            protected_endpoints: protected_endpoints.into_iter().collect(),
108        })
109    }
110}
111
112/// Auth Token
113#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
114pub enum AuthToken {
115    /// Clear Auth token
116    ClearAuth(String),
117    /// Blind Auth token
118    BlindAuth(BlindAuthToken),
119}
120
121impl fmt::Display for AuthToken {
122    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
123        match self {
124            Self::ClearAuth(cat) => cat.fmt(f),
125            Self::BlindAuth(bat) => bat.fmt(f),
126        }
127    }
128}
129
130impl AuthToken {
131    /// Header key for auth token type
132    pub fn header_key(&self) -> String {
133        match self {
134            Self::ClearAuth(_) => "Clear-auth".to_string(),
135            Self::BlindAuth(_) => "Blind-auth".to_string(),
136        }
137    }
138}
139
140/// Required Auth
141#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
142pub enum AuthRequired {
143    /// Clear Auth token
144    Clear,
145    /// Blind Auth token
146    Blind,
147}
148
149/// Auth Proofs
150#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
151#[cfg_attr(feature = "swagger", derive(utoipa::ToSchema))]
152pub struct AuthProof {
153    /// `Keyset id`
154    #[serde(rename = "id")]
155    pub keyset_id: Id,
156    /// Secret message
157    #[cfg_attr(feature = "swagger", schema(value_type = String))]
158    pub secret: Secret,
159    /// Unblinded signature
160    #[serde(rename = "C")]
161    #[cfg_attr(feature = "swagger", schema(value_type = String))]
162    pub c: PublicKey,
163    /// Auth Proof Dleq
164    pub dleq: Option<ProofDleq>,
165}
166
167impl AuthProof {
168    /// Y of AuthProof
169    pub fn y(&self) -> Result<PublicKey, Error> {
170        Ok(hash_to_curve(self.secret.as_bytes())?)
171    }
172}
173
174impl From<AuthProof> for Proof {
175    fn from(value: AuthProof) -> Self {
176        Self {
177            amount: 1.into(),
178            keyset_id: value.keyset_id,
179            secret: value.secret,
180            c: value.c,
181            witness: None,
182            dleq: value.dleq,
183        }
184    }
185}
186
187impl TryFrom<Proof> for AuthProof {
188    type Error = Error;
189    fn try_from(value: Proof) -> Result<Self, Self::Error> {
190        Ok(Self {
191            keyset_id: value.keyset_id,
192            secret: value.secret,
193            c: value.c,
194            dleq: value.dleq,
195        })
196    }
197}
198
199/// Blind Auth Token
200#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
201pub struct BlindAuthToken {
202    /// [AuthProof]
203    pub auth_proof: AuthProof,
204}
205
206impl BlindAuthToken {
207    /// Create new [ `BlindAuthToken`]
208    pub fn new(auth_proof: AuthProof) -> Self {
209        BlindAuthToken { auth_proof }
210    }
211
212    /// Remove DLEQ
213    ///
214    /// We do not send the DLEQ to the mint as it links redemption and creation
215    pub fn without_dleq(&self) -> Self {
216        Self {
217            auth_proof: AuthProof {
218                keyset_id: self.auth_proof.keyset_id,
219                secret: self.auth_proof.secret.clone(),
220                c: self.auth_proof.c,
221                dleq: None,
222            },
223        }
224    }
225}
226
227impl fmt::Display for BlindAuthToken {
228    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
229        let json_string = serde_json::to_string(&self.auth_proof).map_err(|_| fmt::Error)?;
230        let encoded = general_purpose::URL_SAFE.encode(json_string);
231        write!(f, "authA{encoded}")
232    }
233}
234
235impl std::str::FromStr for BlindAuthToken {
236    type Err = Error;
237
238    fn from_str(s: &str) -> Result<Self, Self::Err> {
239        // Check prefix and extract the base64 encoded part in one step
240        let encoded = s.strip_prefix("authA").ok_or(Error::InvalidPrefix)?;
241
242        // Decode the base64 URL-safe string
243        let json_string = general_purpose::URL_SAFE.decode(encoded)?;
244
245        // Convert bytes to UTF-8 string
246        let json_str = String::from_utf8(json_string)?;
247
248        // Deserialize the JSON string into AuthProof
249        let auth_proof: AuthProof = serde_json::from_str(&json_str)?;
250
251        Ok(BlindAuthToken { auth_proof })
252    }
253}
254
255/// Mint auth request [NUT-XX]
256#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
257#[cfg_attr(feature = "swagger", derive(utoipa::ToSchema))]
258pub struct MintAuthRequest {
259    /// Outputs
260    #[cfg_attr(feature = "swagger", schema(max_items = 1_000))]
261    pub outputs: Vec<BlindedMessage>,
262}
263
264impl MintAuthRequest {
265    /// Count of tokens
266    pub fn amount(&self) -> u64 {
267        self.outputs.len() as u64
268    }
269}
270
271#[cfg(test)]
272mod tests {
273    use std::collections::HashSet;
274
275    use super::super::nut21::{Method, RoutePath};
276    use super::*;
277    use crate::nut00::KnownMethod;
278    use crate::PaymentMethod;
279
280    #[test]
281    fn test_settings_deserialize_direct_paths() {
282        let json = r#"{
283            "bat_max_mint": 10,
284            "protected_endpoints": [
285                {
286                    "method": "GET",
287                    "path": "/v1/mint/bolt11"
288                },
289                {
290                    "method": "POST",
291                    "path": "/v1/swap"
292                }
293            ]
294        }"#;
295
296        let settings: Settings = serde_json::from_str(json).unwrap();
297
298        assert_eq!(settings.bat_max_mint, 10);
299        assert_eq!(settings.protected_endpoints.len(), 2);
300
301        // Check that both paths are included
302        let paths = settings
303            .protected_endpoints
304            .iter()
305            .map(|ep| (ep.method, ep.path.clone()))
306            .collect::<Vec<_>>();
307        assert!(paths.contains(&(
308            Method::Get,
309            RoutePath::Mint(PaymentMethod::Known(KnownMethod::Bolt11).to_string())
310        )));
311        assert!(paths.contains(&(Method::Post, RoutePath::Swap)));
312    }
313
314    #[test]
315    fn test_settings_deserialize_with_regex() {
316        let json = r#"{
317            "bat_max_mint": 5,
318            "protected_endpoints": [
319                {
320                    "method": "GET",
321                    "path": "/v1/mint/*"
322                },
323                {
324                    "method": "POST",
325                    "path": "/v1/swap"
326                }
327            ]
328        }"#;
329
330        let settings: Settings = serde_json::from_str(json).unwrap();
331
332        assert_eq!(settings.bat_max_mint, 5);
333        assert_eq!(settings.protected_endpoints.len(), 5); // 4 mint paths + 1 swap path
334
335        let expected_protected: HashSet<ProtectedEndpoint> = HashSet::from_iter(vec![
336            ProtectedEndpoint::new(Method::Post, RoutePath::Swap),
337            ProtectedEndpoint::new(
338                Method::Get,
339                RoutePath::Mint(PaymentMethod::Known(KnownMethod::Bolt11).to_string()),
340            ),
341            ProtectedEndpoint::new(
342                Method::Get,
343                RoutePath::MintQuote(PaymentMethod::Known(KnownMethod::Bolt11).to_string()),
344            ),
345            ProtectedEndpoint::new(
346                Method::Get,
347                RoutePath::MintQuote(PaymentMethod::Known(KnownMethod::Bolt12).to_string()),
348            ),
349            ProtectedEndpoint::new(
350                Method::Get,
351                RoutePath::Mint(PaymentMethod::Known(KnownMethod::Bolt12).to_string()),
352            ),
353        ]);
354
355        let deserialized_protected = settings.protected_endpoints.into_iter().collect();
356
357        assert_eq!(expected_protected, deserialized_protected);
358    }
359
360    #[test]
361    fn test_settings_deserialize_invalid_regex() {
362        let json = r#"{
363            "bat_max_mint": 5,
364            "protected_endpoints": [
365                {
366                    "method": "GET",
367                    "path": "/*wildcard_start"
368                }
369            ]
370        }"#;
371
372        let result = serde_json::from_str::<Settings>(json);
373        assert!(result.is_err());
374    }
375
376    #[test]
377    fn test_settings_deserialize_all_paths() {
378        let json = r#"{
379            "bat_max_mint": 5,
380            "protected_endpoints": [
381                {
382                    "method": "GET",
383                    "path": "/v1/*"
384                }
385            ]
386        }"#;
387
388        let settings: Settings = serde_json::from_str(json).unwrap();
389        assert_eq!(
390            settings.protected_endpoints.len(),
391            RoutePath::all_known_paths().len()
392        );
393    }
394}