cashu/nuts/auth/
nut21.rs

1//! 21 Clear Auth
2
3use std::collections::HashSet;
4use std::str::FromStr;
5
6use regex::Regex;
7use serde::{Deserialize, Serialize};
8use strum::IntoEnumIterator;
9use strum_macros::EnumIter;
10use thiserror::Error;
11
12/// NUT21 Error
13#[derive(Debug, Error)]
14pub enum Error {
15    /// Invalid regex pattern
16    #[error("Invalid regex pattern: {0}")]
17    InvalidRegex(#[from] regex::Error),
18}
19
20/// Clear Auth Settings
21#[derive(Debug, Clone, PartialEq, Eq, Hash, Default, Serialize)]
22#[cfg_attr(feature = "swagger", derive(utoipa::ToSchema))]
23pub struct Settings {
24    /// Openid discovery
25    pub openid_discovery: String,
26    /// Client ID
27    pub client_id: String,
28    /// Protected endpoints
29    pub protected_endpoints: Vec<ProtectedEndpoint>,
30}
31
32impl Settings {
33    /// Create new [`Settings`]
34    pub fn new(
35        openid_discovery: String,
36        client_id: String,
37        protected_endpoints: Vec<ProtectedEndpoint>,
38    ) -> Self {
39        Self {
40            openid_discovery,
41            client_id,
42            protected_endpoints,
43        }
44    }
45}
46
47// Custom deserializer for Settings to expand regex patterns in protected endpoints
48impl<'de> Deserialize<'de> for Settings {
49    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
50    where
51        D: serde::Deserializer<'de>,
52    {
53        // Define a temporary struct to deserialize the raw data
54        #[derive(Deserialize)]
55        struct RawSettings {
56            openid_discovery: String,
57            client_id: String,
58            protected_endpoints: Vec<RawProtectedEndpoint>,
59        }
60
61        #[derive(Deserialize)]
62        struct RawProtectedEndpoint {
63            method: Method,
64            path: String,
65        }
66
67        // Deserialize into the temporary struct
68        let raw = RawSettings::deserialize(deserializer)?;
69
70        // Process protected endpoints, expanding regex patterns if present
71        let mut protected_endpoints = HashSet::new();
72
73        for raw_endpoint in raw.protected_endpoints {
74            let expanded_paths = matching_route_paths(&raw_endpoint.path).map_err(|e| {
75                serde::de::Error::custom(format!(
76                    "Invalid regex pattern '{}': {}",
77                    raw_endpoint.path, e
78                ))
79            })?;
80
81            for path in expanded_paths {
82                protected_endpoints.insert(ProtectedEndpoint::new(raw_endpoint.method, path));
83            }
84        }
85
86        // Create the final Settings struct
87        Ok(Settings {
88            openid_discovery: raw.openid_discovery,
89            client_id: raw.client_id,
90            protected_endpoints: protected_endpoints.into_iter().collect(),
91        })
92    }
93}
94
95/// List of the methods and paths that are protected
96#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
97#[cfg_attr(feature = "swagger", derive(utoipa::ToSchema))]
98pub struct ProtectedEndpoint {
99    /// HTTP Method
100    pub method: Method,
101    /// Route path
102    pub path: RoutePath,
103}
104
105impl ProtectedEndpoint {
106    /// Create [`CachedEndpoint`]
107    pub fn new(method: Method, path: RoutePath) -> Self {
108        Self { method, path }
109    }
110}
111
112/// HTTP method
113#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
114#[serde(rename_all = "UPPERCASE")]
115#[cfg_attr(feature = "swagger", derive(utoipa::ToSchema))]
116pub enum Method {
117    /// Get
118    Get,
119    /// POST
120    Post,
121}
122
123/// Route path
124#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, EnumIter)]
125#[cfg_attr(feature = "swagger", derive(utoipa::ToSchema))]
126#[serde(rename_all = "snake_case")]
127pub enum RoutePath {
128    /// Bolt11 Mint Quote
129    #[serde(rename = "/v1/mint/quote/bolt11")]
130    MintQuoteBolt11,
131    /// Bolt11 Mint
132    #[serde(rename = "/v1/mint/bolt11")]
133    MintBolt11,
134    /// Bolt11 Melt Quote
135    #[serde(rename = "/v1/melt/quote/bolt11")]
136    MeltQuoteBolt11,
137    /// Bolt11 Melt
138    #[serde(rename = "/v1/melt/bolt11")]
139    MeltBolt11,
140    /// Swap
141    #[serde(rename = "/v1/swap")]
142    Swap,
143    /// Checkstate
144    #[serde(rename = "/v1/checkstate")]
145    Checkstate,
146    /// Restore
147    #[serde(rename = "/v1/restore")]
148    Restore,
149    /// Mint Blind Auth
150    #[serde(rename = "/v1/auth/blind/mint")]
151    MintBlindAuth,
152}
153
154pub fn matching_route_paths(pattern: &str) -> Result<Vec<RoutePath>, Error> {
155    let regex = Regex::from_str(pattern)?;
156
157    Ok(RoutePath::iter()
158        .filter(|path| regex.is_match(&path.to_string()))
159        .collect())
160}
161
162impl std::fmt::Display for RoutePath {
163    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
164        // Use serde to serialize to a JSON string, then extract the value without quotes
165        let json_str = match serde_json::to_string(self) {
166            Ok(s) => s,
167            Err(_) => return write!(f, "<error>"),
168        };
169        // Remove the quotes from the JSON string
170        let path = json_str.trim_matches('"');
171        write!(f, "{}", path)
172    }
173}
174
175#[cfg(test)]
176mod tests {
177
178    use super::*;
179
180    #[test]
181    fn test_matching_route_paths_all() {
182        // Regex that matches all paths
183        let paths = matching_route_paths(".*").unwrap();
184
185        // Should match all variants
186        assert_eq!(paths.len(), RoutePath::iter().count());
187
188        // Verify all variants are included
189        assert!(paths.contains(&RoutePath::MintQuoteBolt11));
190        assert!(paths.contains(&RoutePath::MintBolt11));
191        assert!(paths.contains(&RoutePath::MeltQuoteBolt11));
192        assert!(paths.contains(&RoutePath::MeltBolt11));
193        assert!(paths.contains(&RoutePath::Swap));
194        assert!(paths.contains(&RoutePath::Checkstate));
195        assert!(paths.contains(&RoutePath::Restore));
196        assert!(paths.contains(&RoutePath::MintBlindAuth));
197    }
198
199    #[test]
200    fn test_matching_route_paths_mint_only() {
201        // Regex that matches only mint paths
202        let paths = matching_route_paths("^/v1/mint/.*").unwrap();
203
204        // Should match only mint paths
205        assert_eq!(paths.len(), 2);
206        assert!(paths.contains(&RoutePath::MintQuoteBolt11));
207        assert!(paths.contains(&RoutePath::MintBolt11));
208
209        // Should not match other paths
210        assert!(!paths.contains(&RoutePath::MeltQuoteBolt11));
211        assert!(!paths.contains(&RoutePath::MeltBolt11));
212        assert!(!paths.contains(&RoutePath::Swap));
213    }
214
215    #[test]
216    fn test_matching_route_paths_quote_only() {
217        // Regex that matches only quote paths
218        let paths = matching_route_paths(".*/quote/.*").unwrap();
219
220        // Should match only quote paths
221        assert_eq!(paths.len(), 2);
222        assert!(paths.contains(&RoutePath::MintQuoteBolt11));
223        assert!(paths.contains(&RoutePath::MeltQuoteBolt11));
224
225        // Should not match non-quote paths
226        assert!(!paths.contains(&RoutePath::MintBolt11));
227        assert!(!paths.contains(&RoutePath::MeltBolt11));
228    }
229
230    #[test]
231    fn test_matching_route_paths_no_match() {
232        // Regex that matches nothing
233        let paths = matching_route_paths("/nonexistent/path").unwrap();
234
235        // Should match nothing
236        assert!(paths.is_empty());
237    }
238
239    #[test]
240    fn test_matching_route_paths_quote_bolt11_only() {
241        // Regex that matches only quote paths
242        let paths = matching_route_paths("/v1/mint/quote/bolt11").unwrap();
243
244        // Should match only quote paths
245        assert_eq!(paths.len(), 1);
246        assert!(paths.contains(&RoutePath::MintQuoteBolt11));
247    }
248
249    #[test]
250    fn test_matching_route_paths_invalid_regex() {
251        // Invalid regex pattern
252        let result = matching_route_paths("(unclosed parenthesis");
253
254        // Should return an error for invalid regex
255        assert!(result.is_err());
256        assert!(matches!(result.unwrap_err(), Error::InvalidRegex(_)));
257    }
258
259    #[test]
260    fn test_route_path_to_string() {
261        // Test that to_string() returns the correct path strings
262        assert_eq!(
263            RoutePath::MintQuoteBolt11.to_string(),
264            "/v1/mint/quote/bolt11"
265        );
266        assert_eq!(RoutePath::MintBolt11.to_string(), "/v1/mint/bolt11");
267        assert_eq!(
268            RoutePath::MeltQuoteBolt11.to_string(),
269            "/v1/melt/quote/bolt11"
270        );
271        assert_eq!(RoutePath::MeltBolt11.to_string(), "/v1/melt/bolt11");
272        assert_eq!(RoutePath::Swap.to_string(), "/v1/swap");
273        assert_eq!(RoutePath::Checkstate.to_string(), "/v1/checkstate");
274        assert_eq!(RoutePath::Restore.to_string(), "/v1/restore");
275        assert_eq!(RoutePath::MintBlindAuth.to_string(), "/v1/auth/blind/mint");
276    }
277
278    #[test]
279    fn test_settings_deserialize_direct_paths() {
280        let json = r#"{
281            "openid_discovery": "https://example.com/.well-known/openid-configuration",
282            "client_id": "client123",
283            "protected_endpoints": [
284                {
285                    "method": "GET",
286                    "path": "/v1/mint/bolt11"
287                },
288                {
289                    "method": "POST",
290                    "path": "/v1/swap"
291                }
292            ]
293        }"#;
294
295        let settings: Settings = serde_json::from_str(json).unwrap();
296
297        assert_eq!(
298            settings.openid_discovery,
299            "https://example.com/.well-known/openid-configuration"
300        );
301        assert_eq!(settings.client_id, "client123");
302        assert_eq!(settings.protected_endpoints.len(), 2);
303
304        // Check that both paths are included
305        let paths = settings
306            .protected_endpoints
307            .iter()
308            .map(|ep| (ep.method, ep.path))
309            .collect::<Vec<_>>();
310        assert!(paths.contains(&(Method::Get, RoutePath::MintBolt11)));
311        assert!(paths.contains(&(Method::Post, RoutePath::Swap)));
312    }
313
314    #[test]
315    fn test_settings_deserialize_with_regex() {
316        let json = r#"{
317            "openid_discovery": "https://example.com/.well-known/openid-configuration",
318            "client_id": "client123",
319            "protected_endpoints": [
320                {
321                    "method": "GET",
322                    "path": "^/v1/mint/.*"
323                },
324                {
325                    "method": "POST",
326                    "path": "/v1/swap"
327                }
328            ]
329        }"#;
330
331        let settings: Settings = serde_json::from_str(json).unwrap();
332
333        assert_eq!(
334            settings.openid_discovery,
335            "https://example.com/.well-known/openid-configuration"
336        );
337        assert_eq!(settings.client_id, "client123");
338        assert_eq!(settings.protected_endpoints.len(), 3); // 2 mint paths + 1 swap path
339
340        let expected_protected: HashSet<ProtectedEndpoint> = HashSet::from_iter(vec![
341            ProtectedEndpoint::new(Method::Post, RoutePath::Swap),
342            ProtectedEndpoint::new(Method::Get, RoutePath::MintBolt11),
343            ProtectedEndpoint::new(Method::Get, RoutePath::MintQuoteBolt11),
344        ]);
345
346        let deserlized_protected = settings.protected_endpoints.into_iter().collect();
347
348        assert_eq!(expected_protected, deserlized_protected);
349    }
350
351    #[test]
352    fn test_settings_deserialize_invalid_regex() {
353        let json = r#"{
354            "openid_discovery": "https://example.com/.well-known/openid-configuration",
355            "client_id": "client123",
356            "protected_endpoints": [
357                {
358                    "method": "GET",
359                    "path": "(unclosed parenthesis"
360                }
361            ]
362        }"#;
363
364        let result = serde_json::from_str::<Settings>(json);
365        assert!(result.is_err());
366    }
367
368    #[test]
369    fn test_settings_deserialize_exact_path_match() {
370        let json = r#"{
371            "openid_discovery": "https://example.com/.well-known/openid-configuration",
372            "client_id": "client123",
373            "protected_endpoints": [
374                {
375                    "method": "GET",
376                    "path": "/v1/mint/quote/bolt11"
377                }
378            ]
379        }"#;
380
381        let settings: Settings = serde_json::from_str(json).unwrap();
382        assert_eq!(settings.protected_endpoints.len(), 1);
383        assert_eq!(settings.protected_endpoints[0].method, Method::Get);
384        assert_eq!(
385            settings.protected_endpoints[0].path,
386            RoutePath::MintQuoteBolt11
387        );
388    }
389
390    #[test]
391    fn test_settings_deserialize_all_paths() {
392        let json = r#"{
393            "openid_discovery": "https://example.com/.well-known/openid-configuration",
394            "client_id": "client123",
395            "protected_endpoints": [
396                {
397                    "method": "GET",
398                    "path": ".*"
399                }
400            ]
401        }"#;
402
403        let settings: Settings = serde_json::from_str(json).unwrap();
404        assert_eq!(
405            settings.protected_endpoints.len(),
406            RoutePath::iter().count()
407        );
408    }
409}