greentic_runner_host/http/
auth.rs1use std::net::SocketAddr;
2
3use axum::Json;
4use axum::extract::connect_info::ConnectInfo;
5use axum::extract::{FromRef, FromRequestParts};
6use axum::http::StatusCode;
7use axum::http::header::AUTHORIZATION;
8use axum::http::request::Parts;
9use serde_json::json;
10
11use crate::runner::ServerState;
12
13#[derive(Clone, Default)]
14pub struct AdminAuth {
15 token: Option<String>,
16}
17
18impl AdminAuth {
19 pub fn from_env() -> Self {
20 let token = std::env::var("ADMIN_TOKEN")
21 .ok()
22 .map(|value| value.trim().to_string())
23 .filter(|value| !value.is_empty());
24 Self { token }
25 }
26
27 fn authorize(&self, addr: SocketAddr, bearer: Option<&str>) -> Result<(), StatusCode> {
28 if let Some(expected) = &self.token {
29 let token = bearer.ok_or(StatusCode::UNAUTHORIZED)?;
30 if constant_time_eq(token.as_bytes(), expected.as_bytes()) {
31 Ok(())
32 } else {
33 Err(StatusCode::UNAUTHORIZED)
34 }
35 } else if addr.ip().is_loopback() {
36 Ok(())
37 } else {
38 Err(StatusCode::FORBIDDEN)
39 }
40 }
41}
42
43pub struct AdminGuard;
44
45impl<S> FromRequestParts<S> for AdminGuard
46where
47 ServerState: FromRef<S>,
48 S: Send + Sync,
49{
50 type Rejection = (StatusCode, Json<serde_json::Value>);
51
52 fn from_request_parts(
53 parts: &mut Parts,
54 state: &S,
55 ) -> impl std::future::Future<Output = Result<Self, Self::Rejection>> + Send {
56 let server_state = ServerState::from_ref(state);
57 let admin = server_state.admin.clone();
58 let addr = parts
59 .extensions
60 .get::<ConnectInfo<SocketAddr>>()
61 .map(|info| info.0);
62 let bearer = extract_bearer(parts);
63
64 async move {
65 let addr = addr.ok_or((
66 StatusCode::INTERNAL_SERVER_ERROR,
67 Json(json!({ "error": "connect info unavailable" })),
68 ))?;
69 admin.authorize(addr, bearer.as_deref()).map_err(|status| {
70 (
71 status,
72 Json(json!({
73 "error": if status == StatusCode::UNAUTHORIZED {
74 "admin token required"
75 } else {
76 "admin access restricted"
77 }
78 })),
79 )
80 })?;
81 Ok(AdminGuard)
82 }
83 }
84}
85
86fn extract_bearer(parts: &Parts) -> Option<String> {
87 let header = parts.headers.get(AUTHORIZATION)?.to_str().ok()?;
88 header
89 .strip_prefix("Bearer ")
90 .map(|value| value.trim().to_string())
91}
92
93fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
94 if a.len() != b.len() {
95 return false;
96 }
97 let mut diff = 0u8;
98 for (&left, &right) in a.iter().zip(b.iter()) {
99 diff |= left ^ right;
100 }
101 diff == 0
102}
103
104#[cfg(test)]
105mod tests {
106 use super::*;
107
108 #[test]
109 fn loopback_without_token_is_allowed() {
110 let auth = AdminAuth::from_env();
111 assert!(auth.authorize("127.0.0.1:0".parse().unwrap(), None).is_ok());
112 }
113
114 #[test]
115 fn remote_without_token_is_forbidden() {
116 let auth = AdminAuth::default();
117 assert_eq!(
118 auth.authorize("10.0.0.1:0".parse().unwrap(), None),
119 Err(StatusCode::FORBIDDEN)
120 );
121 }
122
123 #[test]
124 fn token_requires_bearer() {
125 let auth = AdminAuth {
126 token: Some("secret".into()),
127 };
128 assert_eq!(
129 auth.authorize("127.0.0.1:0".parse().unwrap(), None),
130 Err(StatusCode::UNAUTHORIZED)
131 );
132 assert!(
133 auth.authorize("127.0.0.1:0".parse().unwrap(), Some("secret"))
134 .is_ok()
135 );
136 }
137}