Skip to main content

roboticus_api/
auth.rs

1use std::net::SocketAddr;
2use std::sync::Arc;
3use std::task::{Context, Poll};
4
5use subtle::ConstantTimeEq;
6
7use axum::body::Body;
8use axum::extract::connect_info::ConnectInfo;
9use axum::http::{Request, Response, StatusCode};
10use futures_util::future::BoxFuture;
11use tower::{Layer, Service};
12use tracing::warn;
13
14#[derive(Clone)]
15pub struct ApiKeyLayer {
16    key: Option<Arc<str>>,
17}
18
19impl ApiKeyLayer {
20    pub fn new(key: Option<String>) -> Self {
21        Self {
22            key: key.map(|k| Arc::from(k.as_str())),
23        }
24    }
25}
26
27impl<S> Layer<S> for ApiKeyLayer {
28    type Service = ApiKeyMiddleware<S>;
29
30    fn layer(&self, inner: S) -> Self::Service {
31        ApiKeyMiddleware {
32            inner,
33            key: self.key.clone(),
34        }
35    }
36}
37
38#[derive(Clone)]
39pub struct ApiKeyMiddleware<S> {
40    inner: S,
41    key: Option<Arc<str>>,
42}
43
44/// Returns `true` for paths that must be reachable without an API key.
45///
46/// - `/` and `/api/health` -- uptime probes; read-only, no side-effects.
47/// - `/api/webhooks/*` -- inbound from Telegram/WhatsApp; these services
48///   cannot supply our API key, so they authenticate via HMAC or provider
49///   token validation inside the handler itself.
50/// - `/.well-known/agent.json` -- public A2A agent-card discovery.
51///
52/// NOTE: All exempt paths are still subject to the global and per-IP rate
53/// limiter. If you add a new exempt path, ensure it cannot be abused to
54/// amplify work (e.g. trigger LLM calls) without its own auth check.
55fn is_exempt(path: &str) -> bool {
56    path == "/"
57        || path == "/api/health"
58        || path == "/api/webhooks/telegram"
59        || path == "/api/webhooks/whatsapp"
60        || path == "/.well-known/agent.json"
61}
62
63fn extract_api_key(req: &Request<Body>) -> Option<String> {
64    if let Some(val) = req.headers().get("x-api-key")
65        && let Ok(s) = val.to_str()
66    {
67        return Some(s.to_string());
68    }
69    if let Some(val) = req.headers().get("authorization")
70        && let Ok(s) = val.to_str()
71        && let Some(token) = s.strip_prefix("Bearer ")
72    {
73        return Some(token.to_string());
74    }
75    // S-HIGH-2: query-string ?token= removed — use POST /api/ws-ticket
76    // for short-lived, single-use tickets instead.
77    None
78}
79
80pub(crate) fn extract_auth_principal(req: &Request<Body>) -> Option<String> {
81    if req.headers().contains_key("x-api-key") {
82        return Some("api_key".to_string());
83    }
84    if let Some(val) = req.headers().get("authorization")
85        && let Ok(s) = val.to_str()
86        && s.starts_with("Bearer ")
87    {
88        return Some("bearer".to_string());
89    }
90    None
91}
92
93fn unauthorized_response() -> Response<Body> {
94    let body = serde_json::json!({"error": "unauthorized", "message": "Valid API key required"});
95    let bytes = serde_json::to_vec(&body).unwrap_or_else(|_| {
96        br#"{"error":"unauthorized","message":"Valid API key required"}"#.to_vec()
97    });
98    let mut response = Response::new(Body::from(bytes));
99    *response.status_mut() = StatusCode::UNAUTHORIZED;
100    response.headers_mut().insert(
101        axum::http::header::CONTENT_TYPE,
102        axum::http::HeaderValue::from_static("application/json"),
103    );
104    response
105}
106
107impl<S> Service<Request<Body>> for ApiKeyMiddleware<S>
108where
109    S: Service<Request<Body>, Response = Response<Body>> + Send + Clone + 'static,
110    S::Future: Send + 'static,
111{
112    type Response = S::Response;
113    type Error = S::Error;
114    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
115
116    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
117        self.inner.poll_ready(cx)
118    }
119
120    fn call(&mut self, req: Request<Body>) -> Self::Future {
121        let key = self.key.clone();
122        let mut inner = self.inner.clone();
123
124        Box::pin(async move {
125            let path = req.uri().path().to_string();
126            if let Some(ref expected) = key {
127                if !is_exempt(&path) {
128                    match extract_api_key(&req) {
129                        Some(provided)
130                            if bool::from(provided.as_bytes().ct_eq(expected.as_bytes())) => {}
131                        _ => return Ok(unauthorized_response()),
132                    }
133                }
134            } else if !is_exempt(&path) {
135                // No API key configured — restrict to loopback addresses only.
136                let is_loopback = req
137                    .extensions()
138                    .get::<ConnectInfo<SocketAddr>>()
139                    .map(|ci| ci.0.ip().is_loopback())
140                    .unwrap_or(false);
141                if !is_loopback {
142                    warn!(
143                        path = %path,
144                        "rejected non-loopback request: no API key configured — set server.api_key"
145                    );
146                    return Ok(unauthorized_response());
147                }
148            }
149            inner.call(req).await
150        })
151    }
152}
153
154#[cfg(test)]
155mod tests {
156    use super::*;
157
158    #[test]
159    fn exempt_paths() {
160        assert!(is_exempt("/"));
161        assert!(!is_exempt("/ws"));
162        assert!(is_exempt("/api/health"));
163        assert!(is_exempt("/api/webhooks/telegram"));
164        assert!(is_exempt("/api/webhooks/whatsapp"));
165        assert!(!is_exempt("/api/config"));
166        assert!(!is_exempt("/api/sessions"));
167        assert!(!is_exempt("/api/agent/message"));
168    }
169
170    #[test]
171    fn extract_bearer_token() {
172        let req = Request::builder()
173            .header("authorization", "Bearer test-key-123")
174            .body(Body::empty())
175            .unwrap();
176        assert_eq!(extract_api_key(&req).as_deref(), Some("test-key-123"));
177    }
178
179    #[test]
180    fn extract_x_api_key_header() {
181        let req = Request::builder()
182            .header("x-api-key", "test-key-789")
183            .body(Body::empty())
184            .unwrap();
185        assert_eq!(extract_api_key(&req).as_deref(), Some("test-key-789"));
186    }
187
188    #[test]
189    fn query_token_no_longer_accepted() {
190        // S-HIGH-2: ?token= in URL is removed — tickets replace it
191        let req = Request::builder()
192            .uri("/ws?token=query-key-456")
193            .body(Body::empty())
194            .unwrap();
195        assert_eq!(extract_api_key(&req), None);
196    }
197
198    #[test]
199    fn query_token_not_accepted_for_non_ws_paths() {
200        let req = Request::builder()
201            .uri("/api/health?token=query-key-456")
202            .body(Body::empty())
203            .unwrap();
204        assert_eq!(extract_api_key(&req), None);
205    }
206
207    #[test]
208    fn no_key_returns_none() {
209        let req = Request::builder().body(Body::empty()).unwrap();
210        assert_eq!(extract_api_key(&req), None);
211    }
212
213    #[test]
214    fn x_api_key_takes_precedence() {
215        let req = Request::builder()
216            .header("x-api-key", "header-key")
217            .header("authorization", "Bearer bearer-key")
218            .body(Body::empty())
219            .unwrap();
220        assert_eq!(extract_api_key(&req).as_deref(), Some("header-key"));
221    }
222
223    #[test]
224    fn extract_auth_principal_prefers_api_key() {
225        let req = Request::builder()
226            .header("x-api-key", "abc")
227            .header("authorization", "Bearer token")
228            .body(Body::empty())
229            .unwrap();
230        assert_eq!(extract_auth_principal(&req).as_deref(), Some("api_key"));
231    }
232
233    #[test]
234    fn extract_auth_principal_bearer() {
235        let req = Request::builder()
236            .header("authorization", "Bearer token")
237            .body(Body::empty())
238            .unwrap();
239        assert_eq!(extract_auth_principal(&req).as_deref(), Some("bearer"));
240    }
241
242    #[test]
243    fn unknown_webhook_not_exempt() {
244        assert!(!is_exempt("/api/webhooks/unknown"));
245        assert!(!is_exempt("/api/webhooks/"));
246    }
247
248    #[test]
249    fn no_key_rejects_non_loopback() {
250        // Without ConnectInfo in extensions, the middleware treats it as non-loopback
251        let mut req = Request::builder()
252            .uri("/api/sessions")
253            .body(Body::empty())
254            .unwrap();
255        // Insert a non-loopback ConnectInfo
256        let addr: SocketAddr = "192.168.1.5:12345".parse().unwrap();
257        req.extensions_mut().insert(ConnectInfo(addr));
258        let is_loopback = req
259            .extensions()
260            .get::<ConnectInfo<SocketAddr>>()
261            .map(|ci| ci.0.ip().is_loopback())
262            .unwrap_or(false);
263        assert!(!is_loopback);
264    }
265
266    #[test]
267    fn no_key_allows_loopback() {
268        let mut req = Request::builder()
269            .uri("/api/sessions")
270            .body(Body::empty())
271            .unwrap();
272        let addr: SocketAddr = "127.0.0.1:12345".parse().unwrap();
273        req.extensions_mut().insert(ConnectInfo(addr));
274        let is_loopback = req
275            .extensions()
276            .get::<ConnectInfo<SocketAddr>>()
277            .map(|ci| ci.0.ip().is_loopback())
278            .unwrap_or(false);
279        assert!(is_loopback);
280    }
281
282    #[test]
283    fn no_key_allows_exempt_paths() {
284        // Exempt paths should be allowed regardless of loopback status
285        assert!(is_exempt("/"));
286        assert!(is_exempt("/api/health"));
287        assert!(is_exempt("/.well-known/agent.json"));
288    }
289
290    #[test]
291    fn no_key_no_connect_info_rejects() {
292        // Without ConnectInfo extension, default is non-loopback (fail closed)
293        let req = Request::builder()
294            .uri("/api/sessions")
295            .body(Body::empty())
296            .unwrap();
297        let is_loopback = req
298            .extensions()
299            .get::<ConnectInfo<SocketAddr>>()
300            .map(|ci| ci.0.ip().is_loopback())
301            .unwrap_or(false);
302        assert!(!is_loopback, "missing ConnectInfo should default to reject");
303    }
304
305    #[test]
306    fn layer_none_key_creates_middleware() {
307        let layer = ApiKeyLayer::new(None);
308        assert!(layer.key.is_none());
309    }
310
311    #[test]
312    fn layer_some_key_creates_middleware() {
313        let layer = ApiKeyLayer::new(Some("test-layer-key".into()));
314        assert_eq!(layer.key.as_deref(), Some("test-layer-key"));
315    }
316}