Skip to main content

agent_proxy_rust_core/
auth.rs

1//! Authentication middleware for the proxy.
2//!
3//! Supports two modes:
4//! - **Simple mode**: a single `proxy_api_key` or `proxy_token`.
5//! - **Role mapping mode**: multiple keys, each mapped to an agent role.
6//!
7//! The role (if any) is injected into the request extensions as [`AgentRole`]
8//! before the request reaches the handler.
9
10use std::collections::HashMap;
11
12use axum::{extract::State, middleware::Next, response::Response};
13use http::{Request, StatusCode};
14use secrecy::ExposeSecret;
15
16use crate::config::AuthKeyEntry;
17
18/// Wrapper type stored in request extensions to carry the authenticated role.
19#[derive(Debug, Clone)]
20pub struct AgentRole(pub String);
21
22/// Auth configuration extracted from [`ProxyConfig`](crate::config::ProxyConfig).
23///
24/// Used as axum state for the auth middleware.
25#[derive(Debug, Clone)]
26pub struct AuthState {
27    /// Optional simple auth key.
28    pub proxy_api_key: Option<secrecy::SecretString>,
29    /// Optional simple token auth.
30    pub proxy_token: Option<secrecy::SecretString>,
31    /// Role-based auth mapping from API keys to roles.
32    pub proxy_auth_keys: HashMap<String, AuthKeyEntry>,
33}
34
35impl AuthState {
36    /// Creates an [`AuthState`] from a [`ProxyConfig`](crate::config::ProxyConfig).
37    #[must_use]
38    pub fn from_config(config: &crate::config::ProxyConfig) -> Self {
39        Self {
40            proxy_api_key: config.proxy_api_key.clone(),
41            proxy_token: config.proxy_token.clone(),
42            proxy_auth_keys: config.proxy_auth_keys.clone(),
43        }
44    }
45
46    /// Returns `true` if any authentication mechanism is configured.
47    #[must_use]
48    pub fn has_auth(&self) -> bool {
49        self.proxy_api_key.is_some()
50            || self.proxy_token.is_some()
51            || !self.proxy_auth_keys.is_empty()
52    }
53}
54
55/// Axum middleware that authenticates every request.
56///
57/// On success, injects [`AgentRole`] into request extensions (for role mapping mode).
58/// On failure, returns `401 Unauthorized`.
59///
60/// # Errors
61///
62/// Returns `StatusCode::UNAUTHORIZED` if authentication is required and the
63/// request does not provide valid credentials.
64pub async fn auth_middleware(
65    State(auth_state): State<AuthState>,
66    mut req: Request<axum::body::Body>,
67    next: Next,
68) -> Result<Response, StatusCode> {
69    if !auth_state.has_auth() {
70        return Ok(next.run(req).await);
71    }
72
73    // Role mapping mode: check against proxy_auth_keys
74    if !auth_state.proxy_auth_keys.is_empty() {
75        if let Some(entry) =
76            extract_api_key(req.headers()).and_then(|k| auth_state.proxy_auth_keys.get(&k))
77        {
78            req.extensions_mut().insert(AgentRole(entry.role.clone()));
79            return Ok(next.run(req).await);
80        }
81        return Err(StatusCode::UNAUTHORIZED);
82    }
83
84    // Simple mode: check proxy_api_key via Authorization header
85    if let Some(ref expected) = auth_state.proxy_api_key {
86        let provided = req
87            .headers()
88            .get("authorization")
89            .and_then(|v| v.to_str().ok())
90            .and_then(|v| v.strip_prefix("Bearer "));
91        if provided == Some(expected.expose_secret()) {
92            return Ok(next.run(req).await);
93        }
94        return Err(StatusCode::UNAUTHORIZED);
95    }
96
97    // Simple mode: check proxy_token via X-Proxy-Token header
98    if let Some(ref expected) = auth_state.proxy_token {
99        let provided = req
100            .headers()
101            .get("x-proxy-token")
102            .and_then(|v| v.to_str().ok());
103        if provided == Some(expected.expose_secret()) {
104            return Ok(next.run(req).await);
105        }
106        return Err(StatusCode::UNAUTHORIZED);
107    }
108
109    Ok(next.run(req).await)
110}
111
112/// Extracts the API key from request headers.
113///
114/// Checks `x-api-key` first, then `Authorization: Bearer <key>`.
115pub fn extract_api_key(headers: &http::HeaderMap) -> Option<String> {
116    if let Some(key) = headers.get("x-api-key").and_then(|v| v.to_str().ok()) {
117        return Some(key.to_string());
118    }
119    headers
120        .get("authorization")
121        .and_then(|v| v.to_str().ok())
122        .and_then(|v| v.strip_prefix("Bearer "))
123        .map(std::string::ToString::to_string)
124}
125
126#[cfg(test)]
127#[allow(clippy::unwrap_used)]
128mod tests {
129    use http::HeaderMap;
130
131    use super::*;
132
133    #[test]
134    fn test_extract_api_key_x_api_key() {
135        let mut headers = HeaderMap::new();
136        headers.insert("x-api-key", "sk-test-key".parse().unwrap());
137        assert_eq!(extract_api_key(&headers), Some("sk-test-key".into()));
138    }
139
140    #[test]
141    fn test_extract_api_key_bearer() {
142        let mut headers = HeaderMap::new();
143        headers.insert("authorization", "Bearer sk-test-key".parse().unwrap());
144        assert_eq!(extract_api_key(&headers), Some("sk-test-key".into()));
145    }
146
147    #[test]
148    fn test_extract_api_key_x_api_key_priority() {
149        let mut headers = HeaderMap::new();
150        headers.insert("x-api-key", "sk-x-api".parse().unwrap());
151        headers.insert("authorization", "Bearer sk-bearer".parse().unwrap());
152        assert_eq!(extract_api_key(&headers), Some("sk-x-api".into()));
153    }
154
155    #[test]
156    fn test_extract_api_key_none() {
157        let headers = HeaderMap::new();
158        assert_eq!(extract_api_key(&headers), None);
159    }
160
161    #[test]
162    fn test_extract_api_key_malformed_bearer() {
163        let mut headers = HeaderMap::new();
164        headers.insert("authorization", "Basic sk-test".parse().unwrap());
165        assert_eq!(extract_api_key(&headers), None);
166    }
167
168    #[test]
169    fn test_auth_state_has_auth_empty() {
170        let state = AuthState {
171            proxy_api_key: None,
172            proxy_token: None,
173            proxy_auth_keys: HashMap::new(),
174        };
175        assert!(!state.has_auth());
176    }
177
178    #[test]
179    fn test_auth_state_has_auth_with_key() {
180        let state = AuthState {
181            proxy_api_key: Some(secrecy::SecretString::new("sk-test".into())),
182            proxy_token: None,
183            proxy_auth_keys: HashMap::new(),
184        };
185        assert!(state.has_auth());
186    }
187}