Skip to main content

cashu/nuts/auth/
nut22.rs

1//! 22 Blind Auth
2
3use std::fmt;
4
5use bitcoin::base64::engine::general_purpose::{self, GeneralPurposeConfig};
6use bitcoin::base64::engine::GeneralPurpose;
7use bitcoin::base64::{alphabet, Engine};
8use serde::{Deserialize, Serialize};
9use thiserror::Error;
10
11use super::nut21::ProtectedEndpoint;
12use crate::dhke::hash_to_curve;
13use crate::secret::Secret;
14use crate::util::hex;
15use crate::{BlindedMessage, Id, Proof, ProofDleq, PublicKey};
16
17/// NUT22 Error
18#[derive(Debug, Error)]
19pub enum Error {
20    /// Invalid Prefix
21    #[error("Invalid prefix")]
22    InvalidPrefix,
23    /// Dleq proof not included
24    #[error("Dleq Proof not included for auth proof")]
25    DleqProofNotIncluded,
26    /// Hex Error
27    #[error(transparent)]
28    HexError(#[from] hex::Error),
29    /// Base64 error
30    #[error(transparent)]
31    Base64Error(#[from] bitcoin::base64::DecodeError),
32    /// Serde Json error
33    #[error(transparent)]
34    SerdeJsonError(#[from] serde_json::Error),
35    /// Utf8 parse error
36    #[error(transparent)]
37    Utf8ParseError(#[from] std::string::FromUtf8Error),
38    /// DHKE error
39    #[error(transparent)]
40    DHKE(#[from] crate::dhke::Error),
41}
42
43/// Blind auth settings
44#[derive(Debug, Clone, PartialEq, Eq, Hash, Default, Serialize)]
45pub struct Settings {
46    /// Max number of blind auth tokens that can be minted per request
47    pub bat_max_mint: u64,
48    /// Protected endpoints
49    pub protected_endpoints: Vec<ProtectedEndpoint>,
50}
51
52impl Settings {
53    /// Create new [`Settings`]
54    pub fn new(bat_max_mint: u64, protected_endpoints: Vec<ProtectedEndpoint>) -> Self {
55        Self {
56            bat_max_mint,
57            protected_endpoints,
58        }
59    }
60}
61
62// Custom deserializer for Settings to expand patterns in protected endpoints
63impl<'de> Deserialize<'de> for Settings {
64    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
65    where
66        D: serde::Deserializer<'de>,
67    {
68        use std::collections::HashSet;
69
70        use super::nut21::matching_route_paths;
71
72        // Define a temporary struct to deserialize the raw data
73        #[derive(Deserialize)]
74        struct RawSettings {
75            bat_max_mint: u64,
76            protected_endpoints: Vec<RawProtectedEndpoint>,
77        }
78
79        #[derive(Deserialize)]
80        struct RawProtectedEndpoint {
81            method: super::nut21::Method,
82            path: String,
83        }
84
85        // Deserialize into the temporary struct
86        let raw = RawSettings::deserialize(deserializer)?;
87
88        // Process protected endpoints, expanding patterns if present
89        let mut protected_endpoints = HashSet::new();
90
91        for raw_endpoint in raw.protected_endpoints {
92            let expanded_paths = matching_route_paths(&raw_endpoint.path).map_err(|e| {
93                serde::de::Error::custom(format!("Invalid pattern '{}': {}", raw_endpoint.path, e))
94            })?;
95
96            for path in expanded_paths {
97                protected_endpoints.insert(super::nut21::ProtectedEndpoint::new(
98                    raw_endpoint.method,
99                    path,
100                ));
101            }
102        }
103
104        // Create the final Settings struct
105        Ok(Settings {
106            bat_max_mint: raw.bat_max_mint,
107            protected_endpoints: protected_endpoints.into_iter().collect(),
108        })
109    }
110}
111
112/// Auth Token
113#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
114pub enum AuthToken {
115    /// Clear Auth token
116    ClearAuth(String),
117    /// Blind Auth token
118    BlindAuth(BlindAuthToken),
119}
120
121impl fmt::Display for AuthToken {
122    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
123        match self {
124            Self::ClearAuth(cat) => cat.fmt(f),
125            Self::BlindAuth(bat) => bat.fmt(f),
126        }
127    }
128}
129
130impl AuthToken {
131    /// Header key for auth token type
132    pub fn header_key(&self) -> String {
133        match self {
134            Self::ClearAuth(_) => "Clear-auth".to_string(),
135            Self::BlindAuth(_) => "Blind-auth".to_string(),
136        }
137    }
138}
139
140/// Required Auth
141#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
142pub enum AuthRequired {
143    /// Clear Auth token
144    Clear,
145    /// Blind Auth token
146    Blind,
147}
148
149/// Auth Proofs
150#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
151pub struct AuthProof {
152    /// `Keyset id`
153    #[serde(rename = "id")]
154    pub keyset_id: Id,
155    /// Secret message
156    pub secret: Secret,
157    /// Unblinded signature
158    #[serde(rename = "C")]
159    pub c: PublicKey,
160    /// Auth Proof Dleq
161    pub dleq: Option<ProofDleq>,
162}
163
164impl AuthProof {
165    /// Y of AuthProof
166    pub fn y(&self) -> Result<PublicKey, Error> {
167        Ok(hash_to_curve(self.secret.as_bytes())?)
168    }
169}
170
171impl From<AuthProof> for Proof {
172    fn from(value: AuthProof) -> Self {
173        Self {
174            amount: 1.into(),
175            keyset_id: value.keyset_id,
176            secret: value.secret,
177            c: value.c,
178            witness: None,
179            dleq: value.dleq,
180            p2pk_e: None,
181        }
182    }
183}
184
185impl TryFrom<Proof> for AuthProof {
186    type Error = Error;
187    fn try_from(value: Proof) -> Result<Self, Self::Error> {
188        Ok(Self {
189            keyset_id: value.keyset_id,
190            secret: value.secret,
191            c: value.c,
192            dleq: value.dleq,
193        })
194    }
195}
196
197/// Blind Auth Token
198#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
199pub struct BlindAuthToken {
200    /// [AuthProof]
201    pub auth_proof: AuthProof,
202}
203
204impl BlindAuthToken {
205    /// Create new [ `BlindAuthToken`]
206    pub fn new(auth_proof: AuthProof) -> Self {
207        BlindAuthToken { auth_proof }
208    }
209
210    /// Remove DLEQ
211    ///
212    /// We do not send the DLEQ to the mint as it links redemption and creation
213    pub fn without_dleq(&self) -> Self {
214        Self {
215            auth_proof: AuthProof {
216                keyset_id: self.auth_proof.keyset_id,
217                secret: self.auth_proof.secret.clone(),
218                c: self.auth_proof.c,
219                dleq: None,
220            },
221        }
222    }
223}
224
225impl fmt::Display for BlindAuthToken {
226    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
227        let json_string = serde_json::to_string(&self.auth_proof).map_err(|_| fmt::Error)?;
228        let encoded = general_purpose::URL_SAFE.encode(json_string);
229        write!(f, "authA{encoded}")
230    }
231}
232
233impl std::str::FromStr for BlindAuthToken {
234    type Err = Error;
235
236    fn from_str(s: &str) -> Result<Self, Self::Err> {
237        // Check prefix and extract the base64 encoded part in one step
238        let encoded = s.strip_prefix("authA").ok_or(Error::InvalidPrefix)?;
239
240        // Decode the base64 URL-safe string (accept with or without padding)
241        let decode_config = GeneralPurposeConfig::new()
242            .with_decode_padding_mode(bitcoin::base64::engine::DecodePaddingMode::Indifferent);
243        let json_string =
244            GeneralPurpose::new(&alphabet::URL_SAFE, decode_config).decode(encoded)?;
245
246        // Convert bytes to UTF-8 string
247        let json_str = String::from_utf8(json_string)?;
248
249        // Deserialize the JSON string into AuthProof
250        let auth_proof: AuthProof = serde_json::from_str(&json_str)?;
251
252        Ok(BlindAuthToken { auth_proof })
253    }
254}
255
256/// Mint auth request [NUT-XX]
257#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
258pub struct MintAuthRequest {
259    /// Outputs
260    pub outputs: Vec<BlindedMessage>,
261}
262
263impl MintAuthRequest {
264    /// Count of tokens
265    pub fn amount(&self) -> u64 {
266        self.outputs.len() as u64
267    }
268}
269
270#[cfg(test)]
271mod tests {
272    use std::collections::HashSet;
273
274    use super::super::nut21::{Method, RoutePath};
275    use super::*;
276    use crate::nut00::KnownMethod;
277    use crate::PaymentMethod;
278
279    #[test]
280    fn test_blind_auth_token_padding() {
281        use std::str::FromStr;
282
283        use crate::SecretKey;
284
285        // Build a valid BlindAuthToken programmatically
286        let secret_key = SecretKey::generate();
287        let public_key = secret_key.public_key();
288        let secret = Secret::generate();
289        let auth_proof = AuthProof {
290            keyset_id: Id::from_bytes(&[0, 1, 2, 3, 4, 5, 6, 7]).expect("valid id"),
291            secret,
292            c: public_key,
293            dleq: None,
294        };
295        let token = BlindAuthToken::new(auth_proof);
296
297        // Serialize (Display impl produces padded base64)
298        let token_str = token.to_string();
299        assert!(token_str.starts_with("authA"));
300
301        // Parse with padding
302        let parsed =
303            BlindAuthToken::from_str(&token_str).expect("Failed to parse token with padding");
304        assert_eq!(token, parsed);
305
306        // Strip padding and parse again
307        let token_no_pad = token_str.trim_end_matches('=');
308        let parsed_no_pad =
309            BlindAuthToken::from_str(token_no_pad).expect("Failed to parse token without padding");
310        assert_eq!(token, parsed_no_pad);
311    }
312
313    #[test]
314    fn test_settings_deserialize_direct_paths() {
315        let json = r#"{
316            "bat_max_mint": 10,
317            "protected_endpoints": [
318                {
319                    "method": "GET",
320                    "path": "/v1/mint/bolt11"
321                },
322                {
323                    "method": "POST",
324                    "path": "/v1/swap"
325                }
326            ]
327        }"#;
328
329        let settings: Settings = serde_json::from_str(json).unwrap();
330
331        assert_eq!(settings.bat_max_mint, 10);
332        assert_eq!(settings.protected_endpoints.len(), 2);
333
334        // Check that both paths are included
335        let paths = settings
336            .protected_endpoints
337            .iter()
338            .map(|ep| (ep.method, ep.path.clone()))
339            .collect::<Vec<_>>();
340        assert!(paths.contains(&(
341            Method::Get,
342            RoutePath::Mint(PaymentMethod::Known(KnownMethod::Bolt11).to_string())
343        )));
344        assert!(paths.contains(&(Method::Post, RoutePath::Swap)));
345    }
346
347    #[test]
348    fn test_settings_deserialize_with_regex() {
349        let json = r#"{
350            "bat_max_mint": 5,
351            "protected_endpoints": [
352                {
353                    "method": "GET",
354                    "path": "/v1/mint/*"
355                },
356                {
357                    "method": "POST",
358                    "path": "/v1/swap"
359                }
360            ]
361        }"#;
362
363        let settings: Settings = serde_json::from_str(json).unwrap();
364
365        assert_eq!(settings.bat_max_mint, 5);
366        assert_eq!(settings.protected_endpoints.len(), 6); // 4 mint paths + wildcard + 1 swap path
367
368        let expected_protected: HashSet<ProtectedEndpoint> = HashSet::from_iter(vec![
369            ProtectedEndpoint::new(Method::Post, RoutePath::Swap),
370            ProtectedEndpoint::new(
371                Method::Get,
372                RoutePath::Mint(PaymentMethod::Known(KnownMethod::Bolt11).to_string()),
373            ),
374            ProtectedEndpoint::new(
375                Method::Get,
376                RoutePath::MintQuote(PaymentMethod::Known(KnownMethod::Bolt11).to_string()),
377            ),
378            ProtectedEndpoint::new(
379                Method::Get,
380                RoutePath::MintQuote(PaymentMethod::Known(KnownMethod::Bolt12).to_string()),
381            ),
382            ProtectedEndpoint::new(
383                Method::Get,
384                RoutePath::Mint(PaymentMethod::Known(KnownMethod::Bolt12).to_string()),
385            ),
386            ProtectedEndpoint::new(Method::Get, RoutePath::Wildcard("/v1/mint/".to_string())),
387        ]);
388
389        let deserialized_protected = settings.protected_endpoints.into_iter().collect();
390
391        assert_eq!(expected_protected, deserialized_protected);
392    }
393
394    #[test]
395    fn test_settings_deserialize_invalid_regex() {
396        let json = r#"{
397            "bat_max_mint": 5,
398            "protected_endpoints": [
399                {
400                    "method": "GET",
401                    "path": "/*wildcard_start"
402                }
403            ]
404        }"#;
405
406        let result = serde_json::from_str::<Settings>(json);
407        assert!(result.is_err());
408    }
409
410    #[test]
411    fn test_settings_deserialize_unknown_exact_path() {
412        let json = r#"{
413            "bat_max_mint": 5,
414            "protected_endpoints": [
415                {
416                    "method": "POST",
417                    "path": "/v1/swp"
418                }
419            ]
420        }"#;
421
422        let result = serde_json::from_str::<Settings>(json);
423        assert!(result.is_err());
424    }
425
426    #[test]
427    fn test_settings_deserialize_all_paths() {
428        let json = r#"{
429            "bat_max_mint": 5,
430            "protected_endpoints": [
431                {
432                    "method": "GET",
433                    "path": "/v1/*"
434                }
435            ]
436        }"#;
437
438        let settings: Settings = serde_json::from_str(json).unwrap();
439        assert_eq!(
440            settings.protected_endpoints.len(),
441            RoutePath::all_known_paths().len() + 1
442        );
443    }
444}