use futures_util::{SinkExt, StreamExt};
use poem::{
handler,
http::{HeaderMap, Method, StatusCode},
web::{websocket::WebSocket, Data},
Body, Error, FromRequest, IntoResponse, Request, Response, Result,
};
use std::sync::Arc;
use tokio::sync::RwLock;
use tokio_tungstenite::connect_async;
#[derive(Clone, Debug)]
pub struct ProxyConfig {
proxy_target: String,
web_secure: Option<bool>,
ws_secure: Option<bool>,
support_nesting: bool,
}
impl Default for ProxyConfig {
fn default() -> Self {
Self {
proxy_target: "http://localhost:3000".into(),
web_secure: None,
ws_secure: None,
support_nesting: false,
}
}
}
impl ProxyConfig {
pub fn new<'a>(target: impl Into<String>) -> ProxyConfig {
ProxyConfig {
proxy_target: target.into(),
..ProxyConfig::default()
}
}
pub fn ws_secure<'a>(&'a mut self) -> &'a mut ProxyConfig {
self.ws_secure = Some(true);
self
}
pub fn ws_insecure<'a>(&'a mut self) -> &'a mut ProxyConfig {
self.ws_secure = Some(false);
self
}
pub fn web_secure<'a>(&'a mut self) -> &'a mut ProxyConfig {
self.web_secure = Some(true);
self
}
pub fn web_insecure<'a>(&'a mut self) -> &'a mut ProxyConfig {
self.web_secure = Some(false);
self
}
pub fn enable_nesting<'a>(&'a mut self) -> &'a mut ProxyConfig {
self.support_nesting = true;
self
}
pub fn disable_nesting<'a>(&'a mut self) -> &'a mut ProxyConfig {
self.support_nesting = false;
self
}
pub fn finish<'a>(&'a mut self) -> ProxyConfig {
self.clone()
}
}
impl ProxyConfig {
pub fn get_web_request_uri(&self, subpath: Option<String>) -> Result<String, ()> {
let Some(secure) = self.web_secure else {
return Err(());
};
let base = if secure {
format!("https://{}", self.proxy_target)
} else {
format!("http://{}", self.proxy_target)
};
let sub = if self.support_nesting && subpath.is_some() {
subpath.unwrap()
} else {
"".into()
};
println!("base: {} | sub: {}", base, sub);
Ok(base + &sub)
}
pub fn get_web_socket_uri(&self) -> Result<String, ()> {
let Some(secure) = self.ws_secure else {
return Err(());
};
Ok(if secure {
format!("wss://{}", self.proxy_target)
} else {
format!("ws://{}", self.proxy_target)
})
}
}
#[handler]
pub async fn proxy(
req: &Request,
headers: &HeaderMap,
config: Data<&ProxyConfig>,
method: Method,
body: Body,
) -> Result<Response> {
if let Ok(ws) = WebSocket::from_request_without_body(req).await {
let Ok(uri) = config.get_web_socket_uri() else {
return Err(Error::from_string(
"Proxy endpoint not configured to support websockets!",
StatusCode::NOT_IMPLEMENTED,
));
};
let mut w_request = http::Request::builder().uri(&uri);
for (key, value) in headers.iter() {
w_request = w_request.header(key, value);
}
return Ok(ws
.on_upgrade(move |socket| async move {
let (mut clientsink, mut clientstream) = socket.split();
let (mut serversocket, _) =
connect_async(w_request.body(()).unwrap()).await.unwrap();
let (mut serversink, mut serverstream) = serversocket.split();
let client_live = Arc::new(RwLock::new(true));
let server_live = client_live.clone();
tokio::spawn(async move {
while let Some(Ok(msg)) = clientstream.next().await {
match serversink.send(msg.into()).await {
Err(_) => break,
_ => {}
};
if !*client_live.read().await {
break;
};
}
*client_live.write().await = false;
});
tokio::spawn(async move {
while let Some(Ok(msg)) = serverstream.next().await {
match clientsink.send(msg.into()).await {
Err(_) => break,
_ => {}
};
if !*server_live.read().await {
break;
};
}
*server_live.write().await = false;
});
})
.into_response());
}
else {
let Ok(uri) = config.get_web_request_uri(Some(req.uri().to_string())) else {
return Err(Error::from_string(
"Proxy endpoint not configured to support web requests!",
StatusCode::NOT_IMPLEMENTED,
));
};
let client = reqwest::Client::new();
let res = match method {
Method::GET => {
client.get( uri )
.headers( req.headers().clone() )
.body( body.into_bytes().await.unwrap() )
.send()
.await
},
Method::POST => {
client.post( uri )
.headers( req.headers().clone() )
.body( body.into_bytes().await.unwrap() )
.send()
.await
},
_ => {
return Err( Error::from_string( "Unsupported Method! The proxy endpoint currently only supports GET and POST requests!", StatusCode::METHOD_NOT_ALLOWED ) )
}
};
match res {
Ok(result) => {
let mut res = Response::default();
res.extensions().clone_from(&result.extensions());
result.headers().iter().for_each(|(key, val)| {
res.headers_mut().insert(key, val.to_owned());
});
res.set_status(result.status());
res.set_version(result.version());
res.set_body(result.bytes().await.unwrap());
Ok(res)
}
Err(error) => Err(Error::from_string(
error.to_string(),
error.status().unwrap_or(StatusCode::BAD_GATEWAY),
)),
}
}
}