use std::net::Ipv4Addr;
use std::path::PathBuf;
use std::process::Stdio;
use tokio::net::TcpStream;
use tokio::process::{Child, Command};
use tokio::signal;
use tokio::time::{sleep, Duration, Instant};
use tracing::debug;
const CLOUDFLARED_READY_TIMEOUT: Duration = Duration::from_secs(10);
const CLOUDFLARED_READY_POLL_INTERVAL: Duration = Duration::from_millis(120);
#[derive(Debug, Clone)]
pub struct CloudflaredTcpOptions {
pub hostname: String,
pub listener: Option<String>,
pub destination: Option<String>,
pub binary_path: Option<PathBuf>,
}
pub struct CloudflaredTunnel {
child: Child,
pub hostname: String,
pub listener_addr: String,
pub local_port: u16,
}
impl CloudflaredTunnel {
pub async fn start(options: CloudflaredTcpOptions, debug_mode: bool) -> Result<Self, String> {
let listener_addr = match options.listener {
Some(listener) => listener,
None => format!("127.0.0.1:{}", reserve_local_port()?),
};
let local_port = parse_listener_port(&listener_addr)?;
let binary = options
.binary_path
.unwrap_or_else(|| PathBuf::from("cloudflared"));
let mut command = Command::new(&binary);
command
.arg("access")
.arg("tcp")
.arg("--hostname")
.arg(&options.hostname)
.arg("--url")
.arg(&listener_addr)
.stdin(Stdio::null())
.stdout(Stdio::null())
.stderr(Stdio::inherit());
if let Some(destination) = options.destination.and_then(normalize_optional_string) {
command.arg("--destination").arg(destination);
}
if debug_mode {
debug!(
"Starting cloudflared => bin: {}, hostname: {}, listen: {}",
binary.display(),
options.hostname,
listener_addr
);
}
let mut child = command
.spawn()
.map_err(|e| format!("Failed to start cloudflared: {}", e))?;
wait_for_forwarder(&mut child, &listener_addr).await?;
Ok(Self {
child,
hostname: options.hostname,
listener_addr,
local_port,
})
}
pub async fn shutdown(&mut self) {
match self.child.try_wait() {
Ok(Some(_)) => {}
Ok(None) => {
let _ = self.child.kill().await;
let _ = self.child.wait().await;
}
Err(_) => {}
}
}
}
pub async fn run_cloudflared_tcp(
options: CloudflaredTcpOptions,
debug_mode: bool,
) -> Result<(), String> {
let mut tunnel = CloudflaredTunnel::start(options, debug_mode).await?;
println!("cloudflared tcp ready on {}", tunnel.listener_addr);
println!("Press Ctrl+C to stop the forwarder.");
let ctrl_c = signal::ctrl_c().await;
tunnel.shutdown().await;
ctrl_c.map_err(|e| format!("Failed to wait for Ctrl+C: {}", e))?;
Ok(())
}
fn parse_listener_port(listener_addr: &str) -> Result<u16, String> {
listener_addr
.rsplit_once(':')
.ok_or_else(|| format!("Invalid listener address `{listener_addr}`"))?
.1
.parse::<u16>()
.map_err(|e| format!("Invalid listener port in `{listener_addr}`: {}", e))
}
fn reserve_local_port() -> Result<u16, String> {
let listener = std::net::TcpListener::bind((Ipv4Addr::LOCALHOST, 0))
.map_err(|e| format!("Failed to reserve local port: {}", e))?;
listener
.local_addr()
.map(|addr| addr.port())
.map_err(|e| format!("Failed to inspect local listener address: {}", e))
}
async fn wait_for_forwarder(child: &mut Child, listener_addr: &str) -> Result<(), String> {
let deadline = Instant::now() + CLOUDFLARED_READY_TIMEOUT;
loop {
if let Ok(stream) = TcpStream::connect(listener_addr).await {
drop(stream);
return Ok(());
}
match child.try_wait() {
Ok(Some(status)) => {
return Err(format!(
"cloudflared exited before opening local tunnel (status: {})",
status
));
}
Ok(None) => {}
Err(err) => return Err(format!("Failed to inspect cloudflared process: {}", err)),
}
if Instant::now() >= deadline {
let _ = child.kill().await;
let _ = child.wait().await;
return Err(format!(
"Timed out waiting for cloudflared to open {}",
listener_addr
));
}
sleep(CLOUDFLARED_READY_POLL_INTERVAL).await;
}
}
fn normalize_optional_string(value: String) -> Option<String> {
let trimmed = value.trim();
if trimmed.is_empty() {
None
} else {
Some(trimmed.to_string())
}
}
#[cfg(test)]
mod tests {
use super::parse_listener_port;
#[test]
fn parses_listener_port_from_host_port() {
assert_eq!(parse_listener_port("127.0.0.1:2222").unwrap(), 2222);
}
#[test]
fn rejects_invalid_listener_address() {
assert!(parse_listener_port("localhost").is_err());
}
}