use async_trait::async_trait;
use http::header;
use http::uri::{Scheme, Uri};
use log::error;
use module_utils::{RequestFilter, RequestFilterResult};
use pingora_core::upstreams::peer::HttpPeer;
use pingora_core::{Error, ErrorType};
use pingora_proxy::Session;
use serde::{
de::{Deserializer, Error as _},
Deserialize,
};
use std::net::{SocketAddr, ToSocketAddrs};
use structopt::StructOpt;
#[derive(Debug, Default, StructOpt)]
pub struct UpstreamOpt {
#[structopt(long, parse(try_from_str))]
pub upstream: Option<Uri>,
}
fn deserialize_uri<'de, D>(d: D) -> Result<Option<Uri>, D::Error>
where
D: Deserializer<'de>,
{
let uri = String::deserialize(d)?;
let uri = uri
.parse()
.map_err(|err| D::Error::custom(format!("URL {uri} could not be parsed: {err}")))?;
Ok(Some(uri))
}
#[derive(Debug, Default, Deserialize)]
#[serde(default)]
pub struct UpstreamConf {
#[serde(deserialize_with = "deserialize_uri")]
pub upstream: Option<Uri>,
}
impl UpstreamConf {
pub fn merge_with_opt(&mut self, opt: UpstreamOpt) {
if opt.upstream.is_some() {
self.upstream = opt.upstream;
}
}
}
#[derive(Debug, Clone)]
pub struct UpstreamContext {
addr: SocketAddr,
tls: bool,
sni: String,
}
#[derive(Debug)]
pub struct UpstreamHandler {
host_port: String,
context: Option<UpstreamContext>,
}
impl UpstreamHandler {
pub async fn upstream_peer(
_session: &mut Session,
ctx: &mut Option<UpstreamContext>,
) -> Result<Box<HttpPeer>, Box<Error>> {
if let Some(context) = ctx {
Ok(Box::new(HttpPeer::new(
context.addr,
context.tls,
context.sni.clone(),
)))
} else {
Err(Error::new(ErrorType::HTTPStatus(404)))
}
}
}
impl TryFrom<UpstreamConf> for UpstreamHandler {
type Error = Box<Error>;
fn try_from(conf: UpstreamConf) -> Result<Self, Self::Error> {
if let Some(upstream) = conf.upstream {
let scheme = upstream.scheme().ok_or_else(|| {
error!("provided upstream URL has no scheme: {upstream}");
Error::new(ErrorType::InternalError)
})?;
let tls = if scheme == &Scheme::HTTP {
false
} else if scheme == &Scheme::HTTPS {
true
} else {
error!("provided upstream URL is neither HTTP nor HTTPS: {upstream}");
return Err(Error::new(ErrorType::InternalError));
};
let host = upstream.host().ok_or_else(|| {
error!("provided upstream URL has no host name: {upstream}");
Error::new(ErrorType::InternalError)
})?;
let port = upstream.port_u16().unwrap_or(if tls { 443 } else { 80 });
let addr = (host, port)
.to_socket_addrs()
.map_err(|err| {
error!("failed resolving upstream host name {host}: {err}");
Error::new(ErrorType::InternalError)
})?
.next()
.ok_or_else(|| {
error!("DNS lookup of upstream host name {host} didn't produce any results");
Error::new(ErrorType::InternalError)
})?;
let mut host_port = host.to_owned();
if let Some(port) = upstream.port() {
host_port.push(':');
host_port.push_str(port.as_str());
}
Ok(Self {
host_port,
context: Some(UpstreamContext {
tls,
addr,
sni: host.to_owned(),
}),
})
} else {
Ok(Self {
host_port: Default::default(),
context: None,
})
}
}
}
#[async_trait]
impl RequestFilter for UpstreamHandler {
type Conf = UpstreamConf;
type CTX = Option<UpstreamContext>;
fn new_ctx() -> Self::CTX {
None
}
async fn request_filter(
&self,
session: &mut Session,
ctx: &mut Self::CTX,
) -> Result<RequestFilterResult, Box<Error>> {
if let Some(context) = &self.context {
session
.req_header_mut()
.insert_header(header::HOST, &self.host_port)?;
*ctx = Some(context.clone());
Ok(RequestFilterResult::Handled)
} else {
Ok(RequestFilterResult::Unhandled)
}
}
}