#[cfg(not(tokio_unstable))]
compile_error!("tokio_unstable cfg must be enabled; see .cargo/config.toml");
use anyhow::{anyhow, Context};
use tracing::instrument;
pub mod deploy;
pub mod port_ranges;
pub mod protocol;
pub mod streams;
pub mod tls;
pub mod tracelog;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, serde::Serialize, serde::Deserialize)]
pub enum NetworkProfile {
#[default]
Datacenter,
Internet,
}
impl std::fmt::Display for NetworkProfile {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Datacenter => write!(f, "datacenter"),
Self::Internet => write!(f, "internet"),
}
}
}
impl std::str::FromStr for NetworkProfile {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"datacenter" => Ok(Self::Datacenter),
"internet" => Ok(Self::Internet),
_ => Err(format!(
"invalid network profile '{}', expected 'datacenter' or 'internet'",
s
)),
}
}
}
pub const DATACENTER_REMOTE_COPY_BUFFER_SIZE: usize = 16 * 1024 * 1024;
pub const INTERNET_REMOTE_COPY_BUFFER_SIZE: usize = 2 * 1024 * 1024;
impl NetworkProfile {
pub fn default_remote_copy_buffer_size(&self) -> usize {
match self {
Self::Datacenter => DATACENTER_REMOTE_COPY_BUFFER_SIZE,
Self::Internet => INTERNET_REMOTE_COPY_BUFFER_SIZE,
}
}
}
#[derive(Debug, Clone)]
pub struct TcpConfig {
pub port_ranges: Option<String>,
pub conn_timeout_sec: u64,
pub network_profile: NetworkProfile,
pub buffer_size: Option<usize>,
pub max_connections: usize,
pub pending_writes_multiplier: usize,
}
pub const DEFAULT_PENDING_WRITES_MULTIPLIER: usize = 4;
impl Default for TcpConfig {
fn default() -> Self {
Self {
port_ranges: None,
conn_timeout_sec: 15,
network_profile: NetworkProfile::default(),
buffer_size: None,
max_connections: 100,
pending_writes_multiplier: DEFAULT_PENDING_WRITES_MULTIPLIER,
}
}
}
impl TcpConfig {
pub fn with_timeout(conn_timeout_sec: u64) -> Self {
Self {
port_ranges: None,
conn_timeout_sec,
network_profile: NetworkProfile::default(),
buffer_size: None,
max_connections: 100,
pending_writes_multiplier: DEFAULT_PENDING_WRITES_MULTIPLIER,
}
}
pub fn with_port_ranges(mut self, ranges: impl Into<String>) -> Self {
self.port_ranges = Some(ranges.into());
self
}
pub fn with_network_profile(mut self, profile: NetworkProfile) -> Self {
self.network_profile = profile;
self
}
pub fn with_buffer_size(mut self, size: usize) -> Self {
self.buffer_size = Some(size);
self
}
pub fn with_max_connections(mut self, max: usize) -> Self {
self.max_connections = max;
self
}
pub fn with_pending_writes_multiplier(mut self, multiplier: usize) -> Self {
self.pending_writes_multiplier = multiplier;
self
}
pub fn effective_buffer_size(&self) -> usize {
self.buffer_size
.unwrap_or_else(|| self.network_profile.default_remote_copy_buffer_size())
}
}
#[derive(Debug, PartialEq, Eq, Clone, Hash)]
pub struct SshSession {
pub user: Option<String>,
pub host: String,
pub port: Option<u16>,
}
impl SshSession {
pub fn local() -> Self {
Self {
user: None,
host: "localhost".to_string(),
port: None,
}
}
}
pub use common::is_localhost;
async fn setup_ssh_session(
session: &SshSession,
) -> anyhow::Result<std::sync::Arc<openssh::Session>> {
let host = session.host.as_str();
let destination = match (session.user.as_deref(), session.port) {
(Some(user), Some(port)) => format!("ssh://{user}@{host}:{port}"),
(None, Some(port)) => format!("ssh://{}:{}", session.host, port),
(Some(user), None) => format!("ssh://{user}@{host}"),
(None, None) => format!("ssh://{host}"),
};
tracing::debug!("Connecting to SSH destination: {}", destination);
let session = std::sync::Arc::new(
openssh::Session::connect(destination, openssh::KnownHosts::Accept)
.await
.context("Failed to establish SSH connection")?,
);
Ok(session)
}
#[instrument]
pub async fn get_remote_home_for_session(
session: &SshSession,
) -> anyhow::Result<std::path::PathBuf> {
let ssh_session = setup_ssh_session(session).await?;
let home = get_remote_home(&ssh_session).await?;
Ok(std::path::PathBuf::from(home))
}
#[instrument]
pub async fn wait_for_rcpd_process(
process: openssh::Child<std::sync::Arc<openssh::Session>>,
) -> anyhow::Result<()> {
tracing::info!("Waiting on rcpd server on: {:?}", process);
let output = tokio::time::timeout(
std::time::Duration::from_secs(10),
process.wait_with_output(),
)
.await
.context("Timeout waiting for rcpd process to exit")?
.context("Failed to wait for rcpd process")?;
if !output.status.success() {
let stdout = String::from_utf8_lossy(&output.stdout);
let stderr = String::from_utf8_lossy(&output.stderr);
tracing::error!(
"rcpd command failed on remote host, status code: {:?}\nstdout:\n{}\nstderr:\n{}",
output.status.code(),
stdout,
stderr
);
return Err(anyhow!(
"rcpd command failed on remote host, status code: {:?}",
output.status.code(),
));
}
if !output.stderr.is_empty() {
let stderr = String::from_utf8_lossy(&output.stderr);
tracing::debug!("rcpd stderr output:\n{}", stderr);
}
Ok(())
}
pub(crate) fn shell_escape(s: &str) -> String {
format!("'{}'", s.replace('\'', r"'\''"))
}
pub async fn get_remote_home(session: &std::sync::Arc<openssh::Session>) -> anyhow::Result<String> {
if let Ok(home_override) = std::env::var("RCP_REMOTE_HOME_OVERRIDE") {
if !home_override.is_empty() {
return Ok(home_override);
}
}
let output = session
.command("sh")
.arg("-c")
.arg("echo \"${HOME:?HOME not set}\"")
.output()
.await
.context("failed to check HOME environment variable on remote host")?;
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
anyhow::bail!(
"HOME environment variable is not set on remote host\n\
\n\
stderr: {}\n\
\n\
The HOME environment variable is required for rcpd deployment and discovery.\n\
Please ensure your SSH configuration preserves environment variables.",
stderr
);
}
let home = String::from_utf8_lossy(&output.stdout).trim().to_string();
if home.is_empty() {
anyhow::bail!(
"HOME environment variable is empty on remote host\n\
\n\
The HOME environment variable is required for rcpd deployment and discovery.\n\
Please ensure your SSH configuration sets HOME correctly."
);
}
Ok(home)
}
#[cfg(test)]
mod shell_escape_tests {
use super::*;
#[test]
fn test_shell_escape_simple() {
assert_eq!(shell_escape("simple"), "'simple'");
}
#[test]
fn test_shell_escape_with_spaces() {
assert_eq!(shell_escape("path with spaces"), "'path with spaces'");
}
#[test]
fn test_shell_escape_with_single_quote() {
assert_eq!(
shell_escape("path'with'quotes"),
r"'path'\''with'\''quotes'"
);
}
#[test]
fn test_shell_escape_injection_attempt() {
assert_eq!(shell_escape("foo; rm -rf /"), "'foo; rm -rf /'");
}
#[test]
fn test_shell_escape_special_chars() {
assert_eq!(shell_escape("$PATH && echo pwned"), "'$PATH && echo pwned'");
}
}
trait DiscoverySession {
fn test_executable<'a>(
&'a self,
path: &'a str,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = anyhow::Result<bool>> + Send + 'a>>;
fn which<'a>(
&'a self,
binary: &'a str,
) -> std::pin::Pin<
Box<dyn std::future::Future<Output = anyhow::Result<Option<String>>> + Send + 'a>,
>;
fn remote_home<'a>(
&'a self,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = anyhow::Result<String>> + Send + 'a>>;
}
struct RealDiscoverySession<'a> {
session: &'a std::sync::Arc<openssh::Session>,
}
impl<'a> DiscoverySession for RealDiscoverySession<'a> {
fn test_executable<'b>(
&'b self,
path: &'b str,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = anyhow::Result<bool>> + Send + 'b>>
{
Box::pin(async move {
let output = self
.session
.command("sh")
.arg("-c")
.arg(format!("test -x {}", shell_escape(path)))
.output()
.await?;
Ok(output.status.success())
})
}
fn which<'b>(
&'b self,
binary: &'b str,
) -> std::pin::Pin<
Box<dyn std::future::Future<Output = anyhow::Result<Option<String>>> + Send + 'b>,
> {
Box::pin(async move {
let output = self.session.command("which").arg(binary).output().await?;
if output.status.success() {
let path = String::from_utf8_lossy(&output.stdout).trim().to_string();
if !path.is_empty() {
return Ok(Some(path));
}
}
Ok(None)
})
}
fn remote_home<'b>(
&'b self,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = anyhow::Result<String>> + Send + 'b>>
{
Box::pin(get_remote_home(self.session))
}
}
async fn discover_rcpd_path(
session: &std::sync::Arc<openssh::Session>,
explicit_path: Option<&str>,
) -> anyhow::Result<String> {
let real_session = RealDiscoverySession { session };
discover_rcpd_path_internal(&real_session, explicit_path, None).await
}
async fn discover_rcpd_path_internal<S: DiscoverySession + ?Sized>(
session: &S,
explicit_path: Option<&str>,
current_exe_override: Option<std::path::PathBuf>,
) -> anyhow::Result<String> {
let local_version = common::version::ProtocolVersion::current();
if let Some(path) = explicit_path {
tracing::debug!("Trying explicit rcpd path: {}", path);
if session.test_executable(path).await? {
tracing::info!("Found rcpd at explicit path: {}", path);
return Ok(path.to_string());
}
return Err(anyhow::anyhow!(
"rcpd binary not found or not executable at explicit path: {}",
path
));
}
if let Ok(current_exe) = current_exe_override
.map(Ok)
.unwrap_or_else(std::env::current_exe)
{
if let Some(bin_dir) = current_exe.parent() {
let path = bin_dir.join("rcpd").display().to_string();
tracing::debug!("Trying same directory as rcp: {}", path);
if session.test_executable(&path).await? {
tracing::info!("Found rcpd in same directory as rcp: {}", path);
return Ok(path);
}
}
}
tracing::debug!("Trying to find rcpd in PATH");
if let Some(path) = session.which("rcpd").await? {
tracing::info!("Found rcpd in PATH: {}", path);
return Ok(path);
}
let cache_path = match session.remote_home().await {
Ok(home) => {
let path = format!("{}/.cache/rcp/bin/rcpd-{}", home, local_version.semantic);
tracing::debug!("Trying deployed cache path: {}", path);
if session.test_executable(&path).await? {
tracing::info!("Found rcpd in deployed cache: {}", path);
return Ok(path);
}
Some(path)
}
Err(e) => {
tracing::debug!(
"HOME not set on remote host, skipping cache directory check: {:#}",
e
);
None
}
};
let mut searched = vec![];
searched.push("- Same directory as local rcp binary".to_string());
searched.push("- PATH (via 'which rcpd')".to_string());
if let Some(path) = cache_path.as_ref() {
searched.push(format!("- Deployed cache: {}", path));
} else {
searched.push("- Deployed cache: (skipped, HOME not set)".to_string());
}
if let Some(path) = explicit_path {
searched.insert(
0,
format!("- Explicit path: {} (not found or not executable)", path),
);
}
Err(anyhow::anyhow!(
"rcpd binary not found on remote host\n\
\n\
Searched in:\n\
{}\n\
\n\
Options:\n\
- Use automatic deployment: rcp --auto-deploy-rcpd ...\n\
- Install rcpd manually: cargo install rcp-tools-rcp --version {}\n\
- Specify explicit path: rcp --rcpd-path=/path/to/rcpd ...",
searched.join("\n"),
local_version.semantic
))
}
async fn try_discover_and_check_version(
session: &std::sync::Arc<openssh::Session>,
explicit_path: Option<&str>,
remote_host: &str,
) -> anyhow::Result<String> {
let rcpd_path = discover_rcpd_path(session, explicit_path).await?;
check_rcpd_version(session, &rcpd_path, remote_host).await?;
Ok(rcpd_path)
}
async fn check_rcpd_version(
session: &std::sync::Arc<openssh::Session>,
rcpd_path: &str,
remote_host: &str,
) -> anyhow::Result<()> {
let local_version = common::version::ProtocolVersion::current();
tracing::debug!("Checking rcpd version on remote host: {}", remote_host);
let output = session
.command(rcpd_path)
.arg("--protocol-version")
.output()
.await
.context("Failed to execute rcpd --protocol-version on remote host")?;
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
return Err(anyhow::anyhow!(
"rcpd --protocol-version failed on remote host '{}'\n\
\n\
stderr: {}\n\
\n\
This may indicate an old version of rcpd that does not support --protocol-version.\n\
Please install a matching version of rcpd on the remote host:\n\
- cargo install rcp-tools-rcp --version {}",
remote_host,
stderr,
local_version.semantic
));
}
let stdout = String::from_utf8_lossy(&output.stdout);
let remote_version = common::version::ProtocolVersion::from_json(stdout.trim())
.context("Failed to parse rcpd version JSON from remote host")?;
tracing::info!(
"Local version: {}, Remote version: {}",
local_version,
remote_version
);
if !local_version.is_compatible_with(&remote_version) {
return Err(anyhow::anyhow!(
"rcpd version mismatch\n\
\n\
Local: rcp {}\n\
Remote: rcpd {} on host '{}'\n\
\n\
The rcpd version on the remote host must exactly match the rcp version.\n\
\n\
To fix this, install the matching version on the remote host:\n\
- ssh {} 'cargo install rcp-tools-rcp --version {}'",
local_version,
remote_version,
remote_host,
shell_escape(remote_host),
local_version.semantic
));
}
Ok(())
}
#[derive(Debug, Clone)]
pub struct RcpdConnectionInfo {
pub addr: std::net::SocketAddr,
pub fingerprint: Option<tls::Fingerprint>,
}
pub struct RcpdProcess {
pub child: openssh::Child<std::sync::Arc<openssh::Session>>,
pub conn_info: RcpdConnectionInfo,
_stderr_drain: tokio::task::JoinHandle<()>,
_stdout_drain: Option<tokio::task::JoinHandle<()>>,
}
#[allow(clippy::too_many_arguments)]
#[instrument]
pub async fn start_rcpd(
rcpd_config: &protocol::RcpdConfig,
session: &SshSession,
explicit_rcpd_path: Option<&str>,
auto_deploy_rcpd: bool,
bind_ip: Option<&str>,
role: protocol::RcpdRole,
) -> anyhow::Result<RcpdProcess> {
tracing::info!("Starting rcpd server on: {:?}", session);
let remote_host = &session.host;
let ssh_session = setup_ssh_session(session).await?;
let rcpd_path =
match try_discover_and_check_version(&ssh_session, explicit_rcpd_path, remote_host).await {
Ok(path) => {
path
}
Err(e) => {
if auto_deploy_rcpd {
tracing::info!(
"rcpd not found or version mismatch, attempting auto-deployment"
);
let local_rcpd = deploy::find_local_rcpd_binary()
.context("failed to find local rcpd binary for deployment")?;
tracing::info!("Found local rcpd binary at {}", local_rcpd.display());
let local_version = common::version::ProtocolVersion::current();
let deployed_path = deploy::deploy_rcpd(
&ssh_session,
&local_rcpd,
&local_version.semantic,
remote_host,
)
.await
.context("failed to deploy rcpd to remote host")?;
tracing::info!("Successfully deployed rcpd to {}", deployed_path);
if let Err(e) = deploy::cleanup_old_versions(&ssh_session, 3).await {
tracing::warn!("failed to cleanup old versions (non-fatal): {:#}", e);
}
deployed_path
} else {
return Err(e);
}
}
};
let rcpd_args = rcpd_config.to_args();
tracing::debug!("rcpd arguments: {:?}", rcpd_args);
let mut cmd = ssh_session.arc_command(&rcpd_path);
cmd.arg("--role").arg(role.to_string()).args(rcpd_args);
if let Some(ip) = bind_ip {
tracing::debug!("passing --bind-ip {} to rcpd", ip);
cmd.arg("--bind-ip").arg(ip);
}
cmd.stdin(openssh::Stdio::piped());
cmd.stdout(openssh::Stdio::piped());
cmd.stderr(openssh::Stdio::piped());
tracing::info!("Will run remotely: {cmd:?}");
let mut child = cmd.spawn().await.context("Failed to spawn rcpd command")?;
let stderr = child.stderr().take().context("rcpd stderr not available")?;
let mut stderr_reader = tokio::io::BufReader::new(stderr);
let mut line = String::new();
use tokio::io::AsyncBufReadExt;
stderr_reader
.read_line(&mut line)
.await
.context("failed to read connection info from rcpd")?;
let line = line.trim();
let host_stderr = session.host.clone();
let stderr_drain = tokio::spawn(async move {
let mut line = String::new();
loop {
line.clear();
match stderr_reader.read_line(&mut line).await {
Ok(0) => break, Ok(_) => {
let trimmed = line.trim();
if !trimmed.is_empty() {
tracing::debug!(host = %host_stderr, "rcpd stderr: {}", trimmed);
}
}
Err(e) => {
tracing::debug!(host = %host_stderr, "rcpd stderr read error: {:#}", e);
break;
}
}
}
});
let stdout_drain = if let Some(stdout) = child.stdout().take() {
let host_stdout = session.host.clone();
let mut stdout_reader = tokio::io::BufReader::new(stdout);
Some(tokio::spawn(async move {
let mut line = String::new();
loop {
line.clear();
match stdout_reader.read_line(&mut line).await {
Ok(0) => break, Ok(_) => {
let trimmed = line.trim();
if !trimmed.is_empty() {
tracing::debug!(host = %host_stdout, "rcpd stdout: {}", trimmed);
}
}
Err(e) => {
tracing::debug!(host = %host_stdout, "rcpd stdout read error: {:#}", e);
break;
}
}
}
}))
} else {
None
};
tracing::debug!("rcpd connection line: {}", line);
let conn_info = if let Some(rest) = line.strip_prefix("RCP_TLS ") {
let parts: Vec<&str> = rest.split_whitespace().collect();
if parts.len() != 2 {
anyhow::bail!("invalid RCP_TLS line from rcpd: {}", line);
}
let addr = parts[0]
.parse()
.with_context(|| format!("invalid address in RCP_TLS line: {}", parts[0]))?;
let fingerprint = tls::fingerprint_from_hex(parts[1])
.with_context(|| format!("invalid fingerprint in RCP_TLS line: {}", parts[1]))?;
RcpdConnectionInfo {
addr,
fingerprint: Some(fingerprint),
}
} else if let Some(rest) = line.strip_prefix("RCP_TCP ") {
let addr = rest
.trim()
.parse()
.with_context(|| format!("invalid address in RCP_TCP line: {}", rest))?;
RcpdConnectionInfo {
addr,
fingerprint: None,
}
} else {
anyhow::bail!(
"unexpected output from rcpd (expected RCP_TLS or RCP_TCP): {}",
line
);
};
tracing::info!(
"rcpd listening on {} (encryption={})",
conn_info.addr,
conn_info.fingerprint.is_some()
);
Ok(RcpdProcess {
child,
conn_info,
_stderr_drain: stderr_drain,
_stdout_drain: stdout_drain,
})
}
fn get_local_ip(explicit_bind_ip: Option<&str>) -> anyhow::Result<std::net::IpAddr> {
if let Some(ip_str) = explicit_bind_ip {
let ip = ip_str
.parse::<std::net::IpAddr>()
.with_context(|| format!("invalid IP address: {}", ip_str))?;
match ip {
std::net::IpAddr::V4(ipv4) => {
tracing::debug!("using explicit bind IP: {}", ipv4);
return Ok(std::net::IpAddr::V4(ipv4));
}
std::net::IpAddr::V6(_) => {
anyhow::bail!(
"IPv6 address not supported for binding (got {}). \
TCP endpoints bind to 0.0.0.0 (IPv4 only)",
ip
);
}
}
}
if let Some(ipv4) = try_ipv4_via_kernel_routing()? {
return Ok(std::net::IpAddr::V4(ipv4));
}
tracing::debug!("routing-based detection failed, falling back to interface enumeration");
let interfaces = collect_ipv4_interfaces().context("Failed to enumerate network interfaces")?;
if let Some(ipv4) = choose_best_ipv4(&interfaces) {
tracing::debug!("using IPv4 address from interface scan: {}", ipv4);
return Ok(std::net::IpAddr::V4(ipv4));
}
anyhow::bail!("No IPv4 interfaces found (TCP endpoints require IPv4 as they bind to 0.0.0.0)")
}
fn try_ipv4_via_kernel_routing() -> anyhow::Result<Option<std::net::Ipv4Addr>> {
let private_ips = ["10.0.0.1:80", "172.16.0.1:80", "192.168.1.1:80"];
for addr_str in &private_ips {
let addr = addr_str
.parse::<std::net::SocketAddr>()
.expect("hardcoded socket addresses are valid");
let socket = match std::net::UdpSocket::bind("0.0.0.0:0") {
Ok(socket) => socket,
Err(err) => {
tracing::debug!(?err, "failed to bind UDP socket for routing detection");
continue;
}
};
if let Err(err) = socket.connect(addr) {
tracing::debug!(?err, "connect() failed for routing target {}", addr);
continue;
}
match socket.local_addr() {
Ok(std::net::SocketAddr::V4(local_addr)) => {
let ipv4 = *local_addr.ip();
if !ipv4.is_loopback() && !ipv4.is_unspecified() {
tracing::debug!(
"using IPv4 address from kernel routing (via {}): {}",
addr,
ipv4
);
return Ok(Some(ipv4));
}
}
Ok(_) => {
tracing::debug!("kernel routing returned IPv6 despite IPv4 bind, ignoring");
}
Err(err) => {
tracing::debug!(?err, "local_addr() failed for routing-based detection");
}
}
}
Ok(None)
}
#[derive(Clone, Debug, PartialEq, Eq)]
struct InterfaceIpv4 {
name: String,
addr: std::net::Ipv4Addr,
}
fn collect_ipv4_interfaces() -> anyhow::Result<Vec<InterfaceIpv4>> {
use if_addrs::get_if_addrs;
let mut interfaces = Vec::new();
for iface in get_if_addrs()? {
if let std::net::IpAddr::V4(ipv4) = iface.addr.ip() {
interfaces.push(InterfaceIpv4 {
name: iface.name,
addr: ipv4,
});
}
}
Ok(interfaces)
}
fn choose_best_ipv4(interfaces: &[InterfaceIpv4]) -> Option<std::net::Ipv4Addr> {
interfaces
.iter()
.filter(|iface| !iface.addr.is_unspecified())
.min_by_key(|iface| interface_priority(&iface.name, &iface.addr))
.map(|iface| iface.addr)
}
fn interface_priority(
name: &str,
addr: &std::net::Ipv4Addr,
) -> (InterfaceCategory, u8, u8, std::net::Ipv4Addr) {
(
classify_interface(name, addr),
if addr.is_link_local() { 1 } else { 0 },
if addr.is_private() { 1 } else { 0 },
*addr,
)
}
#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd)]
enum InterfaceCategory {
Preferred = 0,
Normal = 1,
Virtual = 2,
Loopback = 3,
}
fn classify_interface(name: &str, addr: &std::net::Ipv4Addr) -> InterfaceCategory {
if addr.is_loopback() {
return InterfaceCategory::Loopback;
}
let normalized = normalize_interface_name(name);
if is_virtual_interface(&normalized) {
return InterfaceCategory::Virtual;
}
if is_preferred_physical_interface(&normalized) {
return InterfaceCategory::Preferred;
}
InterfaceCategory::Normal
}
fn normalize_interface_name(original: &str) -> String {
let mut normalized = String::with_capacity(original.len());
for ch in original.chars() {
if ch.is_ascii_alphanumeric() {
normalized.push(ch.to_ascii_lowercase());
}
}
normalized
}
fn is_virtual_interface(name: &str) -> bool {
const VIRTUAL_PREFIXES: &[&str] = &[
"br",
"docker",
"veth",
"virbr",
"vmnet",
"wg",
"tailscale",
"zt",
"zerotier",
"tap",
"tun",
"utun",
"ham",
"vpn",
"lo",
"lxc",
];
VIRTUAL_PREFIXES
.iter()
.any(|prefix| name.starts_with(prefix))
|| name.contains("virtual")
}
fn is_preferred_physical_interface(name: &str) -> bool {
const PHYSICAL_PREFIXES: &[&str] = &[
"en", "eth", "em", "eno", "ens", "enp", "wl", "ww", "wlan", "ethernet", "lan", "wifi",
];
PHYSICAL_PREFIXES
.iter()
.any(|prefix| name.starts_with(prefix))
}
#[instrument]
pub fn get_random_server_name() -> String {
rand::random_iter::<u8>()
.filter(|b| b.is_ascii_alphanumeric())
.take(20)
.map(char::from)
.collect()
}
#[instrument(skip(config))]
pub async fn create_tcp_control_listener(
config: &TcpConfig,
bind_ip: Option<&str>,
) -> anyhow::Result<tokio::net::TcpListener> {
let bind_addr = if let Some(ip_str) = bind_ip {
let ip = ip_str
.parse::<std::net::IpAddr>()
.with_context(|| format!("invalid IP address: {}", ip_str))?;
std::net::SocketAddr::new(ip, 0)
} else {
std::net::SocketAddr::new(std::net::IpAddr::V4(std::net::Ipv4Addr::UNSPECIFIED), 0)
};
let listener = if let Some(ranges_str) = config.port_ranges.as_deref() {
let ranges = port_ranges::PortRanges::parse(ranges_str)?;
ranges.bind_tcp_listener(bind_addr.ip()).await?
} else {
tokio::net::TcpListener::bind(bind_addr).await?
};
let local_addr = listener.local_addr()?;
tracing::info!("TCP control listener bound to {}", local_addr);
Ok(listener)
}
#[instrument(skip(config))]
pub async fn create_tcp_data_listener(
config: &TcpConfig,
bind_ip: Option<&str>,
) -> anyhow::Result<tokio::net::TcpListener> {
let bind_addr = if let Some(ip_str) = bind_ip {
let ip = ip_str
.parse::<std::net::IpAddr>()
.with_context(|| format!("invalid IP address: {}", ip_str))?;
std::net::SocketAddr::new(ip, 0)
} else {
std::net::SocketAddr::new(std::net::IpAddr::V4(std::net::Ipv4Addr::UNSPECIFIED), 0)
};
let listener = if let Some(ranges_str) = config.port_ranges.as_deref() {
let ranges = port_ranges::PortRanges::parse(ranges_str)?;
ranges.bind_tcp_listener(bind_addr.ip()).await?
} else {
tokio::net::TcpListener::bind(bind_addr).await?
};
let local_addr = listener.local_addr()?;
tracing::info!("TCP data listener bound to {}", local_addr);
Ok(listener)
}
pub fn get_tcp_listener_addr(
listener: &tokio::net::TcpListener,
bind_ip: Option<&str>,
) -> anyhow::Result<std::net::SocketAddr> {
let local_addr = listener.local_addr()?;
if local_addr.ip().is_unspecified() {
let local_ip = get_local_ip(bind_ip).context("failed to get local IP address")?;
Ok(std::net::SocketAddr::new(local_ip, local_addr.port()))
} else {
Ok(local_addr)
}
}
#[instrument]
pub async fn connect_tcp_control(
addr: std::net::SocketAddr,
timeout_sec: u64,
) -> anyhow::Result<tokio::net::TcpStream> {
let stream = tokio::time::timeout(
std::time::Duration::from_secs(timeout_sec),
tokio::net::TcpStream::connect(addr),
)
.await
.with_context(|| format!("connection to {} timed out after {}s", addr, timeout_sec))?
.with_context(|| format!("failed to connect to {}", addr))?;
stream.set_nodelay(true)?;
tracing::debug!("connected to TCP control server at {}", addr);
Ok(stream)
}
pub fn configure_tcp_buffers(stream: &tokio::net::TcpStream, profile: NetworkProfile) {
use socket2::SockRef;
let (send_buf, recv_buf) = match profile {
NetworkProfile::Datacenter => (16 * 1024 * 1024, 16 * 1024 * 1024),
NetworkProfile::Internet => (2 * 1024 * 1024, 2 * 1024 * 1024),
};
let sock_ref = SockRef::from(stream);
if let Err(err) = sock_ref.set_send_buffer_size(send_buf) {
tracing::warn!("failed to set TCP send buffer size: {err:#}");
}
if let Err(err) = sock_ref.set_recv_buffer_size(recv_buf) {
tracing::warn!("failed to set TCP receive buffer size: {err:#}");
}
if let (Ok(send), Ok(recv)) = (sock_ref.send_buffer_size(), sock_ref.recv_buffer_size()) {
tracing::debug!(
"TCP socket buffer sizes: send={} recv={}",
bytesize::ByteSize(send as u64),
bytesize::ByteSize(recv as u64),
);
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::Mutex;
struct MockDiscoverySession {
test_responses: HashMap<String, bool>,
which_response: Option<String>,
home_response: Result<String, String>,
calls: Mutex<Vec<String>>,
}
impl Default for MockDiscoverySession {
fn default() -> Self {
Self {
test_responses: HashMap::new(),
which_response: None,
home_response: Err("HOME not set".to_string()),
calls: Mutex::new(Vec::new()),
}
}
}
impl MockDiscoverySession {
fn new() -> Self {
Self::default()
}
fn with_home(mut self, home: Option<&str>) -> Self {
self.home_response = match home {
Some(home) => Ok(home.to_string()),
None => Err("HOME not set".to_string()),
};
self
}
fn with_which(mut self, path: Option<&str>) -> Self {
self.which_response = path.map(|p| p.to_string());
self
}
fn set_test_response(&mut self, path: &str, exists: bool) {
self.test_responses.insert(path.to_string(), exists);
}
fn calls(&self) -> Vec<String> {
self.calls.lock().unwrap().clone()
}
}
impl DiscoverySession for MockDiscoverySession {
fn test_executable<'a>(
&'a self,
path: &'a str,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = anyhow::Result<bool>> + Send + 'a>>
{
self.calls.lock().unwrap().push(format!("test:{}", path));
let exists = self.test_responses.get(path).copied().unwrap_or(false);
Box::pin(async move { Ok(exists) })
}
fn which<'a>(
&'a self,
binary: &'a str,
) -> std::pin::Pin<
Box<dyn std::future::Future<Output = anyhow::Result<Option<String>>> + Send + 'a>,
> {
self.calls.lock().unwrap().push(format!("which:{}", binary));
let result = self.which_response.clone();
Box::pin(async move { Ok(result) })
}
fn remote_home<'a>(
&'a self,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = anyhow::Result<String>> + Send + 'a>>
{
self.calls.lock().unwrap().push("home".to_string());
let result = self.home_response.clone();
Box::pin(async move {
match result {
Ok(home) => Ok(home),
Err(e) => Err(anyhow::anyhow!(e)),
}
})
}
}
#[tokio::test]
async fn discover_rcpd_prefers_explicit_path() {
let mut session = MockDiscoverySession::new();
session.set_test_response("/opt/rcpd", true);
let path = discover_rcpd_path_internal(&session, Some("/opt/rcpd"), None)
.await
.expect("should return explicit path");
assert_eq!(path, "/opt/rcpd");
assert_eq!(session.calls(), vec!["test:/opt/rcpd"]);
}
#[tokio::test]
async fn discover_rcpd_explicit_path_errors_without_fallbacks() {
let session = MockDiscoverySession::new();
let err = discover_rcpd_path_internal(&session, Some("/missing/rcpd"), None)
.await
.expect_err("should fail when explicit path is missing");
assert!(
err.to_string()
.contains("rcpd binary not found or not executable"),
"unexpected error: {err}"
);
assert_eq!(session.calls(), vec!["test:/missing/rcpd"]);
}
#[tokio::test]
async fn discover_rcpd_uses_same_dir_first() {
let mut session = MockDiscoverySession::new();
session.set_test_response("/custom/bin/rcpd", true);
let path =
discover_rcpd_path_internal(&session, None, Some(PathBuf::from("/custom/bin/rcp")))
.await
.expect("should find in same directory");
assert_eq!(path, "/custom/bin/rcpd");
assert_eq!(session.calls(), vec!["test:/custom/bin/rcpd"]);
}
#[tokio::test]
async fn discover_rcpd_falls_back_to_path_after_same_dir() {
let mut session = MockDiscoverySession::new().with_which(Some("/usr/bin/rcpd"));
session.set_test_response("/custom/bin/rcpd", false);
let path =
discover_rcpd_path_internal(&session, None, Some(PathBuf::from("/custom/bin/rcp")))
.await
.expect("should find in PATH after same dir miss");
assert_eq!(path, "/usr/bin/rcpd");
assert_eq!(session.calls(), vec!["test:/custom/bin/rcpd", "which:rcpd"]);
}
#[tokio::test]
async fn discover_rcpd_uses_cache_last() {
let mut session = MockDiscoverySession::new()
.with_home(Some("/home/rcp"))
.with_which(None);
session.set_test_response("/custom/bin/rcpd", false);
let local_version = common::version::ProtocolVersion::current();
let cache_path = format!("/home/rcp/.cache/rcp/bin/rcpd-{}", local_version.semantic);
session.set_test_response(&cache_path, true);
let path =
discover_rcpd_path_internal(&session, None, Some(PathBuf::from("/custom/bin/rcp")))
.await
.expect("should fall back to cache");
assert_eq!(path, cache_path);
assert_eq!(
session.calls(),
vec![
"test:/custom/bin/rcpd".to_string(),
"which:rcpd".to_string(),
"home".to_string(),
format!("test:{cache_path}")
]
);
}
#[tokio::test]
async fn discover_rcpd_reports_home_missing_in_error() {
let mut session = MockDiscoverySession::new().with_which(None);
session.set_test_response("/custom/bin/rcpd", false);
let err =
discover_rcpd_path_internal(&session, None, Some(PathBuf::from("/custom/bin/rcp")))
.await
.expect_err("should fail when nothing is found");
let msg = err.to_string();
assert!(
msg.contains("Deployed cache: (skipped, HOME not set)"),
"expected searched list to mention skipped cache, got: {msg}"
);
assert_eq!(
session.calls(),
vec!["test:/custom/bin/rcpd", "which:rcpd", "home"]
);
}
#[test]
fn test_tokio_unstable_enabled() {
#[cfg(not(tokio_unstable))]
{
panic!(
"tokio_unstable cfg flag is not enabled! \
This is required for console-subscriber support. \
Check .cargo/config.toml"
);
}
#[cfg(tokio_unstable)]
{
let _join_set: tokio::task::JoinSet<()> = tokio::task::JoinSet::new();
}
}
fn iface(name: &str, addr: [u8; 4]) -> InterfaceIpv4 {
InterfaceIpv4 {
name: name.to_string(),
addr: std::net::Ipv4Addr::new(addr[0], addr[1], addr[2], addr[3]),
}
}
#[test]
fn choose_best_ipv4_prefers_physical_interfaces() {
let interfaces = vec![
iface("docker0", [172, 17, 0, 1]),
iface("enp3s0", [192, 168, 1, 44]),
iface("tailscale0", [100, 115, 92, 5]),
];
assert_eq!(
choose_best_ipv4(&interfaces),
Some(std::net::Ipv4Addr::new(192, 168, 1, 44))
);
}
#[test]
fn choose_best_ipv4_deprioritizes_link_local() {
let interfaces = vec![
iface("enp0s8", [169, 254, 10, 2]),
iface("wlan0", [10, 0, 0, 23]),
];
assert_eq!(
choose_best_ipv4(&interfaces),
Some(std::net::Ipv4Addr::new(10, 0, 0, 23))
);
}
#[test]
fn choose_best_ipv4_falls_back_to_loopback() {
let interfaces = vec![iface("lo", [127, 0, 0, 1]), iface("docker0", [0, 0, 0, 0])];
assert_eq!(
choose_best_ipv4(&interfaces),
Some(std::net::Ipv4Addr::new(127, 0, 0, 1))
);
}
#[test]
fn test_get_local_ip_with_explicit_ipv4() {
let result = get_local_ip(Some("192.168.1.100"));
assert!(result.is_ok(), "should accept valid IPv4 address");
let ip = result.unwrap();
assert_eq!(
ip,
std::net::IpAddr::V4(std::net::Ipv4Addr::new(192, 168, 1, 100))
);
}
#[test]
fn test_get_local_ip_with_explicit_loopback() {
let result = get_local_ip(Some("127.0.0.1"));
assert!(result.is_ok(), "should accept loopback address");
let ip = result.unwrap();
assert_eq!(
ip,
std::net::IpAddr::V4(std::net::Ipv4Addr::new(127, 0, 0, 1))
);
}
#[test]
fn test_get_local_ip_rejects_ipv6() {
let result = get_local_ip(Some("::1"));
assert!(result.is_err(), "should reject IPv6 address");
let err = result.unwrap_err();
let err_msg = format!("{err:#}");
assert!(
err_msg.contains("IPv6 address not supported"),
"error should mention IPv6 not supported, got: {err_msg}"
);
assert!(
err_msg.contains("0.0.0.0"),
"error should mention IPv4-only binding, got: {err_msg}"
);
}
#[test]
fn test_get_local_ip_rejects_ipv6_full() {
let result = get_local_ip(Some("2001:db8::1"));
assert!(result.is_err(), "should reject IPv6 address");
let err = result.unwrap_err();
let err_msg = format!("{err:#}");
assert!(
err_msg.contains("IPv6 address not supported"),
"error should mention IPv6 not supported, got: {err_msg}"
);
}
#[test]
fn test_get_local_ip_rejects_invalid_ip() {
let result = get_local_ip(Some("not-an-ip"));
assert!(result.is_err(), "should reject invalid IP format");
let err = result.unwrap_err();
let err_msg = format!("{err:#}");
assert!(
err_msg.contains("invalid IP address"),
"error should mention invalid IP address, got: {err_msg}"
);
}
#[test]
fn test_get_local_ip_rejects_invalid_ipv4() {
let result = get_local_ip(Some("999.999.999.999"));
assert!(result.is_err(), "should reject invalid IPv4 address");
let err = result.unwrap_err();
let err_msg = format!("{err:#}");
assert!(
err_msg.contains("invalid IP address"),
"error should mention invalid IP address, got: {err_msg}"
);
}
}