greentic_runner_host/http/
auth.rs1use std::net::SocketAddr;
2
3use axum::Json;
4use axum::extract::connect_info::ConnectInfo;
5use axum::extract::{FromRef, FromRequestParts};
6use axum::http::StatusCode;
7use axum::http::header::AUTHORIZATION;
8use axum::http::request::Parts;
9use serde_json::json;
10
11use crate::runner::ServerState;
12
13#[derive(Clone, Default)]
14pub struct AdminAuth {
15 token: Option<String>,
16}
17
18impl AdminAuth {
19 pub fn new(token: Option<String>) -> Self {
20 Self {
21 token: token.filter(|v| !v.is_empty()),
22 }
23 }
24
25 fn authorize(&self, addr: SocketAddr, bearer: Option<&str>) -> Result<(), StatusCode> {
26 if let Some(expected) = &self.token {
27 let token = bearer.ok_or(StatusCode::UNAUTHORIZED)?;
28 if constant_time_eq(token.as_bytes(), expected.as_bytes()) {
29 Ok(())
30 } else {
31 Err(StatusCode::UNAUTHORIZED)
32 }
33 } else if addr.ip().is_loopback() {
34 Ok(())
35 } else {
36 Err(StatusCode::FORBIDDEN)
37 }
38 }
39}
40
41pub struct AdminGuard;
42
43impl<S> FromRequestParts<S> for AdminGuard
44where
45 ServerState: FromRef<S>,
46 S: Send + Sync,
47{
48 type Rejection = (StatusCode, Json<serde_json::Value>);
49
50 fn from_request_parts(
51 parts: &mut Parts,
52 state: &S,
53 ) -> impl std::future::Future<Output = Result<Self, Self::Rejection>> + Send {
54 let server_state = ServerState::from_ref(state);
55 let admin = server_state.admin.clone();
56 let addr = parts
57 .extensions
58 .get::<ConnectInfo<SocketAddr>>()
59 .map(|info| info.0);
60 let bearer = extract_bearer(parts);
61
62 async move {
63 let addr = addr.ok_or((
64 StatusCode::INTERNAL_SERVER_ERROR,
65 Json(json!({ "error": "connect info unavailable" })),
66 ))?;
67 admin.authorize(addr, bearer.as_deref()).map_err(|status| {
68 (
69 status,
70 Json(json!({
71 "error": if status == StatusCode::UNAUTHORIZED {
72 "admin token required"
73 } else {
74 "admin access restricted"
75 }
76 })),
77 )
78 })?;
79 Ok(AdminGuard)
80 }
81 }
82}
83
84fn extract_bearer(parts: &Parts) -> Option<String> {
85 let header = parts.headers.get(AUTHORIZATION)?.to_str().ok()?;
86 header
87 .strip_prefix("Bearer ")
88 .map(|value| value.trim().to_string())
89}
90
91fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
92 if a.len() != b.len() {
93 return false;
94 }
95 let mut diff = 0u8;
96 for (&left, &right) in a.iter().zip(b.iter()) {
97 diff |= left ^ right;
98 }
99 diff == 0
100}
101
102#[cfg(test)]
103mod tests {
104 use super::*;
105
106 #[test]
107 fn loopback_without_token_is_allowed() {
108 let auth = AdminAuth::new(None);
109 assert!(auth.authorize("127.0.0.1:0".parse().unwrap(), None).is_ok());
110 }
111
112 #[test]
113 fn remote_without_token_is_forbidden() {
114 let auth = AdminAuth::new(None);
115 assert_eq!(
116 auth.authorize("10.0.0.1:0".parse().unwrap(), None),
117 Err(StatusCode::FORBIDDEN)
118 );
119 }
120
121 #[test]
122 fn token_requires_bearer() {
123 let auth = AdminAuth {
124 token: Some("secret".into()),
125 };
126 assert_eq!(
127 auth.authorize("127.0.0.1:0".parse().unwrap(), None),
128 Err(StatusCode::UNAUTHORIZED)
129 );
130 assert!(
131 auth.authorize("127.0.0.1:0".parse().unwrap(), Some("secret"))
132 .is_ok()
133 );
134 }
135}