use std::net::SocketAddr;
use anyhow::{anyhow, Context, Result};
use iroh::endpoint::Connection;
use iroh::Endpoint;
use iroh_tickets::{endpoint::EndpointTicket, Ticket};
use serde::Deserialize;
use tokio::io;
use tokio::net::{TcpListener, TcpStream};
pub(crate) const ALPN: &[u8] = b"hey-proxy/tcp/0";
pub(crate) const TICKET_PREFIX: &str = "heyo://";
pub(crate) struct Client {
conn: Connection,
listener: TcpListener,
_endpoint: Endpoint,
}
impl Client {
pub(crate) async fn connect(
ticket_url: &str,
listen_port: u16,
relay_override: Option<&str>,
) -> Result<Self> {
Self::connect_with_host(ticket_url, "127.0.0.1", listen_port, relay_override).await
}
pub(crate) async fn connect_with_host(
ticket_url: &str,
listen_host: &str,
listen_port: u16,
relay_override: Option<&str>,
) -> Result<Self> {
let payload = ticket_url
.strip_prefix(TICKET_PREFIX)
.ok_or_else(|| anyhow!("connection string must start with {TICKET_PREFIX}"))?;
let ticket_str = resolve_ticket(payload, relay_override).await?;
let ticket = <EndpointTicket as Ticket>::deserialize(&ticket_str)
.map_err(|e| anyhow!("invalid ticket: {e}"))?;
let remote_addr = ticket.endpoint_addr().clone();
let endpoint = Endpoint::bind().await?;
endpoint.online().await;
let conn = endpoint
.connect(remote_addr, ALPN)
.await
.context("failed to connect to remote peer")?;
let listener = TcpListener::bind((listen_host, listen_port))
.await
.with_context(|| format!("failed to bind to {listen_host}:{listen_port}"))?;
Ok(Self {
conn,
listener,
_endpoint: endpoint,
})
}
pub(crate) fn local_addr(&self) -> Result<SocketAddr> {
Ok(self.listener.local_addr()?)
}
pub(crate) async fn run(self) -> Result<()> {
loop {
let (tcp_stream, _peer_addr) = self.listener.accept().await?;
let conn = self.conn.clone();
tokio::spawn(async move {
let _ = handle_client_connection(conn, tcp_stream).await;
});
}
}
}
async fn handle_client_connection(conn: Connection, tcp_stream: TcpStream) -> Result<()> {
let (iroh_send, iroh_recv) =
tokio::time::timeout(std::time::Duration::from_secs(15), conn.open_bi())
.await
.map_err(|_| anyhow!("timed out opening bi-directional stream (15s)"))?
.context("failed to open bi-directional stream")?;
let (tcp_read, tcp_write) = tcp_stream.into_split();
proxy_streams(iroh_recv, iroh_send, tcp_read, tcp_write).await
}
async fn proxy_streams(
mut iroh_recv: iroh::endpoint::RecvStream,
mut iroh_send: iroh::endpoint::SendStream,
mut tcp_read: tokio::net::tcp::OwnedReadHalf,
mut tcp_write: tokio::net::tcp::OwnedWriteHalf,
) -> Result<()> {
let iroh_to_tcp = io::copy(&mut iroh_recv, &mut tcp_write);
let tcp_to_iroh = io::copy(&mut tcp_read, &mut iroh_send);
let _ = tokio::join!(iroh_to_tcp, tcp_to_iroh);
let _ = iroh_send.finish();
Ok(())
}
#[derive(Deserialize)]
struct RelayLookupResponse {
ticket: String,
}
async fn lookup_from_relay(relay_url: &str, code: &str) -> Result<String> {
let base = relay_url.trim_end_matches('/');
let resp: RelayLookupResponse = reqwest::Client::new()
.get(format!("{base}/api/lookup/{code}"))
.send()
.await
.context("failed to contact relay server")?
.error_for_status()
.context("ticket not found on relay server")?
.json()
.await?;
Ok(resp.ticket)
}
async fn resolve_ticket(payload: &str, relay_override: Option<&str>) -> Result<String> {
if let Some((authority, code)) = payload.split_once('/') {
let relay_url = relay_override
.map(String::from)
.unwrap_or_else(|| format!("http://{authority}"));
lookup_from_relay(&relay_url, code).await
} else if payload.len() < 64 {
let relay_url =
relay_override.ok_or_else(|| anyhow!("short code requires a relay URL"))?;
lookup_from_relay(relay_url, payload).await
} else {
Ok(payload.to_string())
}
}