brainwires_proxy/middleware/
auth.rs1use crate::error::ProxyResult;
4use crate::middleware::{LayerAction, ProxyLayer};
5use crate::types::{ProxyRequest, ProxyResponse};
6use http::StatusCode;
7use http::header::{AUTHORIZATION, HeaderValue};
8
9pub enum AuthStrategy {
11 StaticBearer(String),
13 Passthrough,
15 Validate(String),
17 Strip,
19}
20
21pub struct AuthLayer {
23 strategy: AuthStrategy,
24}
25
26impl AuthLayer {
27 pub fn new(strategy: AuthStrategy) -> Self {
28 Self { strategy }
29 }
30
31 pub fn static_bearer(token: impl Into<String>) -> Self {
33 Self::new(AuthStrategy::StaticBearer(token.into()))
34 }
35
36 pub fn passthrough() -> Self {
38 Self::new(AuthStrategy::Passthrough)
39 }
40
41 pub fn validate(expected: impl Into<String>) -> Self {
43 Self::new(AuthStrategy::Validate(expected.into()))
44 }
45
46 pub fn strip() -> Self {
48 Self::new(AuthStrategy::Strip)
49 }
50}
51
52#[async_trait::async_trait]
53impl ProxyLayer for AuthLayer {
54 async fn on_request(&self, mut request: ProxyRequest) -> ProxyResult<LayerAction> {
55 match &self.strategy {
56 AuthStrategy::StaticBearer(token) => {
57 let value = HeaderValue::from_str(&format!("Bearer {token}"))
58 .map_err(|e| crate::error::ProxyError::Config(e.to_string()))?;
59 request.headers.insert(AUTHORIZATION, value);
60 Ok(LayerAction::Forward(request))
61 }
62 AuthStrategy::Passthrough => Ok(LayerAction::Forward(request)),
63 AuthStrategy::Validate(expected) => {
64 let expected_val = format!("Bearer {expected}");
65 match request.headers.get(AUTHORIZATION) {
66 Some(val) if val.as_bytes() == expected_val.as_bytes() => {
67 Ok(LayerAction::Forward(request))
68 }
69 _ => {
70 tracing::warn!(request_id = %request.id, "auth validation failed");
71 Ok(LayerAction::Respond(
72 ProxyResponse::for_request(request.id, StatusCode::UNAUTHORIZED)
73 .with_body("Unauthorized"),
74 ))
75 }
76 }
77 }
78 AuthStrategy::Strip => {
79 request.headers.remove(AUTHORIZATION);
80 Ok(LayerAction::Forward(request))
81 }
82 }
83 }
84
85 fn name(&self) -> &str {
86 "auth"
87 }
88}
89
90#[cfg(test)]
91mod tests {
92 use super::*;
93 use http::Method;
94
95 fn make_request() -> ProxyRequest {
96 ProxyRequest::new(Method::GET, "/api".parse().unwrap())
97 }
98
99 fn make_request_with_auth(token: &str) -> ProxyRequest {
100 let mut req = make_request();
101 req.headers.insert(
102 AUTHORIZATION,
103 HeaderValue::from_str(&format!("Bearer {token}")).unwrap(),
104 );
105 req
106 }
107
108 #[tokio::test]
109 async fn static_bearer_injects_token() {
110 let layer = AuthLayer::static_bearer("sk-test-123");
111 let result = layer.on_request(make_request()).await.unwrap();
112 match result {
113 LayerAction::Forward(req) => {
114 assert_eq!(
115 req.headers.get(AUTHORIZATION).unwrap(),
116 "Bearer sk-test-123"
117 );
118 }
119 _ => panic!("expected forward"),
120 }
121 }
122
123 #[tokio::test]
124 async fn passthrough_preserves_header() {
125 let layer = AuthLayer::passthrough();
126 let req = make_request_with_auth("my-token");
127 let result = layer.on_request(req).await.unwrap();
128 match result {
129 LayerAction::Forward(req) => {
130 assert_eq!(req.headers.get(AUTHORIZATION).unwrap(), "Bearer my-token");
131 }
132 _ => panic!("expected forward"),
133 }
134 }
135
136 #[tokio::test]
137 async fn validate_accepts_correct_token() {
138 let layer = AuthLayer::validate("valid-token");
139 let req = make_request_with_auth("valid-token");
140 let result = layer.on_request(req).await.unwrap();
141 assert!(matches!(result, LayerAction::Forward(_)));
142 }
143
144 #[tokio::test]
145 async fn validate_rejects_wrong_token() {
146 let layer = AuthLayer::validate("valid-token");
147 let req = make_request_with_auth("wrong-token");
148 let result = layer.on_request(req).await.unwrap();
149 match result {
150 LayerAction::Respond(resp) => {
151 assert_eq!(resp.status, StatusCode::UNAUTHORIZED);
152 }
153 _ => panic!("expected reject"),
154 }
155 }
156
157 #[tokio::test]
158 async fn validate_rejects_missing_header() {
159 let layer = AuthLayer::validate("valid-token");
160 let result = layer.on_request(make_request()).await.unwrap();
161 match result {
162 LayerAction::Respond(resp) => {
163 assert_eq!(resp.status, StatusCode::UNAUTHORIZED);
164 }
165 _ => panic!("expected reject"),
166 }
167 }
168
169 #[tokio::test]
170 async fn strip_removes_auth_header() {
171 let layer = AuthLayer::strip();
172 let req = make_request_with_auth("remove-me");
173 let result = layer.on_request(req).await.unwrap();
174 match result {
175 LayerAction::Forward(req) => {
176 assert!(req.headers.get(AUTHORIZATION).is_none());
177 }
178 _ => panic!("expected forward"),
179 }
180 }
181}