1use std::collections::HashSet;
4use std::str::FromStr;
5
6use regex::Regex;
7use serde::{Deserialize, Serialize};
8use strum::IntoEnumIterator;
9use strum_macros::EnumIter;
10use thiserror::Error;
11
12#[derive(Debug, Error)]
14pub enum Error {
15 #[error("Invalid regex pattern: {0}")]
17 InvalidRegex(#[from] regex::Error),
18}
19
20#[derive(Debug, Clone, PartialEq, Eq, Hash, Default, Serialize)]
22#[cfg_attr(feature = "swagger", derive(utoipa::ToSchema))]
23pub struct Settings {
24 pub openid_discovery: String,
26 pub client_id: String,
28 pub protected_endpoints: Vec<ProtectedEndpoint>,
30}
31
32impl Settings {
33 pub fn new(
35 openid_discovery: String,
36 client_id: String,
37 protected_endpoints: Vec<ProtectedEndpoint>,
38 ) -> Self {
39 Self {
40 openid_discovery,
41 client_id,
42 protected_endpoints,
43 }
44 }
45}
46
47impl<'de> Deserialize<'de> for Settings {
49 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
50 where
51 D: serde::Deserializer<'de>,
52 {
53 #[derive(Deserialize)]
55 struct RawSettings {
56 openid_discovery: String,
57 client_id: String,
58 protected_endpoints: Vec<RawProtectedEndpoint>,
59 }
60
61 #[derive(Deserialize)]
62 struct RawProtectedEndpoint {
63 method: Method,
64 path: String,
65 }
66
67 let raw = RawSettings::deserialize(deserializer)?;
69
70 let mut protected_endpoints = HashSet::new();
72
73 for raw_endpoint in raw.protected_endpoints {
74 let expanded_paths = matching_route_paths(&raw_endpoint.path).map_err(|e| {
75 serde::de::Error::custom(format!(
76 "Invalid regex pattern '{}': {}",
77 raw_endpoint.path, e
78 ))
79 })?;
80
81 for path in expanded_paths {
82 protected_endpoints.insert(ProtectedEndpoint::new(raw_endpoint.method, path));
83 }
84 }
85
86 Ok(Settings {
88 openid_discovery: raw.openid_discovery,
89 client_id: raw.client_id,
90 protected_endpoints: protected_endpoints.into_iter().collect(),
91 })
92 }
93}
94
95#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
97#[cfg_attr(feature = "swagger", derive(utoipa::ToSchema))]
98pub struct ProtectedEndpoint {
99 pub method: Method,
101 pub path: RoutePath,
103}
104
105impl ProtectedEndpoint {
106 pub fn new(method: Method, path: RoutePath) -> Self {
108 Self { method, path }
109 }
110}
111
112#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
114#[serde(rename_all = "UPPERCASE")]
115#[cfg_attr(feature = "swagger", derive(utoipa::ToSchema))]
116pub enum Method {
117 Get,
119 Post,
121}
122
123#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, EnumIter)]
125#[cfg_attr(feature = "swagger", derive(utoipa::ToSchema))]
126#[serde(rename_all = "snake_case")]
127pub enum RoutePath {
128 #[serde(rename = "/v1/mint/quote/bolt11")]
130 MintQuoteBolt11,
131 #[serde(rename = "/v1/mint/bolt11")]
133 MintBolt11,
134 #[serde(rename = "/v1/melt/quote/bolt11")]
136 MeltQuoteBolt11,
137 #[serde(rename = "/v1/melt/bolt11")]
139 MeltBolt11,
140 #[serde(rename = "/v1/swap")]
142 Swap,
143 #[serde(rename = "/v1/checkstate")]
145 Checkstate,
146 #[serde(rename = "/v1/restore")]
148 Restore,
149 #[serde(rename = "/v1/auth/blind/mint")]
151 MintBlindAuth,
152}
153
154pub fn matching_route_paths(pattern: &str) -> Result<Vec<RoutePath>, Error> {
155 let regex = Regex::from_str(pattern)?;
156
157 Ok(RoutePath::iter()
158 .filter(|path| regex.is_match(&path.to_string()))
159 .collect())
160}
161
162impl std::fmt::Display for RoutePath {
163 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
164 let json_str = match serde_json::to_string(self) {
166 Ok(s) => s,
167 Err(_) => return write!(f, "<error>"),
168 };
169 let path = json_str.trim_matches('"');
171 write!(f, "{}", path)
172 }
173}
174
175#[cfg(test)]
176mod tests {
177
178 use super::*;
179
180 #[test]
181 fn test_matching_route_paths_all() {
182 let paths = matching_route_paths(".*").unwrap();
184
185 assert_eq!(paths.len(), RoutePath::iter().count());
187
188 assert!(paths.contains(&RoutePath::MintQuoteBolt11));
190 assert!(paths.contains(&RoutePath::MintBolt11));
191 assert!(paths.contains(&RoutePath::MeltQuoteBolt11));
192 assert!(paths.contains(&RoutePath::MeltBolt11));
193 assert!(paths.contains(&RoutePath::Swap));
194 assert!(paths.contains(&RoutePath::Checkstate));
195 assert!(paths.contains(&RoutePath::Restore));
196 assert!(paths.contains(&RoutePath::MintBlindAuth));
197 }
198
199 #[test]
200 fn test_matching_route_paths_mint_only() {
201 let paths = matching_route_paths("^/v1/mint/.*").unwrap();
203
204 assert_eq!(paths.len(), 2);
206 assert!(paths.contains(&RoutePath::MintQuoteBolt11));
207 assert!(paths.contains(&RoutePath::MintBolt11));
208
209 assert!(!paths.contains(&RoutePath::MeltQuoteBolt11));
211 assert!(!paths.contains(&RoutePath::MeltBolt11));
212 assert!(!paths.contains(&RoutePath::Swap));
213 }
214
215 #[test]
216 fn test_matching_route_paths_quote_only() {
217 let paths = matching_route_paths(".*/quote/.*").unwrap();
219
220 assert_eq!(paths.len(), 2);
222 assert!(paths.contains(&RoutePath::MintQuoteBolt11));
223 assert!(paths.contains(&RoutePath::MeltQuoteBolt11));
224
225 assert!(!paths.contains(&RoutePath::MintBolt11));
227 assert!(!paths.contains(&RoutePath::MeltBolt11));
228 }
229
230 #[test]
231 fn test_matching_route_paths_no_match() {
232 let paths = matching_route_paths("/nonexistent/path").unwrap();
234
235 assert!(paths.is_empty());
237 }
238
239 #[test]
240 fn test_matching_route_paths_quote_bolt11_only() {
241 let paths = matching_route_paths("/v1/mint/quote/bolt11").unwrap();
243
244 assert_eq!(paths.len(), 1);
246 assert!(paths.contains(&RoutePath::MintQuoteBolt11));
247 }
248
249 #[test]
250 fn test_matching_route_paths_invalid_regex() {
251 let result = matching_route_paths("(unclosed parenthesis");
253
254 assert!(result.is_err());
256 assert!(matches!(result.unwrap_err(), Error::InvalidRegex(_)));
257 }
258
259 #[test]
260 fn test_route_path_to_string() {
261 assert_eq!(
263 RoutePath::MintQuoteBolt11.to_string(),
264 "/v1/mint/quote/bolt11"
265 );
266 assert_eq!(RoutePath::MintBolt11.to_string(), "/v1/mint/bolt11");
267 assert_eq!(
268 RoutePath::MeltQuoteBolt11.to_string(),
269 "/v1/melt/quote/bolt11"
270 );
271 assert_eq!(RoutePath::MeltBolt11.to_string(), "/v1/melt/bolt11");
272 assert_eq!(RoutePath::Swap.to_string(), "/v1/swap");
273 assert_eq!(RoutePath::Checkstate.to_string(), "/v1/checkstate");
274 assert_eq!(RoutePath::Restore.to_string(), "/v1/restore");
275 assert_eq!(RoutePath::MintBlindAuth.to_string(), "/v1/auth/blind/mint");
276 }
277
278 #[test]
279 fn test_settings_deserialize_direct_paths() {
280 let json = r#"{
281 "openid_discovery": "https://example.com/.well-known/openid-configuration",
282 "client_id": "client123",
283 "protected_endpoints": [
284 {
285 "method": "GET",
286 "path": "/v1/mint/bolt11"
287 },
288 {
289 "method": "POST",
290 "path": "/v1/swap"
291 }
292 ]
293 }"#;
294
295 let settings: Settings = serde_json::from_str(json).unwrap();
296
297 assert_eq!(
298 settings.openid_discovery,
299 "https://example.com/.well-known/openid-configuration"
300 );
301 assert_eq!(settings.client_id, "client123");
302 assert_eq!(settings.protected_endpoints.len(), 2);
303
304 let paths = settings
306 .protected_endpoints
307 .iter()
308 .map(|ep| (ep.method, ep.path))
309 .collect::<Vec<_>>();
310 assert!(paths.contains(&(Method::Get, RoutePath::MintBolt11)));
311 assert!(paths.contains(&(Method::Post, RoutePath::Swap)));
312 }
313
314 #[test]
315 fn test_settings_deserialize_with_regex() {
316 let json = r#"{
317 "openid_discovery": "https://example.com/.well-known/openid-configuration",
318 "client_id": "client123",
319 "protected_endpoints": [
320 {
321 "method": "GET",
322 "path": "^/v1/mint/.*"
323 },
324 {
325 "method": "POST",
326 "path": "/v1/swap"
327 }
328 ]
329 }"#;
330
331 let settings: Settings = serde_json::from_str(json).unwrap();
332
333 assert_eq!(
334 settings.openid_discovery,
335 "https://example.com/.well-known/openid-configuration"
336 );
337 assert_eq!(settings.client_id, "client123");
338 assert_eq!(settings.protected_endpoints.len(), 3); let expected_protected: HashSet<ProtectedEndpoint> = HashSet::from_iter(vec![
341 ProtectedEndpoint::new(Method::Post, RoutePath::Swap),
342 ProtectedEndpoint::new(Method::Get, RoutePath::MintBolt11),
343 ProtectedEndpoint::new(Method::Get, RoutePath::MintQuoteBolt11),
344 ]);
345
346 let deserlized_protected = settings.protected_endpoints.into_iter().collect();
347
348 assert_eq!(expected_protected, deserlized_protected);
349 }
350
351 #[test]
352 fn test_settings_deserialize_invalid_regex() {
353 let json = r#"{
354 "openid_discovery": "https://example.com/.well-known/openid-configuration",
355 "client_id": "client123",
356 "protected_endpoints": [
357 {
358 "method": "GET",
359 "path": "(unclosed parenthesis"
360 }
361 ]
362 }"#;
363
364 let result = serde_json::from_str::<Settings>(json);
365 assert!(result.is_err());
366 }
367
368 #[test]
369 fn test_settings_deserialize_exact_path_match() {
370 let json = r#"{
371 "openid_discovery": "https://example.com/.well-known/openid-configuration",
372 "client_id": "client123",
373 "protected_endpoints": [
374 {
375 "method": "GET",
376 "path": "/v1/mint/quote/bolt11"
377 }
378 ]
379 }"#;
380
381 let settings: Settings = serde_json::from_str(json).unwrap();
382 assert_eq!(settings.protected_endpoints.len(), 1);
383 assert_eq!(settings.protected_endpoints[0].method, Method::Get);
384 assert_eq!(
385 settings.protected_endpoints[0].path,
386 RoutePath::MintQuoteBolt11
387 );
388 }
389
390 #[test]
391 fn test_settings_deserialize_all_paths() {
392 let json = r#"{
393 "openid_discovery": "https://example.com/.well-known/openid-configuration",
394 "client_id": "client123",
395 "protected_endpoints": [
396 {
397 "method": "GET",
398 "path": ".*"
399 }
400 ]
401 }"#;
402
403 let settings: Settings = serde_json::from_str(json).unwrap();
404 assert_eq!(
405 settings.protected_endpoints.len(),
406 RoutePath::iter().count()
407 );
408 }
409}