agent_proxy_rust_core/
auth.rs1use std::collections::HashMap;
11
12use axum::{extract::State, middleware::Next, response::Response};
13use http::{Request, StatusCode};
14use secrecy::ExposeSecret;
15
16use crate::config::AuthKeyEntry;
17
18#[derive(Debug, Clone)]
20pub struct AgentRole(pub String);
21
22#[derive(Debug, Clone)]
26pub struct AuthState {
27 pub proxy_api_key: Option<secrecy::SecretString>,
29 pub proxy_token: Option<secrecy::SecretString>,
31 pub proxy_auth_keys: HashMap<String, AuthKeyEntry>,
33}
34
35impl AuthState {
36 #[must_use]
38 pub fn from_config(config: &crate::config::ProxyConfig) -> Self {
39 Self {
40 proxy_api_key: config.proxy_api_key.clone(),
41 proxy_token: config.proxy_token.clone(),
42 proxy_auth_keys: config.proxy_auth_keys.clone(),
43 }
44 }
45
46 #[must_use]
48 pub fn has_auth(&self) -> bool {
49 self.proxy_api_key.is_some()
50 || self.proxy_token.is_some()
51 || !self.proxy_auth_keys.is_empty()
52 }
53}
54
55pub async fn auth_middleware(
65 State(auth_state): State<AuthState>,
66 mut req: Request<axum::body::Body>,
67 next: Next,
68) -> Result<Response, StatusCode> {
69 if !auth_state.has_auth() {
70 return Ok(next.run(req).await);
71 }
72
73 if !auth_state.proxy_auth_keys.is_empty() {
75 if let Some(entry) =
76 extract_api_key(req.headers()).and_then(|k| auth_state.proxy_auth_keys.get(&k))
77 {
78 req.extensions_mut().insert(AgentRole(entry.role.clone()));
79 return Ok(next.run(req).await);
80 }
81 return Err(StatusCode::UNAUTHORIZED);
82 }
83
84 if let Some(ref expected) = auth_state.proxy_api_key {
86 let provided = req
87 .headers()
88 .get("authorization")
89 .and_then(|v| v.to_str().ok())
90 .and_then(|v| v.strip_prefix("Bearer "));
91 if provided == Some(expected.expose_secret()) {
92 return Ok(next.run(req).await);
93 }
94 return Err(StatusCode::UNAUTHORIZED);
95 }
96
97 if let Some(ref expected) = auth_state.proxy_token {
99 let provided = req
100 .headers()
101 .get("x-proxy-token")
102 .and_then(|v| v.to_str().ok());
103 if provided == Some(expected.expose_secret()) {
104 return Ok(next.run(req).await);
105 }
106 return Err(StatusCode::UNAUTHORIZED);
107 }
108
109 Ok(next.run(req).await)
110}
111
112pub fn extract_api_key(headers: &http::HeaderMap) -> Option<String> {
116 if let Some(key) = headers.get("x-api-key").and_then(|v| v.to_str().ok()) {
117 return Some(key.to_string());
118 }
119 headers
120 .get("authorization")
121 .and_then(|v| v.to_str().ok())
122 .and_then(|v| v.strip_prefix("Bearer "))
123 .map(std::string::ToString::to_string)
124}
125
126#[cfg(test)]
127#[allow(clippy::unwrap_used)]
128mod tests {
129 use http::HeaderMap;
130
131 use super::*;
132
133 #[test]
134 fn test_extract_api_key_x_api_key() {
135 let mut headers = HeaderMap::new();
136 headers.insert("x-api-key", "sk-test-key".parse().unwrap());
137 assert_eq!(extract_api_key(&headers), Some("sk-test-key".into()));
138 }
139
140 #[test]
141 fn test_extract_api_key_bearer() {
142 let mut headers = HeaderMap::new();
143 headers.insert("authorization", "Bearer sk-test-key".parse().unwrap());
144 assert_eq!(extract_api_key(&headers), Some("sk-test-key".into()));
145 }
146
147 #[test]
148 fn test_extract_api_key_x_api_key_priority() {
149 let mut headers = HeaderMap::new();
150 headers.insert("x-api-key", "sk-x-api".parse().unwrap());
151 headers.insert("authorization", "Bearer sk-bearer".parse().unwrap());
152 assert_eq!(extract_api_key(&headers), Some("sk-x-api".into()));
153 }
154
155 #[test]
156 fn test_extract_api_key_none() {
157 let headers = HeaderMap::new();
158 assert_eq!(extract_api_key(&headers), None);
159 }
160
161 #[test]
162 fn test_extract_api_key_malformed_bearer() {
163 let mut headers = HeaderMap::new();
164 headers.insert("authorization", "Basic sk-test".parse().unwrap());
165 assert_eq!(extract_api_key(&headers), None);
166 }
167
168 #[test]
169 fn test_auth_state_has_auth_empty() {
170 let state = AuthState {
171 proxy_api_key: None,
172 proxy_token: None,
173 proxy_auth_keys: HashMap::new(),
174 };
175 assert!(!state.has_auth());
176 }
177
178 #[test]
179 fn test_auth_state_has_auth_with_key() {
180 let state = AuthState {
181 proxy_api_key: Some(secrecy::SecretString::new("sk-test".into())),
182 proxy_token: None,
183 proxy_auth_keys: HashMap::new(),
184 };
185 assert!(state.has_auth());
186 }
187}