use crate::retry::Backoff;
use crate::util::collect_scopes;
use crate::{Credentials, Retry};
use anyhow::{anyhow, bail, Context, Error, Result};
use reqwest::header::HeaderValue;
use serde_json::json;
use serde_json::Value;
use std::iter::IntoIterator;
use std::str::FromStr;
use std::time::Duration;
#[derive(Default, Debug, Clone)]
pub struct ClientBuilder {
root_url: String,
retry: Retry,
credentials: Option<Credentials>,
path_prefix: Option<String>,
authorized_scopes: Option<Vec<String>>,
timeout: Duration,
}
impl ClientBuilder {
pub fn new<S: Into<String>>(root_url: S) -> Self {
Self {
root_url: root_url.into(),
timeout: Duration::from_secs(30),
..Self::default()
}
}
pub fn credentials(mut self, credentials: Credentials) -> Self {
self.credentials = Some(credentials);
self
}
pub fn retry(mut self, retry: Retry) -> Self {
self.retry = retry;
self
}
pub fn timeout(mut self, timeout: Duration) -> Self {
self.timeout = timeout;
self
}
pub(crate) fn path_prefix<S: Into<String>>(mut self, path_prefix: S) -> Self {
let path_prefix = path_prefix.into();
debug_assert!(path_prefix.ends_with('/'));
self.path_prefix = Some(path_prefix);
self
}
pub fn authorized_scopes(
mut self,
authorized_scopes: impl IntoIterator<Item = impl AsRef<str>>,
) -> Self {
let authorized_scopes = collect_scopes(authorized_scopes);
self.authorized_scopes = Some(authorized_scopes);
self
}
pub fn build(self) -> Result<Client> {
Client::new(self)
}
}
impl From<String> for ClientBuilder {
fn from(root_url: String) -> Self {
Self::new(root_url)
}
}
impl From<&str> for ClientBuilder {
fn from(root_url: &str) -> Self {
Self::new(root_url)
}
}
pub struct Client {
credentials: Option<hawk::Credentials>,
ext: Option<String>,
retry: Retry,
base_url: reqwest::Url,
host: String,
port: u16,
client: reqwest::Client,
}
impl Client {
fn new(b: ClientBuilder) -> Result<Client> {
let mut base_url = reqwest::Url::parse(b.root_url.as_ref())
.context(format!("while parsing {}", b.root_url))?;
let host = base_url
.host_str()
.ok_or_else(|| anyhow!("The root URL {} doesn't contain a host", b.root_url))?
.to_owned();
let port = base_url
.port_or_known_default()
.ok_or_else(|| anyhow!("Unkown port for protocol {}", base_url.scheme()))?;
if let Some(path_prefix) = b.path_prefix {
base_url = base_url.join(path_prefix.as_ref()).context(format!(
"while adding path_prefix to root_url {}",
b.root_url
))?;
}
let retry = b.retry;
let timeout = b.timeout;
let client = reqwest::Client::builder()
.redirect(reqwest::redirect::Policy::none())
.timeout(timeout)
.build()?;
let mut certificate: Option<Value> = None;
if let Some(Credentials {
certificate: Some(ref cert_str),
..
}) = b.credentials
{
certificate = Some(
serde_json::from_str(cert_str)
.context("while parsing given certificate as JSON")?,
);
}
let mut authorized_scopes: Option<Value> = None;
if let Some(scopes) = b.authorized_scopes {
authorized_scopes = Some(scopes.into());
}
let ext_json = match (certificate, authorized_scopes) {
(Some(c), None) => Some(json!({ "certificate": c })),
(None, Some(s)) => Some(json!({ "authorizedScopes": s })),
(Some(c), Some(s)) => Some(json!({ "certificate": c, "authorizedScopes": s })),
(None, None) => None,
};
let ext = if let Some(ext) = ext_json {
let ext_str = serde_json::to_string(&ext)?;
Some(base64::encode_config(ext_str, base64::URL_SAFE_NO_PAD))
} else {
None
};
let credentials = match b.credentials {
None => None,
Some(c) => Some(hawk::Credentials {
id: c.client_id.clone(),
key: hawk::Key::new(&c.access_token, hawk::SHA256).context(c.client_id)?,
}),
};
Ok(Client {
credentials,
ext,
retry,
base_url,
host,
port,
client,
})
}
pub async fn request(
&self,
method: &str,
path: &str,
query: Option<Vec<(&str, &str)>>,
body: Option<&Value>,
) -> Result<reqwest::Response, Error> {
let mut backoff = Backoff::new(&self.retry);
let req = self.build_request(method, path, query, body)?;
let url = req.url().as_str();
let mut retries = self.retry.retries;
loop {
let req = req
.try_clone()
.ok_or_else(|| anyhow!("Cannot clone the request {}", url))?;
let retry_for;
match self.client.execute(req).await {
Err(e) => {
retry_for = e;
}
Ok(resp) if resp.status().is_server_error() => {
retry_for = resp.error_for_status().err().unwrap();
}
Ok(resp) if resp.status().is_client_error() => {
let err = resp.error_for_status_ref().err().unwrap();
if let Ok(json) = resp.json::<Value>().await {
if let Some(message) = json.get("message") {
if let Some(s) = message.as_str() {
return Err(Error::from(err).context(s.to_owned()));
}
}
}
return Err(err.into());
}
Ok(resp) => {
return Ok(resp);
}
};
if retries == 0 {
return Err(retry_for.into());
}
retries -= 1;
match backoff.next_backoff() {
Some(duration) => tokio::time::sleep(duration).await,
None => return Err(retry_for.into()),
}
}
}
fn build_request(
&self,
method: &str,
path: &str,
query: Option<Vec<(&str, &str)>>,
body: Option<&Value>,
) -> Result<reqwest::Request, Error> {
if path.starts_with('/') {
bail!("Request path must not begin with `/`");
}
let mut url = self.base_url.join(path)?;
if let Some(q) = query {
url.query_pairs_mut().extend_pairs(q);
}
let meth = reqwest::Method::from_str(method)?;
let req = self.client.request(meth, url);
let req = match body {
Some(b) => req.json(&b),
None => req,
};
let req = req.build()?;
match self.credentials {
Some(ref creds) => self.sign_request(creds, req),
None => Ok(req),
}
}
fn sign_request(
&self,
creds: &hawk::Credentials,
req: reqwest::Request,
) -> Result<reqwest::Request, Error> {
let mut signed_req_builder = hawk::RequestBuilder::new(
req.method().as_str(),
&self.host,
self.port,
req.url().path(),
);
let payload_hash;
if let Some(ref b) = req.body() {
let b = b
.as_bytes()
.ok_or_else(|| anyhow!("stream request bodies are not supported"))?;
payload_hash = hawk::PayloadHasher::hash("application/json", hawk::SHA256, b)?;
signed_req_builder = signed_req_builder.hash(&payload_hash[..])
}
signed_req_builder = signed_req_builder.ext(self.ext.as_ref().map(|s| s.as_ref()));
let header = signed_req_builder.request().make_header(&creds)?;
let token = HeaderValue::from_str(format!("Hawk {}", header).as_str()).context(header)?;
let mut req = req;
req.headers_mut().insert("Authorization", token);
Ok(req)
}
pub fn make_url(&self, path: &str, query: Option<Vec<(&str, &str)>>) -> Result<String> {
if path.starts_with('/') {
bail!("Request path must not begin with `/`");
}
let mut url = self.base_url.join(path)?;
if let Some(q) = query {
url.query_pairs_mut().extend_pairs(q);
}
Ok(url.as_ref().to_owned())
}
pub fn make_signed_url(
&self,
path: &str,
query: Option<Vec<(&str, &str)>>,
ttl: Duration,
) -> Result<String> {
if path.starts_with('/') {
bail!("Request path must not begin with `/`");
}
let creds = if let Some(ref creds) = self.credentials {
creds
} else {
return Err(anyhow!("Cannot sign a URL without credentials"));
};
let mut url = self.base_url.join(path)?;
if let Some(q) = query {
url.query_pairs_mut().extend_pairs(q);
}
let path_with_query = match url.query() {
Some(q) => format!("{}?{}", url.path(), q),
None => url.path().to_owned(),
};
let req = hawk::RequestBuilder::new("GET", &self.host, self.port, &path_with_query)
.ext(self.ext.as_ref().map(|s| s.as_ref()))
.request();
let bewit = req.make_bewit_with_ttl(creds, ttl)?;
url.query_pairs_mut().append_pair("bewit", &bewit.to_str());
Ok(url.as_ref().to_owned())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::err_status_code;
use anyhow::bail;
use httptest::{matchers::*, responders::*, Expectation, Server};
use serde_json::json;
use std::fmt;
use std::net::SocketAddr;
use std::time::Duration;
use tokio;
pub fn signed_with(creds: Credentials, addr: SocketAddr) -> SignedWith {
SignedWith(creds, addr)
}
#[derive(Debug)]
pub struct SignedWith(Credentials, SocketAddr);
impl<B> Matcher<httptest::http::Request<B>> for SignedWith {
fn matches(
&mut self,
input: &httptest::http::Request<B>,
_ctx: &mut ExecutionContext,
) -> bool {
let auth_header = input
.headers()
.get(httptest::http::header::AUTHORIZATION)
.unwrap();
let auth_header = auth_header.to_str().unwrap();
if !auth_header.starts_with("Hawk ") {
println!("Authorization header does not start with Hawk");
return false;
}
let auth_header: hawk::Header = auth_header[5..].parse().unwrap();
let host = format!("{}", self.1.ip());
let hawk_req = hawk::RequestBuilder::new(
input.method().as_str(),
&host,
self.1.port(),
input.uri().path(),
)
.request();
let key = hawk::Key::new(&self.0.access_token, hawk::SHA256).unwrap();
if !hawk_req.validate_header(&auth_header, &key, Duration::from_secs(60)) {
println!("Validation failed");
return false;
}
true
}
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
<Self as fmt::Debug>::fmt(self, f)
}
}
fn get_authorized_scopes(client: &Client) -> Result<Vec<String>> {
let ext = if let Some(ref ext) = client.ext {
ext
} else {
bail!("client has no ext")
};
let ext = base64::decode(ext)?;
#[derive(serde::Deserialize)]
#[serde(rename_all = "camelCase")]
struct Certificate {
authorized_scopes: Vec<String>,
}
let ext = serde_json::from_slice::<Certificate>(&ext)?;
Ok(ext.authorized_scopes)
}
#[test]
fn test_authorized_scopes_vec() {
let client = ClientBuilder::new("https://tc-tests.example.com")
.authorized_scopes(vec!["a-scope"])
.build()
.unwrap();
assert_eq!(get_authorized_scopes(&client).unwrap(), vec!["a-scope"]);
}
#[test]
fn test_authorized_scopes_iter() {
let nums = vec![1, 2, 3];
let client = ClientBuilder::new("https://tc-tests.example.com")
.authorized_scopes(nums.iter().map(|n| format!("scope:{}", n)))
.build()
.unwrap();
assert_eq!(
get_authorized_scopes(&client).unwrap(),
vec!["scope:1", "scope:2", "scope:3"]
);
}
#[tokio::test]
async fn test_simple_request() -> Result<(), Error> {
let server = Server::run();
server.expect(
Expectation::matching(request::method_path("GET", "/api/queue/v1/ping"))
.respond_with(status_code(200)),
);
let root_url = format!("http://{}", server.addr());
let client = ClientBuilder::new(&root_url)
.path_prefix("api/queue/v1/")
.build()?;
let resp = client.request("GET", "ping", None, None).await?;
assert!(resp.status().is_success());
Ok(())
}
#[tokio::test]
async fn test_timeout() -> Result<(), Error> {
let server = Server::run();
server.expect(
Expectation::matching(request::method_path("GET", "/api/queue/v1/ping")).respond_with(
delay_and_then(Duration::from_secs(30), status_code(200)),
),
);
let root_url = format!("http://{}", server.addr());
let client = ClientBuilder::new(&root_url)
.path_prefix("api/queue/v1/")
.timeout(Duration::from_millis(5))
.retry(Retry {
retries: 0,
..Default::default()
})
.build()?;
let err = client.request("GET", "ping", None, None).await.unwrap_err();
let reqerr = err.downcast::<reqwest::Error>().unwrap();
assert!(reqerr.is_timeout());
Ok(())
}
#[tokio::test]
async fn test_simple_request_with_perm_creds() -> Result<(), Error> {
let creds = Credentials::new("clientId", "accessToken");
let server = Server::run();
server.expect(
Expectation::matching(all_of![
request::method_path("GET", "/api/queue/v1/ping"),
signed_with(creds.clone(), server.addr()),
])
.respond_with(status_code(200)),
);
let root_url = format!("http://{}", server.addr());
let client = ClientBuilder::new(&root_url)
.path_prefix("api/queue/v1/")
.credentials(creds)
.build()?;
let resp = client.request("GET", "ping", None, None).await?;
assert!(resp.status().is_success());
Ok(())
}
#[tokio::test]
async fn test_query() -> Result<(), Error> {
let server = Server::run();
server.expect(
Expectation::matching(all_of![
request::method_path("GET", "/api/queue/v1/test"),
request::query(url_decoded(contains(("taskcluster", "test")))),
request::query(url_decoded(contains(("client", "rust")))),
])
.respond_with(status_code(200)),
);
let root_url = format!("http://{}", server.addr());
let client = ClientBuilder::new(&root_url)
.path_prefix("api/queue/v1/")
.build()?;
let resp = client
.request(
"GET",
"test",
Some(vec![("taskcluster", "test"), ("client", "rust")]),
None,
)
.await?;
assert!(resp.status().is_success());
Ok(())
}
#[tokio::test]
async fn test_body() -> Result<(), Error> {
let body = json!({"hello": "world"});
let server = Server::run();
server.expect(
Expectation::matching(all_of![
request::method_path("POST", "/api/queue/v1/test"),
request::body(json_decoded(eq(body.clone()))),
])
.respond_with(status_code(200)),
);
let root_url = format!("http://{}", server.addr());
let client = ClientBuilder::new(&root_url)
.path_prefix("api/queue/v1/")
.build()?;
let resp = client.request("POST", "test", None, Some(&body)).await?;
assert!(resp.status().is_success());
Ok(())
}
#[test]
fn make_url_simple() -> Result<(), Error> {
let client = ClientBuilder::new("https://tc-test.example.com")
.path_prefix("api/queue/v1/")
.build()?;
let url = client.make_url("ping", None)?;
assert_eq!(url, "https://tc-test.example.com/api/queue/v1/ping");
Ok(())
}
#[test]
fn make_url_escapable_characters() -> Result<(), Error> {
let client = ClientBuilder::new("https://tc-test.example.com")
.path_prefix("api/queue/v1/")
.build()?;
let url = client.make_url("escape%2Fthis!", None)?;
assert_eq!(
url,
"https://tc-test.example.com/api/queue/v1/escape%2Fthis!"
);
Ok(())
}
#[test]
fn make_url_query() -> Result<(), Error> {
let client = ClientBuilder::new("https://tc-test.example.com")
.path_prefix("api/queue/v1/")
.build()?;
let url = client.make_url("a/b/c", Some(vec![("abc", "def"), ("x!z", "1/3")]))?;
assert_eq!(
url,
"https://tc-test.example.com/api/queue/v1/a/b/c?abc=def&x%21z=1%2F3"
);
Ok(())
}
#[test]
fn make_signed_url_simple() -> Result<(), Error> {
let creds = Credentials::new("clientId", "accessToken");
let client = ClientBuilder::new("https://tc-test.example.com")
.path_prefix("api/queue/v1/")
.credentials(creds)
.build()?;
let url = client.make_signed_url("a/b", None, Duration::from_secs(10))?;
assert!(url.starts_with("https://tc-test.example.com/api/queue/v1/a/b?bewit="));
Ok(())
}
#[test]
fn make_signed_url_query() -> Result<(), Error> {
let creds = Credentials::new("clientId", "accessToken");
let client = ClientBuilder::new("https://tc-test.example.com")
.path_prefix("api/queue/v1/")
.credentials(creds)
.build()?;
let url = client.make_signed_url(
"a/b/c",
Some(vec![("abc", "def"), ("xyz", "1/3")]),
Duration::from_secs(10),
)?;
assert!(url.starts_with(
"https://tc-test.example.com/api/queue/v1/a/b/c?abc=def&xyz=1%2F3&bewit="
));
Ok(())
}
fn retry_fast() -> Retry {
Retry {
retries: 6,
max_delay: Duration::from_millis(1),
..Default::default()
}
}
#[tokio::test]
async fn test_500_retry() -> Result<(), Error> {
let server = Server::run();
server.expect(
Expectation::matching(request::method_path("GET", "/api/queue/v1/test"))
.times(7) .respond_with(status_code(500)),
);
let root_url = format!("http://{}", server.addr());
let client = ClientBuilder::new(root_url)
.path_prefix("api/queue/v1/")
.retry(retry_fast())
.build()?;
let result = client.request("GET", "test", None, None).await;
assert!(result.is_err());
let reqw_err: reqwest::Error = result.err().unwrap().downcast()?;
assert_eq!(reqw_err.status().unwrap(), 500);
Ok(())
}
#[tokio::test]
async fn test_400_no_retry() -> Result<(), Error> {
let server = Server::run();
server.expect(
Expectation::matching(request::method_path("GET", "/api/queue/v1/test"))
.times(1)
.respond_with(status_code(400)),
);
let root_url = format!("http://{}", server.addr());
let client = ClientBuilder::new(root_url)
.path_prefix("api/queue/v1/")
.retry(retry_fast())
.build()?;
let result = client.request("GET", "test", None, None).await;
assert!(result.is_err());
assert_eq!(
err_status_code(&result.err().unwrap()),
Some(reqwest::StatusCode::BAD_REQUEST)
);
Ok(())
}
#[tokio::test]
async fn test_303_no_follow() -> Result<(), Error> {
let server = Server::run();
server.expect(
Expectation::matching(request::method_path("GET", "/api/queue/v1/test"))
.times(1)
.respond_with(
status_code(303)
.insert_header("location", "http://httpstat.us/404")
.insert_header("content-type", "application/json")
.body("{\"url\":\"http://httpstat.us/404\"}"),
),
);
let root_url = format!("http://{}", server.addr());
let client = ClientBuilder::new(root_url)
.path_prefix("api/queue/v1/")
.retry(retry_fast())
.build()?;
let resp = client.request("GET", "test", None, None).await?;
assert_eq!(resp.status(), 303);
assert_eq!(
resp.json::<serde_json::Value>().await?,
json!({"url": "http://httpstat.us/404"})
);
Ok(())
}
}