1use std::sync::Arc;
2
3use http::{HeaderName, HeaderValue, Request, Response, header::USER_AGENT};
4use hyper::body::{Body, Incoming};
5
6use crate::Error;
7
8const DEFAULT_USER_AGENT: &str = concat!("tailscale-rs/", env!("CARGO_PKG_VERSION"));
14
15fn origin_form_target(url: &url::Url) -> String {
23 match url.query() {
24 Some(query) => format!("{}?{}", url.path(), query),
25 None => url.path().to_owned(),
26 }
27}
28
29pub trait Client<B>
36where
37 B: Body + Send + 'static,
38 <B as Body>::Data: Send,
39 B::Error: Send + Sync + 'static,
40{
41 fn send(
47 &self,
48 req: Request<B>,
49 ) -> impl Future<Output = Result<Response<Incoming>, Error>> + Send;
50}
51
52pub trait ClientExt<B>: Client<B>
55where
56 B: Body + Send + 'static,
57 <B as Body>::Data: Send,
58 B::Error: Send + Sync + 'static,
59{
60 fn get(
66 &self,
67 url: &url::Url,
68 headers: impl IntoIterator<Item = (HeaderName, HeaderValue)>,
69 ) -> impl Future<Output = Result<Response<Incoming>, Error>>
70 where
71 B: Default,
72 {
73 let mut req = Request::get(origin_form_target(url));
74
75 if let Some(hdrs) = req.headers_mut() {
76 hdrs.append(USER_AGENT, HeaderValue::from_static(DEFAULT_USER_AGENT));
77 hdrs.extend(crate::host_header(url));
78 hdrs.extend(headers);
79 }
80
81 async move {
82 let req = req.body(Default::default()).map_err(|e| {
83 tracing::error!(error = %e, "constructing request");
84 Error::InvalidInput
85 })?;
86
87 self.send(req).await
88 }
89 }
90
91 fn post(
97 &self,
98 url: &url::Url,
99 headers: impl IntoIterator<Item = (HeaderName, HeaderValue)>,
100 body: B,
101 ) -> impl Future<Output = Result<Response<Incoming>, Error>> {
102 let mut req = Request::post(origin_form_target(url));
103
104 if let Some(hdrs) = req.headers_mut() {
105 hdrs.append(USER_AGENT, HeaderValue::from_static(DEFAULT_USER_AGENT));
106 hdrs.extend(crate::host_header(url));
107 hdrs.extend(headers);
108 }
109
110 async move {
111 let req = req.body(body).map_err(|e| {
112 tracing::error!(error = %e, "constructing request");
113 Error::InvalidInput
114 })?;
115
116 self.send(req).await
117 }
118 }
119}
120
121impl<T, B> ClientExt<B> for T
122where
123 T: Client<B>,
124 B: Body + Send + 'static,
125 <B as Body>::Data: Send,
126 B::Error: Send + Sync + 'static,
127{
128}
129
130impl<T, B> Client<B> for Arc<T>
131where
132 T: Client<B>,
133 B: Body + Send + 'static,
134 <B as Body>::Data: Send,
135 B::Error: Send + Sync + 'static,
136{
137 fn send(
138 &self,
139 req: Request<B>,
140 ) -> impl Future<Output = Result<Response<Incoming>, Error>> + Send {
141 self.as_ref().send(req)
142 }
143}
144
145impl<T, B> Client<B> for &T
146where
147 T: Client<B>,
148 B: Body + Send + 'static,
149 <B as Body>::Data: Send,
150 B::Error: Send + Sync + 'static,
151{
152 fn send(
153 &self,
154 req: Request<B>,
155 ) -> impl Future<Output = Result<Response<Incoming>, Error>> + Send {
156 (**self).send(req)
157 }
158}
159
160impl<T, B> Client<B> for &mut T
161where
162 T: Client<B>,
163 B: Body + Send + 'static,
164 <B as Body>::Data: Send,
165 B::Error: Send + Sync + 'static,
166{
167 fn send(
168 &self,
169 req: Request<B>,
170 ) -> impl Future<Output = Result<Response<Incoming>, Error>> + Send {
171 (**self).send(req)
172 }
173}
174
175#[cfg(test)]
176mod tests {
177 use std::{
178 cell::RefCell,
179 pin::pin,
180 rc::Rc,
181 task::{Context, Poll, Waker},
182 };
183
184 use bytes::Bytes;
185 use http_body_util::Empty;
186
187 use super::*;
188
189 fn url(s: &str) -> url::Url {
190 url::Url::parse(s).unwrap()
191 }
192
193 #[test]
194 fn origin_form_target_no_query() {
195 assert_eq!(origin_form_target(&url("https://h/dir")), "/dir");
196 }
197
198 #[test]
199 fn origin_form_target_with_query() {
200 assert_eq!(
201 origin_form_target(&url("https://h/dir?x=1&y=2")),
202 "/dir?x=1&y=2"
203 );
204 }
205
206 #[test]
207 fn origin_form_target_root_path() {
208 assert_eq!(origin_form_target(&url("https://h")), "/");
210 }
211
212 #[test]
213 fn origin_form_target_excludes_fragment() {
214 assert_eq!(origin_form_target(&url("https://h/p#frag")), "/p");
216 }
217
218 #[test]
219 fn origin_form_target_is_never_absolute_form() {
220 for u in [
224 "https://host.example/dir",
225 "https://host.example/dir?x=1&y=2",
226 "https://host.example",
227 "https://host.example/p#frag",
228 "https://host.example:14000/path?q=1",
229 "http://host.example/",
230 ] {
231 let parsed = url(u);
232 let target = origin_form_target(&parsed);
233 assert!(
234 target.starts_with('/'),
235 "origin-form target must start with '/': {u} -> {target}"
236 );
237 assert!(
238 !target.starts_with("https://") && !target.starts_with("http://"),
239 "origin-form target must not be absolute-form: {u} -> {target}"
240 );
241 assert!(
242 !target.contains("host.example"),
243 "origin-form target must not contain the host: {u} -> {target}"
244 );
245 }
246 }
247
248 #[test]
249 fn default_user_agent_is_crate_versioned_and_nonempty() {
250 assert_eq!(
251 DEFAULT_USER_AGENT,
252 concat!("tailscale-rs/", env!("CARGO_PKG_VERSION"))
253 );
254 assert!(!DEFAULT_USER_AGENT.is_empty());
255 assert_eq!(
257 HeaderValue::from_static(DEFAULT_USER_AGENT),
258 DEFAULT_USER_AGENT
259 );
260 }
261
262 struct CapturingClient {
266 seen: Rc<RefCell<Option<http::request::Parts>>>,
267 }
268
269 impl Client<Empty<Bytes>> for CapturingClient {
270 fn send(
271 &self,
272 req: Request<Empty<Bytes>>,
273 ) -> impl Future<Output = Result<Response<Incoming>, Error>> + Send {
274 *self.seen.borrow_mut() = Some(req.into_parts().0);
275 async { Err(Error::Io) }
276 }
277 }
278
279 fn drive_ready<F: Future>(fut: F) -> F::Output {
283 let mut cx = Context::from_waker(Waker::noop());
284 let mut fut = pin!(fut);
285 match fut.as_mut().poll(&mut cx) {
286 Poll::Ready(out) => out,
287 Poll::Pending => panic!("future did not complete on first poll"),
288 }
289 }
290
291 #[test]
292 fn get_appends_default_user_agent_header() {
293 let seen = Rc::new(RefCell::new(None));
294 let client = CapturingClient { seen: seen.clone() };
295 assert!(drive_ready(client.get(&url("https://h/dir"), std::iter::empty())).is_err());
297 let parts = seen.borrow();
298 let parts = parts.as_ref().expect("request was sent");
299 assert_eq!(
300 parts.headers.get(USER_AGENT).unwrap(),
301 concat!("tailscale-rs/", env!("CARGO_PKG_VERSION"))
302 );
303 }
304
305 #[test]
306 fn post_appends_default_user_agent_header() {
307 let seen = Rc::new(RefCell::new(None));
308 let client = CapturingClient { seen: seen.clone() };
309 assert!(
311 drive_ready(client.post(&url("https://h/dir"), std::iter::empty(), Empty::new()))
312 .is_err()
313 );
314 let parts = seen.borrow();
315 let parts = parts.as_ref().expect("request was sent");
316 assert_eq!(
317 parts.headers.get(USER_AGENT).unwrap(),
318 concat!("tailscale-rs/", env!("CARGO_PKG_VERSION"))
319 );
320 }
321}