Skip to main content

greentic_runner_host/http/
auth.rs

1use std::net::SocketAddr;
2
3use axum::Json;
4use axum::extract::connect_info::ConnectInfo;
5use axum::extract::{FromRef, FromRequestParts};
6use axum::http::StatusCode;
7use axum::http::header::AUTHORIZATION;
8use axum::http::request::Parts;
9use serde_json::json;
10
11use crate::runner::ServerState;
12
13#[derive(Clone, Default)]
14pub struct AdminAuth {
15    token: Option<String>,
16}
17
18impl AdminAuth {
19    pub fn new(token: Option<String>) -> Self {
20        Self {
21            token: token.filter(|v| !v.is_empty()),
22        }
23    }
24
25    fn authorize(&self, addr: SocketAddr, bearer: Option<&str>) -> Result<(), StatusCode> {
26        if let Some(expected) = &self.token {
27            let token = bearer.ok_or(StatusCode::UNAUTHORIZED)?;
28            if constant_time_eq(token.as_bytes(), expected.as_bytes()) {
29                Ok(())
30            } else {
31                Err(StatusCode::UNAUTHORIZED)
32            }
33        } else if addr.ip().is_loopback() {
34            Ok(())
35        } else {
36            Err(StatusCode::FORBIDDEN)
37        }
38    }
39}
40
41pub struct AdminGuard;
42
43impl<S> FromRequestParts<S> for AdminGuard
44where
45    ServerState: FromRef<S>,
46    S: Send + Sync,
47{
48    type Rejection = (StatusCode, Json<serde_json::Value>);
49
50    fn from_request_parts(
51        parts: &mut Parts,
52        state: &S,
53    ) -> impl std::future::Future<Output = Result<Self, Self::Rejection>> + Send {
54        let server_state = ServerState::from_ref(state);
55        let admin = server_state.admin.clone();
56        let addr = parts
57            .extensions
58            .get::<ConnectInfo<SocketAddr>>()
59            .map(|info| info.0);
60        let bearer = extract_bearer(parts);
61
62        async move {
63            let addr = addr.ok_or((
64                StatusCode::INTERNAL_SERVER_ERROR,
65                Json(json!({ "error": "connect info unavailable" })),
66            ))?;
67            admin.authorize(addr, bearer.as_deref()).map_err(|status| {
68                (
69                    status,
70                    Json(json!({
71                        "error": if status == StatusCode::UNAUTHORIZED {
72                            "admin token required"
73                        } else {
74                            "admin access restricted"
75                        }
76                    })),
77                )
78            })?;
79            Ok(AdminGuard)
80        }
81    }
82}
83
84fn extract_bearer(parts: &Parts) -> Option<String> {
85    let header = parts.headers.get(AUTHORIZATION)?.to_str().ok()?;
86    header
87        .strip_prefix("Bearer ")
88        .map(|value| value.trim().to_string())
89}
90
91fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
92    if a.len() != b.len() {
93        return false;
94    }
95    let mut diff = 0u8;
96    for (&left, &right) in a.iter().zip(b.iter()) {
97        diff |= left ^ right;
98    }
99    diff == 0
100}
101
102#[cfg(test)]
103mod tests {
104    use super::*;
105
106    #[test]
107    fn loopback_without_token_is_allowed() {
108        let auth = AdminAuth::new(None);
109        assert!(auth.authorize("127.0.0.1:0".parse().unwrap(), None).is_ok());
110    }
111
112    #[test]
113    fn remote_without_token_is_forbidden() {
114        let auth = AdminAuth::new(None);
115        assert_eq!(
116            auth.authorize("10.0.0.1:0".parse().unwrap(), None),
117            Err(StatusCode::FORBIDDEN)
118        );
119    }
120
121    #[test]
122    fn token_requires_bearer() {
123        let auth = AdminAuth {
124            token: Some("secret".into()),
125        };
126        assert_eq!(
127            auth.authorize("127.0.0.1:0".parse().unwrap(), None),
128            Err(StatusCode::UNAUTHORIZED)
129        );
130        assert!(
131            auth.authorize("127.0.0.1:0".parse().unwrap(), Some("secret"))
132                .is_ok()
133        );
134    }
135}