use std::{fs, io};
use std::path::Path;
use std::time::Duration;
use bytes::Bytes;
use chrono::{DateTime, Utc};
use log::error;
use reqwest::{header, redirect};
use reqwest::{Certificate, Proxy, StatusCode};
use reqwest::blocking::{Client, ClientBuilder, RequestBuilder, Response};
use rpki::uri;
use crate::config::Config;
use crate::error::Fatal;
use crate::utils::date::{format_http_date, parse_http_date};
#[derive(Debug)]
pub struct HttpClient {
client: Result<Client, Option<ClientBuilder>>,
timeout: Option<Duration>,
}
impl HttpClient {
pub fn new(config: &Config) -> Result<Self, Fatal> {
#[cfg(not(feature = "native-tls"))]
fn create_builder() -> ClientBuilder {
Client::builder().use_rustls_tls()
}
#[cfg(feature = "native-tls")]
fn create_builder() -> ClientBuilder {
Client::builder().use_native_tls()
}
let mut builder = create_builder();
builder = builder.user_agent(&config.rrdp_user_agent);
builder = builder.tcp_keepalive(config.rrdp_tcp_keepalive);
builder = builder.timeout(None); builder = builder.gzip(true);
builder = builder.redirect(
redirect::Policy::custom(Self::redirect_policy)
);
builder = builder.timeout(config.rrdp_read_timeout);
if let Some(timeout) = config.rrdp_connect_timeout {
builder = builder.connect_timeout(timeout);
}
if let Some(addr) = config.rrdp_local_addr {
builder = builder.local_address(addr)
}
for path in &config.rrdp_root_certs {
builder = builder.add_root_certificate(
Self::load_cert(path)?
);
}
for proxy in &config.rrdp_proxies {
let proxy = match Proxy::all(proxy) {
Ok(proxy) => proxy,
Err(err) => {
error!(
"Invalid rrdp-proxy '{proxy}': {err}"
);
return Err(Fatal)
}
};
builder = builder.proxy(proxy);
}
Ok(HttpClient {
client: Err(Some(builder)),
timeout: config.rrdp_timeout,
})
}
pub fn ignite(&mut self) -> Result<(), Fatal> {
let builder = match self.client.as_mut() {
Ok(_) => return Ok(()),
Err(builder) => match builder.take() {
Some(builder) => builder,
None => {
error!("Previously failed to initialize HTTP client.");
return Err(Fatal)
}
}
};
let client = match builder.build() {
Ok(client) => client,
Err(err) => {
error!("Failed to initialize HTTP client: {err}.");
return Err(Fatal)
}
};
self.client = Ok(client);
Ok(())
}
fn load_cert(path: &Path) -> Result<Certificate, Fatal> {
let mut file = match fs::File::open(path) {
Ok(file) => file,
Err(err) => {
error!(
"Cannot open rrdp-root-cert file '{}': {}'",
path.display(), err
);
return Err(Fatal);
}
};
let mut data = Vec::new();
if let Err(err) = io::Read::read_to_end(&mut file, &mut data) {
error!(
"Cannot read rrdp-root-cert file '{}': {}'",
path.display(), err
);
return Err(Fatal);
}
Certificate::from_pem(&data).map_err(|err| {
error!(
"Cannot decode rrdp-root-cert file '{}': {}'",
path.display(), err
);
Fatal
})
}
fn client(&self) -> &Client {
self.client.as_ref().expect("HTTP client has not been ignited")
}
pub fn response(
&self,
uri: &uri::Https,
) -> Result<HttpResponse, reqwest::Error> {
self._response(self.client().get(uri.as_str()))
}
pub fn conditional_response(
&self,
uri: &uri::Https,
etag: Option<&Bytes>,
last_modified: Option<DateTime<Utc>>,
) -> Result<HttpResponse, reqwest::Error> {
let mut request = self.client().get(uri.as_str());
if let Some(etag) = etag {
request = request.header(
header::IF_NONE_MATCH, etag.as_ref()
);
}
if let Some(last_modified) = last_modified {
request = request.header(
header::IF_MODIFIED_SINCE,
format_http_date(last_modified)
);
}
self._response(request)
}
fn _response(
&self,
mut request: RequestBuilder,
) -> Result<HttpResponse, reqwest::Error> {
if let Some(timeout) = self.timeout {
request = request.timeout(timeout);
}
request.send().and_then(|response| {
response.error_for_status()
}).map(|response| {
HttpResponse::create(response)
})
}
fn redirect_policy(attempt: redirect::Attempt) -> redirect::Action {
if attempt.previous().len() > 9 {
return attempt.stop();
}
let orig = match attempt.previous().first() {
Some(url) => url,
None => return attempt.follow() };
let new = attempt.url();
let orig = (orig.scheme(), orig.host(), orig.port());
let new = (new.scheme(), new.host(), new.port());
if orig == new {
attempt.follow()
}
else {
attempt.stop()
}
}
}
pub struct HttpResponse {
response: Response,
}
impl HttpResponse {
pub fn create(
response: Response,
) -> Self {
HttpResponse { response }
}
pub fn content_length(&self) -> Option<u64> {
self.response.content_length()
}
pub fn copy_to<W: io::Write + ?Sized>(
&mut self, w: &mut W
) -> Result<u64, io::Error> {
io::copy(self, w)
}
pub fn status(&self) -> StatusCode {
self.response.status()
}
pub fn etag(&self) -> Option<Bytes> {
let mut etags = self.response.headers()
.get_all(header::ETAG)
.into_iter();
let etag = etags.next()?;
if etags.next().is_some() {
return None
}
Self::parse_etag(etag.as_bytes())
}
fn parse_etag(etag: &[u8]) -> Option<Bytes> {
let start = if etag.starts_with(b"W/\"") {
3
}
else if etag.first() == Some(&b'"') {
1
}
else {
return None
};
if etag.len() <= start {
return None
}
if etag.last() != Some(&b'"') {
return None
}
Some(Bytes::copy_from_slice(etag))
}
pub fn last_modified(&self) -> Option<DateTime<Utc>> {
let mut iter = self.response.headers()
.get_all(header::LAST_MODIFIED)
.into_iter();
let value = iter.next()?;
if iter.next().is_some() {
return None
}
parse_http_date(value.to_str().ok()?)
}
}
impl io::Read for HttpResponse {
fn read(&mut self, buf: &mut [u8]) -> Result<usize, io::Error> {
self.response.read(buf)
}
}
#[derive(Clone, Copy, Debug)]
pub enum HttpStatus {
Response(StatusCode),
Rejected,
Error
}
impl HttpStatus {
pub fn into_i16(self) -> i16 {
match self {
HttpStatus::Response(code) => code.as_u16() as i16,
HttpStatus::Rejected => -2,
HttpStatus::Error => -1,
}
}
pub fn is_not_modified(self) -> bool {
matches!(
self,
HttpStatus::Response(code) if code == StatusCode::NOT_MODIFIED
)
}
pub fn is_success(self) -> bool {
matches!(
self,
HttpStatus::Response(code) if code.is_success()
)
}
}
impl From<StatusCode> for HttpStatus {
fn from(code: StatusCode) -> Self {
HttpStatus::Response(code)
}
}