greentic_runner_host/http/
auth.rs1use std::net::SocketAddr;
2
3use axum::Json;
4use axum::extract::connect_info::ConnectInfo;
5use axum::extract::{FromRef, FromRequestParts};
6use axum::http::StatusCode;
7use axum::http::header::AUTHORIZATION;
8use axum::http::request::Parts;
9use serde_json::json;
10
11use crate::runner::ServerState;
12
13#[derive(Clone, Default)]
14pub struct AdminAuth {
15 token: Option<String>,
16}
17
18impl AdminAuth {
19 pub fn new(token: Option<String>) -> Self {
20 Self {
21 token: token.filter(|v| !v.is_empty()),
22 }
23 }
24
25 fn authorize(&self, addr: SocketAddr, bearer: Option<&str>) -> Result<(), StatusCode> {
26 if let Some(expected) = &self.token {
27 let token = bearer.ok_or(StatusCode::UNAUTHORIZED)?;
28 if constant_time_eq(token.as_bytes(), expected.as_bytes()) {
29 Ok(())
30 } else {
31 Err(StatusCode::UNAUTHORIZED)
32 }
33 } else if addr.ip().is_loopback() {
34 Ok(())
35 } else {
36 Err(StatusCode::FORBIDDEN)
37 }
38 }
39}
40
41pub struct AdminGuard;
42
43impl<S> FromRequestParts<S> for AdminGuard
44where
45 ServerState: FromRef<S>,
46 S: Send + Sync,
47{
48 type Rejection = (StatusCode, Json<serde_json::Value>);
49
50 fn from_request_parts(
51 parts: &mut Parts,
52 state: &S,
53 ) -> impl std::future::Future<Output = Result<Self, Self::Rejection>> + Send {
54 let server_state = ServerState::from_ref(state);
55 let admin = server_state.admin.clone();
56 let addr = parts
57 .extensions
58 .get::<ConnectInfo<SocketAddr>>()
59 .map(|info| info.0);
60 let bearer = extract_bearer(parts);
61
62 async move {
63 let addr = addr.ok_or((
64 StatusCode::INTERNAL_SERVER_ERROR,
65 Json(json!({ "error": "connect info unavailable" })),
66 ))?;
67 admin.authorize(addr, bearer.as_deref()).map_err(|status| {
68 (
69 status,
70 Json(json!({
71 "error": if status == StatusCode::UNAUTHORIZED {
72 "admin token required"
73 } else {
74 "admin access restricted"
75 }
76 })),
77 )
78 })?;
79 Ok(AdminGuard)
80 }
81 }
82}
83
84fn extract_bearer(parts: &Parts) -> Option<String> {
85 let header = parts.headers.get(AUTHORIZATION)?.to_str().ok()?;
86 let (scheme, value) = header.split_once(' ')?;
87 if !scheme.eq_ignore_ascii_case("Bearer") {
88 return None;
89 }
90 Some(value.trim().to_string())
91}
92
93fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
94 if a.len() != b.len() {
95 return false;
96 }
97 let mut diff = 0u8;
98 for (&left, &right) in a.iter().zip(b.iter()) {
99 diff |= left ^ right;
100 }
101 diff == 0
102}
103
104#[cfg(test)]
105mod tests {
106 use super::*;
107 use axum::extract::FromRef;
108 use axum::extract::connect_info::ConnectInfo;
109 use axum::http::Request;
110 use std::sync::Arc;
111
112 use crate::http::health::HealthState;
113 use crate::routing::{RoutingConfig, TenantRouting};
114 use crate::runner::ServerState;
115 use crate::runtime::ActivePacks;
116
117 #[derive(Clone)]
118 struct AppState {
119 server: ServerState,
120 }
121
122 impl FromRef<AppState> for ServerState {
123 fn from_ref(input: &AppState) -> Self {
124 input.server.clone()
125 }
126 }
127
128 fn server_state(admin: AdminAuth) -> AppState {
129 AppState {
130 server: ServerState {
131 active: Arc::new(ActivePacks::new()),
132 routing: TenantRouting::new(RoutingConfig::default()),
133 health: Arc::new(HealthState::new()),
134 reload: None,
135 admin,
136 },
137 }
138 }
139
140 #[test]
141 fn loopback_without_token_is_allowed() {
142 let auth = AdminAuth::new(None);
143 assert!(auth.authorize("127.0.0.1:0".parse().unwrap(), None).is_ok());
144 }
145
146 #[test]
147 fn remote_without_token_is_forbidden() {
148 let auth = AdminAuth::new(None);
149 assert_eq!(
150 auth.authorize("10.0.0.1:0".parse().unwrap(), None),
151 Err(StatusCode::FORBIDDEN)
152 );
153 }
154
155 #[test]
156 fn token_requires_bearer() {
157 let auth = AdminAuth {
158 token: Some("secret".into()),
159 };
160 assert_eq!(
161 auth.authorize("127.0.0.1:0".parse().unwrap(), None),
162 Err(StatusCode::UNAUTHORIZED)
163 );
164 assert!(
165 auth.authorize("127.0.0.1:0".parse().unwrap(), Some("secret"))
166 .is_ok()
167 );
168 }
169
170 #[test]
171 fn bearer_scheme_is_case_insensitive() {
172 let (parts, _) = axum::http::Request::builder()
173 .header(AUTHORIZATION, "bearer secret")
174 .body(())
175 .expect("request")
176 .into_parts();
177
178 assert_eq!(extract_bearer(&parts).as_deref(), Some("secret"));
179 }
180
181 #[test]
182 fn non_bearer_authorization_header_is_rejected() {
183 let (parts, _) = axum::http::Request::builder()
184 .header(AUTHORIZATION, "Basic dXNlcjpzZWNyZXQ=")
185 .body(())
186 .expect("request")
187 .into_parts();
188
189 assert_eq!(extract_bearer(&parts), None);
190 }
191
192 #[test]
193 fn empty_admin_token_is_treated_as_disabled() {
194 let auth = AdminAuth::new(Some(String::new()));
195 assert!(auth.authorize("127.0.0.1:0".parse().unwrap(), None).is_ok());
196 }
197
198 #[test]
199 fn wrong_bearer_token_is_rejected() {
200 let auth = AdminAuth::new(Some("secret".into()));
201 assert_eq!(
202 auth.authorize("127.0.0.1:0".parse().unwrap(), Some("wrong")),
203 Err(StatusCode::UNAUTHORIZED)
204 );
205 }
206
207 #[test]
208 fn constant_time_eq_rejects_length_mismatch() {
209 assert!(!constant_time_eq(b"short", b"longer"));
210 }
211
212 #[test]
213 fn malformed_authorization_header_is_rejected() {
214 let (parts, _) = axum::http::Request::builder()
215 .header(AUTHORIZATION, "Bearer")
216 .body(())
217 .expect("request")
218 .into_parts();
219
220 assert_eq!(extract_bearer(&parts), None);
221 }
222
223 #[tokio::test]
224 async fn admin_guard_rejects_missing_connect_info() {
225 let (mut parts, _) = Request::builder().body(()).expect("request").into_parts();
226 let state = server_state(AdminAuth::default());
227
228 let rejection = match AdminGuard::from_request_parts(&mut parts, &state).await {
229 Ok(_) => panic!("missing connect info should reject"),
230 Err(rejection) => rejection,
231 };
232
233 assert_eq!(rejection.0, StatusCode::INTERNAL_SERVER_ERROR);
234 assert_eq!(rejection.1.0["error"], "connect info unavailable");
235 }
236
237 #[tokio::test]
238 async fn admin_guard_rejects_wrong_remote_token() {
239 let (mut parts, _) = Request::builder()
240 .header(AUTHORIZATION, "Bearer wrong")
241 .body(())
242 .expect("request")
243 .into_parts();
244 parts.extensions.insert(ConnectInfo(
245 "10.0.0.2:8080".parse::<std::net::SocketAddr>().unwrap(),
246 ));
247 let state = server_state(AdminAuth::new(Some("secret".into())));
248
249 let rejection = match AdminGuard::from_request_parts(&mut parts, &state).await {
250 Ok(_) => panic!("wrong token should reject"),
251 Err(rejection) => rejection,
252 };
253
254 assert_eq!(rejection.0, StatusCode::UNAUTHORIZED);
255 assert_eq!(rejection.1.0["error"], "admin token required");
256 }
257
258 #[tokio::test]
259 async fn admin_guard_allows_loopback_without_token_when_disabled() {
260 let (mut parts, _) = Request::builder().body(()).expect("request").into_parts();
261 parts.extensions.insert(ConnectInfo(
262 "127.0.0.1:8080".parse::<std::net::SocketAddr>().unwrap(),
263 ));
264 let state = server_state(AdminAuth::default());
265
266 AdminGuard::from_request_parts(&mut parts, &state)
267 .await
268 .expect("loopback should pass without token");
269 }
270}