1use std::sync::Arc;
21
22use async_trait::async_trait;
23use reqwest::RequestBuilder;
24
25use crate::error::AuthError;
26
27#[async_trait]
36pub trait OutboundAuthProvider: Send + Sync {
37 async fn authorize(
42 &self,
43 request: RequestBuilder,
44 audience: &str,
45 ) -> Result<RequestBuilder, AuthError>;
46}
47
48#[derive(Debug, Default, Clone, Copy)]
53pub struct NoOpOutboundAuthProvider;
54
55#[async_trait]
56impl OutboundAuthProvider for NoOpOutboundAuthProvider {
57 async fn authorize(
58 &self,
59 request: RequestBuilder,
60 _audience: &str,
61 ) -> Result<RequestBuilder, AuthError> {
62 Ok(request)
63 }
64}
65
66#[derive(Debug, Clone)]
74pub struct StaticBearerOutboundAuthProvider {
75 token: String,
76}
77
78impl StaticBearerOutboundAuthProvider {
79 pub fn new(token: impl Into<String>) -> Self {
81 Self {
82 token: token.into(),
83 }
84 }
85}
86
87#[async_trait]
88impl OutboundAuthProvider for StaticBearerOutboundAuthProvider {
89 async fn authorize(
90 &self,
91 request: RequestBuilder,
92 _audience: &str,
93 ) -> Result<RequestBuilder, AuthError> {
94 Ok(request.header("Authorization", format!("Bearer {}", self.token)))
95 }
96}
97
98pub fn provider_from_token(token: Option<&str>) -> Arc<dyn OutboundAuthProvider> {
103 match token.filter(|t| !t.trim().is_empty()) {
104 Some(t) => Arc::new(StaticBearerOutboundAuthProvider::new(t.to_string())),
105 None => Arc::new(NoOpOutboundAuthProvider),
106 }
107}
108
109#[cfg(test)]
110mod tests {
111 use super::*;
112 use reqwest::Client;
113 use wiremock::matchers::{header, method, path};
114 use wiremock::{Mock, MockServer, ResponseTemplate};
115
116 #[tokio::test]
117 async fn noop_provider_does_not_modify_request() {
118 let server = MockServer::start().await;
119 Mock::given(method("POST"))
120 .and(path("/x"))
121 .respond_with(ResponseTemplate::new(200))
122 .mount(&server)
123 .await;
124
125 let client = Client::new();
126 let request = client.post(format!("{}/x", server.uri()));
127 let request = NoOpOutboundAuthProvider
128 .authorize(request, &server.uri())
129 .await
130 .unwrap();
131
132 let response = request.send().await.unwrap();
133 assert!(response.status().is_success());
134
135 let received = &server.received_requests().await.unwrap()[0];
136 assert!(received.headers.get("authorization").is_none());
137 }
138
139 #[tokio::test]
140 async fn static_bearer_provider_adds_authorization_header() {
141 let server = MockServer::start().await;
142 Mock::given(method("POST"))
143 .and(path("/x"))
144 .and(header("Authorization", "Bearer test-token"))
145 .respond_with(ResponseTemplate::new(200))
146 .mount(&server)
147 .await;
148
149 let provider = StaticBearerOutboundAuthProvider::new("test-token");
150 let client = Client::new();
151 let request = client.post(format!("{}/x", server.uri()));
152 let request = provider.authorize(request, &server.uri()).await.unwrap();
153
154 let response = request.send().await.unwrap();
155 assert!(
156 response.status().is_success(),
157 "request reached the matcher with bearer token"
158 );
159 }
160
161 #[tokio::test]
162 async fn static_bearer_appends_alongside_existing_headers() {
163 let server = MockServer::start().await;
164 Mock::given(method("POST"))
165 .and(path("/x"))
166 .and(header("X-Custom", "value"))
167 .and(header("Authorization", "Bearer abc"))
168 .respond_with(ResponseTemplate::new(200))
169 .mount(&server)
170 .await;
171
172 let provider = StaticBearerOutboundAuthProvider::new("abc");
173 let client = Client::new();
174 let request = client
175 .post(format!("{}/x", server.uri()))
176 .header("X-Custom", "value");
177 let request = provider.authorize(request, &server.uri()).await.unwrap();
178
179 let response = request.send().await.unwrap();
180 assert!(response.status().is_success());
181 }
182
183 #[test]
184 fn provider_from_token_returns_noop_when_none() {
185 let provider = provider_from_token(None);
186 let _ = provider; }
190
191 #[tokio::test]
192 async fn provider_from_token_returns_static_when_some() {
193 let server = MockServer::start().await;
194 Mock::given(method("POST"))
195 .and(path("/y"))
196 .and(header("Authorization", "Bearer xyz"))
197 .respond_with(ResponseTemplate::new(200))
198 .mount(&server)
199 .await;
200
201 let provider = provider_from_token(Some("xyz"));
202 let client = Client::new();
203 let request = client.post(format!("{}/y", server.uri()));
204 let request = provider.authorize(request, &server.uri()).await.unwrap();
205
206 assert!(request.send().await.unwrap().status().is_success());
207 }
208
209 #[tokio::test]
210 async fn provider_from_token_treats_empty_string_as_none() {
211 let server = MockServer::start().await;
212 Mock::given(method("POST"))
213 .and(path("/y"))
214 .respond_with(ResponseTemplate::new(200))
215 .mount(&server)
216 .await;
217
218 let provider = provider_from_token(Some(" "));
219 let client = Client::new();
220 let request = client.post(format!("{}/y", server.uri()));
221 let request = provider.authorize(request, &server.uri()).await.unwrap();
222 request.send().await.unwrap();
223
224 let received = &server.received_requests().await.unwrap()[0];
225 assert!(received.headers.get("authorization").is_none());
226 }
227}