use std::sync::Arc;
use http::{HeaderName, HeaderValue, Request, Response, header::USER_AGENT};
use hyper::body::{Body, Incoming};
use crate::Error;
const DEFAULT_USER_AGENT: &str = concat!("tailscale-rs/", env!("CARGO_PKG_VERSION"));
fn origin_form_target(url: &url::Url) -> String {
match url.query() {
Some(query) => format!("{}?{}", url.path(), query),
None => url.path().to_owned(),
}
}
pub trait Client<B>
where
B: Body + Send + 'static,
<B as Body>::Data: Send,
B::Error: Send + Sync + 'static,
{
fn send(
&self,
req: Request<B>,
) -> impl Future<Output = Result<Response<Incoming>, Error>> + Send;
}
pub trait ClientExt<B>: Client<B>
where
B: Body + Send + 'static,
<B as Body>::Data: Send,
B::Error: Send + Sync + 'static,
{
fn get(
&self,
url: &url::Url,
headers: impl IntoIterator<Item = (HeaderName, HeaderValue)>,
) -> impl Future<Output = Result<Response<Incoming>, Error>>
where
B: Default,
{
let mut req = Request::get(origin_form_target(url));
if let Some(hdrs) = req.headers_mut() {
hdrs.append(USER_AGENT, HeaderValue::from_static(DEFAULT_USER_AGENT));
hdrs.extend(crate::host_header(url));
hdrs.extend(headers);
}
async move {
let req = req.body(Default::default()).map_err(|e| {
tracing::error!(error = %e, "constructing request");
Error::InvalidInput
})?;
self.send(req).await
}
}
fn post(
&self,
url: &url::Url,
headers: impl IntoIterator<Item = (HeaderName, HeaderValue)>,
body: B,
) -> impl Future<Output = Result<Response<Incoming>, Error>> {
let mut req = Request::post(origin_form_target(url));
if let Some(hdrs) = req.headers_mut() {
hdrs.append(USER_AGENT, HeaderValue::from_static(DEFAULT_USER_AGENT));
hdrs.extend(crate::host_header(url));
hdrs.extend(headers);
}
async move {
let req = req.body(body).map_err(|e| {
tracing::error!(error = %e, "constructing request");
Error::InvalidInput
})?;
self.send(req).await
}
}
}
impl<T, B> ClientExt<B> for T
where
T: Client<B>,
B: Body + Send + 'static,
<B as Body>::Data: Send,
B::Error: Send + Sync + 'static,
{
}
impl<T, B> Client<B> for Arc<T>
where
T: Client<B>,
B: Body + Send + 'static,
<B as Body>::Data: Send,
B::Error: Send + Sync + 'static,
{
fn send(
&self,
req: Request<B>,
) -> impl Future<Output = Result<Response<Incoming>, Error>> + Send {
self.as_ref().send(req)
}
}
impl<T, B> Client<B> for &T
where
T: Client<B>,
B: Body + Send + 'static,
<B as Body>::Data: Send,
B::Error: Send + Sync + 'static,
{
fn send(
&self,
req: Request<B>,
) -> impl Future<Output = Result<Response<Incoming>, Error>> + Send {
(**self).send(req)
}
}
impl<T, B> Client<B> for &mut T
where
T: Client<B>,
B: Body + Send + 'static,
<B as Body>::Data: Send,
B::Error: Send + Sync + 'static,
{
fn send(
&self,
req: Request<B>,
) -> impl Future<Output = Result<Response<Incoming>, Error>> + Send {
(**self).send(req)
}
}
#[cfg(test)]
mod tests {
use std::{
cell::RefCell,
pin::pin,
rc::Rc,
task::{Context, Poll, Waker},
};
use bytes::Bytes;
use http_body_util::Empty;
use super::*;
fn url(s: &str) -> url::Url {
url::Url::parse(s).unwrap()
}
#[test]
fn origin_form_target_no_query() {
assert_eq!(origin_form_target(&url("https://h/dir")), "/dir");
}
#[test]
fn origin_form_target_with_query() {
assert_eq!(
origin_form_target(&url("https://h/dir?x=1&y=2")),
"/dir?x=1&y=2"
);
}
#[test]
fn origin_form_target_root_path() {
assert_eq!(origin_form_target(&url("https://h")), "/");
}
#[test]
fn origin_form_target_excludes_fragment() {
assert_eq!(origin_form_target(&url("https://h/p#frag")), "/p");
}
#[test]
fn origin_form_target_is_never_absolute_form() {
for u in [
"https://host.example/dir",
"https://host.example/dir?x=1&y=2",
"https://host.example",
"https://host.example/p#frag",
"https://host.example:14000/path?q=1",
"http://host.example/",
] {
let parsed = url(u);
let target = origin_form_target(&parsed);
assert!(
target.starts_with('/'),
"origin-form target must start with '/': {u} -> {target}"
);
assert!(
!target.starts_with("https://") && !target.starts_with("http://"),
"origin-form target must not be absolute-form: {u} -> {target}"
);
assert!(
!target.contains("host.example"),
"origin-form target must not contain the host: {u} -> {target}"
);
}
}
#[test]
fn default_user_agent_is_crate_versioned_and_nonempty() {
assert_eq!(
DEFAULT_USER_AGENT,
concat!("tailscale-rs/", env!("CARGO_PKG_VERSION"))
);
assert!(!DEFAULT_USER_AGENT.is_empty());
assert_eq!(
HeaderValue::from_static(DEFAULT_USER_AGENT),
DEFAULT_USER_AGENT
);
}
struct CapturingClient {
seen: Rc<RefCell<Option<http::request::Parts>>>,
}
impl Client<Empty<Bytes>> for CapturingClient {
fn send(
&self,
req: Request<Empty<Bytes>>,
) -> impl Future<Output = Result<Response<Incoming>, Error>> + Send {
*self.seen.borrow_mut() = Some(req.into_parts().0);
async { Err(Error::Io) }
}
}
fn drive_ready<F: Future>(fut: F) -> F::Output {
let mut cx = Context::from_waker(Waker::noop());
let mut fut = pin!(fut);
match fut.as_mut().poll(&mut cx) {
Poll::Ready(out) => out,
Poll::Pending => panic!("future did not complete on first poll"),
}
}
#[test]
fn get_appends_default_user_agent_header() {
let seen = Rc::new(RefCell::new(None));
let client = CapturingClient { seen: seen.clone() };
assert!(drive_ready(client.get(&url("https://h/dir"), std::iter::empty())).is_err());
let parts = seen.borrow();
let parts = parts.as_ref().expect("request was sent");
assert_eq!(
parts.headers.get(USER_AGENT).unwrap(),
concat!("tailscale-rs/", env!("CARGO_PKG_VERSION"))
);
}
#[test]
fn post_appends_default_user_agent_header() {
let seen = Rc::new(RefCell::new(None));
let client = CapturingClient { seen: seen.clone() };
assert!(
drive_ready(client.post(&url("https://h/dir"), std::iter::empty(), Empty::new()))
.is_err()
);
let parts = seen.borrow();
let parts = parts.as_ref().expect("request was sent");
assert_eq!(
parts.headers.get(USER_AGENT).unwrap(),
concat!("tailscale-rs/", env!("CARGO_PKG_VERSION"))
);
}
}