use std::{collections::HashMap, sync::Arc, time::Duration};
use axum::{
body::Body,
extract::{FromRef, Path, Request as ExtractRequest, State, WebSocketUpgrade},
http::{Request, StatusCode, header::CONTENT_TYPE},
middleware::Next,
response::{IntoResponse, Response},
routing::get,
};
use serde::Deserialize;
use tokio::{
net::{TcpListener, TcpStream},
sync::RwLock,
};
use tokio_util::sync::CancellationToken;
use tower_http::trace::TraceLayer;
use tracing::{Span, error, info};
use wsrx::proxy;
use crate::cli::logger::init_logger;
pub async fn launch(
host: Option<String>, port: Option<u16>, secret: Option<String>, log_json: Option<bool>,
) {
let log_json = log_json.unwrap_or(false);
init_logger(log_json);
let router = build_router(secret);
let listener = TcpListener::bind(&format!(
"{}:{}",
host.unwrap_or(String::from("127.0.0.1")),
port.unwrap_or(0)
))
.await
.expect("failed to bind port");
info!(
"wsrx server is listening on {}",
listener.local_addr().expect("failed to bind port")
);
info!(
"you can access manage api at http://{}/pool",
listener.local_addr().expect("failed to bind port")
);
axum::serve(listener, router)
.await
.expect("failed to launch server");
}
type ConnectionMap = Arc<RwLock<HashMap<String, String>>>;
#[derive(Clone, FromRef)]
pub struct GlobalState {
pub secret: Option<String>,
pub connections: ConnectionMap,
}
fn build_router(secret: Option<String>) -> axum::Router {
let state = GlobalState {
secret,
connections: Default::default(),
};
axum::Router::new()
.route(
"/pool",
get(get_tunnels).post(launch_tunnel).delete(close_tunnel),
)
.layer(axum::middleware::from_fn_with_state(
state.clone(),
|State(secret): State<Option<String>>, req: ExtractRequest, next: Next| async move {
if let Some(secret) = secret {
if let Some(auth) = req.headers().get("authorization")
&& auth.to_str().map_err(|_| StatusCode::UNAUTHORIZED)? == secret
{
return Ok(next.run(req).await);
}
return Err(StatusCode::UNAUTHORIZED);
}
Ok(next.run(req).await)
},
))
.route("/traffic/{*key}", get(process_traffic).options(ping))
.layer(
TraceLayer::new_for_http()
.make_span_with(|request: &Request<Body>| {
tracing::info_span!(
"http",
method = %request.method(),
uri = %request.uri().path(),
)
})
.on_request(())
.on_response(|response: &Response, latency: Duration, _span: &Span| {
info!("[{}] in {}ms", response.status(), latency.as_millis());
}),
)
.with_state::<()>(state)
}
#[derive(Deserialize)]
struct TunnelRequest {
pub from: String,
pub to: String,
}
async fn launch_tunnel(
State(connections): State<ConnectionMap>, axum::Json(req): axum::Json<TunnelRequest>,
) -> Result<impl IntoResponse, (StatusCode, &'static str)> {
let mut pool = connections.write().await;
pool.insert(req.from, req.to);
Ok(StatusCode::CREATED)
}
async fn get_tunnels(State(connections): State<ConnectionMap>) -> impl IntoResponse {
let pool = connections.read().await;
let resp = serde_json::to_string::<HashMap<String, String>>(&pool).map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("failed to serialize pool: {e}"),
)
});
axum::response::Response::builder()
.status(StatusCode::OK)
.header(CONTENT_TYPE, "application/json")
.body(resp.unwrap())
.unwrap()
}
#[derive(Deserialize)]
struct CloseTunnelRequest {
pub key: String,
}
async fn close_tunnel(
State(connections): State<ConnectionMap>, axum::Json(req): axum::Json<CloseTunnelRequest>,
) -> Result<impl IntoResponse, (StatusCode, &'static str)> {
let mut pool = connections.write().await;
if pool.remove(&req.key).is_some() {
Ok(StatusCode::OK)
} else {
Err((StatusCode::NOT_FOUND, "not found"))
}
}
async fn process_traffic(
State(connections): State<ConnectionMap>, Path(key): Path<String>, ws: WebSocketUpgrade,
) -> Result<impl IntoResponse, (StatusCode, &'static str)> {
let pool = connections.read().await;
if let Some(conn) = pool.get(&key) {
let tcp_addr = conn.to_owned();
Ok(ws.on_upgrade(move |socket| async move {
let tcp = TcpStream::connect(&tcp_addr).await;
if let Err(e) = tcp {
error!("failed to connect to tcp server: {e:?}");
return;
}
let tcp = tcp.unwrap();
proxy(socket.into(), tcp, CancellationToken::new())
.await
.ok();
}))
} else {
Err((StatusCode::NOT_FOUND, "not found"))
}
}
async fn ping(
State(connections): State<ConnectionMap>, Path(key): Path<String>,
) -> impl IntoResponse {
let pool = connections.read().await;
if pool.get(&key).is_some() {
StatusCode::OK
} else {
StatusCode::NOT_FOUND
}
}