use futures_util::{ SinkExt, StreamExt };
use poem::{
Request, Result, Response, Error, handler, Body, FromRequest, IntoResponse,
http::{ StatusCode, Method, HeaderMap },
web::{ Data, websocket::{ WebSocket } }
};
use tokio_tungstenite::connect_async;
use tokio::sync::RwLock;
use std::sync::Arc;
#[derive(Clone, Debug)]
pub struct ProxyConfig {
proxy_target: String,
web_secure: Option<bool>,
ws_secure: Option<bool>,
support_nesting: bool,
}
impl Default for ProxyConfig {
fn default() -> Self {
Self {
proxy_target: "http://localhost:3000".into(),
web_secure: None, ws_secure: None, support_nesting: false
}
}
}
impl ProxyConfig {
pub fn new<'a>( target: impl Into<String> ) -> ProxyConfig {
ProxyConfig {
proxy_target: target.into(),
..ProxyConfig::default()
}
}
pub fn ws_secure<'a>( &'a mut self ) -> &'a mut ProxyConfig {
self.ws_secure = Some( true );
self
}
pub fn ws_insecure<'a>( &'a mut self ) -> &'a mut ProxyConfig {
self.ws_secure = Some( false );
self
}
pub fn web_secure<'a>( &'a mut self ) -> &'a mut ProxyConfig {
self.web_secure = Some( true );
self
}
pub fn web_insecure<'a>( &'a mut self ) -> &'a mut ProxyConfig {
self.web_secure = Some( false );
self
}
pub fn enable_nesting<'a>( &'a mut self ) -> &'a mut ProxyConfig {
self.support_nesting = true;
self
}
pub fn disable_nesting<'a>( &'a mut self ) -> &'a mut ProxyConfig {
self.support_nesting = false;
self
}
pub fn finish<'a>( &'a mut self ) -> ProxyConfig {
self.clone()
}
}
impl ProxyConfig {
pub fn get_web_request_uri( &self, subpath: Option<String> ) -> Result<String, ()> {
let Some( secure ) = self.web_secure else {
return Err(());
};
let base = if secure {
format!( "https://{}", self.proxy_target )
} else {
format!( "http://{}", self.proxy_target )
};
let sub = if self.support_nesting && subpath.is_some() {
subpath.unwrap()
} else {
"".into()
};
println!( "base: {} | sub: {}", base, sub );
Ok( base+&sub )
}
pub fn get_web_socket_uri( &self ) -> Result<String, ()> {
let Some( secure ) = self.ws_secure else {
return Err(());
};
Ok(
if secure {
format!( "wss://{}", self.proxy_target )
} else {
format!( "ws://{}", self.proxy_target )
}
)
}
}
#[handler]
pub async fn proxy(
req: &Request,
headers: &HeaderMap,
config: Data<&ProxyConfig>,
method: Method,
body: Body,
) -> Result<Response> {
if let Ok( ws ) = WebSocket::from_request_without_body( req ).await {
let Ok( uri ) = config.get_web_socket_uri() else {
return Err( Error::from_string( "Proxy endpoint not configured to support websockets!", StatusCode::NOT_IMPLEMENTED ) )
};
let mut w_request = http::Request::builder().uri( &uri );
for (key, value) in headers.iter() {
w_request = w_request.header( key, value );
}
return Ok(
ws.on_upgrade(move |socket| async move {
let ( mut clientsink, mut clientstream ) = socket.split();
let ( mut serversocket, _ ) = connect_async( w_request.body(()).unwrap() ).await.unwrap();
let ( mut serversink, mut serverstream ) = serversocket.split();
let client_live = Arc::new( RwLock::new( true ) );
let server_live = client_live.clone();
tokio::spawn( async move {
while let Some( Ok( msg ) ) = clientstream.next().await {
match serversink.send( msg.into() ).await {
Err( _ ) => break,
_ => {},
};
if !*client_live.read().await { break };
};
*client_live.write().await = false;
});
tokio::spawn( async move {
while let Some( Ok( msg ) ) = serverstream.next().await {
match clientsink.send( msg.into() ).await {
Err( _ ) => break,
_ => {},
};
if !*server_live.read().await { break };
};
*server_live.write().await = false;
});
}).into_response()
);
}
else {
let Ok( uri ) = config.get_web_request_uri( Some( req.uri().to_string() ) ) else {
return Err( Error::from_string( "Proxy endpoint not configured to support web requests!", StatusCode::NOT_IMPLEMENTED ) )
};
let client = reqwest::Client::new();
let res = match method {
Method::GET => {
client.get( uri )
.headers( req.headers().clone() )
.body( body.into_bytes().await.unwrap() )
.send()
.await
},
Method::POST => {
client.post( uri )
.headers( req.headers().clone() )
.body( body.into_bytes().await.unwrap() )
.send()
.await
},
_ => {
return Err( Error::from_string( "Unsupported Method! The proxy endpoint currently only supports GET and POST requests!", StatusCode::METHOD_NOT_ALLOWED ) )
}
};
match res {
Ok( result ) => {
let mut res = Response::default();
res.extensions().clone_from( &result.extensions() );
result.headers().iter().for_each(|(key, val)| {
res.headers_mut().insert( key, val.to_owned() );
});
res.set_status( result.status() );
res.set_version( result.version() );
res.set_body( result.bytes().await.unwrap() );
Ok( res )
},
Err( error ) => {
Err( Error::from_string( error.to_string(), error.status().unwrap_or( StatusCode::BAD_GATEWAY ) ) )
}
}
}
}