use crate::url::UrlExt;
use anyhow::{Context, Result};
use reqwest::header::HeaderMap;
use reqwest::{redirect::Policy, Client, Proxy};
use std::collections::HashMap;
use std::convert::TryInto;
use std::path::Path;
use std::time::Duration;
use url::Url;
pub struct ClientConfig<'a, I>
where
I: IntoIterator,
I::Item: AsRef<Path> + std::fmt::Debug,
{
pub timeout: u64,
pub user_agent: &'a str,
pub redirects: bool,
pub insecure: bool,
pub headers: &'a HashMap<String, String>,
pub proxy: Option<&'a str>,
pub server_certs: Option<I>,
pub client_cert: Option<&'a str>,
pub client_key: Option<&'a str>,
pub scope: &'a [Url],
}
fn create_redirect_policy<I>(config: &ClientConfig<'_, I>) -> Policy
where
I: IntoIterator,
I::Item: AsRef<Path> + std::fmt::Debug,
{
if config.redirects && config.scope.is_empty() {
Policy::limited(10)
} else if config.redirects {
let scoped_urls = config.scope.to_vec();
Policy::custom(move |attempt| {
let redirect_url = attempt.url();
if redirect_url.is_in_scope(&scoped_urls) {
attempt.follow()
} else {
attempt.stop()
}
})
} else {
Policy::none()
}
}
pub fn initialize<I>(config: ClientConfig<'_, I>) -> Result<Client>
where
I: IntoIterator,
I::Item: AsRef<Path> + std::fmt::Debug,
{
let policy = create_redirect_policy(&config);
let header_map: HeaderMap = config.headers.try_into()?;
let mut client = Client::builder()
.timeout(Duration::new(config.timeout, 0))
.user_agent(config.user_agent)
.danger_accept_invalid_certs(config.insecure)
.default_headers(header_map)
.redirect(policy)
.http1_title_case_headers();
if let Some(some_proxy) = config.proxy {
if !some_proxy.is_empty() {
let proxy_obj = Proxy::all(some_proxy)?;
client = client.proxy(proxy_obj);
}
}
for cert_path in config.server_certs.into_iter().flatten() {
let buf = std::fs::read(&cert_path)?;
let cert = match reqwest::Certificate::from_pem(&buf) {
Ok(cert) => cert,
Err(err) => reqwest::Certificate::from_der(&buf).with_context(|| {
format!(
"{:?} does not contain a valid PEM or DER certificate\n{}",
&cert_path, err
)
})?,
};
client = client.add_root_certificate(cert);
}
if let (Some(cert_path), Some(key_path)) = (config.client_cert, config.client_key) {
if !cert_path.is_empty() && !key_path.is_empty() {
let cert = std::fs::read(cert_path)?;
let key = std::fs::read(key_path)?;
let identity = reqwest::Identity::from_pkcs8_pem(&cert, &key).with_context(|| {
format!(
"either {cert_path} or {key_path} are invalid; expecting PEM encoded certificate and key")
})?;
client = client.identity(identity);
}
}
Ok(client.build()?)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
#[should_panic]
fn client_with_bad_proxy() {
let headers = HashMap::new();
let client_config = ClientConfig {
timeout: 0,
user_agent: "stuff",
redirects: true,
insecure: false,
headers: &headers,
proxy: Some("not a valid proxy"),
server_certs: Option::<Vec<String>>::None,
client_cert: None,
client_key: None,
scope: &Vec::new(),
};
initialize(client_config).unwrap();
}
#[test]
fn client_with_good_proxy() {
let headers = HashMap::new();
let proxy = "http://127.0.0.1:8080";
let client_config = ClientConfig {
timeout: 0,
user_agent: "stuff",
redirects: true,
insecure: true,
headers: &headers,
proxy: Some(proxy),
server_certs: Option::<Vec<String>>::None,
client_cert: None,
client_key: None,
scope: &Vec::new(),
};
initialize(client_config).unwrap();
}
#[test]
fn client_with_valid_server_pem() {
let headers = HashMap::new();
let server_certs = vec!["tests/mutual-auth/certs/server/server.crt.1".to_string()];
let client_config = ClientConfig {
timeout: 0,
user_agent: "stuff",
redirects: true,
insecure: true,
headers: &headers,
proxy: None,
server_certs: Some(server_certs),
client_cert: None,
client_key: None,
scope: &Vec::new(),
};
initialize(client_config).unwrap();
}
#[test]
fn client_with_valid_server_der() {
let headers = HashMap::new();
let server_certs = vec!["tests/mutual-auth/certs/server/server.der".to_string()];
let client_config = ClientConfig {
timeout: 0,
user_agent: "stuff",
redirects: true,
insecure: true,
headers: &headers,
proxy: None,
server_certs: Some(server_certs),
client_cert: None,
client_key: None,
scope: &Vec::new(),
};
initialize(client_config).unwrap();
}
#[test]
fn client_with_valid_server_pem_and_der() {
let headers = HashMap::new();
let server_certs = vec![
"tests/mutual-auth/certs/server/server.crt.1".to_string(),
"tests/mutual-auth/certs/server/server.der".to_string(),
];
println!("{}", std::env::current_dir().unwrap().display());
let client_config = ClientConfig {
timeout: 0,
user_agent: "stuff",
redirects: true,
insecure: true,
headers: &headers,
proxy: None,
server_certs: Some(server_certs),
client_cert: None,
client_key: None,
scope: &Vec::new(),
};
initialize(client_config).unwrap();
}
#[test]
#[should_panic]
fn client_with_invalid_server_cert() {
let headers = HashMap::new();
let server_certs = vec!["tests/mutual-auth/certs/client/client.key".to_string()];
let client_config = ClientConfig {
timeout: 0,
user_agent: "stuff",
redirects: true,
insecure: true,
headers: &headers,
proxy: None,
server_certs: Some(server_certs),
client_cert: None,
client_key: None,
scope: &Vec::new(),
};
initialize(client_config).unwrap();
}
#[test]
fn initialize_with_scope_creates_client() {
let headers = HashMap::new();
let scope = vec![
Url::parse("https://api.example.com").unwrap(),
Url::parse("https://cdn.example.com").unwrap(),
];
let client_config = ClientConfig {
timeout: 5,
user_agent: "test-agent",
redirects: true,
insecure: false,
headers: &headers,
proxy: None,
server_certs: Option::<Vec<String>>::None,
client_cert: None,
client_key: None,
scope: &scope,
};
let client = initialize(client_config);
assert!(client.is_ok());
}
#[test]
fn initialize_with_scope_empty_scope() {
let headers = HashMap::new();
let scope = vec![];
let client_config = ClientConfig {
timeout: 5,
user_agent: "test-agent",
redirects: true,
insecure: false,
headers: &headers,
proxy: None,
server_certs: Option::<Vec<String>>::None,
client_cert: None,
client_key: None,
scope: &scope,
};
let client = initialize(client_config);
assert!(client.is_ok());
}
}