use crate::error::{Error, Result};
use crate::sprite::Sprite;
use futures_util::{SinkExt, StreamExt};
use std::net::SocketAddr;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::broadcast;
use tokio_tungstenite::tungstenite::Message;
#[derive(Debug, Clone, Copy)]
pub struct PortMapping {
pub local_port: u16,
pub remote_port: u16,
}
impl PortMapping {
pub fn new(local_port: u16, remote_port: u16) -> Self {
Self {
local_port,
remote_port,
}
}
}
pub struct ProxySession {
local_addr: SocketAddr,
remote_port: u16,
shutdown_tx: broadcast::Sender<()>,
task_handle: tokio::task::JoinHandle<()>,
}
impl ProxySession {
pub fn local_addr(&self) -> Option<SocketAddr> {
Some(self.local_addr)
}
pub fn local_port(&self) -> u16 {
self.local_addr.port()
}
pub fn remote_port(&self) -> u16 {
self.remote_port
}
pub async fn wait(&self) {
let mut shutdown_rx = self.shutdown_tx.subscribe();
let _ = shutdown_rx.recv().await;
}
pub fn close(&self) {
let _ = self.shutdown_tx.send(());
}
}
impl Drop for ProxySession {
fn drop(&mut self) {
self.close();
self.task_handle.abort();
}
}
pub(crate) async fn start_proxy(sprite: Sprite, local_port: u16, remote_port: u16) -> Result<ProxySession> {
let listener = TcpListener::bind(format!("127.0.0.1:{local_port}"))
.await
.map_err(|e| Error::connection(format!("Failed to bind to port {local_port}: {e}")))?;
let local_addr = listener.local_addr()?;
let (shutdown_tx, shutdown_rx) = broadcast::channel::<()>(1);
let sprite_clone = sprite.clone();
let shutdown_tx_clone = shutdown_tx.clone();
let task_handle = tokio::spawn(async move {
run_proxy_loop(listener, sprite_clone, remote_port, shutdown_tx_clone, shutdown_rx).await;
});
Ok(ProxySession {
local_addr,
remote_port,
shutdown_tx,
task_handle,
})
}
async fn run_proxy_loop(
listener: TcpListener,
sprite: Sprite,
remote_port: u16,
shutdown_tx: broadcast::Sender<()>,
mut shutdown_rx: broadcast::Receiver<()>,
) {
loop {
tokio::select! {
accept_result = listener.accept() => {
match accept_result {
Ok((stream, peer_addr)) => {
let sprite = sprite.clone();
let shutdown_rx = shutdown_tx.subscribe();
tokio::spawn(async move {
if let Err(e) = handle_connection(stream, sprite, remote_port, shutdown_rx).await {
eprintln!("Proxy connection from {peer_addr} failed: {e}");
}
});
}
Err(e) => {
eprintln!("Failed to accept connection: {e}");
}
}
}
_ = shutdown_rx.recv() => {
break;
}
}
}
}
async fn handle_connection(
mut tcp_stream: TcpStream,
sprite: Sprite,
remote_port: u16,
mut shutdown_rx: broadcast::Receiver<()>,
) -> Result<()> {
let base_url = sprite.client().base_url();
let ws_base = base_url
.replace("https://", "wss://")
.replace("http://", "ws://");
let url = format!(
"{}/v1/sprites/{}/proxy/{}",
ws_base,
sprite.name(),
remote_port
);
let token = sprite.client().token().to_string();
let ws_key = {
use std::time::{SystemTime, UNIX_EPOCH};
let nanos = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("system time before UNIX epoch")
.as_nanos();
crate::exec::base64_encode_public(&nanos.to_le_bytes()[..16])
};
let host = url
.strip_prefix("wss://")
.and_then(|s| s.split('/').next())
.unwrap_or("api.sprites.dev");
let request = tokio_tungstenite::tungstenite::http::Request::builder()
.method("GET")
.uri(&url)
.header("Authorization", format!("Bearer {token}"))
.header("Connection", "Upgrade")
.header("Upgrade", "websocket")
.header("Sec-WebSocket-Version", "13")
.header("Sec-WebSocket-Key", &ws_key)
.header("Host", host)
.body(())
.map_err(|e| Error::InvalidResponse(e.to_string()))?;
let (ws_stream, _) = tokio_tungstenite::connect_async(request).await?;
let (mut ws_write, mut ws_read) = ws_stream.split();
let (mut tcp_read, mut tcp_write) = tcp_stream.split();
let mut tcp_buf = vec![0u8; 16384];
loop {
tokio::select! {
read_result = tcp_read.read(&mut tcp_buf) => {
match read_result {
Ok(0) => {
let _ = ws_write.close().await;
break;
}
Ok(n) => {
let data = tcp_buf[..n].to_vec();
if ws_write.send(Message::Binary(data)).await.is_err() {
break;
}
}
Err(_) => break,
}
}
ws_msg = ws_read.next() => {
match ws_msg {
Some(Ok(Message::Binary(data))) => {
if tcp_write.write_all(&data).await.is_err() {
break;
}
}
Some(Ok(Message::Close(_))) | None => break,
Some(Err(_)) => break,
_ => {}
}
}
_ = shutdown_rx.recv() => {
let _ = ws_write.close().await;
break;
}
}
}
Ok(())
}
pub(crate) async fn start_proxies(sprite: Sprite, mappings: &[PortMapping]) -> Result<Vec<ProxySession>> {
let mut sessions = Vec::with_capacity(mappings.len());
for mapping in mappings {
let session = start_proxy(sprite.clone(), mapping.local_port, mapping.remote_port).await?;
sessions.push(session);
}
Ok(sessions)
}