use rama::{
Context, Layer, Service,
context::{Extensions, RequestContextExt},
http::{
Body, Request, Response, StatusCode,
client::EasyHttpWebClient,
layer::{
proxy_auth::ProxyAuthLayer,
remove_header::{RemoveRequestHeaderLayer, RemoveResponseHeaderLayer},
trace::TraceLayer,
upgrade::{UpgradeLayer, Upgraded},
},
matcher::{DomainMatcher, HttpMatcher, MethodMatcher},
server::HttpServer,
service::web::{
extract::Path,
match_service,
response::{IntoResponse, Json},
},
},
layer::HijackLayer,
net::{
address::Domain, conn::is_connection_error, http::RequestContext, stream::ClientSocketInfo,
stream::layer::http::BodyLimitLayer, user::Basic,
},
rt::Executor,
service::service_fn,
tcp::{client::default_tcp_connect, server::TcpListener},
username::{
UsernameLabelParser, UsernameLabelState, UsernameLabels, UsernameOpaqueLabelParser,
},
};
use serde::Deserialize;
use serde_json::json;
use std::{convert::Infallible, sync::Arc, time::Duration};
use tracing::level_filters::LevelFilter;
use tracing_subscriber::{EnvFilter, fmt, layer::SubscriberExt, util::SubscriberInitExt};
#[tokio::main]
async fn main() {
tracing_subscriber::registry()
.with(fmt::layer())
.with(
EnvFilter::builder()
.with_default_directive(LevelFilter::DEBUG.into())
.from_env_lossy(),
)
.init();
let graceful = rama::graceful::Shutdown::default();
#[derive(Deserialize)]
struct APILuckyParams {
number: u32,
}
graceful.spawn_task_fn(async move |guard| {
let tcp_service = TcpListener::build().bind("127.0.0.1:62001").await.expect("bind tcp proxy to 127.0.0.1:62001");
let exec = Executor::graceful(guard.clone());
let http_service = HttpServer::auto(exec)
.service((
TraceLayer::new_for_http(),
ProxyAuthLayer::new(Basic::new("john", "secret")).with_labels::<(PriorityUsernameLabelParser, UsernameOpaqueLabelParser)>(),
HijackLayer::new(
DomainMatcher::exact(Domain::from_static("echo.example.internal")),
Arc::new(match_service!{
HttpMatcher::post("/lucky/:number") => async move |path: Path<APILuckyParams>| {
Json(json!({
"lucky_number": path.number,
}))
},
HttpMatcher::get("/*") => async move |ctx: Context<()>, req: Request| {
Json(json!({
"method": req.method().as_str(),
"path": req.uri().path(),
"username_labels": ctx.get::<UsernameLabels>().map(|labels| &labels.0),
"user_priority": ctx.get::<Priority>().map(|p| match p {
Priority::High => "high",
Priority::Medium => "medium",
Priority::Low => "low",
}),
}))
},
_ => StatusCode::NOT_FOUND,
})
),
UpgradeLayer::new(
MethodMatcher::CONNECT,
service_fn(http_connect_accept),
service_fn(http_connect_proxy),
),
RemoveResponseHeaderLayer::hop_by_hop(),
RemoveRequestHeaderLayer::hop_by_hop(),
)
.into_layer(service_fn(http_plain_proxy)));
tcp_service.serve_graceful(guard, (
BodyLimitLayer::symmetric(2 * 1024 * 1024),
).into_layer(http_service)).await;
});
graceful
.shutdown_with_limit(Duration::from_secs(30))
.await
.expect("graceful shutdown");
}
async fn http_connect_accept<S>(
mut ctx: Context<S>,
req: Request,
) -> Result<(Response, Context<S>, Request), Response>
where
S: Clone + Send + Sync + 'static,
{
match ctx.get_or_try_insert_with_ctx::<RequestContext, _>(|ctx| (ctx, &req).try_into()) {
Ok(request_ctx) => tracing::info!("accept CONNECT to {}", request_ctx.authority),
Err(err) => {
tracing::error!(err = %err, "error extracting authority");
return Err(StatusCode::BAD_REQUEST.into_response());
}
}
Ok((StatusCode::OK.into_response(), ctx, req))
}
async fn http_connect_proxy<S>(ctx: Context<S>, mut upgraded: Upgraded) -> Result<(), Infallible>
where
S: Clone + Send + Sync + 'static,
{
let authority = ctx .get::<RequestContext>()
.unwrap()
.authority
.clone();
tracing::info!("CONNECT to {authority}");
let (mut stream, _) = match default_tcp_connect(&ctx, authority).await {
Ok(stream) => stream,
Err(err) => {
tracing::error!(error = %err, "error connecting to host");
return Ok(());
}
};
if let Err(err) = tokio::io::copy_bidirectional(&mut upgraded, &mut stream).await {
if !is_connection_error(&err) {
tracing::error!(error = %err, "error copying data");
}
}
Ok(())
}
async fn http_plain_proxy<S>(ctx: Context<S>, req: Request) -> Result<Response, Infallible>
where
S: Clone + Send + Sync + 'static,
{
let client = EasyHttpWebClient::default();
match client.serve(ctx, req).await {
Ok(resp) => {
match resp
.extensions()
.get::<RequestContextExt>()
.and_then(|ext| ext.get::<ClientSocketInfo>())
{
Some(client_socket_info) => tracing::info!(
status = %resp.status(),
local_addr = ?client_socket_info.local_addr(),
server_addr = %client_socket_info.peer_addr(),
"http plain text proxy received response",
),
None => tracing::info!(
status = %resp.status(),
"http plain text proxy received response, IP info unknown",
),
};
Ok(resp)
}
Err(err) => {
tracing::error!(error = %err, "error in client request");
Ok(Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(Body::empty())
.unwrap())
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub enum Priority {
High,
Medium,
Low,
}
#[derive(Debug, Clone, Default)]
pub struct PriorityUsernameLabelParser {
key_seen: bool,
priority: Option<Priority>,
}
impl UsernameLabelParser for PriorityUsernameLabelParser {
type Error = Infallible;
fn parse_label(&mut self, label: &str) -> UsernameLabelState {
let label = label.trim().to_ascii_lowercase();
if self.key_seen {
self.key_seen = false;
match label.as_str() {
"high" => self.priority = Some(Priority::High),
"medium" => self.priority = Some(Priority::Medium),
"low" => self.priority = Some(Priority::Low),
_ => {
tracing::trace!("invalid priority username label value: {label}");
return UsernameLabelState::Abort;
}
}
} else if label == "priority" {
self.key_seen = true;
}
UsernameLabelState::Used
}
fn build(self, ext: &mut Extensions) -> Result<(), Self::Error> {
ext.maybe_insert(self.priority);
Ok(())
}
}