Skip to main content

by_loco_openapi/
auth.rs

1use std::sync::OnceLock;
2
3use utoipa::{
4    openapi::security::{ApiKey, ApiKeyValue, HttpAuthScheme, HttpBuilder, SecurityScheme},
5    Modify,
6};
7
8// Import Loco types for conversion
9use loco_rs::{app::AppContext, config::JWTLocation as LocoJWTLocation};
10
11// Our own JWTLocation enum that doesn't depend on Loco
12#[derive(Default, Debug, Clone, PartialEq, Eq)]
13pub enum JWTLocation {
14    #[default]
15    Bearer,
16    Query(String),
17    Cookie(String),
18}
19
20// Implement From trait for conversion from Loco type to our type
21impl From<&LocoJWTLocation> for JWTLocation {
22    fn from(loco_location: &LocoJWTLocation) -> Self {
23        match loco_location {
24            LocoJWTLocation::Bearer => Self::Bearer,
25            LocoJWTLocation::Query { name } => Self::Query(name.clone()),
26            LocoJWTLocation::Cookie { name } => Self::Cookie(name.clone()),
27        }
28    }
29}
30
31// Direct conversion from AppContext to JWTLocation for ease of use
32impl From<&AppContext> for JWTLocation {
33    fn from(ctx: &AppContext) -> Self {
34        ctx.config
35            .auth
36            .as_ref()
37            .and_then(|auth| auth.jwt.as_ref())
38            .and_then(|jwt| jwt.location.as_ref())
39            .map_or(Self::Bearer, std::convert::Into::into)
40    }
41}
42
43static JWT_LOCATION: OnceLock<Option<JWTLocation>> = OnceLock::new();
44
45// Main API for working with JWT location - independent from Loco
46pub fn set_jwt_location(jwt_location: JWTLocation) -> &'static Option<JWTLocation> {
47    JWT_LOCATION.get_or_init(|| Some(jwt_location))
48}
49
50pub fn get_jwt_location() -> Option<&'static JWTLocation> {
51    JWT_LOCATION.get().unwrap_or(&None).as_ref()
52}
53
54// Security implementation using our JWTLocation
55pub struct SecurityAddon;
56
57impl Modify for SecurityAddon {
58    fn modify(&self, openapi: &mut utoipa::openapi::OpenApi) {
59        if let Some(jwt_location) = get_jwt_location() {
60            if let Some(components) = openapi.components.as_mut() {
61                components.add_security_schemes_from_iter([
62                    (
63                        "jwt_token",
64                        match jwt_location {
65                            JWTLocation::Bearer => SecurityScheme::Http(
66                                HttpBuilder::new()
67                                    .scheme(HttpAuthScheme::Bearer)
68                                    .bearer_format("JWT")
69                                    .build(),
70                            ),
71                            JWTLocation::Query(name) => {
72                                SecurityScheme::ApiKey(ApiKey::Query(ApiKeyValue::new(name)))
73                            }
74                            JWTLocation::Cookie(name) => {
75                                SecurityScheme::ApiKey(ApiKey::Cookie(ApiKeyValue::new(name)))
76                            }
77                        },
78                    ),
79                    (
80                        "api_key",
81                        SecurityScheme::ApiKey(ApiKey::Header(ApiKeyValue::new("apikey"))),
82                    ),
83                ]);
84            }
85        }
86    }
87}
88
89#[cfg(test)]
90mod tests {
91    use super::*;
92
93    #[test]
94    fn test_default_jwt_location() {
95        assert_eq!(JWTLocation::default(), JWTLocation::Bearer);
96    }
97
98    #[test]
99    fn test_set_get_jwt_location() {
100        set_jwt_location(JWTLocation::Bearer);
101        assert_eq!(get_jwt_location(), Some(&JWTLocation::Bearer));
102    }
103
104    #[test]
105    fn test_from_loco_jwt_location() {
106        let loco_bearer = LocoJWTLocation::Bearer;
107        assert_eq!(JWTLocation::from(&loco_bearer), JWTLocation::Bearer);
108
109        let loco_query = LocoJWTLocation::Query {
110            name: "token".to_string(),
111        };
112        assert_eq!(
113            JWTLocation::from(&loco_query),
114            JWTLocation::Query("token".to_string())
115        );
116
117        let loco_cookie = LocoJWTLocation::Cookie {
118            name: "auth".to_string(),
119        };
120        assert_eq!(
121            JWTLocation::from(&loco_cookie),
122            JWTLocation::Cookie("auth".to_string())
123        );
124    }
125}