Skip to main content

greentic_runner_host/http/
auth.rs

1use std::net::SocketAddr;
2
3use axum::Json;
4use axum::extract::connect_info::ConnectInfo;
5use axum::extract::{FromRef, FromRequestParts};
6use axum::http::StatusCode;
7use axum::http::header::AUTHORIZATION;
8use axum::http::request::Parts;
9use serde_json::json;
10
11use crate::runner::ServerState;
12
13#[derive(Clone, Default)]
14pub struct AdminAuth {
15    token: Option<String>,
16}
17
18impl AdminAuth {
19    pub fn new(token: Option<String>) -> Self {
20        Self {
21            token: token.filter(|v| !v.is_empty()),
22        }
23    }
24
25    fn authorize(&self, addr: SocketAddr, bearer: Option<&str>) -> Result<(), StatusCode> {
26        if let Some(expected) = &self.token {
27            let token = bearer.ok_or(StatusCode::UNAUTHORIZED)?;
28            if constant_time_eq(token.as_bytes(), expected.as_bytes()) {
29                Ok(())
30            } else {
31                Err(StatusCode::UNAUTHORIZED)
32            }
33        } else if addr.ip().is_loopback() {
34            Ok(())
35        } else {
36            Err(StatusCode::FORBIDDEN)
37        }
38    }
39}
40
41pub struct AdminGuard;
42
43impl<S> FromRequestParts<S> for AdminGuard
44where
45    ServerState: FromRef<S>,
46    S: Send + Sync,
47{
48    type Rejection = (StatusCode, Json<serde_json::Value>);
49
50    fn from_request_parts(
51        parts: &mut Parts,
52        state: &S,
53    ) -> impl std::future::Future<Output = Result<Self, Self::Rejection>> + Send {
54        let server_state = ServerState::from_ref(state);
55        let admin = server_state.admin.clone();
56        let addr = parts
57            .extensions
58            .get::<ConnectInfo<SocketAddr>>()
59            .map(|info| info.0);
60        let bearer = extract_bearer(parts);
61
62        async move {
63            let addr = addr.ok_or((
64                StatusCode::INTERNAL_SERVER_ERROR,
65                Json(json!({ "error": "connect info unavailable" })),
66            ))?;
67            admin.authorize(addr, bearer.as_deref()).map_err(|status| {
68                (
69                    status,
70                    Json(json!({
71                        "error": if status == StatusCode::UNAUTHORIZED {
72                            "admin token required"
73                        } else {
74                            "admin access restricted"
75                        }
76                    })),
77                )
78            })?;
79            Ok(AdminGuard)
80        }
81    }
82}
83
84fn extract_bearer(parts: &Parts) -> Option<String> {
85    let header = parts.headers.get(AUTHORIZATION)?.to_str().ok()?;
86    let (scheme, value) = header.split_once(' ')?;
87    if !scheme.eq_ignore_ascii_case("Bearer") {
88        return None;
89    }
90    Some(value.trim().to_string())
91}
92
93fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
94    if a.len() != b.len() {
95        return false;
96    }
97    let mut diff = 0u8;
98    for (&left, &right) in a.iter().zip(b.iter()) {
99        diff |= left ^ right;
100    }
101    diff == 0
102}
103
104#[cfg(test)]
105mod tests {
106    use super::*;
107    use axum::extract::FromRef;
108    use axum::extract::connect_info::ConnectInfo;
109    use axum::http::Request;
110    use std::sync::Arc;
111
112    use crate::http::health::HealthState;
113    use crate::routing::{RoutingConfig, TenantRouting};
114    use crate::runner::ServerState;
115    use crate::runtime::ActivePacks;
116
117    #[derive(Clone)]
118    struct AppState {
119        server: ServerState,
120    }
121
122    impl FromRef<AppState> for ServerState {
123        fn from_ref(input: &AppState) -> Self {
124            input.server.clone()
125        }
126    }
127
128    fn server_state(admin: AdminAuth) -> AppState {
129        AppState {
130            server: ServerState {
131                active: Arc::new(ActivePacks::new()),
132                routing: TenantRouting::new(RoutingConfig::default()),
133                health: Arc::new(HealthState::new()),
134                reload: None,
135                admin,
136            },
137        }
138    }
139
140    #[test]
141    fn loopback_without_token_is_allowed() {
142        let auth = AdminAuth::new(None);
143        assert!(auth.authorize("127.0.0.1:0".parse().unwrap(), None).is_ok());
144    }
145
146    #[test]
147    fn remote_without_token_is_forbidden() {
148        let auth = AdminAuth::new(None);
149        assert_eq!(
150            auth.authorize("10.0.0.1:0".parse().unwrap(), None),
151            Err(StatusCode::FORBIDDEN)
152        );
153    }
154
155    #[test]
156    fn token_requires_bearer() {
157        let auth = AdminAuth {
158            token: Some("secret".into()),
159        };
160        assert_eq!(
161            auth.authorize("127.0.0.1:0".parse().unwrap(), None),
162            Err(StatusCode::UNAUTHORIZED)
163        );
164        assert!(
165            auth.authorize("127.0.0.1:0".parse().unwrap(), Some("secret"))
166                .is_ok()
167        );
168    }
169
170    #[test]
171    fn bearer_scheme_is_case_insensitive() {
172        let (parts, _) = axum::http::Request::builder()
173            .header(AUTHORIZATION, "bearer secret")
174            .body(())
175            .expect("request")
176            .into_parts();
177
178        assert_eq!(extract_bearer(&parts).as_deref(), Some("secret"));
179    }
180
181    #[test]
182    fn non_bearer_authorization_header_is_rejected() {
183        let (parts, _) = axum::http::Request::builder()
184            .header(AUTHORIZATION, "Basic dXNlcjpzZWNyZXQ=")
185            .body(())
186            .expect("request")
187            .into_parts();
188
189        assert_eq!(extract_bearer(&parts), None);
190    }
191
192    #[test]
193    fn empty_admin_token_is_treated_as_disabled() {
194        let auth = AdminAuth::new(Some(String::new()));
195        assert!(auth.authorize("127.0.0.1:0".parse().unwrap(), None).is_ok());
196    }
197
198    #[test]
199    fn wrong_bearer_token_is_rejected() {
200        let auth = AdminAuth::new(Some("secret".into()));
201        assert_eq!(
202            auth.authorize("127.0.0.1:0".parse().unwrap(), Some("wrong")),
203            Err(StatusCode::UNAUTHORIZED)
204        );
205    }
206
207    #[test]
208    fn constant_time_eq_rejects_length_mismatch() {
209        assert!(!constant_time_eq(b"short", b"longer"));
210    }
211
212    #[test]
213    fn malformed_authorization_header_is_rejected() {
214        let (parts, _) = axum::http::Request::builder()
215            .header(AUTHORIZATION, "Bearer")
216            .body(())
217            .expect("request")
218            .into_parts();
219
220        assert_eq!(extract_bearer(&parts), None);
221    }
222
223    #[tokio::test]
224    async fn admin_guard_rejects_missing_connect_info() {
225        let (mut parts, _) = Request::builder().body(()).expect("request").into_parts();
226        let state = server_state(AdminAuth::default());
227
228        let rejection = match AdminGuard::from_request_parts(&mut parts, &state).await {
229            Ok(_) => panic!("missing connect info should reject"),
230            Err(rejection) => rejection,
231        };
232
233        assert_eq!(rejection.0, StatusCode::INTERNAL_SERVER_ERROR);
234        assert_eq!(rejection.1.0["error"], "connect info unavailable");
235    }
236
237    #[tokio::test]
238    async fn admin_guard_rejects_wrong_remote_token() {
239        let (mut parts, _) = Request::builder()
240            .header(AUTHORIZATION, "Bearer wrong")
241            .body(())
242            .expect("request")
243            .into_parts();
244        parts.extensions.insert(ConnectInfo(
245            "10.0.0.2:8080".parse::<std::net::SocketAddr>().unwrap(),
246        ));
247        let state = server_state(AdminAuth::new(Some("secret".into())));
248
249        let rejection = match AdminGuard::from_request_parts(&mut parts, &state).await {
250            Ok(_) => panic!("wrong token should reject"),
251            Err(rejection) => rejection,
252        };
253
254        assert_eq!(rejection.0, StatusCode::UNAUTHORIZED);
255        assert_eq!(rejection.1.0["error"], "admin token required");
256    }
257
258    #[tokio::test]
259    async fn admin_guard_allows_loopback_without_token_when_disabled() {
260        let (mut parts, _) = Request::builder().body(()).expect("request").into_parts();
261        parts.extensions.insert(ConnectInfo(
262            "127.0.0.1:8080".parse::<std::net::SocketAddr>().unwrap(),
263        ));
264        let state = server_state(AdminAuth::default());
265
266        AdminGuard::from_request_parts(&mut parts, &state)
267            .await
268            .expect("loopback should pass without token");
269    }
270}