1use std::net::SocketAddr;
2use std::sync::Arc;
3use std::task::{Context, Poll};
4
5use subtle::ConstantTimeEq;
6
7use axum::body::Body;
8use axum::extract::connect_info::ConnectInfo;
9use axum::http::{Request, Response, StatusCode};
10use futures_util::future::BoxFuture;
11use tower::{Layer, Service};
12use tracing::warn;
13
14#[derive(Clone)]
15pub struct ApiKeyLayer {
16 key: Option<Arc<str>>,
17}
18
19impl ApiKeyLayer {
20 pub fn new(key: Option<String>) -> Self {
21 Self {
22 key: key.map(|k| Arc::from(k.as_str())),
23 }
24 }
25}
26
27impl<S> Layer<S> for ApiKeyLayer {
28 type Service = ApiKeyMiddleware<S>;
29
30 fn layer(&self, inner: S) -> Self::Service {
31 ApiKeyMiddleware {
32 inner,
33 key: self.key.clone(),
34 }
35 }
36}
37
38#[derive(Clone)]
39pub struct ApiKeyMiddleware<S> {
40 inner: S,
41 key: Option<Arc<str>>,
42}
43
44fn is_exempt(path: &str) -> bool {
56 path == "/"
57 || path == "/api/health"
58 || path == "/api/webhooks/telegram"
59 || path == "/api/webhooks/whatsapp"
60 || path == "/.well-known/agent.json"
61}
62
63fn extract_api_key(req: &Request<Body>) -> Option<String> {
64 if let Some(val) = req.headers().get("x-api-key")
65 && let Ok(s) = val.to_str()
66 {
67 return Some(s.to_string());
68 }
69 if let Some(val) = req.headers().get("authorization")
70 && let Ok(s) = val.to_str()
71 && let Some(token) = s.strip_prefix("Bearer ")
72 {
73 return Some(token.to_string());
74 }
75 None
78}
79
80pub(crate) fn extract_auth_principal(req: &Request<Body>) -> Option<String> {
81 if req.headers().contains_key("x-api-key") {
82 return Some("api_key".to_string());
83 }
84 if let Some(val) = req.headers().get("authorization")
85 && let Ok(s) = val.to_str()
86 && s.starts_with("Bearer ")
87 {
88 return Some("bearer".to_string());
89 }
90 None
91}
92
93fn unauthorized_response() -> Response<Body> {
94 let body = serde_json::json!({"error": "unauthorized", "message": "Valid API key required"});
95 let bytes = serde_json::to_vec(&body).unwrap_or_else(|_| {
96 br#"{"error":"unauthorized","message":"Valid API key required"}"#.to_vec()
97 });
98 let mut response = Response::new(Body::from(bytes));
99 *response.status_mut() = StatusCode::UNAUTHORIZED;
100 response.headers_mut().insert(
101 axum::http::header::CONTENT_TYPE,
102 axum::http::HeaderValue::from_static("application/json"),
103 );
104 response
105}
106
107impl<S> Service<Request<Body>> for ApiKeyMiddleware<S>
108where
109 S: Service<Request<Body>, Response = Response<Body>> + Send + Clone + 'static,
110 S::Future: Send + 'static,
111{
112 type Response = S::Response;
113 type Error = S::Error;
114 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
115
116 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
117 self.inner.poll_ready(cx)
118 }
119
120 fn call(&mut self, req: Request<Body>) -> Self::Future {
121 let key = self.key.clone();
122 let mut inner = self.inner.clone();
123
124 Box::pin(async move {
125 let path = req.uri().path().to_string();
126 if let Some(ref expected) = key {
127 if !is_exempt(&path) {
128 match extract_api_key(&req) {
129 Some(provided)
130 if bool::from(provided.as_bytes().ct_eq(expected.as_bytes())) => {}
131 _ => return Ok(unauthorized_response()),
132 }
133 }
134 } else if !is_exempt(&path) {
135 let is_loopback = req
137 .extensions()
138 .get::<ConnectInfo<SocketAddr>>()
139 .map(|ci| ci.0.ip().is_loopback())
140 .unwrap_or(false);
141 if !is_loopback {
142 warn!(
143 path = %path,
144 "rejected non-loopback request: no API key configured — set server.api_key"
145 );
146 return Ok(unauthorized_response());
147 }
148 }
149 inner.call(req).await
150 })
151 }
152}
153
154#[cfg(test)]
155mod tests {
156 use super::*;
157
158 #[test]
159 fn exempt_paths() {
160 assert!(is_exempt("/"));
161 assert!(!is_exempt("/ws"));
162 assert!(is_exempt("/api/health"));
163 assert!(is_exempt("/api/webhooks/telegram"));
164 assert!(is_exempt("/api/webhooks/whatsapp"));
165 assert!(!is_exempt("/api/config"));
166 assert!(!is_exempt("/api/sessions"));
167 assert!(!is_exempt("/api/agent/message"));
168 }
169
170 #[test]
171 fn extract_bearer_token() {
172 let req = Request::builder()
173 .header("authorization", "Bearer test-key-123")
174 .body(Body::empty())
175 .unwrap();
176 assert_eq!(extract_api_key(&req).as_deref(), Some("test-key-123"));
177 }
178
179 #[test]
180 fn extract_x_api_key_header() {
181 let req = Request::builder()
182 .header("x-api-key", "test-key-789")
183 .body(Body::empty())
184 .unwrap();
185 assert_eq!(extract_api_key(&req).as_deref(), Some("test-key-789"));
186 }
187
188 #[test]
189 fn query_token_no_longer_accepted() {
190 let req = Request::builder()
192 .uri("/ws?token=query-key-456")
193 .body(Body::empty())
194 .unwrap();
195 assert_eq!(extract_api_key(&req), None);
196 }
197
198 #[test]
199 fn query_token_not_accepted_for_non_ws_paths() {
200 let req = Request::builder()
201 .uri("/api/health?token=query-key-456")
202 .body(Body::empty())
203 .unwrap();
204 assert_eq!(extract_api_key(&req), None);
205 }
206
207 #[test]
208 fn no_key_returns_none() {
209 let req = Request::builder().body(Body::empty()).unwrap();
210 assert_eq!(extract_api_key(&req), None);
211 }
212
213 #[test]
214 fn x_api_key_takes_precedence() {
215 let req = Request::builder()
216 .header("x-api-key", "header-key")
217 .header("authorization", "Bearer bearer-key")
218 .body(Body::empty())
219 .unwrap();
220 assert_eq!(extract_api_key(&req).as_deref(), Some("header-key"));
221 }
222
223 #[test]
224 fn extract_auth_principal_prefers_api_key() {
225 let req = Request::builder()
226 .header("x-api-key", "abc")
227 .header("authorization", "Bearer token")
228 .body(Body::empty())
229 .unwrap();
230 assert_eq!(extract_auth_principal(&req).as_deref(), Some("api_key"));
231 }
232
233 #[test]
234 fn extract_auth_principal_bearer() {
235 let req = Request::builder()
236 .header("authorization", "Bearer token")
237 .body(Body::empty())
238 .unwrap();
239 assert_eq!(extract_auth_principal(&req).as_deref(), Some("bearer"));
240 }
241
242 #[test]
243 fn unknown_webhook_not_exempt() {
244 assert!(!is_exempt("/api/webhooks/unknown"));
245 assert!(!is_exempt("/api/webhooks/"));
246 }
247
248 #[test]
249 fn no_key_rejects_non_loopback() {
250 let mut req = Request::builder()
252 .uri("/api/sessions")
253 .body(Body::empty())
254 .unwrap();
255 let addr: SocketAddr = "192.168.1.5:12345".parse().unwrap();
257 req.extensions_mut().insert(ConnectInfo(addr));
258 let is_loopback = req
259 .extensions()
260 .get::<ConnectInfo<SocketAddr>>()
261 .map(|ci| ci.0.ip().is_loopback())
262 .unwrap_or(false);
263 assert!(!is_loopback);
264 }
265
266 #[test]
267 fn no_key_allows_loopback() {
268 let mut req = Request::builder()
269 .uri("/api/sessions")
270 .body(Body::empty())
271 .unwrap();
272 let addr: SocketAddr = "127.0.0.1:12345".parse().unwrap();
273 req.extensions_mut().insert(ConnectInfo(addr));
274 let is_loopback = req
275 .extensions()
276 .get::<ConnectInfo<SocketAddr>>()
277 .map(|ci| ci.0.ip().is_loopback())
278 .unwrap_or(false);
279 assert!(is_loopback);
280 }
281
282 #[test]
283 fn no_key_allows_exempt_paths() {
284 assert!(is_exempt("/"));
286 assert!(is_exempt("/api/health"));
287 assert!(is_exempt("/.well-known/agent.json"));
288 }
289
290 #[test]
291 fn no_key_no_connect_info_rejects() {
292 let req = Request::builder()
294 .uri("/api/sessions")
295 .body(Body::empty())
296 .unwrap();
297 let is_loopback = req
298 .extensions()
299 .get::<ConnectInfo<SocketAddr>>()
300 .map(|ci| ci.0.ip().is_loopback())
301 .unwrap_or(false);
302 assert!(!is_loopback, "missing ConnectInfo should default to reject");
303 }
304
305 #[test]
306 fn layer_none_key_creates_middleware() {
307 let layer = ApiKeyLayer::new(None);
308 assert!(layer.key.is_none());
309 }
310
311 #[test]
312 fn layer_some_key_creates_middleware() {
313 let layer = ApiKeyLayer::new(Some("test-layer-key".into()));
314 assert_eq!(layer.key.as_deref(), Some("test-layer-key"));
315 }
316}