#[macro_use]
extern crate log;
mod acme;
mod binding;
mod body;
mod device;
mod error;
mod response;
mod shutdown;
mod tls;
mod utils;
#[cfg(feature = "c-api")]
mod exports;
pub mod auth;
use std::{
future::Future,
io,
net::SocketAddr,
pin::Pin,
str::FromStr,
sync::Arc,
task::{Context, Poll},
time::Duration,
};
use futures::{
channel::mpsc::{self, UnboundedSender},
future::{AbortHandle, Either},
FutureExt, StreamExt,
};
use hyper::{body::Incoming, Request, Response};
use hyper_util::rt::{TokioExecutor, TokioIo, TokioTimer};
use uuid::Uuid;
pub use async_trait;
pub use hyper;
pub use hyper::http;
use self::{
acme::{ChallengeRegistrations, Watchdog},
auth::BasicAuthorization,
binding::Bindings,
device::{DeviceConnection, DeviceManager},
error::{HttpError, Unauthorized},
shutdown::{GracefulShutdown, UpgradeableConnectionExt},
tls::{Identity, TlsAcceptor, TlsMode},
utils::{AbortOnDrop, RequestExt},
};
pub use self::{binding::ConnectionInfo, body::Body, error::Error};
const DEVICE_HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(60);
pub enum DeviceHandlerResult {
Accept,
Unauthorized,
Redirect(String),
}
impl DeviceHandlerResult {
pub fn accept() -> Self {
Self::Accept
}
pub fn unauthorized() -> Self {
Self::Unauthorized
}
pub fn redirect<T>(location: T) -> Self
where
T: ToString,
{
Self::Redirect(location.to_string())
}
}
pub enum ClientHandlerResult {
Forward(String, Request<Body>),
Block(Response<Body>),
}
impl ClientHandlerResult {
pub fn forward<T>(device_id: T, request: Request<Body>) -> Self
where
T: ToString,
{
Self::Forward(device_id.to_string(), request)
}
pub fn block(response: Response<Body>) -> Self {
Self::Block(response)
}
}
#[async_trait::async_trait]
pub trait RequestHandler {
async fn handle_device_request(
&self,
authorization: BasicAuthorization,
) -> Result<DeviceHandlerResult, Error>;
async fn handle_client_request(
&self,
request: Request<Body>,
) -> Result<ClientHandlerResult, Error>;
}
pub trait BlockingRequestHandler {
fn handle_device_request(
&self,
authorization: BasicAuthorization,
) -> Result<DeviceHandlerResult, Error>;
fn handle_client_request(&self, request: Request<Body>) -> Result<ClientHandlerResult, Error>;
}
pub struct RequestHandlerAdapter<T> {
inner: Arc<T>,
}
#[async_trait::async_trait]
impl<T> RequestHandler for RequestHandlerAdapter<T>
where
T: BlockingRequestHandler + Send + Sync + 'static,
{
async fn handle_device_request(
&self,
authorization: BasicAuthorization,
) -> Result<DeviceHandlerResult, Error> {
let inner = self.inner.clone();
let blocking =
tokio::task::spawn_blocking(move || inner.handle_device_request(authorization));
blocking
.await
.map_err(|_| Error::from_static_msg("terminating"))?
}
async fn handle_client_request(
&self,
request: Request<Body>,
) -> Result<ClientHandlerResult, Error> {
let inner = self.inner.clone();
let blocking = tokio::task::spawn_blocking(move || inner.handle_client_request(request));
blocking
.await
.map_err(|_| Error::from_static_msg("terminating"))?
}
}
impl<T> From<T> for RequestHandlerAdapter<T> {
fn from(handler: T) -> Self {
Self {
inner: Arc::new(handler),
}
}
}
pub struct ProxyBuilder {
hostname: String,
http_bind_addresses: Vec<SocketAddr>,
https_bind_addresses: Vec<SocketAddr>,
tls_mode: TlsMode,
}
impl ProxyBuilder {
pub fn new() -> Self {
Self {
hostname: String::from("localhost"),
http_bind_addresses: Vec::new(),
https_bind_addresses: Vec::new(),
tls_mode: TlsMode::None,
}
}
pub fn hostname<T>(&mut self, hostname: T) -> &mut Self
where
T: ToString,
{
self.hostname = hostname.to_string();
self
}
pub fn http_bind_address(&mut self, addr: SocketAddr) -> &mut Self {
self.http_bind_addresses.push(addr);
self
}
pub fn https_bind_address(&mut self, addr: SocketAddr) -> &mut Self {
self.https_bind_addresses.push(addr);
self
}
pub fn tls_identity(&mut self, key: &[u8], cert: &[u8]) -> Result<&mut Self, Error> {
let identity = Identity::from_pkcs8(cert, key)?;
self.tls_mode = TlsMode::Simple(identity);
Ok(self)
}
pub fn lets_encrypt(&mut self) -> &mut Self {
self.tls_mode = TlsMode::LetsEncrypt;
self
}
pub async fn build<T>(&self, request_handler: T) -> Result<Proxy, Error>
where
T: RequestHandler + Send + Sync + 'static,
{
info!("starting GoodCam device proxy");
let res = self.build_inner(request_handler).await;
if let Err(err) = &res {
warn!("unable to start the proxy: {err}");
} else {
info!("proxy started");
}
res
}
async fn build_inner<T>(&self, request_handler: T) -> Result<Proxy, Error>
where
T: RequestHandler + Send + Sync + 'static,
{
let tls_acceptor = self.tls_mode.create_tls_acceptor()?;
let bindings = self.create_bindings(tls_acceptor.clone()).await?;
info!("hostname: {}", self.hostname);
let acme_challenges = ChallengeRegistrations::new();
let mut acme_watchdog = None;
if let Some(tls_acceptor) = tls_acceptor {
if let Some(watchdog) = self
.create_acme_watchdog(tls_acceptor, acme_challenges.clone())
.await?
{
acme_watchdog = Some(watchdog);
}
}
let handler = InternalRequestHandler {
acme_challenges,
devices: DeviceManager::new(),
handler: request_handler.into(),
};
let (shutdown_tx, mut shutdown_rx) = mpsc::unbounded();
let server = self.create_proxy(bindings, handler, async move {
if shutdown_rx.next().await.is_none() {
futures::future::pending().await
}
});
self.start_proxy(server, shutdown_tx, acme_watchdog)
}
async fn create_bindings(&self, tls_acceptor: Option<TlsAcceptor>) -> io::Result<Bindings> {
let http_bind_addresses = self.http_bind_addresses.iter();
let https_bind_addresses = self.https_bind_addresses.iter();
let mut bindings = Bindings::new();
for binding in http_bind_addresses.clone() {
info!("HTTP binding: {binding}");
}
bindings
.add_tcp_bindings(http_bind_addresses.copied())
.await?;
if let Some(acceptor) = tls_acceptor {
for binding in https_bind_addresses.clone() {
info!("HTTPS binding: {binding}");
}
bindings
.add_tls_bindings(acceptor, https_bind_addresses.copied())
.await?;
}
Ok(bindings)
}
async fn create_acme_watchdog(
&self,
tls_acceptor: TlsAcceptor,
challenge_registrations: ChallengeRegistrations,
) -> Result<Option<Watchdog>, Error> {
if let Some(acme_account) = self.tls_mode.create_acme_account().await? {
let watchdog = Watchdog::new(
acme_account,
challenge_registrations,
tls_acceptor,
&self.hostname,
);
let res = watchdog.await?;
Ok(Some(res))
} else {
Ok(None)
}
}
fn create_proxy<T, F>(
&self,
mut bindings: Bindings,
handler: InternalRequestHandler<T>,
shutdown: F,
) -> impl Future<Output = Result<(), Error>>
where
T: RequestHandler + Send + Sync + 'static,
F: Future,
{
let executor = TokioExecutor::new();
let timer = TokioTimer::new();
let mut builder = hyper_util::server::conn::auto::Builder::new(executor);
builder
.http1()
.header_read_timeout(Duration::from_secs(60))
.keep_alive(true)
.timer(timer.clone());
builder
.http2()
.keep_alive_interval(Some(Duration::from_secs(120)))
.keep_alive_timeout(Duration::from_secs(20))
.timer(timer);
async move {
let connections = GracefulShutdown::new();
futures::pin_mut!(shutdown);
loop {
let next = bindings.next();
let connection = match futures::future::select(&mut shutdown, next).await {
Either::Left((_, _)) => break,
Either::Right((res, _)) => match res.transpose() {
Ok(Some(c)) => c,
Ok(None) => return Err(Error::from_static_msg("server socket(s) closed")),
Err(err) => return Err(err.into()),
},
};
let shutdown_registration = connections.register_task();
let info = connection.info();
let handler = handler.clone();
let builder = builder.clone();
let io = TokioIo::new(connection);
let service = hyper::service::service_fn(move |mut request| {
let extensions = request.extensions_mut();
extensions.insert(info);
let handler = handler.clone();
async move {
let response = handler.handle_request(request).await;
Ok(response) as Result<_, hyper::Error>
}
});
tokio::spawn(async move {
let serve = builder
.serve_connection_with_upgrades(io, service)
.with_graceful_shutdown(shutdown_registration);
if let Err(err) = serve.await {
debug!("connection error: {err}");
}
});
}
connections.shutdown().await;
Ok(())
}
}
fn start_proxy<F>(
&self,
proxy: F,
shutdown: UnboundedSender<()>,
acme_watchdog: Option<Watchdog>,
) -> Result<Proxy, Error>
where
F: Future<Output = Result<(), Error>> + Send + 'static,
{
let (server, server_handle) = futures::future::abortable(proxy);
let server = AbortOnDrop::from(tokio::spawn(server));
let watchdog = if let Some(watchdog) = acme_watchdog {
AbortOnDrop::from(tokio::spawn(watchdog.watch()))
} else {
AbortOnDrop::from(tokio::spawn(futures::future::pending()))
};
let join = async move {
let res = match server.await {
Ok(Ok(Ok(()))) => Ok(()),
Ok(Ok(Err(err))) => Err(err),
Ok(Err(_)) => Ok(()),
Err(_) => Ok(()),
};
watchdog.abort();
let _ = watchdog.await;
res
};
let handle = ProxyHandle {
shutdown,
abort: server_handle,
};
let proxy = Proxy {
join: Box::pin(join),
handle,
};
Ok(proxy)
}
}
impl Default for ProxyBuilder {
fn default() -> Self {
Self::new()
}
}
pub struct Proxy {
join: Pin<Box<dyn Future<Output = Result<(), Error>> + Send>>,
handle: ProxyHandle,
}
impl Proxy {
pub fn builder() -> ProxyBuilder {
ProxyBuilder::new()
}
pub fn handle(&self) -> ProxyHandle {
self.handle.clone()
}
}
impl Future for Proxy {
type Output = Result<(), Error>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.join.poll_unpin(cx)
}
}
#[derive(Clone)]
pub struct ProxyHandle {
shutdown: UnboundedSender<()>,
abort: AbortHandle,
}
impl ProxyHandle {
pub fn stop(&self) {
let _ = self.shutdown.unbounded_send(());
}
pub fn abort(&self) {
self.abort.abort();
}
}
struct InternalRequestHandler<T> {
acme_challenges: acme::ChallengeRegistrations,
devices: DeviceManager,
handler: Arc<T>,
}
impl<T> InternalRequestHandler<T>
where
T: RequestHandler + Send + Sync + 'static,
{
async fn handle_request(&self, request: Request<Incoming>) -> Response<Body> {
self.handle_request_inner(request.map(Body::from))
.await
.unwrap_or_else(|err| {
if let Some(response) = err.to_response() {
return response;
}
warn!("internal server error: {err}");
response::internal_server_error()
})
}
async fn handle_request_inner(
&self,
request: Request<Body>,
) -> Result<Response<Body>, HttpError> {
if let Some(token) = request.get_acme_challenge_token() {
if let Some(response) = self.acme_challenges.create_response(token) {
return Ok(response);
}
}
if request.is_device_request() {
self.handle_device_request(request).await
} else {
self.handle_client_request(request).await
}
}
async fn handle_device_request(
&self,
request: Request<Body>,
) -> Result<Response<Body>, HttpError> {
let session_id = Uuid::new_v4();
info!("received device connection request (session_id: {session_id})");
let authorization = request
.headers()
.get("authorization")
.map(|auth| auth.to_str())
.and_then(|res| res.ok())
.map(BasicAuthorization::from_str)
.and_then(|res| res.ok());
if authorization.is_none() {
warn!("unable to process device connection request (session_id: {session_id}): missing or invalid authorization header");
}
let authorization = authorization.ok_or(Unauthorized)?;
let res = self
.handle_device_request_inner(session_id, authorization, request)
.await;
if let Err(err) = &res {
warn!("unable to process device connection request (session_id: {session_id}): {err}");
}
res.map_err(HttpError::from)
}
async fn handle_device_request_inner(
&self,
session_id: Uuid,
authorization: BasicAuthorization,
request: Request<Body>,
) -> Result<Response<Body>, Error> {
let device_id = String::from(authorization.username());
match self.handler.handle_device_request(authorization).await? {
DeviceHandlerResult::Accept => {
let this = self.clone();
tokio::spawn(async move {
if let Err(err) = this
.handle_device_connection(&device_id, session_id, request)
.await
{
warn!("unable to upgrade device connection (device_id: {device_id}, session_id: {session_id}): {err}");
}
});
let res = Response::builder()
.status(101)
.header("Upgrade", "goodcam-device-proxy")
.body(Body::empty())
.unwrap();
Ok(res)
}
DeviceHandlerResult::Unauthorized => {
info!("unauthorized device (device_id: {device_id}, session_id: {session_id})");
Ok(response::unauthorized())
}
DeviceHandlerResult::Redirect(location) => {
info!("redirecting device (device_id: {device_id}, session_id: {session_id}, location: {location})");
Ok(response::temporary_redirect(location))
}
}
}
async fn handle_device_connection(
&self,
device_id: &str,
session_id: Uuid,
request: Request<Body>,
) -> Result<(), Error> {
let handshake = DeviceConnection::new(hyper::upgrade::on(request));
let (connection, handle) = tokio::time::timeout(DEVICE_HANDSHAKE_TIMEOUT, handshake)
.await
.map_err(|_| Error::from_static_msg("device handshake timeout"))??;
if let Some(old) = self.devices.add(device_id, session_id, handle) {
old.close();
}
info!("device connected (device_id: {device_id}, session_id: {session_id})");
if let Err(err) = connection.await {
warn!(
"device connection error (device_id: {device_id}, session_id: {session_id}): {err}"
);
} else {
info!("device disconnected (device_id: {device_id}, session_id: {session_id})");
}
self.devices.remove(device_id, Some(session_id));
Ok(())
}
async fn handle_client_request(
&self,
request: Request<Body>,
) -> Result<Response<Body>, HttpError> {
let request_id = Uuid::new_v4();
let method = request.method();
let uri = request.uri();
let path = uri.path();
info!("received client {method} request at {path} (request_id: {request_id})");
let res = self.handle_client_request_inner(request_id, request).await;
if let Err(err) = &res {
warn!("unable to complete client request (request_id: {request_id}): {err}");
}
res.map_err(HttpError::from)
}
async fn handle_client_request_inner(
&self,
request_id: Uuid,
request: Request<Body>,
) -> Result<Response<Body>, Error> {
match self.handler.handle_client_request(request).await? {
ClientHandlerResult::Block(response) => {
info!("client request blocked (request_id: {request_id})");
Ok(response)
}
ClientHandlerResult::Forward(device_id, request) => {
info!(
"forwarding client request (request_id: {request_id}, device_id: {device_id})"
);
if let Some(mut device) = self.devices.get(&device_id) {
match device.send_request(request).await {
Ok(response) => {
let status = response.status();
info!("forwarding device response (request_id: {request_id}, device_id: {device_id}, status: {status})");
Ok(response)
}
Err(err) => {
info!("unable to forward client request (request_id: {request_id}, device_id: {device_id}): {err}");
Ok(response::bad_gateway())
}
}
} else {
info!("unable to forward client request (request_id: {request_id}, device_id: {device_id}): device not connected");
Ok(response::bad_gateway())
}
}
}
}
}
impl<T> Clone for InternalRequestHandler<T> {
fn clone(&self) -> Self {
Self {
acme_challenges: self.acme_challenges.clone(),
devices: self.devices.clone(),
handler: self.handler.clone(),
}
}
}