greentic_runner_host/http/
auth.rs

1use std::net::SocketAddr;
2
3use axum::Json;
4use axum::extract::connect_info::ConnectInfo;
5use axum::extract::{FromRef, FromRequestParts};
6use axum::http::StatusCode;
7use axum::http::header::AUTHORIZATION;
8use axum::http::request::Parts;
9use serde_json::json;
10
11use crate::runner::ServerState;
12
13#[derive(Clone, Default)]
14pub struct AdminAuth {
15    token: Option<String>,
16}
17
18impl AdminAuth {
19    pub fn from_env() -> Self {
20        let token = std::env::var("ADMIN_TOKEN")
21            .ok()
22            .map(|value| value.trim().to_string())
23            .filter(|value| !value.is_empty());
24        Self { token }
25    }
26
27    fn authorize(&self, addr: SocketAddr, bearer: Option<&str>) -> Result<(), StatusCode> {
28        if let Some(expected) = &self.token {
29            let token = bearer.ok_or(StatusCode::UNAUTHORIZED)?;
30            if constant_time_eq(token.as_bytes(), expected.as_bytes()) {
31                Ok(())
32            } else {
33                Err(StatusCode::UNAUTHORIZED)
34            }
35        } else if addr.ip().is_loopback() {
36            Ok(())
37        } else {
38            Err(StatusCode::FORBIDDEN)
39        }
40    }
41}
42
43pub struct AdminGuard;
44
45impl<S> FromRequestParts<S> for AdminGuard
46where
47    ServerState: FromRef<S>,
48    S: Send + Sync,
49{
50    type Rejection = (StatusCode, Json<serde_json::Value>);
51
52    fn from_request_parts(
53        parts: &mut Parts,
54        state: &S,
55    ) -> impl std::future::Future<Output = Result<Self, Self::Rejection>> + Send {
56        let server_state = ServerState::from_ref(state);
57        let admin = server_state.admin.clone();
58        let addr = parts
59            .extensions
60            .get::<ConnectInfo<SocketAddr>>()
61            .map(|info| info.0);
62        let bearer = extract_bearer(parts);
63
64        async move {
65            let addr = addr.ok_or((
66                StatusCode::INTERNAL_SERVER_ERROR,
67                Json(json!({ "error": "connect info unavailable" })),
68            ))?;
69            admin.authorize(addr, bearer.as_deref()).map_err(|status| {
70                (
71                    status,
72                    Json(json!({
73                        "error": if status == StatusCode::UNAUTHORIZED {
74                            "admin token required"
75                        } else {
76                            "admin access restricted"
77                        }
78                    })),
79                )
80            })?;
81            Ok(AdminGuard)
82        }
83    }
84}
85
86fn extract_bearer(parts: &Parts) -> Option<String> {
87    let header = parts.headers.get(AUTHORIZATION)?.to_str().ok()?;
88    header
89        .strip_prefix("Bearer ")
90        .map(|value| value.trim().to_string())
91}
92
93fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
94    if a.len() != b.len() {
95        return false;
96    }
97    let mut diff = 0u8;
98    for (&left, &right) in a.iter().zip(b.iter()) {
99        diff |= left ^ right;
100    }
101    diff == 0
102}
103
104#[cfg(test)]
105mod tests {
106    use super::*;
107
108    #[test]
109    fn loopback_without_token_is_allowed() {
110        let auth = AdminAuth::from_env();
111        assert!(auth.authorize("127.0.0.1:0".parse().unwrap(), None).is_ok());
112    }
113
114    #[test]
115    fn remote_without_token_is_forbidden() {
116        let auth = AdminAuth::default();
117        assert_eq!(
118            auth.authorize("10.0.0.1:0".parse().unwrap(), None),
119            Err(StatusCode::FORBIDDEN)
120        );
121    }
122
123    #[test]
124    fn token_requires_bearer() {
125        let auth = AdminAuth {
126            token: Some("secret".into()),
127        };
128        assert_eq!(
129            auth.authorize("127.0.0.1:0".parse().unwrap(), None),
130            Err(StatusCode::UNAUTHORIZED)
131        );
132        assert!(
133            auth.authorize("127.0.0.1:0".parse().unwrap(), Some("secret"))
134                .is_ok()
135        );
136    }
137}