1use std::collections::HashSet;
4use std::str::FromStr;
5
6use regex::Regex;
7use serde::{Deserialize, Serialize};
8use strum::IntoEnumIterator;
9use strum_macros::EnumIter;
10use thiserror::Error;
11
12#[derive(Debug, Error)]
14pub enum Error {
15 #[error("Invalid regex pattern: {0}")]
17 InvalidRegex(#[from] regex::Error),
18}
19
20#[derive(Debug, Clone, PartialEq, Eq, Hash, Default, Serialize)]
22#[cfg_attr(feature = "swagger", derive(utoipa::ToSchema))]
23pub struct Settings {
24 pub openid_discovery: String,
26 pub client_id: String,
28 pub protected_endpoints: Vec<ProtectedEndpoint>,
30}
31
32impl Settings {
33 pub fn new(
35 openid_discovery: String,
36 client_id: String,
37 protected_endpoints: Vec<ProtectedEndpoint>,
38 ) -> Self {
39 Self {
40 openid_discovery,
41 client_id,
42 protected_endpoints,
43 }
44 }
45}
46
47impl<'de> Deserialize<'de> for Settings {
49 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
50 where
51 D: serde::Deserializer<'de>,
52 {
53 #[derive(Deserialize)]
55 struct RawSettings {
56 openid_discovery: String,
57 client_id: String,
58 protected_endpoints: Vec<RawProtectedEndpoint>,
59 }
60
61 #[derive(Deserialize)]
62 struct RawProtectedEndpoint {
63 method: Method,
64 path: String,
65 }
66
67 let raw = RawSettings::deserialize(deserializer)?;
69
70 let mut protected_endpoints = HashSet::new();
72
73 for raw_endpoint in raw.protected_endpoints {
74 let expanded_paths = matching_route_paths(&raw_endpoint.path).map_err(|e| {
75 serde::de::Error::custom(format!(
76 "Invalid regex pattern '{}': {}",
77 raw_endpoint.path, e
78 ))
79 })?;
80
81 for path in expanded_paths {
82 protected_endpoints.insert(ProtectedEndpoint::new(raw_endpoint.method, path));
83 }
84 }
85
86 Ok(Settings {
88 openid_discovery: raw.openid_discovery,
89 client_id: raw.client_id,
90 protected_endpoints: protected_endpoints.into_iter().collect(),
91 })
92 }
93}
94
95#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
97#[cfg_attr(feature = "swagger", derive(utoipa::ToSchema))]
98pub struct ProtectedEndpoint {
99 pub method: Method,
101 pub path: RoutePath,
103}
104
105impl ProtectedEndpoint {
106 pub fn new(method: Method, path: RoutePath) -> Self {
108 Self { method, path }
109 }
110}
111
112#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
114#[serde(rename_all = "UPPERCASE")]
115#[cfg_attr(feature = "swagger", derive(utoipa::ToSchema))]
116pub enum Method {
117 Get,
119 Post,
121}
122
123#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, EnumIter)]
125#[cfg_attr(feature = "swagger", derive(utoipa::ToSchema))]
126#[serde(rename_all = "snake_case")]
127pub enum RoutePath {
128 #[serde(rename = "/v1/mint/quote/bolt11")]
130 MintQuoteBolt11,
131 #[serde(rename = "/v1/mint/bolt11")]
133 MintBolt11,
134 #[serde(rename = "/v1/melt/quote/bolt11")]
136 MeltQuoteBolt11,
137 #[serde(rename = "/v1/melt/bolt11")]
139 MeltBolt11,
140 #[serde(rename = "/v1/swap")]
142 Swap,
143 #[serde(rename = "/v1/checkstate")]
145 Checkstate,
146 #[serde(rename = "/v1/restore")]
148 Restore,
149 #[serde(rename = "/v1/auth/blind/mint")]
151 MintBlindAuth,
152 #[serde(rename = "/v1/mint/quote/bolt12")]
154 MintQuoteBolt12,
155 #[serde(rename = "/v1/mint/bolt12")]
157 MintBolt12,
158 #[serde(rename = "/v1/melt/quote/bolt12")]
160 MeltQuoteBolt12,
161 #[serde(rename = "/v1/melt/bolt12")]
163 MeltBolt12,
164}
165
166pub fn matching_route_paths(pattern: &str) -> Result<Vec<RoutePath>, Error> {
168 let regex = Regex::from_str(pattern)?;
169
170 Ok(RoutePath::iter()
171 .filter(|path| regex.is_match(&path.to_string()))
172 .collect())
173}
174
175impl std::fmt::Display for RoutePath {
176 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
177 let json_str = match serde_json::to_string(self) {
179 Ok(s) => s,
180 Err(_) => return write!(f, "<error>"),
181 };
182 let path = json_str.trim_matches('"');
184 write!(f, "{path}")
185 }
186}
187
188#[cfg(test)]
189mod tests {
190
191 use super::*;
192
193 #[test]
194 fn test_matching_route_paths_all() {
195 let paths = matching_route_paths(".*").unwrap();
197
198 assert_eq!(paths.len(), RoutePath::iter().count());
200
201 assert!(paths.contains(&RoutePath::MintQuoteBolt11));
203 assert!(paths.contains(&RoutePath::MintBolt11));
204 assert!(paths.contains(&RoutePath::MeltQuoteBolt11));
205 assert!(paths.contains(&RoutePath::MeltBolt11));
206 assert!(paths.contains(&RoutePath::Swap));
207 assert!(paths.contains(&RoutePath::Checkstate));
208 assert!(paths.contains(&RoutePath::Restore));
209 assert!(paths.contains(&RoutePath::MintBlindAuth));
210 assert!(paths.contains(&RoutePath::MintQuoteBolt12));
211 assert!(paths.contains(&RoutePath::MintBolt12));
212 }
213
214 #[test]
215 fn test_matching_route_paths_mint_only() {
216 let paths = matching_route_paths("^/v1/mint/.*").unwrap();
218
219 assert_eq!(paths.len(), 4);
221 assert!(paths.contains(&RoutePath::MintQuoteBolt11));
222 assert!(paths.contains(&RoutePath::MintBolt11));
223 assert!(paths.contains(&RoutePath::MintQuoteBolt12));
224 assert!(paths.contains(&RoutePath::MintBolt12));
225
226 assert!(!paths.contains(&RoutePath::MeltQuoteBolt11));
228 assert!(!paths.contains(&RoutePath::MeltBolt11));
229 assert!(!paths.contains(&RoutePath::MeltQuoteBolt12));
230 assert!(!paths.contains(&RoutePath::MeltBolt12));
231 assert!(!paths.contains(&RoutePath::Swap));
232 }
233
234 #[test]
235 fn test_matching_route_paths_quote_only() {
236 let paths = matching_route_paths(".*/quote/.*").unwrap();
238
239 assert_eq!(paths.len(), 4);
241 assert!(paths.contains(&RoutePath::MintQuoteBolt11));
242 assert!(paths.contains(&RoutePath::MeltQuoteBolt11));
243 assert!(paths.contains(&RoutePath::MintQuoteBolt12));
244 assert!(paths.contains(&RoutePath::MeltQuoteBolt12));
245
246 assert!(!paths.contains(&RoutePath::MintBolt11));
248 assert!(!paths.contains(&RoutePath::MeltBolt11));
249 }
250
251 #[test]
252 fn test_matching_route_paths_no_match() {
253 let paths = matching_route_paths("/nonexistent/path").unwrap();
255
256 assert!(paths.is_empty());
258 }
259
260 #[test]
261 fn test_matching_route_paths_quote_bolt11_only() {
262 let paths = matching_route_paths("/v1/mint/quote/bolt11").unwrap();
264
265 assert_eq!(paths.len(), 1);
267 assert!(paths.contains(&RoutePath::MintQuoteBolt11));
268 }
269
270 #[test]
271 fn test_matching_route_paths_invalid_regex() {
272 let result = matching_route_paths("(unclosed parenthesis");
274
275 assert!(result.is_err());
277 assert!(matches!(result.unwrap_err(), Error::InvalidRegex(_)));
278 }
279
280 #[test]
281 fn test_route_path_to_string() {
282 assert_eq!(
284 RoutePath::MintQuoteBolt11.to_string(),
285 "/v1/mint/quote/bolt11"
286 );
287 assert_eq!(RoutePath::MintBolt11.to_string(), "/v1/mint/bolt11");
288 assert_eq!(
289 RoutePath::MeltQuoteBolt11.to_string(),
290 "/v1/melt/quote/bolt11"
291 );
292 assert_eq!(RoutePath::MeltBolt11.to_string(), "/v1/melt/bolt11");
293 assert_eq!(RoutePath::Swap.to_string(), "/v1/swap");
294 assert_eq!(RoutePath::Checkstate.to_string(), "/v1/checkstate");
295 assert_eq!(RoutePath::Restore.to_string(), "/v1/restore");
296 assert_eq!(RoutePath::MintBlindAuth.to_string(), "/v1/auth/blind/mint");
297 }
298
299 #[test]
300 fn test_settings_deserialize_direct_paths() {
301 let json = r#"{
302 "openid_discovery": "https://example.com/.well-known/openid-configuration",
303 "client_id": "client123",
304 "protected_endpoints": [
305 {
306 "method": "GET",
307 "path": "/v1/mint/bolt11"
308 },
309 {
310 "method": "POST",
311 "path": "/v1/swap"
312 }
313 ]
314 }"#;
315
316 let settings: Settings = serde_json::from_str(json).unwrap();
317
318 assert_eq!(
319 settings.openid_discovery,
320 "https://example.com/.well-known/openid-configuration"
321 );
322 assert_eq!(settings.client_id, "client123");
323 assert_eq!(settings.protected_endpoints.len(), 2);
324
325 let paths = settings
327 .protected_endpoints
328 .iter()
329 .map(|ep| (ep.method, ep.path))
330 .collect::<Vec<_>>();
331 assert!(paths.contains(&(Method::Get, RoutePath::MintBolt11)));
332 assert!(paths.contains(&(Method::Post, RoutePath::Swap)));
333 }
334
335 #[test]
336 fn test_settings_deserialize_with_regex() {
337 let json = r#"{
338 "openid_discovery": "https://example.com/.well-known/openid-configuration",
339 "client_id": "client123",
340 "protected_endpoints": [
341 {
342 "method": "GET",
343 "path": "^/v1/mint/.*"
344 },
345 {
346 "method": "POST",
347 "path": "/v1/swap"
348 }
349 ]
350 }"#;
351
352 let settings: Settings = serde_json::from_str(json).unwrap();
353
354 assert_eq!(
355 settings.openid_discovery,
356 "https://example.com/.well-known/openid-configuration"
357 );
358 assert_eq!(settings.client_id, "client123");
359 assert_eq!(settings.protected_endpoints.len(), 5); let expected_protected: HashSet<ProtectedEndpoint> = HashSet::from_iter(vec![
362 ProtectedEndpoint::new(Method::Post, RoutePath::Swap),
363 ProtectedEndpoint::new(Method::Get, RoutePath::MintBolt11),
364 ProtectedEndpoint::new(Method::Get, RoutePath::MintQuoteBolt11),
365 ProtectedEndpoint::new(Method::Get, RoutePath::MintQuoteBolt12),
366 ProtectedEndpoint::new(Method::Get, RoutePath::MintBolt12),
367 ]);
368
369 let deserlized_protected = settings.protected_endpoints.into_iter().collect();
370
371 assert_eq!(expected_protected, deserlized_protected);
372 }
373
374 #[test]
375 fn test_settings_deserialize_invalid_regex() {
376 let json = r#"{
377 "openid_discovery": "https://example.com/.well-known/openid-configuration",
378 "client_id": "client123",
379 "protected_endpoints": [
380 {
381 "method": "GET",
382 "path": "(unclosed parenthesis"
383 }
384 ]
385 }"#;
386
387 let result = serde_json::from_str::<Settings>(json);
388 assert!(result.is_err());
389 }
390
391 #[test]
392 fn test_settings_deserialize_exact_path_match() {
393 let json = r#"{
394 "openid_discovery": "https://example.com/.well-known/openid-configuration",
395 "client_id": "client123",
396 "protected_endpoints": [
397 {
398 "method": "GET",
399 "path": "/v1/mint/quote/bolt11"
400 }
401 ]
402 }"#;
403
404 let settings: Settings = serde_json::from_str(json).unwrap();
405 assert_eq!(settings.protected_endpoints.len(), 1);
406 assert_eq!(settings.protected_endpoints[0].method, Method::Get);
407 assert_eq!(
408 settings.protected_endpoints[0].path,
409 RoutePath::MintQuoteBolt11
410 );
411 }
412
413 #[test]
414 fn test_settings_deserialize_all_paths() {
415 let json = r#"{
416 "openid_discovery": "https://example.com/.well-known/openid-configuration",
417 "client_id": "client123",
418 "protected_endpoints": [
419 {
420 "method": "GET",
421 "path": ".*"
422 }
423 ]
424 }"#;
425
426 let settings: Settings = serde_json::from_str(json).unwrap();
427 assert_eq!(
428 settings.protected_endpoints.len(),
429 RoutePath::iter().count()
430 );
431 }
432}