cashu/nuts/auth/
nut21.rs

1//! 21 Clear Auth
2
3use std::collections::HashSet;
4use std::str::FromStr;
5
6use regex::Regex;
7use serde::{Deserialize, Serialize};
8use strum::IntoEnumIterator;
9use strum_macros::EnumIter;
10use thiserror::Error;
11
12/// NUT21 Error
13#[derive(Debug, Error)]
14pub enum Error {
15    /// Invalid regex pattern
16    #[error("Invalid regex pattern: {0}")]
17    InvalidRegex(#[from] regex::Error),
18}
19
20/// Clear Auth Settings
21#[derive(Debug, Clone, PartialEq, Eq, Hash, Default, Serialize)]
22#[cfg_attr(feature = "swagger", derive(utoipa::ToSchema))]
23pub struct Settings {
24    /// Openid discovery
25    pub openid_discovery: String,
26    /// Client ID
27    pub client_id: String,
28    /// Protected endpoints
29    pub protected_endpoints: Vec<ProtectedEndpoint>,
30}
31
32impl Settings {
33    /// Create new [`Settings`]
34    pub fn new(
35        openid_discovery: String,
36        client_id: String,
37        protected_endpoints: Vec<ProtectedEndpoint>,
38    ) -> Self {
39        Self {
40            openid_discovery,
41            client_id,
42            protected_endpoints,
43        }
44    }
45}
46
47// Custom deserializer for Settings to expand regex patterns in protected endpoints
48impl<'de> Deserialize<'de> for Settings {
49    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
50    where
51        D: serde::Deserializer<'de>,
52    {
53        // Define a temporary struct to deserialize the raw data
54        #[derive(Deserialize)]
55        struct RawSettings {
56            openid_discovery: String,
57            client_id: String,
58            protected_endpoints: Vec<RawProtectedEndpoint>,
59        }
60
61        #[derive(Deserialize)]
62        struct RawProtectedEndpoint {
63            method: Method,
64            path: String,
65        }
66
67        // Deserialize into the temporary struct
68        let raw = RawSettings::deserialize(deserializer)?;
69
70        // Process protected endpoints, expanding regex patterns if present
71        let mut protected_endpoints = HashSet::new();
72
73        for raw_endpoint in raw.protected_endpoints {
74            let expanded_paths = matching_route_paths(&raw_endpoint.path).map_err(|e| {
75                serde::de::Error::custom(format!(
76                    "Invalid regex pattern '{}': {}",
77                    raw_endpoint.path, e
78                ))
79            })?;
80
81            for path in expanded_paths {
82                protected_endpoints.insert(ProtectedEndpoint::new(raw_endpoint.method, path));
83            }
84        }
85
86        // Create the final Settings struct
87        Ok(Settings {
88            openid_discovery: raw.openid_discovery,
89            client_id: raw.client_id,
90            protected_endpoints: protected_endpoints.into_iter().collect(),
91        })
92    }
93}
94
95/// List of the methods and paths that are protected
96#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
97#[cfg_attr(feature = "swagger", derive(utoipa::ToSchema))]
98pub struct ProtectedEndpoint {
99    /// HTTP Method
100    pub method: Method,
101    /// Route path
102    pub path: RoutePath,
103}
104
105impl ProtectedEndpoint {
106    /// Create [`ProtectedEndpoint`]
107    pub fn new(method: Method, path: RoutePath) -> Self {
108        Self { method, path }
109    }
110}
111
112/// HTTP method
113#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
114#[serde(rename_all = "UPPERCASE")]
115#[cfg_attr(feature = "swagger", derive(utoipa::ToSchema))]
116pub enum Method {
117    /// Get
118    Get,
119    /// POST
120    Post,
121}
122
123/// Route path
124#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, EnumIter)]
125#[cfg_attr(feature = "swagger", derive(utoipa::ToSchema))]
126#[serde(rename_all = "snake_case")]
127pub enum RoutePath {
128    /// Bolt11 Mint Quote
129    #[serde(rename = "/v1/mint/quote/bolt11")]
130    MintQuoteBolt11,
131    /// Bolt11 Mint
132    #[serde(rename = "/v1/mint/bolt11")]
133    MintBolt11,
134    /// Bolt11 Melt Quote
135    #[serde(rename = "/v1/melt/quote/bolt11")]
136    MeltQuoteBolt11,
137    /// Bolt11 Melt
138    #[serde(rename = "/v1/melt/bolt11")]
139    MeltBolt11,
140    /// Swap
141    #[serde(rename = "/v1/swap")]
142    Swap,
143    /// Checkstate
144    #[serde(rename = "/v1/checkstate")]
145    Checkstate,
146    /// Restore
147    #[serde(rename = "/v1/restore")]
148    Restore,
149    /// Mint Blind Auth
150    #[serde(rename = "/v1/auth/blind/mint")]
151    MintBlindAuth,
152}
153
154/// Returns [`RoutePath`]s that match regex
155pub fn matching_route_paths(pattern: &str) -> Result<Vec<RoutePath>, Error> {
156    let regex = Regex::from_str(pattern)?;
157
158    Ok(RoutePath::iter()
159        .filter(|path| regex.is_match(&path.to_string()))
160        .collect())
161}
162
163impl std::fmt::Display for RoutePath {
164    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
165        // Use serde to serialize to a JSON string, then extract the value without quotes
166        let json_str = match serde_json::to_string(self) {
167            Ok(s) => s,
168            Err(_) => return write!(f, "<error>"),
169        };
170        // Remove the quotes from the JSON string
171        let path = json_str.trim_matches('"');
172        write!(f, "{path}")
173    }
174}
175
176#[cfg(test)]
177mod tests {
178
179    use super::*;
180
181    #[test]
182    fn test_matching_route_paths_all() {
183        // Regex that matches all paths
184        let paths = matching_route_paths(".*").unwrap();
185
186        // Should match all variants
187        assert_eq!(paths.len(), RoutePath::iter().count());
188
189        // Verify all variants are included
190        assert!(paths.contains(&RoutePath::MintQuoteBolt11));
191        assert!(paths.contains(&RoutePath::MintBolt11));
192        assert!(paths.contains(&RoutePath::MeltQuoteBolt11));
193        assert!(paths.contains(&RoutePath::MeltBolt11));
194        assert!(paths.contains(&RoutePath::Swap));
195        assert!(paths.contains(&RoutePath::Checkstate));
196        assert!(paths.contains(&RoutePath::Restore));
197        assert!(paths.contains(&RoutePath::MintBlindAuth));
198    }
199
200    #[test]
201    fn test_matching_route_paths_mint_only() {
202        // Regex that matches only mint paths
203        let paths = matching_route_paths("^/v1/mint/.*").unwrap();
204
205        // Should match only mint paths
206        assert_eq!(paths.len(), 2);
207        assert!(paths.contains(&RoutePath::MintQuoteBolt11));
208        assert!(paths.contains(&RoutePath::MintBolt11));
209
210        // Should not match other paths
211        assert!(!paths.contains(&RoutePath::MeltQuoteBolt11));
212        assert!(!paths.contains(&RoutePath::MeltBolt11));
213        assert!(!paths.contains(&RoutePath::Swap));
214    }
215
216    #[test]
217    fn test_matching_route_paths_quote_only() {
218        // Regex that matches only quote paths
219        let paths = matching_route_paths(".*/quote/.*").unwrap();
220
221        // Should match only quote paths
222        assert_eq!(paths.len(), 2);
223        assert!(paths.contains(&RoutePath::MintQuoteBolt11));
224        assert!(paths.contains(&RoutePath::MeltQuoteBolt11));
225
226        // Should not match non-quote paths
227        assert!(!paths.contains(&RoutePath::MintBolt11));
228        assert!(!paths.contains(&RoutePath::MeltBolt11));
229    }
230
231    #[test]
232    fn test_matching_route_paths_no_match() {
233        // Regex that matches nothing
234        let paths = matching_route_paths("/nonexistent/path").unwrap();
235
236        // Should match nothing
237        assert!(paths.is_empty());
238    }
239
240    #[test]
241    fn test_matching_route_paths_quote_bolt11_only() {
242        // Regex that matches only quote paths
243        let paths = matching_route_paths("/v1/mint/quote/bolt11").unwrap();
244
245        // Should match only quote paths
246        assert_eq!(paths.len(), 1);
247        assert!(paths.contains(&RoutePath::MintQuoteBolt11));
248    }
249
250    #[test]
251    fn test_matching_route_paths_invalid_regex() {
252        // Invalid regex pattern
253        let result = matching_route_paths("(unclosed parenthesis");
254
255        // Should return an error for invalid regex
256        assert!(result.is_err());
257        assert!(matches!(result.unwrap_err(), Error::InvalidRegex(_)));
258    }
259
260    #[test]
261    fn test_route_path_to_string() {
262        // Test that to_string() returns the correct path strings
263        assert_eq!(
264            RoutePath::MintQuoteBolt11.to_string(),
265            "/v1/mint/quote/bolt11"
266        );
267        assert_eq!(RoutePath::MintBolt11.to_string(), "/v1/mint/bolt11");
268        assert_eq!(
269            RoutePath::MeltQuoteBolt11.to_string(),
270            "/v1/melt/quote/bolt11"
271        );
272        assert_eq!(RoutePath::MeltBolt11.to_string(), "/v1/melt/bolt11");
273        assert_eq!(RoutePath::Swap.to_string(), "/v1/swap");
274        assert_eq!(RoutePath::Checkstate.to_string(), "/v1/checkstate");
275        assert_eq!(RoutePath::Restore.to_string(), "/v1/restore");
276        assert_eq!(RoutePath::MintBlindAuth.to_string(), "/v1/auth/blind/mint");
277    }
278
279    #[test]
280    fn test_settings_deserialize_direct_paths() {
281        let json = r#"{
282            "openid_discovery": "https://example.com/.well-known/openid-configuration",
283            "client_id": "client123",
284            "protected_endpoints": [
285                {
286                    "method": "GET",
287                    "path": "/v1/mint/bolt11"
288                },
289                {
290                    "method": "POST",
291                    "path": "/v1/swap"
292                }
293            ]
294        }"#;
295
296        let settings: Settings = serde_json::from_str(json).unwrap();
297
298        assert_eq!(
299            settings.openid_discovery,
300            "https://example.com/.well-known/openid-configuration"
301        );
302        assert_eq!(settings.client_id, "client123");
303        assert_eq!(settings.protected_endpoints.len(), 2);
304
305        // Check that both paths are included
306        let paths = settings
307            .protected_endpoints
308            .iter()
309            .map(|ep| (ep.method, ep.path))
310            .collect::<Vec<_>>();
311        assert!(paths.contains(&(Method::Get, RoutePath::MintBolt11)));
312        assert!(paths.contains(&(Method::Post, RoutePath::Swap)));
313    }
314
315    #[test]
316    fn test_settings_deserialize_with_regex() {
317        let json = r#"{
318            "openid_discovery": "https://example.com/.well-known/openid-configuration",
319            "client_id": "client123",
320            "protected_endpoints": [
321                {
322                    "method": "GET",
323                    "path": "^/v1/mint/.*"
324                },
325                {
326                    "method": "POST",
327                    "path": "/v1/swap"
328                }
329            ]
330        }"#;
331
332        let settings: Settings = serde_json::from_str(json).unwrap();
333
334        assert_eq!(
335            settings.openid_discovery,
336            "https://example.com/.well-known/openid-configuration"
337        );
338        assert_eq!(settings.client_id, "client123");
339        assert_eq!(settings.protected_endpoints.len(), 3); // 2 mint paths + 1 swap path
340
341        let expected_protected: HashSet<ProtectedEndpoint> = HashSet::from_iter(vec![
342            ProtectedEndpoint::new(Method::Post, RoutePath::Swap),
343            ProtectedEndpoint::new(Method::Get, RoutePath::MintBolt11),
344            ProtectedEndpoint::new(Method::Get, RoutePath::MintQuoteBolt11),
345        ]);
346
347        let deserlized_protected = settings.protected_endpoints.into_iter().collect();
348
349        assert_eq!(expected_protected, deserlized_protected);
350    }
351
352    #[test]
353    fn test_settings_deserialize_invalid_regex() {
354        let json = r#"{
355            "openid_discovery": "https://example.com/.well-known/openid-configuration",
356            "client_id": "client123",
357            "protected_endpoints": [
358                {
359                    "method": "GET",
360                    "path": "(unclosed parenthesis"
361                }
362            ]
363        }"#;
364
365        let result = serde_json::from_str::<Settings>(json);
366        assert!(result.is_err());
367    }
368
369    #[test]
370    fn test_settings_deserialize_exact_path_match() {
371        let json = r#"{
372            "openid_discovery": "https://example.com/.well-known/openid-configuration",
373            "client_id": "client123",
374            "protected_endpoints": [
375                {
376                    "method": "GET",
377                    "path": "/v1/mint/quote/bolt11"
378                }
379            ]
380        }"#;
381
382        let settings: Settings = serde_json::from_str(json).unwrap();
383        assert_eq!(settings.protected_endpoints.len(), 1);
384        assert_eq!(settings.protected_endpoints[0].method, Method::Get);
385        assert_eq!(
386            settings.protected_endpoints[0].path,
387            RoutePath::MintQuoteBolt11
388        );
389    }
390
391    #[test]
392    fn test_settings_deserialize_all_paths() {
393        let json = r#"{
394            "openid_discovery": "https://example.com/.well-known/openid-configuration",
395            "client_id": "client123",
396            "protected_endpoints": [
397                {
398                    "method": "GET",
399                    "path": ".*"
400                }
401            ]
402        }"#;
403
404        let settings: Settings = serde_json::from_str(json).unwrap();
405        assert_eq!(
406            settings.protected_endpoints.len(),
407            RoutePath::iter().count()
408        );
409    }
410}