sprites 0.1.0

Official Rust SDK for Sprites - stateful sandbox environments from Fly.io
Documentation
//! Port forwarding for sprites
//!
//! This module provides TCP port forwarding from a local port to a port inside a sprite.
//!
//! # Example
//!
//! ```no_run
//! use sprites::SpritesClient;
//!
//! #[tokio::main]
//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
//!     let client = SpritesClient::new("token");
//!     let sprite = client.sprite("my-sprite");
//!
//!     // Forward local port 8080 to port 3000 inside the sprite
//!     let proxy = sprite.proxy_port(8080, 3000).await?;
//!     println!("Proxy listening on {}", proxy.local_addr().unwrap());
//!
//!     // Keep the proxy running
//!     proxy.wait().await;
//!
//!     Ok(())
//! }
//! ```

use crate::error::{Error, Result};
use crate::sprite::Sprite;
use futures_util::{SinkExt, StreamExt};
use std::net::SocketAddr;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::broadcast;
use tokio_tungstenite::tungstenite::Message;

/// A mapping from a local port to a remote port
#[derive(Debug, Clone, Copy)]
pub struct PortMapping {
    /// The local port to listen on
    pub local_port: u16,
    /// The remote port inside the sprite
    pub remote_port: u16,
}

impl PortMapping {
    /// Create a new port mapping
    pub fn new(local_port: u16, remote_port: u16) -> Self {
        Self {
            local_port,
            remote_port,
        }
    }
}

/// An active proxy session forwarding traffic to a sprite
///
/// The proxy runs in the background and forwards TCP connections from a local port
/// to a port inside the sprite via WebSocket.
pub struct ProxySession {
    /// Local address the proxy is listening on
    local_addr: SocketAddr,
    /// Remote port inside the sprite
    remote_port: u16,
    /// Channel to signal shutdown
    shutdown_tx: broadcast::Sender<()>,
    /// Handle to the proxy task
    task_handle: tokio::task::JoinHandle<()>,
}

impl ProxySession {
    /// Get the local address the proxy is listening on
    pub fn local_addr(&self) -> Option<SocketAddr> {
        Some(self.local_addr)
    }

    /// Get the local port the proxy is listening on
    pub fn local_port(&self) -> u16 {
        self.local_addr.port()
    }

    /// Get the remote port inside the sprite
    pub fn remote_port(&self) -> u16 {
        self.remote_port
    }

    /// Wait for the proxy to stop
    ///
    /// This will block until the proxy is closed or encounters an error.
    pub async fn wait(&self) {
        // Subscribe to shutdown signal
        let mut shutdown_rx = self.shutdown_tx.subscribe();
        let _ = shutdown_rx.recv().await;
    }

    /// Close the proxy session
    ///
    /// This stops accepting new connections but allows existing connections to finish.
    pub fn close(&self) {
        let _ = self.shutdown_tx.send(());
    }
}

impl Drop for ProxySession {
    fn drop(&mut self) {
        self.close();
        self.task_handle.abort();
    }
}

/// Start a proxy session
pub(crate) async fn start_proxy(sprite: Sprite, local_port: u16, remote_port: u16) -> Result<ProxySession> {
    // Bind to local port
    let listener = TcpListener::bind(format!("127.0.0.1:{local_port}"))
        .await
        .map_err(|e| Error::connection(format!("Failed to bind to port {local_port}: {e}")))?;

    let local_addr = listener.local_addr()?;

    // Create shutdown channel
    let (shutdown_tx, shutdown_rx) = broadcast::channel::<()>(1);

    // Clone for the spawned task
    let sprite_clone = sprite.clone();
    let shutdown_tx_clone = shutdown_tx.clone();

    // Spawn the main proxy task
    let task_handle = tokio::spawn(async move {
        run_proxy_loop(listener, sprite_clone, remote_port, shutdown_tx_clone, shutdown_rx).await;
    });

    Ok(ProxySession {
        local_addr,
        remote_port,
        shutdown_tx,
        task_handle,
    })
}

/// Main proxy accept loop
async fn run_proxy_loop(
    listener: TcpListener,
    sprite: Sprite,
    remote_port: u16,
    shutdown_tx: broadcast::Sender<()>,
    mut shutdown_rx: broadcast::Receiver<()>,
) {
    loop {
        tokio::select! {
            accept_result = listener.accept() => {
                match accept_result {
                    Ok((stream, peer_addr)) => {
                        let sprite = sprite.clone();
                        let shutdown_rx = shutdown_tx.subscribe();

                        // Spawn a task to handle this connection
                        tokio::spawn(async move {
                            if let Err(e) = handle_connection(stream, sprite, remote_port, shutdown_rx).await {
                                eprintln!("Proxy connection from {peer_addr} failed: {e}");
                            }
                        });
                    }
                    Err(e) => {
                        eprintln!("Failed to accept connection: {e}");
                    }
                }
            }

            _ = shutdown_rx.recv() => {
                // Shutdown signal received
                break;
            }
        }
    }
}

/// Handle a single proxied connection
async fn handle_connection(
    mut tcp_stream: TcpStream,
    sprite: Sprite,
    remote_port: u16,
    mut shutdown_rx: broadcast::Receiver<()>,
) -> Result<()> {
    // Connect to the sprite's proxy endpoint via WebSocket
    let base_url = sprite.client().base_url();
    let ws_base = base_url
        .replace("https://", "wss://")
        .replace("http://", "ws://");

    let url = format!(
        "{}/v1/sprites/{}/proxy/{}",
        ws_base,
        sprite.name(),
        remote_port
    );

    let token = sprite.client().token().to_string();

    // Generate WebSocket key
    let ws_key = {
        use std::time::{SystemTime, UNIX_EPOCH};
        let nanos = SystemTime::now()
            .duration_since(UNIX_EPOCH)
            .expect("system time before UNIX epoch")
            .as_nanos();
        crate::exec::base64_encode_public(&nanos.to_le_bytes()[..16])
    };

    let host = url
        .strip_prefix("wss://")
        .and_then(|s| s.split('/').next())
        .unwrap_or("api.sprites.dev");

    let request = tokio_tungstenite::tungstenite::http::Request::builder()
        .method("GET")
        .uri(&url)
        .header("Authorization", format!("Bearer {token}"))
        .header("Connection", "Upgrade")
        .header("Upgrade", "websocket")
        .header("Sec-WebSocket-Version", "13")
        .header("Sec-WebSocket-Key", &ws_key)
        .header("Host", host)
        .body(())
        .map_err(|e| Error::InvalidResponse(e.to_string()))?;

    let (ws_stream, _) = tokio_tungstenite::connect_async(request).await?;
    let (mut ws_write, mut ws_read) = ws_stream.split();

    // Split TCP stream for bidirectional forwarding
    let (mut tcp_read, mut tcp_write) = tcp_stream.split();

    // Buffer for reading from TCP
    let mut tcp_buf = vec![0u8; 16384];

    loop {
        tokio::select! {
            // Read from TCP, send to WebSocket
            read_result = tcp_read.read(&mut tcp_buf) => {
                match read_result {
                    Ok(0) => {
                        // TCP connection closed
                        let _ = ws_write.close().await;
                        break;
                    }
                    Ok(n) => {
                        let data = tcp_buf[..n].to_vec();
                        if ws_write.send(Message::Binary(data)).await.is_err() {
                            break;
                        }
                    }
                    Err(_) => break,
                }
            }

            // Read from WebSocket, send to TCP
            ws_msg = ws_read.next() => {
                match ws_msg {
                    Some(Ok(Message::Binary(data))) => {
                        if tcp_write.write_all(&data).await.is_err() {
                            break;
                        }
                    }
                    Some(Ok(Message::Close(_))) | None => break,
                    Some(Err(_)) => break,
                    _ => {}
                }
            }

            // Shutdown signal
            _ = shutdown_rx.recv() => {
                let _ = ws_write.close().await;
                break;
            }
        }
    }

    Ok(())
}

/// Start multiple proxy sessions
pub(crate) async fn start_proxies(sprite: Sprite, mappings: &[PortMapping]) -> Result<Vec<ProxySession>> {
    let mut sessions = Vec::with_capacity(mappings.len());

    for mapping in mappings {
        let session = start_proxy(sprite.clone(), mapping.local_port, mapping.remote_port).await?;
        sessions.push(session);
    }

    Ok(sessions)
}