use std::{
io,
net::{SocketAddr, SocketAddrV4, SocketAddrV6, ToSocketAddrs},
str::FromStr,
time::Duration,
};
use clap::{Parser, Subcommand};
use dumbpipe::EndpointTicket;
use iroh::{
endpoint::{presets, Accepting},
Endpoint, EndpointAddr, SecretKey,
};
use n0_error::{bail_any, ensure_any, AnyError, Result, StdResultExt};
use tokio::{
io::{AsyncRead, AsyncWrite, AsyncWriteExt},
select,
time::timeout,
};
use tokio_util::sync::CancellationToken;
#[cfg(unix)]
use {
std::path::PathBuf,
tokio::net::{UnixListener, UnixStream},
};
const ONLINE_TIMEOUT: Duration = Duration::from_secs(5);
#[derive(Parser, Debug)]
pub struct Args {
#[clap(subcommand)]
pub command: Commands,
}
#[derive(Subcommand, Debug)]
pub enum Commands {
GenerateTicket,
Listen(ListenArgs),
ListenTcp(ListenTcpArgs),
Connect(ConnectArgs),
ConnectTcp(ConnectTcpArgs),
#[cfg(unix)]
ListenUnix(ListenUnixArgs),
#[cfg(unix)]
ConnectUnix(ConnectUnixArgs),
}
#[derive(Parser, Debug)]
pub struct CommonArgs {
#[clap(long, default_value = None)]
pub ipv4_addr: Option<SocketAddrV4>,
#[clap(long, default_value = None)]
pub ipv6_addr: Option<SocketAddrV6>,
#[clap(long)]
pub custom_alpn: Option<String>,
#[clap(short = 'v', long, action = clap::ArgAction::Count)]
pub verbose: u8,
}
impl CommonArgs {
fn alpn(&self) -> Result<Vec<u8>> {
Ok(match &self.custom_alpn {
Some(alpn) => parse_alpn(alpn)?,
None => dumbpipe::ALPN.to_vec(),
})
}
fn is_custom_alpn(&self) -> bool {
self.custom_alpn.is_some()
}
}
fn parse_alpn(alpn: &str) -> Result<Vec<u8>> {
Ok(if let Some(text) = alpn.strip_prefix("utf8:") {
text.as_bytes().to_vec()
} else {
hex::decode(alpn).anyerr()?
})
}
#[derive(Parser, Debug)]
pub struct ListenArgs {
#[clap(long)]
pub recv_only: bool,
#[clap(flatten)]
pub common: CommonArgs,
}
#[derive(Parser, Debug)]
pub struct ListenTcpArgs {
#[clap(long)]
pub host: String,
#[clap(flatten)]
pub common: CommonArgs,
}
#[derive(Parser, Debug)]
pub struct ConnectTcpArgs {
#[clap(long)]
pub addr: String,
pub ticket: EndpointTicket,
#[clap(flatten)]
pub common: CommonArgs,
}
#[derive(Parser, Debug)]
pub struct ConnectArgs {
pub ticket: EndpointTicket,
#[clap(long)]
pub recv_only: bool,
#[clap(flatten)]
pub common: CommonArgs,
}
#[cfg(unix)]
#[derive(Parser, Debug)]
pub struct ListenUnixArgs {
#[clap(long)]
pub socket_path: PathBuf,
#[clap(flatten)]
pub common: CommonArgs,
}
#[cfg(unix)]
#[derive(Parser, Debug)]
pub struct ConnectUnixArgs {
#[clap(long)]
pub socket_path: PathBuf,
pub ticket: EndpointTicket,
#[clap(flatten)]
pub common: CommonArgs,
}
async fn copy_to_noq(
mut from: impl AsyncRead + Unpin,
mut send: noq::SendStream,
token: CancellationToken,
) -> io::Result<u64> {
tracing::trace!("copying to noq");
tokio::select! {
res = tokio::io::copy(&mut from, &mut send) => {
let size = res?;
send.finish()?;
Ok(size)
}
_ = token.cancelled() => {
send.reset(0u8.into()).ok();
Err(io::Error::other("cancelled"))
}
}
}
async fn copy_from_noq(
mut recv: noq::RecvStream,
mut to: impl AsyncWrite + Unpin,
token: CancellationToken,
) -> io::Result<u64> {
tokio::select! {
res = tokio::io::copy(&mut recv, &mut to) => {
Ok(res?)
},
_ = token.cancelled() => {
recv.stop(0u8.into()).ok();
Err(io::Error::other("cancelled"))
}
}
}
fn get_or_create_secret() -> Result<SecretKey> {
match std::env::var("IROH_SECRET") {
Ok(secret) => SecretKey::from_str(&secret).std_context("invalid secret"),
Err(_) => {
let key = SecretKey::generate();
eprintln!(
"using secret key {}",
data_encoding::HEXLOWER.encode(&key.to_bytes())
);
Ok(key)
}
}
}
async fn create_endpoint(
secret_key: SecretKey,
common: &CommonArgs,
alpns: Vec<Vec<u8>>,
) -> Result<Endpoint> {
let mut builder = Endpoint::builder(presets::N0)
.secret_key(secret_key)
.alpns(alpns);
if let Some(addr) = common.ipv4_addr {
builder = builder.bind_addr(addr)?;
}
if let Some(addr) = common.ipv6_addr {
builder = builder.bind_addr(addr)?;
}
let endpoint = builder.bind().await.anyerr()?;
Ok(endpoint)
}
fn cancel_token<T>(token: CancellationToken) -> impl Fn(T) -> T {
move |x| {
token.cancel();
x
}
}
async fn forward_bidi(
from1: impl AsyncRead + Send + Sync + Unpin + 'static,
to1: impl AsyncWrite + Send + Sync + Unpin + 'static,
from2: noq::RecvStream,
to2: noq::SendStream,
) -> Result<()> {
let token1 = CancellationToken::new();
let token2 = token1.clone();
let token3 = token1.clone();
let forward_from_stdin = tokio::spawn(async move {
copy_to_noq(from1, to2, token1.clone())
.await
.map_err(cancel_token(token1))
});
let forward_to_stdout = tokio::spawn(async move {
copy_from_noq(from2, to1, token2.clone())
.await
.map_err(cancel_token(token2))
});
let _control_c = tokio::spawn(async move {
tokio::signal::ctrl_c().await?;
token3.cancel();
io::Result::Ok(())
});
forward_to_stdout.await.anyerr()?.anyerr()?;
forward_from_stdin.await.anyerr()?.anyerr()?;
Ok(())
}
async fn listen_stdio(args: ListenArgs) -> Result<()> {
let secret_key = get_or_create_secret()?;
let endpoint = create_endpoint(secret_key, &args.common, vec![args.common.alpn()?]).await?;
if (timeout(ONLINE_TIMEOUT, endpoint.online()).await).is_err() {
eprintln!("Warning: Failed to connect to the home relay");
}
let addr = endpoint.addr();
let short = create_short_ticket(&addr);
let ticket = EndpointTicket::new(addr);
eprintln!("Listening. To connect, use:\ndumbpipe connect {ticket}");
if args.common.verbose > 0 {
eprintln!("or:\ndumbpipe connect {short}");
}
loop {
let Some(connecting) = endpoint.accept().await else {
break;
};
let connection = match connecting.await {
Ok(connection) => connection,
Err(cause) => {
tracing::warn!("error accepting connection: {}", cause);
continue;
}
};
let remote_endpoint_id = &connection.remote_id();
tracing::info!("got connection from {}", remote_endpoint_id);
let (s, mut r) = match connection.accept_bi().await {
Ok(x) => x,
Err(cause) => {
tracing::warn!("error accepting stream: {}", cause);
continue;
}
};
tracing::info!("accepted bidi stream from {}", remote_endpoint_id);
if !args.common.is_custom_alpn() {
let mut buf = [0u8; dumbpipe::HANDSHAKE.len()];
r.read_exact(&mut buf).await.anyerr()?;
ensure_any!(buf == dumbpipe::HANDSHAKE, "invalid handshake");
}
if args.recv_only {
tracing::info!(
"forwarding stdout to {} (ignoring stdin)",
remote_endpoint_id
);
forward_bidi(tokio::io::empty(), tokio::io::stdout(), r, s).await?;
} else {
tracing::info!("forwarding stdin/stdout to {}", remote_endpoint_id);
forward_bidi(tokio::io::stdin(), tokio::io::stdout(), r, s).await?;
}
break;
}
Ok(())
}
async fn connect_stdio(args: ConnectArgs) -> Result<()> {
let secret_key = get_or_create_secret()?;
let endpoint = create_endpoint(secret_key, &args.common, vec![]).await?;
let addr = args.ticket.endpoint_addr();
let remote_endpoint_id = addr.id;
let connection = endpoint
.connect(addr.clone(), &args.common.alpn()?)
.await
.anyerr()?;
tracing::info!("connected to {}", remote_endpoint_id);
let (mut s, r) = connection.open_bi().await.anyerr()?;
tracing::info!("opened bidi stream to {}", remote_endpoint_id);
if !args.common.is_custom_alpn() {
s.write_all(&dumbpipe::HANDSHAKE).await.anyerr()?;
}
if args.recv_only {
tracing::info!(
"forwarding stdout to {} (ignoring stdin)",
remote_endpoint_id
);
forward_bidi(tokio::io::empty(), tokio::io::stdout(), r, s).await?;
} else {
tracing::info!("forwarding stdin/stdout to {}", remote_endpoint_id);
forward_bidi(tokio::io::stdin(), tokio::io::stdout(), r, s).await?;
}
tokio::io::stdout().flush().await.anyerr()?;
Ok(())
}
async fn connect_tcp(args: ConnectTcpArgs) -> Result<()> {
let addrs = args
.addr
.to_socket_addrs()
.std_context(format!("invalid host string {}", args.addr))?;
let secret_key = get_or_create_secret()?;
let endpoint = create_endpoint(secret_key, &args.common, vec![])
.await
.std_context("unable to bind endpoint")?;
tracing::info!("tcp listening on {:?}", addrs);
if (timeout(ONLINE_TIMEOUT, endpoint.online()).await).is_err() {
eprintln!("Warning: Failed to connect to the home relay");
}
let tcp_listener = match tokio::net::TcpListener::bind(addrs.as_slice()).await {
Ok(tcp_listener) => tcp_listener,
Err(cause) => {
tracing::error!("error binding tcp socket to {:?}: {}", addrs, cause);
return Ok(());
}
};
async fn handle_tcp_accept(
next: io::Result<(tokio::net::TcpStream, SocketAddr)>,
addr: EndpointAddr,
endpoint: Endpoint,
handshake: bool,
alpn: &[u8],
) -> Result<()> {
let (tcp_stream, tcp_addr) = next.std_context("error accepting tcp connection")?;
let (tcp_recv, tcp_send) = tcp_stream.into_split();
tracing::info!("got tcp connection from {}", tcp_addr);
let remote_endpoint_id = addr.id;
let connection = endpoint
.connect(addr, alpn)
.await
.std_context(format!("error connecting to {remote_endpoint_id}"))?;
let (mut endpoint_send, endpoint_recv) = connection
.open_bi()
.await
.std_context(format!("error opening bidi stream to {remote_endpoint_id}"))?;
if handshake {
endpoint_send
.write_all(&dumbpipe::HANDSHAKE)
.await
.anyerr()?;
}
forward_bidi(tcp_recv, tcp_send, endpoint_recv, endpoint_send).await?;
Ok::<_, AnyError>(())
}
let addr = args.ticket.endpoint_addr();
loop {
let next = tokio::select! {
stream = tcp_listener.accept() => stream,
_ = tokio::signal::ctrl_c() => {
eprintln!("got ctrl-c, exiting");
break;
}
};
let endpoint = endpoint.clone();
let addr = addr.clone();
let handshake = !args.common.is_custom_alpn();
let alpn = args.common.alpn()?;
tokio::spawn(async move {
if let Err(cause) = handle_tcp_accept(next, addr, endpoint, handshake, &alpn).await {
tracing::warn!("error handling connection: {}", cause);
}
});
}
Ok(())
}
async fn listen_tcp(args: ListenTcpArgs) -> Result<()> {
let addrs = match args.host.to_socket_addrs() {
Ok(addrs) => addrs.collect::<Vec<_>>(),
Err(e) => bail_any!("invalid host string {}: {}", args.host, e),
};
let secret_key = get_or_create_secret()?;
let endpoint = create_endpoint(secret_key, &args.common, vec![args.common.alpn()?]).await?;
if (timeout(ONLINE_TIMEOUT, endpoint.online()).await).is_err() {
eprintln!("Warning: Failed to connect to the home relay");
}
let addr = endpoint.addr();
let short = create_short_ticket(&addr);
let ticket = EndpointTicket::new(addr);
eprintln!("Forwarding incoming requests to '{}'.", args.host);
eprintln!("To connect, use e.g.:");
eprintln!("dumbpipe connect-tcp {ticket}");
if args.common.verbose > 0 {
eprintln!("or:\ndumbpipe connect-tcp {short}");
}
tracing::info!("endpoint id is {}", ticket.endpoint_addr().id);
tracing::info!(
"relay url is {:?}",
ticket
.endpoint_addr()
.relay_urls()
.next()
.map_or("None".to_string(), |url| url.to_string())
);
async fn handle_endpoint_accept(
accepting: Accepting,
addrs: Vec<std::net::SocketAddr>,
handshake: bool,
) -> Result<()> {
let connection = accepting.await.std_context("error accepting connection")?;
let remote_endpoint_id = &connection.remote_id();
tracing::info!("got connection from {}", remote_endpoint_id);
let (s, mut r) = connection
.accept_bi()
.await
.std_context("error accepting stream")?;
tracing::info!("accepted bidi stream from {}", remote_endpoint_id);
if handshake {
let mut buf = [0u8; dumbpipe::HANDSHAKE.len()];
r.read_exact(&mut buf).await.anyerr()?;
ensure_any!(buf == dumbpipe::HANDSHAKE, "invalid handshake");
}
let connection = tokio::net::TcpStream::connect(addrs.as_slice())
.await
.std_context(format!("error connecting to {addrs:?}"))?;
let (read, write) = connection.into_split();
forward_bidi(read, write, r, s).await?;
Ok(())
}
loop {
let incoming = select! {
incoming = endpoint.accept() => incoming,
_ = tokio::signal::ctrl_c() => {
eprintln!("got ctrl-c, exiting");
break;
}
};
let Some(incoming) = incoming else {
break;
};
let Ok(connecting) = incoming.accept() else {
break;
};
let addrs = addrs.clone();
let handshake = !args.common.is_custom_alpn();
tokio::spawn(async move {
if let Err(cause) = handle_endpoint_accept(connecting, addrs, handshake).await {
tracing::warn!("error handling connection: {}", cause);
}
});
}
Ok(())
}
fn create_short_ticket(addr: &EndpointAddr) -> EndpointTicket {
let mut short = EndpointAddr::new(addr.id);
for relay_url in addr.relay_urls() {
short = short.with_relay_url(relay_url.clone());
}
short.into()
}
#[cfg(unix)]
async fn listen_unix(args: ListenUnixArgs) -> Result<()> {
let socket_path = args.socket_path.clone();
let secret_key = get_or_create_secret()?;
let endpoint = create_endpoint(secret_key, &args.common, vec![args.common.alpn()?]).await?;
if (timeout(ONLINE_TIMEOUT, endpoint.online()).await).is_err() {
eprintln!("Warning: Failed to connect to the home relay");
}
let addr = endpoint.addr();
let short = create_short_ticket(&addr);
let ticket = EndpointTicket::new(addr);
eprintln!(
"Forwarding incoming requests to '{}'.",
socket_path.display()
);
eprintln!("To connect, use e.g.:");
eprintln!("dumbpipe connect-unix --socket-path /path/to/client.sock {ticket}");
eprintln!("dumbpipe connect-tcp --addr 127.0.0.1:8080 {ticket}");
if args.common.verbose > 0 {
eprintln!("or:\ndumbpipe connect-unix --socket-path /path/to/client.sock {short}");
eprintln!("dumbpipe connect-tcp --addr 127.0.0.1:8080 {short}");
}
tracing::info!("endpoint id is {}", ticket.endpoint_addr().id);
tracing::info!(
"relay url is {:?}",
ticket
.endpoint_addr()
.relay_urls()
.next()
.map_or("None".to_string(), |url| url.to_string())
);
async fn handle_endpoint_accept(
accepting: Accepting,
socket_path: PathBuf,
handshake: bool,
) -> Result<()> {
tracing::trace!("accepting connection");
let connection = accepting.await.std_context("error accepting connection")?;
let remote_endpoint_id = &connection.remote_id();
tracing::info!("got connection from {}", remote_endpoint_id);
let (s, mut r) = connection
.accept_bi()
.await
.std_context("error accepting stream")?;
tracing::info!("accepted bidi stream from {}", remote_endpoint_id);
if handshake {
tracing::trace!("reading handshake");
let mut buf = [0u8; dumbpipe::HANDSHAKE.len()];
r.read_exact(&mut buf).await.anyerr()?;
ensure_any!(buf == dumbpipe::HANDSHAKE, "invalid handshake");
tracing::trace!("handshake verified");
}
tracing::trace!("connecting to backend socket {:?}", socket_path);
let connection = UnixStream::connect(&socket_path)
.await
.std_context(format!("error connecting to {socket_path:?}"))?;
tracing::trace!("connected to backend socket");
let (read, write) = connection.into_split();
tracing::trace!("starting forward_bidi");
forward_bidi(read, write, r, s).await?;
tracing::trace!("forward_bidi finished");
Ok(())
}
loop {
let incoming = select! {
incoming = endpoint.accept() => incoming,
_ = tokio::signal::ctrl_c() => {
eprintln!("got ctrl-c, exiting");
break;
}
};
let Some(incoming) = incoming else {
break;
};
let Ok(connecting) = incoming.accept() else {
break;
};
let socket_path = socket_path.clone();
let handshake = !args.common.is_custom_alpn();
tokio::spawn(async move {
if let Err(cause) = handle_endpoint_accept(connecting, socket_path, handshake).await {
tracing::warn!("error handling connection: {}", cause);
}
});
}
Ok(())
}
#[cfg(unix)]
struct UnixSocketGuard {
path: PathBuf,
}
#[cfg(unix)]
impl Drop for UnixSocketGuard {
fn drop(&mut self) {
if let Err(e) = std::fs::remove_file(&self.path) {
if e.kind() != std::io::ErrorKind::NotFound {
tracing::error!("failed to remove socket file {:?}: {}", self.path, e);
}
}
}
}
#[cfg(unix)]
async fn connect_unix(args: ConnectUnixArgs) -> Result<()> {
let socket_path = args.socket_path.clone();
let secret_key = get_or_create_secret()?;
let endpoint = create_endpoint(secret_key, &args.common, vec![])
.await
.std_context("unable to bind endpoint")?;
tracing::info!("unix listening on {:?}", socket_path);
if (timeout(ONLINE_TIMEOUT, endpoint.online()).await).is_err() {
eprintln!("Warning: Failed to connect to the home relay");
}
if let Err(e) = tokio::fs::remove_file(&socket_path).await {
if e.kind() != io::ErrorKind::NotFound {
bail_any!("failed to remove existing socket file: {}", e);
}
}
let addr = args.ticket.endpoint_addr();
tracing::info!("connecting to remote endpoint: {:?}", addr);
let connection = endpoint
.connect(addr.clone(), &args.common.alpn()?)
.await
.std_context("failed to connect to remote endpoint")?;
tracing::info!("connected to remote endpoint successfully");
let unix_listener = UnixListener::bind(&socket_path)
.with_std_context(|_| format!("failed to bind Unix socket at {socket_path:?}"))?;
tracing::info!("bound local unix socket: {:?}", socket_path);
let _guard = UnixSocketGuard {
path: socket_path.clone(),
};
async fn handle_unix_accept(
next: io::Result<(UnixStream, tokio::net::unix::SocketAddr)>,
connection: iroh::endpoint::Connection,
handshake: bool,
) -> Result<()> {
tracing::trace!("handling new local connection");
let (unix_stream, unix_addr) = next.std_context("error accepting unix connection")?;
let (unix_recv, unix_send) = unix_stream.into_split();
tracing::trace!("got unix connection from {:?}", unix_addr);
tracing::trace!("opening bidi stream");
let (mut endpoint_send, endpoint_recv) = connection
.open_bi()
.await
.std_context("error opening bidi stream")?;
tracing::trace!("bidi stream opened");
if handshake {
tracing::trace!("sending handshake");
endpoint_send
.write_all(&dumbpipe::HANDSHAKE)
.await
.anyerr()?;
tracing::trace!("handshake sent");
}
tracing::trace!("starting forward_bidi");
forward_bidi(unix_recv, unix_send, endpoint_recv, endpoint_send).await?;
tracing::trace!("forward_bidi finished");
Ok(())
}
tracing::info!("entering accept loop");
loop {
let next = tokio::select! {
stream = unix_listener.accept() => stream,
_ = tokio::signal::ctrl_c() => {
eprintln!("got ctrl-c, exiting");
break;
}
};
tracing::trace!("accepted a local connection");
let connection = connection.clone();
let handshake = !args.common.is_custom_alpn();
tokio::spawn(async move {
tracing::trace!("spawning handler task");
if let Err(cause) = handle_unix_accept(next, connection, handshake).await {
tracing::warn!("error handling connection: {}", cause);
}
tracing::trace!("handler task finished");
});
}
Ok(())
}
async fn generate_ticket() -> Result<()> {
let secret_key = get_or_create_secret()?;
let public_key = secret_key.public();
let addr = EndpointAddr::new(public_key);
let ticket = EndpointTicket::new(addr);
println!("{}", ticket);
Ok(())
}
#[tokio::main]
async fn main() -> Result<()> {
tracing_subscriber::fmt::init();
let args = Args::parse();
let res = match args.command {
Commands::GenerateTicket => generate_ticket().await,
Commands::Listen(args) => listen_stdio(args).await,
Commands::ListenTcp(args) => listen_tcp(args).await,
Commands::Connect(args) => connect_stdio(args).await,
Commands::ConnectTcp(args) => connect_tcp(args).await,
#[cfg(unix)]
Commands::ListenUnix(args) => listen_unix(args).await,
#[cfg(unix)]
Commands::ConnectUnix(args) => connect_unix(args).await,
};
match res {
Ok(()) => std::process::exit(0),
Err(e) => {
eprintln!("error: {e}");
std::process::exit(1)
}
}
}