1#![doc = include_str!("../examples/set_header.rs")]
13use std::{
18 fmt,
19 task::{Context, Poll},
20};
21
22use http::{HeaderName, HeaderValue};
23use tower_layer::Layer;
24use tower_service::Service;
25
26pub trait MakeHeaderValue<T> {
34 fn make_header_value(&mut self, message: &T) -> Option<HeaderValue>;
36}
37
38impl<F, T> MakeHeaderValue<T> for F
39where
40 F: FnMut(&T) -> Option<HeaderValue>,
41{
42 fn make_header_value(&mut self, message: &T) -> Option<HeaderValue> {
43 self(message)
44 }
45}
46
47impl<T> MakeHeaderValue<T> for HeaderValue {
48 fn make_header_value(&mut self, _message: &T) -> Option<HeaderValue> {
49 Some(self.clone())
50 }
51}
52
53impl<T> MakeHeaderValue<T> for Option<HeaderValue> {
54 fn make_header_value(&mut self, _message: &T) -> Option<HeaderValue> {
55 self.clone()
56 }
57}
58
59#[derive(Debug, Clone, Copy)]
60enum InsertHeaderMode {
61 Override,
62 Append,
63 IfNotPresent,
64}
65
66impl InsertHeaderMode {
67 fn apply<M>(self, header_name: &HeaderName, target: &mut reqwest::Request, make: &mut M)
68 where
69 M: MakeHeaderValue<reqwest::Request>,
70 {
71 match self {
72 InsertHeaderMode::Override => {
73 if let Some(value) = make.make_header_value(target) {
74 target.headers_mut().insert(header_name.clone(), value);
75 }
76 }
77 InsertHeaderMode::IfNotPresent => {
78 if !target.headers().contains_key(header_name)
79 && let Some(value) = make.make_header_value(target)
80 {
81 target.headers_mut().insert(header_name.clone(), value);
82 }
83 }
84 InsertHeaderMode::Append => {
85 if let Some(value) = make.make_header_value(target) {
86 target.headers_mut().append(header_name.clone(), value);
87 }
88 }
89 }
90 }
91}
92
93#[doc = include_str!("../examples/set_header.rs")]
101pub struct SetRequestHeaderLayer<M> {
103 header_name: HeaderName,
104 make: M,
105 mode: InsertHeaderMode,
106}
107
108impl<M> fmt::Debug for SetRequestHeaderLayer<M> {
109 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
110 f.debug_struct("SetRequestHeaderLayer")
111 .field("header_name", &self.header_name)
112 .field("mode", &self.mode)
113 .field("make", &std::any::type_name::<M>())
114 .finish()
115 }
116}
117
118impl<M> SetRequestHeaderLayer<M>
119where
120 M: MakeHeaderValue<reqwest::Request>,
121{
122 pub fn overriding(header_name: HeaderName, make: M) -> Self {
127 Self::new(header_name, make, InsertHeaderMode::Override)
128 }
129
130 pub fn appending(header_name: HeaderName, make: M) -> Self {
135 Self::new(header_name, make, InsertHeaderMode::Append)
136 }
137
138 pub fn if_not_present(header_name: HeaderName, make: M) -> Self {
142 Self::new(header_name, make, InsertHeaderMode::IfNotPresent)
143 }
144
145 fn new(header_name: HeaderName, make: M, mode: InsertHeaderMode) -> Self {
146 Self {
147 header_name,
148 make,
149 mode,
150 }
151 }
152}
153
154impl<S, M> Layer<S> for SetRequestHeaderLayer<M>
155where
156 M: Clone,
157{
158 type Service = SetRequestHeader<S, M>;
159
160 fn layer(&self, inner: S) -> Self::Service {
161 SetRequestHeader {
162 inner,
163 header_name: self.header_name.clone(),
164 make: self.make.clone(),
165 mode: self.mode,
166 }
167 }
168}
169
170impl<M> Clone for SetRequestHeaderLayer<M>
171where
172 M: Clone,
173{
174 fn clone(&self) -> Self {
175 Self {
176 make: self.make.clone(),
177 header_name: self.header_name.clone(),
178 mode: self.mode,
179 }
180 }
181}
182
183#[derive(Clone)]
185pub struct SetRequestHeader<S, M> {
186 inner: S,
187 header_name: HeaderName,
188 make: M,
189 mode: InsertHeaderMode,
190}
191
192impl<S, M> SetRequestHeader<S, M> {
193 pub fn overriding(inner: S, header_name: HeaderName, make: M) -> Self {
198 Self::new(inner, header_name, make, InsertHeaderMode::Override)
199 }
200
201 pub fn appending(inner: S, header_name: HeaderName, make: M) -> Self {
206 Self::new(inner, header_name, make, InsertHeaderMode::Append)
207 }
208
209 pub fn if_not_present(inner: S, header_name: HeaderName, make: M) -> Self {
213 Self::new(inner, header_name, make, InsertHeaderMode::IfNotPresent)
214 }
215
216 fn new(inner: S, header_name: HeaderName, make: M, mode: InsertHeaderMode) -> Self {
217 Self {
218 inner,
219 header_name,
220 make,
221 mode,
222 }
223 }
224}
225
226impl<S, M> fmt::Debug for SetRequestHeader<S, M>
227where
228 S: fmt::Debug,
229{
230 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
231 f.debug_struct("SetRequestHeader")
232 .field("inner", &self.inner)
233 .field("header_name", &self.header_name)
234 .field("mode", &self.mode)
235 .field("make", &std::any::type_name::<M>())
236 .finish()
237 }
238}
239
240impl<S, M> Service<reqwest::Request> for SetRequestHeader<S, M>
241where
242 S: Service<reqwest::Request, Response = reqwest::Response>,
243 M: MakeHeaderValue<reqwest::Request>,
244{
245 type Response = S::Response;
246 type Error = S::Error;
247 type Future = S::Future;
248
249 #[inline]
250 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
251 self.inner.poll_ready(cx)
252 }
253
254 fn call(&mut self, mut req: reqwest::Request) -> Self::Future {
255 self.mode.apply(&self.header_name, &mut req, &mut self.make);
256 self.inner.call(req)
257 }
258}
259
260#[cfg(test)]
261mod tests {
262
263 use http::{HeaderName, HeaderValue};
264 use tower_layer::Layer;
265 use tower_service::Service;
266 use wiremock::{
267 Mock, MockServer, ResponseTemplate,
268 matchers::{method, path},
269 };
270
271 use crate::set_header::SetRequestHeaderLayer;
272
273 #[tokio::test]
274 async fn test_set_headers() -> anyhow::Result<()> {
275 let mock_server = MockServer::start().await;
276 let mock_uri = mock_server.uri();
277
278 let header_name = HeaderName::from_static("x-test-header");
279 let header_value = HeaderValue::from_static("test-value");
280
281 Mock::given(method("GET"))
282 .and(path("/test"))
283 .and(wiremock::matchers::header(&header_name, &header_value))
284 .respond_with(ResponseTemplate::new(200))
285 .mount(&mock_server)
286 .await;
287
288 let uri = format!("{mock_uri}/test");
289 let request = reqwest::Request::new(reqwest::Method::GET, uri.parse()?);
290
291 let client = reqwest::Client::new();
292 let response = client.execute(request.try_clone().unwrap()).await?;
294 assert_eq!(response.status(), 404);
295 let response = SetRequestHeaderLayer::overriding(header_name, header_value)
297 .layer(client)
298 .call(request)
299 .await?;
300 assert_eq!(response.status(), 200);
301
302 Ok(())
303 }
304}