1use std::collections::HashSet;
4use std::str::FromStr;
5
6use regex::Regex;
7use serde::{Deserialize, Serialize};
8use strum::IntoEnumIterator;
9use strum_macros::EnumIter;
10use thiserror::Error;
11
12#[derive(Debug, Error)]
14pub enum Error {
15 #[error("Invalid regex pattern: {0}")]
17 InvalidRegex(#[from] regex::Error),
18}
19
20#[derive(Debug, Clone, PartialEq, Eq, Hash, Default, Serialize)]
22#[cfg_attr(feature = "swagger", derive(utoipa::ToSchema))]
23pub struct Settings {
24 pub openid_discovery: String,
26 pub client_id: String,
28 pub protected_endpoints: Vec<ProtectedEndpoint>,
30}
31
32impl Settings {
33 pub fn new(
35 openid_discovery: String,
36 client_id: String,
37 protected_endpoints: Vec<ProtectedEndpoint>,
38 ) -> Self {
39 Self {
40 openid_discovery,
41 client_id,
42 protected_endpoints,
43 }
44 }
45}
46
47impl<'de> Deserialize<'de> for Settings {
49 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
50 where
51 D: serde::Deserializer<'de>,
52 {
53 #[derive(Deserialize)]
55 struct RawSettings {
56 openid_discovery: String,
57 client_id: String,
58 protected_endpoints: Vec<RawProtectedEndpoint>,
59 }
60
61 #[derive(Deserialize)]
62 struct RawProtectedEndpoint {
63 method: Method,
64 path: String,
65 }
66
67 let raw = RawSettings::deserialize(deserializer)?;
69
70 let mut protected_endpoints = HashSet::new();
72
73 for raw_endpoint in raw.protected_endpoints {
74 let expanded_paths = matching_route_paths(&raw_endpoint.path).map_err(|e| {
75 serde::de::Error::custom(format!(
76 "Invalid regex pattern '{}': {}",
77 raw_endpoint.path, e
78 ))
79 })?;
80
81 for path in expanded_paths {
82 protected_endpoints.insert(ProtectedEndpoint::new(raw_endpoint.method, path));
83 }
84 }
85
86 Ok(Settings {
88 openid_discovery: raw.openid_discovery,
89 client_id: raw.client_id,
90 protected_endpoints: protected_endpoints.into_iter().collect(),
91 })
92 }
93}
94
95#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
97#[cfg_attr(feature = "swagger", derive(utoipa::ToSchema))]
98pub struct ProtectedEndpoint {
99 pub method: Method,
101 pub path: RoutePath,
103}
104
105impl ProtectedEndpoint {
106 pub fn new(method: Method, path: RoutePath) -> Self {
108 Self { method, path }
109 }
110}
111
112#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
114#[serde(rename_all = "UPPERCASE")]
115#[cfg_attr(feature = "swagger", derive(utoipa::ToSchema))]
116pub enum Method {
117 Get,
119 Post,
121}
122
123#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, EnumIter)]
125#[cfg_attr(feature = "swagger", derive(utoipa::ToSchema))]
126#[serde(rename_all = "snake_case")]
127pub enum RoutePath {
128 #[serde(rename = "/v1/mint/quote/bolt11")]
130 MintQuoteBolt11,
131 #[serde(rename = "/v1/mint/bolt11")]
133 MintBolt11,
134 #[serde(rename = "/v1/melt/quote/bolt11")]
136 MeltQuoteBolt11,
137 #[serde(rename = "/v1/melt/bolt11")]
139 MeltBolt11,
140 #[serde(rename = "/v1/swap")]
142 Swap,
143 #[serde(rename = "/v1/checkstate")]
145 Checkstate,
146 #[serde(rename = "/v1/restore")]
148 Restore,
149 #[serde(rename = "/v1/auth/blind/mint")]
151 MintBlindAuth,
152}
153
154pub fn matching_route_paths(pattern: &str) -> Result<Vec<RoutePath>, Error> {
156 let regex = Regex::from_str(pattern)?;
157
158 Ok(RoutePath::iter()
159 .filter(|path| regex.is_match(&path.to_string()))
160 .collect())
161}
162
163impl std::fmt::Display for RoutePath {
164 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
165 let json_str = match serde_json::to_string(self) {
167 Ok(s) => s,
168 Err(_) => return write!(f, "<error>"),
169 };
170 let path = json_str.trim_matches('"');
172 write!(f, "{path}")
173 }
174}
175
176#[cfg(test)]
177mod tests {
178
179 use super::*;
180
181 #[test]
182 fn test_matching_route_paths_all() {
183 let paths = matching_route_paths(".*").unwrap();
185
186 assert_eq!(paths.len(), RoutePath::iter().count());
188
189 assert!(paths.contains(&RoutePath::MintQuoteBolt11));
191 assert!(paths.contains(&RoutePath::MintBolt11));
192 assert!(paths.contains(&RoutePath::MeltQuoteBolt11));
193 assert!(paths.contains(&RoutePath::MeltBolt11));
194 assert!(paths.contains(&RoutePath::Swap));
195 assert!(paths.contains(&RoutePath::Checkstate));
196 assert!(paths.contains(&RoutePath::Restore));
197 assert!(paths.contains(&RoutePath::MintBlindAuth));
198 }
199
200 #[test]
201 fn test_matching_route_paths_mint_only() {
202 let paths = matching_route_paths("^/v1/mint/.*").unwrap();
204
205 assert_eq!(paths.len(), 2);
207 assert!(paths.contains(&RoutePath::MintQuoteBolt11));
208 assert!(paths.contains(&RoutePath::MintBolt11));
209
210 assert!(!paths.contains(&RoutePath::MeltQuoteBolt11));
212 assert!(!paths.contains(&RoutePath::MeltBolt11));
213 assert!(!paths.contains(&RoutePath::Swap));
214 }
215
216 #[test]
217 fn test_matching_route_paths_quote_only() {
218 let paths = matching_route_paths(".*/quote/.*").unwrap();
220
221 assert_eq!(paths.len(), 2);
223 assert!(paths.contains(&RoutePath::MintQuoteBolt11));
224 assert!(paths.contains(&RoutePath::MeltQuoteBolt11));
225
226 assert!(!paths.contains(&RoutePath::MintBolt11));
228 assert!(!paths.contains(&RoutePath::MeltBolt11));
229 }
230
231 #[test]
232 fn test_matching_route_paths_no_match() {
233 let paths = matching_route_paths("/nonexistent/path").unwrap();
235
236 assert!(paths.is_empty());
238 }
239
240 #[test]
241 fn test_matching_route_paths_quote_bolt11_only() {
242 let paths = matching_route_paths("/v1/mint/quote/bolt11").unwrap();
244
245 assert_eq!(paths.len(), 1);
247 assert!(paths.contains(&RoutePath::MintQuoteBolt11));
248 }
249
250 #[test]
251 fn test_matching_route_paths_invalid_regex() {
252 let result = matching_route_paths("(unclosed parenthesis");
254
255 assert!(result.is_err());
257 assert!(matches!(result.unwrap_err(), Error::InvalidRegex(_)));
258 }
259
260 #[test]
261 fn test_route_path_to_string() {
262 assert_eq!(
264 RoutePath::MintQuoteBolt11.to_string(),
265 "/v1/mint/quote/bolt11"
266 );
267 assert_eq!(RoutePath::MintBolt11.to_string(), "/v1/mint/bolt11");
268 assert_eq!(
269 RoutePath::MeltQuoteBolt11.to_string(),
270 "/v1/melt/quote/bolt11"
271 );
272 assert_eq!(RoutePath::MeltBolt11.to_string(), "/v1/melt/bolt11");
273 assert_eq!(RoutePath::Swap.to_string(), "/v1/swap");
274 assert_eq!(RoutePath::Checkstate.to_string(), "/v1/checkstate");
275 assert_eq!(RoutePath::Restore.to_string(), "/v1/restore");
276 assert_eq!(RoutePath::MintBlindAuth.to_string(), "/v1/auth/blind/mint");
277 }
278
279 #[test]
280 fn test_settings_deserialize_direct_paths() {
281 let json = r#"{
282 "openid_discovery": "https://example.com/.well-known/openid-configuration",
283 "client_id": "client123",
284 "protected_endpoints": [
285 {
286 "method": "GET",
287 "path": "/v1/mint/bolt11"
288 },
289 {
290 "method": "POST",
291 "path": "/v1/swap"
292 }
293 ]
294 }"#;
295
296 let settings: Settings = serde_json::from_str(json).unwrap();
297
298 assert_eq!(
299 settings.openid_discovery,
300 "https://example.com/.well-known/openid-configuration"
301 );
302 assert_eq!(settings.client_id, "client123");
303 assert_eq!(settings.protected_endpoints.len(), 2);
304
305 let paths = settings
307 .protected_endpoints
308 .iter()
309 .map(|ep| (ep.method, ep.path))
310 .collect::<Vec<_>>();
311 assert!(paths.contains(&(Method::Get, RoutePath::MintBolt11)));
312 assert!(paths.contains(&(Method::Post, RoutePath::Swap)));
313 }
314
315 #[test]
316 fn test_settings_deserialize_with_regex() {
317 let json = r#"{
318 "openid_discovery": "https://example.com/.well-known/openid-configuration",
319 "client_id": "client123",
320 "protected_endpoints": [
321 {
322 "method": "GET",
323 "path": "^/v1/mint/.*"
324 },
325 {
326 "method": "POST",
327 "path": "/v1/swap"
328 }
329 ]
330 }"#;
331
332 let settings: Settings = serde_json::from_str(json).unwrap();
333
334 assert_eq!(
335 settings.openid_discovery,
336 "https://example.com/.well-known/openid-configuration"
337 );
338 assert_eq!(settings.client_id, "client123");
339 assert_eq!(settings.protected_endpoints.len(), 3); let expected_protected: HashSet<ProtectedEndpoint> = HashSet::from_iter(vec![
342 ProtectedEndpoint::new(Method::Post, RoutePath::Swap),
343 ProtectedEndpoint::new(Method::Get, RoutePath::MintBolt11),
344 ProtectedEndpoint::new(Method::Get, RoutePath::MintQuoteBolt11),
345 ]);
346
347 let deserlized_protected = settings.protected_endpoints.into_iter().collect();
348
349 assert_eq!(expected_protected, deserlized_protected);
350 }
351
352 #[test]
353 fn test_settings_deserialize_invalid_regex() {
354 let json = r#"{
355 "openid_discovery": "https://example.com/.well-known/openid-configuration",
356 "client_id": "client123",
357 "protected_endpoints": [
358 {
359 "method": "GET",
360 "path": "(unclosed parenthesis"
361 }
362 ]
363 }"#;
364
365 let result = serde_json::from_str::<Settings>(json);
366 assert!(result.is_err());
367 }
368
369 #[test]
370 fn test_settings_deserialize_exact_path_match() {
371 let json = r#"{
372 "openid_discovery": "https://example.com/.well-known/openid-configuration",
373 "client_id": "client123",
374 "protected_endpoints": [
375 {
376 "method": "GET",
377 "path": "/v1/mint/quote/bolt11"
378 }
379 ]
380 }"#;
381
382 let settings: Settings = serde_json::from_str(json).unwrap();
383 assert_eq!(settings.protected_endpoints.len(), 1);
384 assert_eq!(settings.protected_endpoints[0].method, Method::Get);
385 assert_eq!(
386 settings.protected_endpoints[0].path,
387 RoutePath::MintQuoteBolt11
388 );
389 }
390
391 #[test]
392 fn test_settings_deserialize_all_paths() {
393 let json = r#"{
394 "openid_discovery": "https://example.com/.well-known/openid-configuration",
395 "client_id": "client123",
396 "protected_endpoints": [
397 {
398 "method": "GET",
399 "path": ".*"
400 }
401 ]
402 }"#;
403
404 let settings: Settings = serde_json::from_str(json).unwrap();
405 assert_eq!(
406 settings.protected_endpoints.len(),
407 RoutePath::iter().count()
408 );
409 }
410}