1use std::sync::OnceLock;
2
3use utoipa::{
4 openapi::security::{ApiKey, ApiKeyValue, HttpAuthScheme, HttpBuilder, SecurityScheme},
5 Modify,
6};
7
8use loco_rs::{app::AppContext, config::JWTLocation as LocoJWTLocation};
10
11#[derive(Default, Debug, Clone, PartialEq, Eq)]
13pub enum JWTLocation {
14 #[default]
15 Bearer,
16 Query(String),
17 Cookie(String),
18}
19
20impl From<&LocoJWTLocation> for JWTLocation {
22 fn from(loco_location: &LocoJWTLocation) -> Self {
23 match loco_location {
24 LocoJWTLocation::Bearer => Self::Bearer,
25 LocoJWTLocation::Query { name } => Self::Query(name.clone()),
26 LocoJWTLocation::Cookie { name } => Self::Cookie(name.clone()),
27 }
28 }
29}
30
31impl From<&AppContext> for JWTLocation {
33 fn from(ctx: &AppContext) -> Self {
34 ctx.config
35 .auth
36 .as_ref()
37 .and_then(|auth| auth.jwt.as_ref())
38 .and_then(|jwt| jwt.location.as_ref())
39 .map_or(Self::Bearer, std::convert::Into::into)
40 }
41}
42
43static JWT_LOCATION: OnceLock<Option<JWTLocation>> = OnceLock::new();
44
45pub fn set_jwt_location(jwt_location: JWTLocation) -> &'static Option<JWTLocation> {
47 JWT_LOCATION.get_or_init(|| Some(jwt_location))
48}
49
50pub fn get_jwt_location() -> Option<&'static JWTLocation> {
51 JWT_LOCATION.get().unwrap_or(&None).as_ref()
52}
53
54pub struct SecurityAddon;
56
57impl Modify for SecurityAddon {
58 fn modify(&self, openapi: &mut utoipa::openapi::OpenApi) {
59 if let Some(jwt_location) = get_jwt_location() {
60 if let Some(components) = openapi.components.as_mut() {
61 components.add_security_schemes_from_iter([
62 (
63 "jwt_token",
64 match jwt_location {
65 JWTLocation::Bearer => SecurityScheme::Http(
66 HttpBuilder::new()
67 .scheme(HttpAuthScheme::Bearer)
68 .bearer_format("JWT")
69 .build(),
70 ),
71 JWTLocation::Query(name) => {
72 SecurityScheme::ApiKey(ApiKey::Query(ApiKeyValue::new(name)))
73 }
74 JWTLocation::Cookie(name) => {
75 SecurityScheme::ApiKey(ApiKey::Cookie(ApiKeyValue::new(name)))
76 }
77 },
78 ),
79 (
80 "api_key",
81 SecurityScheme::ApiKey(ApiKey::Header(ApiKeyValue::new("apikey"))),
82 ),
83 ]);
84 }
85 }
86 }
87}
88
89#[cfg(test)]
90mod tests {
91 use super::*;
92
93 #[test]
94 fn test_default_jwt_location() {
95 assert_eq!(JWTLocation::default(), JWTLocation::Bearer);
96 }
97
98 #[test]
99 fn test_set_get_jwt_location() {
100 set_jwt_location(JWTLocation::Bearer);
101 assert_eq!(get_jwt_location(), Some(&JWTLocation::Bearer));
102 }
103
104 #[test]
105 fn test_from_loco_jwt_location() {
106 let loco_bearer = LocoJWTLocation::Bearer;
107 assert_eq!(JWTLocation::from(&loco_bearer), JWTLocation::Bearer);
108
109 let loco_query = LocoJWTLocation::Query {
110 name: "token".to_string(),
111 };
112 assert_eq!(
113 JWTLocation::from(&loco_query),
114 JWTLocation::Query("token".to_string())
115 );
116
117 let loco_cookie = LocoJWTLocation::Cookie {
118 name: "auth".to_string(),
119 };
120 assert_eq!(
121 JWTLocation::from(&loco_cookie),
122 JWTLocation::Cookie("auth".to_string())
123 );
124 }
125}