1use std::sync::Arc;
2
3use http::{HeaderName, HeaderValue, Request, Response, header::USER_AGENT};
4use hyper::body::{Body, Incoming};
5
6use crate::Error;
7
8const DEFAULT_USER_AGENT: &str = concat!("tailscale-rs/", env!("CARGO_PKG_VERSION"));
14
15fn origin_form_target(url: &url::Url) -> String {
23 match url.query() {
24 Some(query) => format!("{}?{}", url.path(), query),
25 None => url.path().to_owned(),
26 }
27}
28
29pub trait Client<B>
36where
37 B: Body + Send + 'static,
38 <B as Body>::Data: Send,
39 B::Error: Send + Sync + 'static,
40{
41 fn send(
47 &self,
48 req: Request<B>,
49 ) -> impl Future<Output = Result<Response<Incoming>, Error>> + Send;
50}
51
52pub trait ClientExt<B>: Client<B>
55where
56 B: Body + Send + 'static,
57 <B as Body>::Data: Send,
58 B::Error: Send + Sync + 'static,
59{
60 fn get(
66 &self,
67 url: &url::Url,
68 headers: impl IntoIterator<Item = (HeaderName, HeaderValue)>,
69 ) -> impl Future<Output = Result<Response<Incoming>, Error>>
70 where
71 B: Default,
72 {
73 let mut req = Request::get(origin_form_target(url));
74
75 if let Some(hdrs) = req.headers_mut() {
76 hdrs.append(USER_AGENT, HeaderValue::from_static(DEFAULT_USER_AGENT));
77 hdrs.extend(crate::host_header(url));
78 hdrs.extend(headers);
79 }
80
81 async move {
82 let req = req.body(Default::default()).map_err(|e| {
83 tracing::error!(error = %e, "constructing request");
84 Error::InvalidInput
85 })?;
86
87 self.send(req).await
88 }
89 }
90
91 fn post(
97 &self,
98 url: &url::Url,
99 headers: impl IntoIterator<Item = (HeaderName, HeaderValue)>,
100 body: B,
101 ) -> impl Future<Output = Result<Response<Incoming>, Error>> {
102 let mut req = Request::post(origin_form_target(url));
103
104 if let Some(hdrs) = req.headers_mut() {
105 hdrs.append(USER_AGENT, HeaderValue::from_static(DEFAULT_USER_AGENT));
106 hdrs.extend(crate::host_header(url));
107 hdrs.extend(headers);
108 }
109
110 async move {
111 let req = req.body(body).map_err(|e| {
112 tracing::error!(error = %e, "constructing request");
113 Error::InvalidInput
114 })?;
115
116 self.send(req).await
117 }
118 }
119
120 fn get_with_body(
128 &self,
129 url: &url::Url,
130 headers: impl IntoIterator<Item = (HeaderName, HeaderValue)>,
131 body: B,
132 ) -> impl Future<Output = Result<Response<Incoming>, Error>> {
133 let mut req = Request::get(origin_form_target(url));
134
135 if let Some(hdrs) = req.headers_mut() {
136 hdrs.append(USER_AGENT, HeaderValue::from_static(DEFAULT_USER_AGENT));
137 hdrs.extend(crate::host_header(url));
138 hdrs.extend(headers);
139 }
140
141 async move {
142 let req = req.body(body).map_err(|e| {
143 tracing::error!(error = %e, "constructing request");
144 Error::InvalidInput
145 })?;
146
147 self.send(req).await
148 }
149 }
150}
151
152impl<T, B> ClientExt<B> for T
153where
154 T: Client<B>,
155 B: Body + Send + 'static,
156 <B as Body>::Data: Send,
157 B::Error: Send + Sync + 'static,
158{
159}
160
161impl<T, B> Client<B> for Arc<T>
162where
163 T: Client<B>,
164 B: Body + Send + 'static,
165 <B as Body>::Data: Send,
166 B::Error: Send + Sync + 'static,
167{
168 fn send(
169 &self,
170 req: Request<B>,
171 ) -> impl Future<Output = Result<Response<Incoming>, Error>> + Send {
172 self.as_ref().send(req)
173 }
174}
175
176impl<T, B> Client<B> for &T
177where
178 T: Client<B>,
179 B: Body + Send + 'static,
180 <B as Body>::Data: Send,
181 B::Error: Send + Sync + 'static,
182{
183 fn send(
184 &self,
185 req: Request<B>,
186 ) -> impl Future<Output = Result<Response<Incoming>, Error>> + Send {
187 (**self).send(req)
188 }
189}
190
191impl<T, B> Client<B> for &mut T
192where
193 T: Client<B>,
194 B: Body + Send + 'static,
195 <B as Body>::Data: Send,
196 B::Error: Send + Sync + 'static,
197{
198 fn send(
199 &self,
200 req: Request<B>,
201 ) -> impl Future<Output = Result<Response<Incoming>, Error>> + Send {
202 (**self).send(req)
203 }
204}
205
206#[cfg(test)]
207mod tests {
208 use std::{
209 cell::RefCell,
210 pin::pin,
211 rc::Rc,
212 task::{Context, Poll, Waker},
213 };
214
215 use bytes::Bytes;
216 use http_body_util::Empty;
217
218 use super::*;
219
220 fn url(s: &str) -> url::Url {
221 url::Url::parse(s).unwrap()
222 }
223
224 #[test]
225 fn origin_form_target_no_query() {
226 assert_eq!(origin_form_target(&url("https://h/dir")), "/dir");
227 }
228
229 #[test]
230 fn origin_form_target_with_query() {
231 assert_eq!(
232 origin_form_target(&url("https://h/dir?x=1&y=2")),
233 "/dir?x=1&y=2"
234 );
235 }
236
237 #[test]
238 fn origin_form_target_root_path() {
239 assert_eq!(origin_form_target(&url("https://h")), "/");
241 }
242
243 #[test]
244 fn origin_form_target_excludes_fragment() {
245 assert_eq!(origin_form_target(&url("https://h/p#frag")), "/p");
247 }
248
249 #[test]
250 fn origin_form_target_is_never_absolute_form() {
251 for u in [
255 "https://host.example/dir",
256 "https://host.example/dir?x=1&y=2",
257 "https://host.example",
258 "https://host.example/p#frag",
259 "https://host.example:14000/path?q=1",
260 "http://host.example/",
261 ] {
262 let parsed = url(u);
263 let target = origin_form_target(&parsed);
264 assert!(
265 target.starts_with('/'),
266 "origin-form target must start with '/': {u} -> {target}"
267 );
268 assert!(
269 !target.starts_with("https://") && !target.starts_with("http://"),
270 "origin-form target must not be absolute-form: {u} -> {target}"
271 );
272 assert!(
273 !target.contains("host.example"),
274 "origin-form target must not contain the host: {u} -> {target}"
275 );
276 }
277 }
278
279 #[test]
280 fn default_user_agent_is_crate_versioned_and_nonempty() {
281 assert_eq!(
282 DEFAULT_USER_AGENT,
283 concat!("tailscale-rs/", env!("CARGO_PKG_VERSION"))
284 );
285 assert!(!DEFAULT_USER_AGENT.is_empty());
286 assert_eq!(
288 HeaderValue::from_static(DEFAULT_USER_AGENT),
289 DEFAULT_USER_AGENT
290 );
291 }
292
293 struct CapturingClient {
297 seen: Rc<RefCell<Option<http::request::Parts>>>,
298 }
299
300 impl Client<Empty<Bytes>> for CapturingClient {
301 fn send(
302 &self,
303 req: Request<Empty<Bytes>>,
304 ) -> impl Future<Output = Result<Response<Incoming>, Error>> + Send {
305 *self.seen.borrow_mut() = Some(req.into_parts().0);
306 async { Err(Error::Io) }
307 }
308 }
309
310 fn drive_ready<F: Future>(fut: F) -> F::Output {
314 let mut cx = Context::from_waker(Waker::noop());
315 let mut fut = pin!(fut);
316 match fut.as_mut().poll(&mut cx) {
317 Poll::Ready(out) => out,
318 Poll::Pending => panic!("future did not complete on first poll"),
319 }
320 }
321
322 #[test]
323 fn get_appends_default_user_agent_header() {
324 let seen = Rc::new(RefCell::new(None));
325 let client = CapturingClient { seen: seen.clone() };
326 assert!(drive_ready(client.get(&url("https://h/dir"), std::iter::empty())).is_err());
328 let parts = seen.borrow();
329 let parts = parts.as_ref().expect("request was sent");
330 assert_eq!(
331 parts.headers.get(USER_AGENT).unwrap(),
332 concat!("tailscale-rs/", env!("CARGO_PKG_VERSION"))
333 );
334 }
335
336 #[test]
337 fn post_appends_default_user_agent_header() {
338 let seen = Rc::new(RefCell::new(None));
339 let client = CapturingClient { seen: seen.clone() };
340 assert!(
342 drive_ready(client.post(&url("https://h/dir"), std::iter::empty(), Empty::new()))
343 .is_err()
344 );
345 let parts = seen.borrow();
346 let parts = parts.as_ref().expect("request was sent");
347 assert_eq!(
348 parts.headers.get(USER_AGENT).unwrap(),
349 concat!("tailscale-rs/", env!("CARGO_PKG_VERSION"))
350 );
351 }
352}