use rama::{
Layer, Service,
error::{BoxError, ErrorContext, OpaqueError},
http::{
Body, Request, Response, StatusCode,
client::{EasyHttpWebClient, TlsConnectorConfig},
layer::{
compress_adapter::CompressAdaptLayer,
map_response_body::MapResponseBodyLayer,
proxy_auth::ProxyAuthLayer,
remove_header::{RemoveRequestHeaderLayer, RemoveResponseHeaderLayer},
required_header::AddRequiredRequestHeadersLayer,
trace::TraceLayer,
traffic_writer::{self, RequestWriterInspector},
upgrade::{UpgradeLayer, Upgraded},
},
matcher::MethodMatcher,
server::HttpServer,
service::web::response::IntoResponse,
},
layer::ConsumeErrLayer,
net::{
http::RequestContext,
stream::layer::http::BodyLimitLayer,
tls::{
ApplicationProtocol, SecureTransport,
client::{
ClientConfig, ClientHelloExtension, ServerVerifyMode,
extract_client_config_from_ctx,
},
server::{SelfSignedData, ServerAuth, ServerConfig},
},
user::Basic,
},
rt::Executor,
service::service_fn,
tcp::server::TcpListener,
tls::boring::server::{TlsAcceptorData, TlsAcceptorLayer},
ua::{
emulate::{
UserAgentEmulateHttpConnectModifier, UserAgentEmulateHttpRequestModifier,
UserAgentEmulateLayer,
},
profile::UserAgentDatabase,
},
};
use std::{convert::Infallible, sync::Arc, time::Duration};
use tracing::level_filters::LevelFilter;
use tracing_subscriber::{EnvFilter, fmt, layer::SubscriberExt, util::SubscriberInitExt};
#[derive(Debug, Clone)]
struct State {
mitm_tls_service_data: TlsAcceptorData,
ua_db: Arc<UserAgentDatabase>,
}
type Context = rama::Context<State>;
#[tokio::main]
async fn main() -> Result<(), BoxError> {
tracing_subscriber::registry()
.with(fmt::layer())
.with(
EnvFilter::builder()
.with_default_directive(LevelFilter::INFO.into())
.from_env_lossy(),
)
.init();
let mitm_tls_service_data =
new_mitm_tls_service_data().context("generate self-signed mitm tls cert")?;
let state = State {
mitm_tls_service_data,
ua_db: Arc::new(UserAgentDatabase::embedded()),
};
let graceful = rama::graceful::Shutdown::default();
graceful.spawn_task_fn(async |guard| {
let tcp_service = TcpListener::build_with_state(state.clone())
.bind("127.0.0.1:62017")
.await
.expect("bind tcp proxy to 127.0.0.1:62017");
let exec = Executor::graceful(guard.clone());
let http_mitm_service = new_http_mitm_proxy(&Context::with_state(state));
let http_service = HttpServer::auto(exec).service(
(
TraceLayer::new_for_http(),
ProxyAuthLayer::new(Basic::new("john", "secret")),
UpgradeLayer::new(
MethodMatcher::CONNECT,
service_fn(http_connect_accept),
service_fn(http_connect_proxy),
),
)
.into_layer(http_mitm_service),
);
tcp_service
.serve_graceful(
guard,
(
BodyLimitLayer::symmetric(2 * 1024 * 1024),
)
.into_layer(http_service),
)
.await;
});
graceful
.shutdown_with_limit(Duration::from_secs(30))
.await
.context("graceful shutdown")?;
Ok(())
}
async fn http_connect_accept(
mut ctx: Context,
req: Request,
) -> Result<(Response, Context, Request), Response> {
match ctx.get_or_try_insert_with_ctx::<RequestContext, _>(|ctx| (ctx, &req).try_into()) {
Ok(request_ctx) => {
tracing::info!("accept CONNECT to {}", request_ctx.authority);
}
Err(err) => {
tracing::error!(err = %err, "error extracting authority");
return Err(StatusCode::BAD_REQUEST.into_response());
}
}
Ok((StatusCode::OK.into_response(), ctx, req))
}
async fn http_connect_proxy(ctx: Context, upgraded: Upgraded) -> Result<(), Infallible> {
let http_service = new_http_mitm_proxy(&ctx);
let http_transport_service = HttpServer::auto(ctx.executor().clone()).service(http_service);
let https_service = TlsAcceptorLayer::new(ctx.state().mitm_tls_service_data.clone())
.with_store_client_hello(true)
.into_layer(http_transport_service);
https_service
.serve(ctx, upgraded)
.await
.expect("infallible");
Ok(())
}
fn new_http_mitm_proxy(
ctx: &Context,
) -> impl Service<State, Request, Response = Response, Error = Infallible> {
(
MapResponseBodyLayer::new(Body::new),
TraceLayer::new_for_http(),
ConsumeErrLayer::default(),
UserAgentEmulateLayer::new(ctx.state().ua_db.clone())
.try_auto_detect_user_agent(true)
.optional(true),
RemoveResponseHeaderLayer::hop_by_hop(),
RemoveRequestHeaderLayer::hop_by_hop(),
CompressAdaptLayer::default(),
AddRequiredRequestHeadersLayer::new(),
)
.into_layer(service_fn(http_mitm_proxy))
}
async fn http_mitm_proxy(ctx: Context, req: Request) -> Result<Response, Infallible> {
let mut client = EasyHttpWebClient::default()
.with_http_conn_req_inspector(UserAgentEmulateHttpConnectModifier::default())
.with_http_conn_req_inspector((
UserAgentEmulateHttpRequestModifier::default(),
RequestWriterInspector::stdout_unbounded(
ctx.executor(),
Some(traffic_writer::WriterMode::Headers),
),
));
let mut base_tls_cfg = ctx
.get::<SecureTransport>()
.and_then(|st| st.client_hello())
.cloned()
.map(Into::into)
.unwrap_or_else(|| ClientConfig {
extensions: Some(vec![
ClientHelloExtension::ApplicationLayerProtocolNegotiation(vec![
ApplicationProtocol::HTTP_2,
ApplicationProtocol::HTTP_11,
]),
]),
..Default::default()
});
base_tls_cfg.server_verify_mode = Some(ServerVerifyMode::Disable);
let tls_client_config = match extract_client_config_from_ctx(&ctx) {
Some(chain) => {
let mut cfg = base_tls_cfg;
for other_cfg in chain.iter() {
cfg.merge(other_cfg.clone());
}
cfg
}
None => base_tls_cfg,
};
client.set_tls_connector_config(TlsConnectorConfig::Boring(Some(tls_client_config)));
match client.serve(ctx, req).await {
Ok(resp) => Ok(resp),
Err(err) => {
tracing::error!(error = ?err, "error in client request");
Ok(Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(Body::empty())
.unwrap())
}
}
}
fn new_mitm_tls_service_data() -> Result<TlsAcceptorData, OpaqueError> {
let tls_server_config = ServerConfig {
application_layer_protocol_negotiation: Some(vec![
ApplicationProtocol::HTTP_2,
ApplicationProtocol::HTTP_11,
]),
..ServerConfig::new(ServerAuth::SelfSigned(SelfSignedData {
organisation_name: Some("Example Server Acceptor".to_owned()),
..Default::default()
}))
};
tls_server_config
.try_into()
.context("create tls server config")
}