Skip to main content

brainwires_proxy/middleware/
auth.rs

1//! Auth token forwarding/validation middleware.
2
3use crate::error::ProxyResult;
4use crate::middleware::{LayerAction, ProxyLayer};
5use crate::types::{ProxyRequest, ProxyResponse};
6use http::StatusCode;
7use http::header::{AUTHORIZATION, HeaderValue};
8
9/// Strategy for handling authentication tokens.
10pub enum AuthStrategy {
11    /// Forward a static bearer token to upstream.
12    StaticBearer(String),
13    /// Pass through the client's Authorization header unchanged.
14    Passthrough,
15    /// Require a specific bearer token from the client; reject mismatches.
16    Validate(String),
17    /// Strip the Authorization header before forwarding.
18    Strip,
19}
20
21/// Auth middleware that manages Authorization headers.
22pub struct AuthLayer {
23    strategy: AuthStrategy,
24}
25
26impl AuthLayer {
27    pub fn new(strategy: AuthStrategy) -> Self {
28        Self { strategy }
29    }
30
31    /// Create an auth layer that injects a static bearer token.
32    pub fn static_bearer(token: impl Into<String>) -> Self {
33        Self::new(AuthStrategy::StaticBearer(token.into()))
34    }
35
36    /// Create an auth layer that passes through client auth.
37    pub fn passthrough() -> Self {
38        Self::new(AuthStrategy::Passthrough)
39    }
40
41    /// Create an auth layer that validates a required token.
42    pub fn validate(expected: impl Into<String>) -> Self {
43        Self::new(AuthStrategy::Validate(expected.into()))
44    }
45
46    /// Create an auth layer that strips auth headers.
47    pub fn strip() -> Self {
48        Self::new(AuthStrategy::Strip)
49    }
50}
51
52#[async_trait::async_trait]
53impl ProxyLayer for AuthLayer {
54    async fn on_request(&self, mut request: ProxyRequest) -> ProxyResult<LayerAction> {
55        match &self.strategy {
56            AuthStrategy::StaticBearer(token) => {
57                let value = HeaderValue::from_str(&format!("Bearer {token}"))
58                    .map_err(|e| crate::error::ProxyError::Config(e.to_string()))?;
59                request.headers.insert(AUTHORIZATION, value);
60                Ok(LayerAction::Forward(request))
61            }
62            AuthStrategy::Passthrough => Ok(LayerAction::Forward(request)),
63            AuthStrategy::Validate(expected) => {
64                let expected_val = format!("Bearer {expected}");
65                match request.headers.get(AUTHORIZATION) {
66                    Some(val) if val.as_bytes() == expected_val.as_bytes() => {
67                        Ok(LayerAction::Forward(request))
68                    }
69                    _ => {
70                        tracing::warn!(request_id = %request.id, "auth validation failed");
71                        Ok(LayerAction::Respond(
72                            ProxyResponse::for_request(request.id, StatusCode::UNAUTHORIZED)
73                                .with_body("Unauthorized"),
74                        ))
75                    }
76                }
77            }
78            AuthStrategy::Strip => {
79                request.headers.remove(AUTHORIZATION);
80                Ok(LayerAction::Forward(request))
81            }
82        }
83    }
84
85    fn name(&self) -> &str {
86        "auth"
87    }
88}
89
90#[cfg(test)]
91mod tests {
92    use super::*;
93    use http::Method;
94
95    fn make_request() -> ProxyRequest {
96        ProxyRequest::new(Method::GET, "/api".parse().unwrap())
97    }
98
99    fn make_request_with_auth(token: &str) -> ProxyRequest {
100        let mut req = make_request();
101        req.headers.insert(
102            AUTHORIZATION,
103            HeaderValue::from_str(&format!("Bearer {token}")).unwrap(),
104        );
105        req
106    }
107
108    #[tokio::test]
109    async fn static_bearer_injects_token() {
110        let layer = AuthLayer::static_bearer("sk-test-123");
111        let result = layer.on_request(make_request()).await.unwrap();
112        match result {
113            LayerAction::Forward(req) => {
114                assert_eq!(
115                    req.headers.get(AUTHORIZATION).unwrap(),
116                    "Bearer sk-test-123"
117                );
118            }
119            _ => panic!("expected forward"),
120        }
121    }
122
123    #[tokio::test]
124    async fn passthrough_preserves_header() {
125        let layer = AuthLayer::passthrough();
126        let req = make_request_with_auth("my-token");
127        let result = layer.on_request(req).await.unwrap();
128        match result {
129            LayerAction::Forward(req) => {
130                assert_eq!(req.headers.get(AUTHORIZATION).unwrap(), "Bearer my-token");
131            }
132            _ => panic!("expected forward"),
133        }
134    }
135
136    #[tokio::test]
137    async fn validate_accepts_correct_token() {
138        let layer = AuthLayer::validate("valid-token");
139        let req = make_request_with_auth("valid-token");
140        let result = layer.on_request(req).await.unwrap();
141        assert!(matches!(result, LayerAction::Forward(_)));
142    }
143
144    #[tokio::test]
145    async fn validate_rejects_wrong_token() {
146        let layer = AuthLayer::validate("valid-token");
147        let req = make_request_with_auth("wrong-token");
148        let result = layer.on_request(req).await.unwrap();
149        match result {
150            LayerAction::Respond(resp) => {
151                assert_eq!(resp.status, StatusCode::UNAUTHORIZED);
152            }
153            _ => panic!("expected reject"),
154        }
155    }
156
157    #[tokio::test]
158    async fn validate_rejects_missing_header() {
159        let layer = AuthLayer::validate("valid-token");
160        let result = layer.on_request(make_request()).await.unwrap();
161        match result {
162            LayerAction::Respond(resp) => {
163                assert_eq!(resp.status, StatusCode::UNAUTHORIZED);
164            }
165            _ => panic!("expected reject"),
166        }
167    }
168
169    #[tokio::test]
170    async fn strip_removes_auth_header() {
171        let layer = AuthLayer::strip();
172        let req = make_request_with_auth("remove-me");
173        let result = layer.on_request(req).await.unwrap();
174        match result {
175            LayerAction::Forward(req) => {
176                assert!(req.headers.get(AUTHORIZATION).is_none());
177            }
178            _ => panic!("expected forward"),
179        }
180    }
181}