cashu/nuts/auth/
nut21.rs

1//! 21 Clear Auth
2
3use std::collections::HashSet;
4use std::str::FromStr;
5
6use regex::Regex;
7use serde::{Deserialize, Serialize};
8use strum::IntoEnumIterator;
9use strum_macros::EnumIter;
10use thiserror::Error;
11
12/// NUT21 Error
13#[derive(Debug, Error)]
14pub enum Error {
15    /// Invalid regex pattern
16    #[error("Invalid regex pattern: {0}")]
17    InvalidRegex(#[from] regex::Error),
18}
19
20/// Clear Auth Settings
21#[derive(Debug, Clone, PartialEq, Eq, Hash, Default, Serialize)]
22#[cfg_attr(feature = "swagger", derive(utoipa::ToSchema))]
23pub struct Settings {
24    /// Openid discovery
25    pub openid_discovery: String,
26    /// Client ID
27    pub client_id: String,
28    /// Protected endpoints
29    pub protected_endpoints: Vec<ProtectedEndpoint>,
30}
31
32impl Settings {
33    /// Create new [`Settings`]
34    pub fn new(
35        openid_discovery: String,
36        client_id: String,
37        protected_endpoints: Vec<ProtectedEndpoint>,
38    ) -> Self {
39        Self {
40            openid_discovery,
41            client_id,
42            protected_endpoints,
43        }
44    }
45}
46
47// Custom deserializer for Settings to expand regex patterns in protected endpoints
48impl<'de> Deserialize<'de> for Settings {
49    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
50    where
51        D: serde::Deserializer<'de>,
52    {
53        // Define a temporary struct to deserialize the raw data
54        #[derive(Deserialize)]
55        struct RawSettings {
56            openid_discovery: String,
57            client_id: String,
58            protected_endpoints: Vec<RawProtectedEndpoint>,
59        }
60
61        #[derive(Deserialize)]
62        struct RawProtectedEndpoint {
63            method: Method,
64            path: String,
65        }
66
67        // Deserialize into the temporary struct
68        let raw = RawSettings::deserialize(deserializer)?;
69
70        // Process protected endpoints, expanding regex patterns if present
71        let mut protected_endpoints = HashSet::new();
72
73        for raw_endpoint in raw.protected_endpoints {
74            let expanded_paths = matching_route_paths(&raw_endpoint.path).map_err(|e| {
75                serde::de::Error::custom(format!(
76                    "Invalid regex pattern '{}': {}",
77                    raw_endpoint.path, e
78                ))
79            })?;
80
81            for path in expanded_paths {
82                protected_endpoints.insert(ProtectedEndpoint::new(raw_endpoint.method, path));
83            }
84        }
85
86        // Create the final Settings struct
87        Ok(Settings {
88            openid_discovery: raw.openid_discovery,
89            client_id: raw.client_id,
90            protected_endpoints: protected_endpoints.into_iter().collect(),
91        })
92    }
93}
94
95/// List of the methods and paths that are protected
96#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
97#[cfg_attr(feature = "swagger", derive(utoipa::ToSchema))]
98pub struct ProtectedEndpoint {
99    /// HTTP Method
100    pub method: Method,
101    /// Route path
102    pub path: RoutePath,
103}
104
105impl ProtectedEndpoint {
106    /// Create [`ProtectedEndpoint`]
107    pub fn new(method: Method, path: RoutePath) -> Self {
108        Self { method, path }
109    }
110}
111
112/// HTTP method
113#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
114#[serde(rename_all = "UPPERCASE")]
115#[cfg_attr(feature = "swagger", derive(utoipa::ToSchema))]
116pub enum Method {
117    /// Get
118    Get,
119    /// POST
120    Post,
121}
122
123/// Route path
124#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, EnumIter)]
125#[cfg_attr(feature = "swagger", derive(utoipa::ToSchema))]
126#[serde(rename_all = "snake_case")]
127pub enum RoutePath {
128    /// Bolt11 Mint Quote
129    #[serde(rename = "/v1/mint/quote/bolt11")]
130    MintQuoteBolt11,
131    /// Bolt11 Mint
132    #[serde(rename = "/v1/mint/bolt11")]
133    MintBolt11,
134    /// Bolt11 Melt Quote
135    #[serde(rename = "/v1/melt/quote/bolt11")]
136    MeltQuoteBolt11,
137    /// Bolt11 Melt
138    #[serde(rename = "/v1/melt/bolt11")]
139    MeltBolt11,
140    /// Swap
141    #[serde(rename = "/v1/swap")]
142    Swap,
143    /// Checkstate
144    #[serde(rename = "/v1/checkstate")]
145    Checkstate,
146    /// Restore
147    #[serde(rename = "/v1/restore")]
148    Restore,
149    /// Mint Blind Auth
150    #[serde(rename = "/v1/auth/blind/mint")]
151    MintBlindAuth,
152    /// Bolt12 Mint Quote
153    #[serde(rename = "/v1/mint/quote/bolt12")]
154    MintQuoteBolt12,
155    /// Bolt12 Mint
156    #[serde(rename = "/v1/mint/bolt12")]
157    MintBolt12,
158    /// Bolt12 Melt Quote
159    #[serde(rename = "/v1/melt/quote/bolt12")]
160    MeltQuoteBolt12,
161    /// Bolt12 Quote
162    #[serde(rename = "/v1/melt/bolt12")]
163    MeltBolt12,
164}
165
166/// Returns [`RoutePath`]s that match regex
167pub fn matching_route_paths(pattern: &str) -> Result<Vec<RoutePath>, Error> {
168    let regex = Regex::from_str(pattern)?;
169
170    Ok(RoutePath::iter()
171        .filter(|path| regex.is_match(&path.to_string()))
172        .collect())
173}
174
175impl std::fmt::Display for RoutePath {
176    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
177        // Use serde to serialize to a JSON string, then extract the value without quotes
178        let json_str = match serde_json::to_string(self) {
179            Ok(s) => s,
180            Err(_) => return write!(f, "<error>"),
181        };
182        // Remove the quotes from the JSON string
183        let path = json_str.trim_matches('"');
184        write!(f, "{path}")
185    }
186}
187
188#[cfg(test)]
189mod tests {
190
191    use super::*;
192
193    #[test]
194    fn test_matching_route_paths_all() {
195        // Regex that matches all paths
196        let paths = matching_route_paths(".*").unwrap();
197
198        // Should match all variants
199        assert_eq!(paths.len(), RoutePath::iter().count());
200
201        // Verify all variants are included
202        assert!(paths.contains(&RoutePath::MintQuoteBolt11));
203        assert!(paths.contains(&RoutePath::MintBolt11));
204        assert!(paths.contains(&RoutePath::MeltQuoteBolt11));
205        assert!(paths.contains(&RoutePath::MeltBolt11));
206        assert!(paths.contains(&RoutePath::Swap));
207        assert!(paths.contains(&RoutePath::Checkstate));
208        assert!(paths.contains(&RoutePath::Restore));
209        assert!(paths.contains(&RoutePath::MintBlindAuth));
210        assert!(paths.contains(&RoutePath::MintQuoteBolt12));
211        assert!(paths.contains(&RoutePath::MintBolt12));
212    }
213
214    #[test]
215    fn test_matching_route_paths_mint_only() {
216        // Regex that matches only mint paths
217        let paths = matching_route_paths("^/v1/mint/.*").unwrap();
218
219        // Should match only mint paths
220        assert_eq!(paths.len(), 4);
221        assert!(paths.contains(&RoutePath::MintQuoteBolt11));
222        assert!(paths.contains(&RoutePath::MintBolt11));
223        assert!(paths.contains(&RoutePath::MintQuoteBolt12));
224        assert!(paths.contains(&RoutePath::MintBolt12));
225
226        // Should not match other paths
227        assert!(!paths.contains(&RoutePath::MeltQuoteBolt11));
228        assert!(!paths.contains(&RoutePath::MeltBolt11));
229        assert!(!paths.contains(&RoutePath::MeltQuoteBolt12));
230        assert!(!paths.contains(&RoutePath::MeltBolt12));
231        assert!(!paths.contains(&RoutePath::Swap));
232    }
233
234    #[test]
235    fn test_matching_route_paths_quote_only() {
236        // Regex that matches only quote paths
237        let paths = matching_route_paths(".*/quote/.*").unwrap();
238
239        // Should match only quote paths
240        assert_eq!(paths.len(), 4);
241        assert!(paths.contains(&RoutePath::MintQuoteBolt11));
242        assert!(paths.contains(&RoutePath::MeltQuoteBolt11));
243        assert!(paths.contains(&RoutePath::MintQuoteBolt12));
244        assert!(paths.contains(&RoutePath::MeltQuoteBolt12));
245
246        // Should not match non-quote paths
247        assert!(!paths.contains(&RoutePath::MintBolt11));
248        assert!(!paths.contains(&RoutePath::MeltBolt11));
249    }
250
251    #[test]
252    fn test_matching_route_paths_no_match() {
253        // Regex that matches nothing
254        let paths = matching_route_paths("/nonexistent/path").unwrap();
255
256        // Should match nothing
257        assert!(paths.is_empty());
258    }
259
260    #[test]
261    fn test_matching_route_paths_quote_bolt11_only() {
262        // Regex that matches only quote paths
263        let paths = matching_route_paths("/v1/mint/quote/bolt11").unwrap();
264
265        // Should match only quote paths
266        assert_eq!(paths.len(), 1);
267        assert!(paths.contains(&RoutePath::MintQuoteBolt11));
268    }
269
270    #[test]
271    fn test_matching_route_paths_invalid_regex() {
272        // Invalid regex pattern
273        let result = matching_route_paths("(unclosed parenthesis");
274
275        // Should return an error for invalid regex
276        assert!(result.is_err());
277        assert!(matches!(result.unwrap_err(), Error::InvalidRegex(_)));
278    }
279
280    #[test]
281    fn test_route_path_to_string() {
282        // Test that to_string() returns the correct path strings
283        assert_eq!(
284            RoutePath::MintQuoteBolt11.to_string(),
285            "/v1/mint/quote/bolt11"
286        );
287        assert_eq!(RoutePath::MintBolt11.to_string(), "/v1/mint/bolt11");
288        assert_eq!(
289            RoutePath::MeltQuoteBolt11.to_string(),
290            "/v1/melt/quote/bolt11"
291        );
292        assert_eq!(RoutePath::MeltBolt11.to_string(), "/v1/melt/bolt11");
293        assert_eq!(RoutePath::Swap.to_string(), "/v1/swap");
294        assert_eq!(RoutePath::Checkstate.to_string(), "/v1/checkstate");
295        assert_eq!(RoutePath::Restore.to_string(), "/v1/restore");
296        assert_eq!(RoutePath::MintBlindAuth.to_string(), "/v1/auth/blind/mint");
297    }
298
299    #[test]
300    fn test_settings_deserialize_direct_paths() {
301        let json = r#"{
302            "openid_discovery": "https://example.com/.well-known/openid-configuration",
303            "client_id": "client123",
304            "protected_endpoints": [
305                {
306                    "method": "GET",
307                    "path": "/v1/mint/bolt11"
308                },
309                {
310                    "method": "POST",
311                    "path": "/v1/swap"
312                }
313            ]
314        }"#;
315
316        let settings: Settings = serde_json::from_str(json).unwrap();
317
318        assert_eq!(
319            settings.openid_discovery,
320            "https://example.com/.well-known/openid-configuration"
321        );
322        assert_eq!(settings.client_id, "client123");
323        assert_eq!(settings.protected_endpoints.len(), 2);
324
325        // Check that both paths are included
326        let paths = settings
327            .protected_endpoints
328            .iter()
329            .map(|ep| (ep.method, ep.path))
330            .collect::<Vec<_>>();
331        assert!(paths.contains(&(Method::Get, RoutePath::MintBolt11)));
332        assert!(paths.contains(&(Method::Post, RoutePath::Swap)));
333    }
334
335    #[test]
336    fn test_settings_deserialize_with_regex() {
337        let json = r#"{
338            "openid_discovery": "https://example.com/.well-known/openid-configuration",
339            "client_id": "client123",
340            "protected_endpoints": [
341                {
342                    "method": "GET",
343                    "path": "^/v1/mint/.*"
344                },
345                {
346                    "method": "POST",
347                    "path": "/v1/swap"
348                }
349            ]
350        }"#;
351
352        let settings: Settings = serde_json::from_str(json).unwrap();
353
354        assert_eq!(
355            settings.openid_discovery,
356            "https://example.com/.well-known/openid-configuration"
357        );
358        assert_eq!(settings.client_id, "client123");
359        assert_eq!(settings.protected_endpoints.len(), 5); // 3 mint paths + 1 swap path
360
361        let expected_protected: HashSet<ProtectedEndpoint> = HashSet::from_iter(vec![
362            ProtectedEndpoint::new(Method::Post, RoutePath::Swap),
363            ProtectedEndpoint::new(Method::Get, RoutePath::MintBolt11),
364            ProtectedEndpoint::new(Method::Get, RoutePath::MintQuoteBolt11),
365            ProtectedEndpoint::new(Method::Get, RoutePath::MintQuoteBolt12),
366            ProtectedEndpoint::new(Method::Get, RoutePath::MintBolt12),
367        ]);
368
369        let deserlized_protected = settings.protected_endpoints.into_iter().collect();
370
371        assert_eq!(expected_protected, deserlized_protected);
372    }
373
374    #[test]
375    fn test_settings_deserialize_invalid_regex() {
376        let json = r#"{
377            "openid_discovery": "https://example.com/.well-known/openid-configuration",
378            "client_id": "client123",
379            "protected_endpoints": [
380                {
381                    "method": "GET",
382                    "path": "(unclosed parenthesis"
383                }
384            ]
385        }"#;
386
387        let result = serde_json::from_str::<Settings>(json);
388        assert!(result.is_err());
389    }
390
391    #[test]
392    fn test_settings_deserialize_exact_path_match() {
393        let json = r#"{
394            "openid_discovery": "https://example.com/.well-known/openid-configuration",
395            "client_id": "client123",
396            "protected_endpoints": [
397                {
398                    "method": "GET",
399                    "path": "/v1/mint/quote/bolt11"
400                }
401            ]
402        }"#;
403
404        let settings: Settings = serde_json::from_str(json).unwrap();
405        assert_eq!(settings.protected_endpoints.len(), 1);
406        assert_eq!(settings.protected_endpoints[0].method, Method::Get);
407        assert_eq!(
408            settings.protected_endpoints[0].path,
409            RoutePath::MintQuoteBolt11
410        );
411    }
412
413    #[test]
414    fn test_settings_deserialize_all_paths() {
415        let json = r#"{
416            "openid_discovery": "https://example.com/.well-known/openid-configuration",
417            "client_id": "client123",
418            "protected_endpoints": [
419                {
420                    "method": "GET",
421                    "path": ".*"
422                }
423            ]
424        }"#;
425
426        let settings: Settings = serde_json::from_str(json).unwrap();
427        assert_eq!(
428            settings.protected_endpoints.len(),
429            RoutePath::iter().count()
430        );
431    }
432}