1use axum::{
15 body::{to_bytes, Body},
16 extract::State,
17 http::{HeaderMap, Request, StatusCode, Uri},
18 middleware::Next,
19 response::IntoResponse,
20};
21use jsonwebtoken::{decode, Algorithm, DecodingKey, Validation};
22use serde_json::Value;
23
24use crate::{
25 config::PROXY_AUTH_HEADER,
26 error::{AuthenticationError, ServerError, ValidationError},
27 management::API_KEY_PREFIX,
28 state::AppState,
29 Claims,
30};
31
32pub async fn proxy_middleware(
38 State(_state): State<AppState>,
39 req: Request<Body>,
40 next: Next,
41) -> impl IntoResponse {
42 next.run(req).await
45}
46
47pub fn proxy_uri(original_uri: Uri, namespace: &str, sandbox_name: &str) -> Uri {
49 let target_host = format!("sandbox-{}.{}.internal", sandbox_name, namespace);
57
58 let uri_string = if let Some(path_and_query) = original_uri.path_and_query() {
59 format!("http://{}:{}{}", target_host, 8080, path_and_query)
60 } else {
61 format!("http://{}:{}/", target_host, 8080)
62 };
63
64 uri_string
67 .parse()
68 .unwrap_or_else(|_| "http://localhost:8080/".parse().unwrap())
69}
70
71pub async fn logging_middleware(
73 req: Request<Body>,
74 next: Next,
75) -> Result<impl IntoResponse, (StatusCode, String)> {
76 let method = req.method().clone();
77 let uri = req.uri().clone();
78
79 tracing::info!("Request: {} {}", method, uri);
81
82 let response = next.run(req).await;
84
85 tracing::info!("Response: {} {}: {}", method, uri, response.status());
87
88 Ok(response)
89}
90
91pub async fn auth_middleware(
93 State(state): State<AppState>,
94 req: Request<Body>,
95 next: Next,
96) -> Result<impl IntoResponse, ServerError> {
97 if *state.get_config().get_dev_mode() {
99 return Ok(next.run(req).await);
100 }
101
102 let api_key = extract_api_key_from_headers(req.headers())?;
104
105 let claims = validate_token(&api_key, &state)?;
107
108 if claims.namespace == "*" {
110 return Ok(next.run(req).await);
111 }
112
113 let (parts, body) = req.into_parts();
116
117 let bytes = to_bytes(body, usize::MAX)
119 .await
120 .map_err(|e| ServerError::InternalError(format!("Failed to read request body: {}", e)))?;
121
122 let namespace_from_request = extract_namespace_from_json_rpc(&bytes)?;
124
125 if claims.namespace != namespace_from_request {
127 return Err(ServerError::AuthorizationError(
128 crate::error::AuthorizationError::AccessDenied(format!(
129 "Token does not have access to namespace '{}'",
130 namespace_from_request
131 )),
132 ));
133 }
134
135 let body = Body::from(bytes);
137 let req = Request::from_parts(parts, body);
138
139 Ok(next.run(req).await)
141}
142
143fn extract_namespace_from_json_rpc(bytes: &[u8]) -> Result<String, ServerError> {
149 let json_value: Value = serde_json::from_slice(bytes).map_err(|e| {
151 ServerError::ValidationError(ValidationError::InvalidInput(format!(
152 "Invalid JSON-RPC request: {}",
153 e
154 )))
155 })?;
156
157 let method = json_value
159 .get("method")
160 .and_then(Value::as_str)
161 .unwrap_or("unknown");
162
163 let params = json_value.get("params").ok_or_else(|| {
165 ServerError::ValidationError(ValidationError::InvalidInput(
166 "Missing 'params' field in JSON-RPC request".to_string(),
167 ))
168 })?;
169
170 params
172 .get("namespace")
173 .and_then(Value::as_str)
174 .map(String::from)
175 .ok_or_else(|| {
176 ServerError::ValidationError(ValidationError::InvalidInput(format!(
177 "Missing or invalid 'namespace' in params for method '{}'",
178 method
179 )))
180 })
181}
182
183fn extract_api_key_from_headers(headers: &HeaderMap) -> Result<String, ServerError> {
185 if let Some(auth_header) = headers.get(PROXY_AUTH_HEADER) {
187 let auth_value = auth_header.to_str().map_err(|_| {
188 ServerError::Authentication(AuthenticationError::InvalidCredentials(
189 "Invalid authorization header format".to_string(),
190 ))
191 })?;
192
193 if let Some(token) = auth_value.strip_prefix("Bearer ") {
195 return Ok(token.to_string());
196 }
197
198 return Ok(auth_value.to_string());
200 }
201
202 if let Some(auth_header) = headers.get("Authorization") {
204 let auth_value = auth_header.to_str().map_err(|_| {
205 ServerError::Authentication(AuthenticationError::InvalidCredentials(
206 "Invalid authorization header format".to_string(),
207 ))
208 })?;
209
210 if let Some(token) = auth_value.strip_prefix("Bearer ") {
212 return Ok(token.to_string());
213 }
214
215 return Ok(auth_value.to_string());
217 }
218
219 Err(ServerError::Authentication(
220 AuthenticationError::InvalidCredentials("Missing authorization header".to_string()),
221 ))
222}
223
224fn convert_api_key_to_jwt(api_key: &str) -> Result<String, ServerError> {
226 if !api_key.starts_with(API_KEY_PREFIX) {
228 return Err(ServerError::Authentication(
229 AuthenticationError::InvalidCredentials(
230 "Invalid API key format: missing prefix".to_string(),
231 ),
232 ));
233 }
234
235 let key_without_prefix = &api_key[API_KEY_PREFIX.len()..];
237
238 let parts: Vec<&str> = key_without_prefix.split('.').collect();
240 if parts.len() != 2 {
241 return Err(ServerError::Authentication(
242 AuthenticationError::InvalidCredentials("Invalid API key format".to_string()),
243 ));
244 }
245
246 let header_value = crate::config::DEFAULT_JWT_HEADER.clone();
249 let jwt_header = header_value.as_str();
250 let payload = parts[0];
251 let signature = parts[1];
252
253 Ok(format!("{}.{}.{}", jwt_header, payload, signature))
254}
255
256fn get_server_key(state: &AppState) -> Result<String, ServerError> {
258 match state.get_config().get_key() {
261 Some(key) => Ok(key.clone()),
262 None => Err(ServerError::Authentication(
263 AuthenticationError::InvalidCredentials(
264 "Server key not found in configuration".to_string(),
265 ),
266 )),
267 }
268}
269
270fn validate_token(api_key: &str, state: &AppState) -> Result<Claims, ServerError> {
272 let jwt = convert_api_key_to_jwt(api_key)?;
274
275 let server_key = get_server_key(state)?;
277
278 let token_data = decode::<Claims>(
280 &jwt,
281 &DecodingKey::from_secret(server_key.as_bytes()),
282 &Validation::new(Algorithm::HS256),
283 )
284 .map_err(|e| {
285 let error_message = match e.kind() {
286 jsonwebtoken::errors::ErrorKind::ExpiredSignature => "Token expired".to_string(),
287 jsonwebtoken::errors::ErrorKind::InvalidSignature => {
288 "Invalid token signature".to_string()
289 }
290 _ => format!("Token validation error: {}", e),
291 };
292 ServerError::Authentication(AuthenticationError::InvalidToken(error_message))
293 })?;
294
295 Ok(token_data.claims)
296}
297
298pub fn validate_token_and_namespace(
300 api_key: &str,
301 requested_namespace: &str,
302 state: &AppState,
303) -> Result<Claims, ServerError> {
304 let claims = validate_token(api_key, state)?;
306
307 if claims.namespace != requested_namespace && claims.namespace != "*" {
309 return Err(ServerError::Authentication(
310 AuthenticationError::InvalidCredentials(format!(
311 "Token does not have access to namespace '{}'",
312 requested_namespace
313 )),
314 ));
315 }
316
317 Ok(claims)
318}