1#[cfg(not(tokio_unstable))]
119compile_error!("tokio_unstable cfg must be enabled; see .cargo/config.toml");
120
121use anyhow::{Context, anyhow};
122use tracing::instrument;
123
124pub mod deploy;
125pub mod port_ranges;
126pub mod protocol;
127pub mod streams;
128pub mod tls;
129pub mod tracelog;
130
131#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, serde::Serialize, serde::Deserialize)]
137pub enum NetworkProfile {
138 #[default]
141 Datacenter,
142 Internet,
145}
146
147impl std::fmt::Display for NetworkProfile {
148 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
149 match self {
150 Self::Datacenter => write!(f, "datacenter"),
151 Self::Internet => write!(f, "internet"),
152 }
153 }
154}
155
156impl std::str::FromStr for NetworkProfile {
157 type Err = String;
158 fn from_str(s: &str) -> Result<Self, Self::Err> {
159 match s.to_lowercase().as_str() {
160 "datacenter" => Ok(Self::Datacenter),
161 "internet" => Ok(Self::Internet),
162 _ => Err(format!(
163 "invalid network profile '{}', expected 'datacenter' or 'internet'",
164 s
165 )),
166 }
167 }
168}
169
170pub const DATACENTER_REMOTE_COPY_BUFFER_SIZE: usize = 16 * 1024 * 1024;
172
173pub const INTERNET_REMOTE_COPY_BUFFER_SIZE: usize = 2 * 1024 * 1024;
175
176impl NetworkProfile {
177 pub fn default_remote_copy_buffer_size(&self) -> usize {
183 match self {
184 Self::Datacenter => DATACENTER_REMOTE_COPY_BUFFER_SIZE,
185 Self::Internet => INTERNET_REMOTE_COPY_BUFFER_SIZE,
186 }
187 }
188}
189
190#[derive(Debug, Clone)]
194pub struct TcpConfig {
195 pub port_ranges: Option<String>,
197 pub conn_timeout_sec: u64,
199 pub network_profile: NetworkProfile,
201 pub buffer_size: Option<usize>,
203 pub max_connections: usize,
205 pub pending_writes_multiplier: usize,
207}
208
209pub const DEFAULT_PENDING_WRITES_MULTIPLIER: usize = 4;
211
212impl Default for TcpConfig {
213 fn default() -> Self {
214 Self {
215 port_ranges: None,
216 conn_timeout_sec: 15,
217 network_profile: NetworkProfile::default(),
218 buffer_size: None,
219 max_connections: 100,
220 pending_writes_multiplier: DEFAULT_PENDING_WRITES_MULTIPLIER,
221 }
222 }
223}
224
225impl TcpConfig {
226 pub fn with_timeout(conn_timeout_sec: u64) -> Self {
228 Self {
229 port_ranges: None,
230 conn_timeout_sec,
231 network_profile: NetworkProfile::default(),
232 buffer_size: None,
233 max_connections: 100,
234 pending_writes_multiplier: DEFAULT_PENDING_WRITES_MULTIPLIER,
235 }
236 }
237 pub fn with_port_ranges(mut self, ranges: impl Into<String>) -> Self {
239 self.port_ranges = Some(ranges.into());
240 self
241 }
242 pub fn with_network_profile(mut self, profile: NetworkProfile) -> Self {
244 self.network_profile = profile;
245 self
246 }
247 pub fn with_buffer_size(mut self, size: usize) -> Self {
249 self.buffer_size = Some(size);
250 self
251 }
252 pub fn with_max_connections(mut self, max: usize) -> Self {
254 self.max_connections = max;
255 self
256 }
257 pub fn with_pending_writes_multiplier(mut self, multiplier: usize) -> Self {
259 self.pending_writes_multiplier = multiplier;
260 self
261 }
262 pub fn effective_buffer_size(&self) -> usize {
264 self.buffer_size
265 .unwrap_or_else(|| self.network_profile.default_remote_copy_buffer_size())
266 }
267}
268
269#[derive(Debug, PartialEq, Eq, Clone, Hash)]
270pub struct SshSession {
271 pub user: Option<String>,
272 pub host: String,
273 pub port: Option<u16>,
274}
275
276impl SshSession {
277 pub fn local() -> Self {
278 Self {
279 user: None,
280 host: "localhost".to_string(),
281 port: None,
282 }
283 }
284}
285
286pub use common::is_localhost;
288
289async fn setup_ssh_session(
290 session: &SshSession,
291) -> anyhow::Result<std::sync::Arc<openssh::Session>> {
292 let host = session.host.as_str();
293 let destination = match (session.user.as_deref(), session.port) {
294 (Some(user), Some(port)) => format!("ssh://{user}@{host}:{port}"),
295 (None, Some(port)) => format!("ssh://{}:{}", session.host, port),
296 (Some(user), None) => format!("ssh://{user}@{host}"),
297 (None, None) => format!("ssh://{host}"),
298 };
299 tracing::debug!("Connecting to SSH destination: {}", destination);
300 let session = std::sync::Arc::new(
301 openssh::Session::connect(destination, openssh::KnownHosts::Accept)
302 .await
303 .context("Failed to establish SSH connection")?,
304 );
305 Ok(session)
306}
307
308#[instrument]
309pub async fn get_remote_home_for_session(
310 session: &SshSession,
311) -> anyhow::Result<std::path::PathBuf> {
312 let ssh_session = setup_ssh_session(session).await?;
313 let home = get_remote_home(&ssh_session).await?;
314 Ok(std::path::PathBuf::from(home))
315}
316
317#[instrument]
318pub async fn wait_for_rcpd_process(
319 process: openssh::Child<std::sync::Arc<openssh::Session>>,
320) -> anyhow::Result<()> {
321 tracing::info!("Waiting on rcpd server on: {:?}", process);
322 let output = tokio::time::timeout(
324 std::time::Duration::from_secs(10),
325 process.wait_with_output(),
326 )
327 .await
328 .context("Timeout waiting for rcpd process to exit")?
329 .context("Failed to wait for rcpd process")?;
330 if !output.status.success() {
331 let stdout = String::from_utf8_lossy(&output.stdout);
332 let stderr = String::from_utf8_lossy(&output.stderr);
333 tracing::error!(
334 "rcpd command failed on remote host, status code: {:?}\nstdout:\n{}\nstderr:\n{}",
335 output.status.code(),
336 stdout,
337 stderr
338 );
339 return Err(anyhow!(
340 "rcpd command failed on remote host, status code: {:?}",
341 output.status.code(),
342 ));
343 }
344 if !output.stderr.is_empty() {
346 let stderr = String::from_utf8_lossy(&output.stderr);
347 tracing::debug!("rcpd stderr output:\n{}", stderr);
348 }
349 Ok(())
350}
351
352pub(crate) fn shell_escape(s: &str) -> String {
356 format!("'{}'", s.replace('\'', r"'\''"))
357}
358
359pub async fn get_remote_home(session: &std::sync::Arc<openssh::Session>) -> anyhow::Result<String> {
377 if let Ok(home_override) = std::env::var("RCP_REMOTE_HOME_OVERRIDE")
378 && !home_override.is_empty()
379 {
380 return Ok(home_override);
381 }
382 let output = session
383 .command("sh")
384 .arg("-c")
385 .arg("echo \"${HOME:?HOME not set}\"")
386 .output()
387 .await
388 .context("failed to check HOME environment variable on remote host")?;
389
390 if !output.status.success() {
391 let stderr = String::from_utf8_lossy(&output.stderr);
392 anyhow::bail!(
393 "HOME environment variable is not set on remote host\n\
394 \n\
395 stderr: {}\n\
396 \n\
397 The HOME environment variable is required for rcpd deployment and discovery.\n\
398 Please ensure your SSH configuration preserves environment variables.",
399 stderr
400 );
401 }
402
403 let home = String::from_utf8_lossy(&output.stdout).trim().to_string();
404
405 if home.is_empty() {
406 anyhow::bail!(
407 "HOME environment variable is empty on remote host\n\
408 \n\
409 The HOME environment variable is required for rcpd deployment and discovery.\n\
410 Please ensure your SSH configuration sets HOME correctly."
411 );
412 }
413
414 Ok(home)
415}
416
417#[cfg(test)]
418mod shell_escape_tests {
419 use super::*;
420
421 #[test]
422 fn test_shell_escape_simple() {
423 assert_eq!(shell_escape("simple"), "'simple'");
424 }
425
426 #[test]
427 fn test_shell_escape_with_spaces() {
428 assert_eq!(shell_escape("path with spaces"), "'path with spaces'");
429 }
430
431 #[test]
432 fn test_shell_escape_with_single_quote() {
433 assert_eq!(
435 shell_escape("path'with'quotes"),
436 r"'path'\''with'\''quotes'"
437 );
438 }
439
440 #[test]
441 fn test_shell_escape_injection_attempt() {
442 assert_eq!(shell_escape("foo; rm -rf /"), "'foo; rm -rf /'");
444 }
446
447 #[test]
448 fn test_shell_escape_special_chars() {
449 assert_eq!(shell_escape("$PATH && echo pwned"), "'$PATH && echo pwned'");
450 }
452}
453
454trait DiscoverySession {
455 fn test_executable<'a>(
456 &'a self,
457 path: &'a str,
458 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = anyhow::Result<bool>> + Send + 'a>>;
459 fn which<'a>(
460 &'a self,
461 binary: &'a str,
462 ) -> std::pin::Pin<
463 Box<dyn std::future::Future<Output = anyhow::Result<Option<String>>> + Send + 'a>,
464 >;
465 fn remote_home<'a>(
466 &'a self,
467 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = anyhow::Result<String>> + Send + 'a>>;
468}
469
470struct RealDiscoverySession<'a> {
471 session: &'a std::sync::Arc<openssh::Session>,
472}
473
474impl<'a> DiscoverySession for RealDiscoverySession<'a> {
475 fn test_executable<'b>(
476 &'b self,
477 path: &'b str,
478 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = anyhow::Result<bool>> + Send + 'b>>
479 {
480 Box::pin(async move {
481 let output = self
482 .session
483 .command("sh")
484 .arg("-c")
485 .arg(format!("test -x {}", shell_escape(path)))
486 .output()
487 .await?;
488 Ok(output.status.success())
489 })
490 }
491 fn which<'b>(
492 &'b self,
493 binary: &'b str,
494 ) -> std::pin::Pin<
495 Box<dyn std::future::Future<Output = anyhow::Result<Option<String>>> + Send + 'b>,
496 > {
497 Box::pin(async move {
498 let output = self.session.command("which").arg(binary).output().await?;
499 if output.status.success() {
500 let path = String::from_utf8_lossy(&output.stdout).trim().to_string();
501 if !path.is_empty() {
502 return Ok(Some(path));
503 }
504 }
505 Ok(None)
506 })
507 }
508 fn remote_home<'b>(
509 &'b self,
510 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = anyhow::Result<String>> + Send + 'b>>
511 {
512 Box::pin(get_remote_home(self.session))
513 }
514}
515
516async fn discover_rcpd_path(
529 session: &std::sync::Arc<openssh::Session>,
530 explicit_path: Option<&str>,
531) -> anyhow::Result<String> {
532 let real_session = RealDiscoverySession { session };
533 discover_rcpd_path_internal(&real_session, explicit_path, None).await
534}
535
536async fn discover_rcpd_path_internal<S: DiscoverySession + ?Sized>(
537 session: &S,
538 explicit_path: Option<&str>,
539 current_exe_override: Option<std::path::PathBuf>,
540) -> anyhow::Result<String> {
541 let local_version = common::version::ProtocolVersion::current();
542 if let Some(path) = explicit_path {
544 tracing::debug!("Trying explicit rcpd path: {}", path);
545 if session.test_executable(path).await? {
546 tracing::info!("Found rcpd at explicit path: {}", path);
547 return Ok(path.to_string());
548 }
549 return Err(anyhow::anyhow!(
552 "rcpd binary not found or not executable at explicit path: {}",
553 path
554 ));
555 }
556 if let Ok(current_exe) = current_exe_override
558 .map(Ok)
559 .unwrap_or_else(std::env::current_exe)
560 && let Some(bin_dir) = current_exe.parent()
561 {
562 let path = bin_dir.join("rcpd").display().to_string();
563 tracing::debug!("Trying same directory as rcp: {}", path);
564 if session.test_executable(&path).await? {
565 tracing::info!("Found rcpd in same directory as rcp: {}", path);
566 return Ok(path);
567 }
568 }
569 tracing::debug!("Trying to find rcpd in PATH");
571 if let Some(path) = session.which("rcpd").await? {
572 tracing::info!("Found rcpd in PATH: {}", path);
573 return Ok(path);
574 }
575 let cache_path = match session.remote_home().await {
578 Ok(home) => {
579 let path = format!("{}/.cache/rcp/bin/rcpd-{}", home, local_version.semantic);
580 tracing::debug!("Trying deployed cache path: {}", path);
581 if session.test_executable(&path).await? {
582 tracing::info!("Found rcpd in deployed cache: {}", path);
583 return Ok(path);
584 }
585 Some(path)
586 }
587 Err(e) => {
588 tracing::debug!(
589 "HOME not set on remote host, skipping cache directory check: {:#}",
590 e
591 );
592 None
593 }
594 };
595 let mut searched = vec![];
597 searched.push("- Same directory as local rcp binary".to_string());
598 searched.push("- PATH (via 'which rcpd')".to_string());
599 if let Some(path) = cache_path.as_ref() {
600 searched.push(format!("- Deployed cache: {}", path));
601 } else {
602 searched.push("- Deployed cache: (skipped, HOME not set)".to_string());
603 }
604 if let Some(path) = explicit_path {
605 searched.insert(
606 0,
607 format!("- Explicit path: {} (not found or not executable)", path),
608 );
609 }
610 Err(anyhow::anyhow!(
611 "rcpd binary not found on remote host\n\
612 \n\
613 Searched in:\n\
614 {}\n\
615 \n\
616 Options:\n\
617 - Use automatic deployment: rcp --auto-deploy-rcpd ...\n\
618 - Install rcpd manually: cargo install rcp-tools-rcp --version {}\n\
619 - Specify explicit path: rcp --rcpd-path=/path/to/rcpd ...",
620 searched.join("\n"),
621 local_version.semantic
622 ))
623}
624
625async fn try_discover_and_check_version(
630 session: &std::sync::Arc<openssh::Session>,
631 explicit_path: Option<&str>,
632 remote_host: &str,
633) -> anyhow::Result<String> {
634 let rcpd_path = discover_rcpd_path(session, explicit_path).await?;
636 check_rcpd_version(session, &rcpd_path, remote_host).await?;
638 Ok(rcpd_path)
639}
640
641async fn check_rcpd_version(
645 session: &std::sync::Arc<openssh::Session>,
646 rcpd_path: &str,
647 remote_host: &str,
648) -> anyhow::Result<()> {
649 let local_version = common::version::ProtocolVersion::current();
650
651 tracing::debug!("Checking rcpd version on remote host: {}", remote_host);
652
653 let output = session
655 .command(rcpd_path)
656 .arg("--protocol-version")
657 .output()
658 .await
659 .context("Failed to execute rcpd --protocol-version on remote host")?;
660
661 if !output.status.success() {
662 let stderr = String::from_utf8_lossy(&output.stderr);
663 return Err(anyhow::anyhow!(
664 "rcpd --protocol-version failed on remote host '{}'\n\
665 \n\
666 stderr: {}\n\
667 \n\
668 This may indicate an old version of rcpd that does not support --protocol-version.\n\
669 Please install a matching version of rcpd on the remote host:\n\
670 - cargo install rcp-tools-rcp --version {}",
671 remote_host,
672 stderr,
673 local_version.semantic
674 ));
675 }
676
677 let stdout = String::from_utf8_lossy(&output.stdout);
678 let remote_version = common::version::ProtocolVersion::from_json(stdout.trim())
679 .context("Failed to parse rcpd version JSON from remote host")?;
680
681 tracing::info!(
682 "Local version: {}, Remote version: {}",
683 local_version,
684 remote_version
685 );
686
687 if !local_version.is_compatible_with(&remote_version) {
688 return Err(anyhow::anyhow!(
689 "rcpd version mismatch\n\
690 \n\
691 Local: rcp {}\n\
692 Remote: rcpd {} on host '{}'\n\
693 \n\
694 The rcpd version on the remote host must exactly match the rcp version.\n\
695 \n\
696 To fix this, install the matching version on the remote host:\n\
697 - ssh {} 'cargo install rcp-tools-rcp --version {}'",
698 local_version,
699 remote_version,
700 remote_host,
701 shell_escape(remote_host),
702 local_version.semantic
703 ));
704 }
705
706 Ok(())
707}
708
709#[derive(Debug, Clone)]
711pub struct RcpdConnectionInfo {
712 pub addr: std::net::SocketAddr,
714 pub fingerprint: Option<tls::Fingerprint>,
716}
717
718pub struct RcpdProcess {
720 pub child: openssh::Child<std::sync::Arc<openssh::Session>>,
722 pub conn_info: RcpdConnectionInfo,
724 _stderr_drain: tokio::task::JoinHandle<()>,
726 _stdout_drain: Option<tokio::task::JoinHandle<()>>,
728}
729
730#[allow(clippy::too_many_arguments)]
731#[instrument]
732pub async fn start_rcpd(
733 rcpd_config: &protocol::RcpdConfig,
734 session: &SshSession,
735 explicit_rcpd_path: Option<&str>,
736 auto_deploy_rcpd: bool,
737 bind_ip: Option<&str>,
738 role: protocol::RcpdRole,
739) -> anyhow::Result<RcpdProcess> {
740 tracing::info!("Starting rcpd server on: {:?}", session);
741 let remote_host = &session.host;
742 let ssh_session = setup_ssh_session(session).await?;
743 let rcpd_path =
745 match try_discover_and_check_version(&ssh_session, explicit_rcpd_path, remote_host).await {
746 Ok(path) => {
747 path
749 }
750 Err(e) => {
751 if auto_deploy_rcpd {
753 tracing::info!(
754 "rcpd not found or version mismatch, attempting auto-deployment"
755 );
756 let local_rcpd = deploy::find_local_rcpd_binary()
758 .context("failed to find local rcpd binary for deployment")?;
759 tracing::info!("Found local rcpd binary at {}", local_rcpd.display());
760 let local_version = common::version::ProtocolVersion::current();
762 let deployed_path = deploy::deploy_rcpd(
764 &ssh_session,
765 &local_rcpd,
766 &local_version.semantic,
767 remote_host,
768 )
769 .await
770 .context("failed to deploy rcpd to remote host")?;
771 tracing::info!("Successfully deployed rcpd to {}", deployed_path);
772 if let Err(e) = deploy::cleanup_old_versions(&ssh_session, 3).await {
774 tracing::warn!("failed to cleanup old versions (non-fatal): {:#}", e);
775 }
776 deployed_path
777 } else {
778 return Err(e);
780 }
781 }
782 };
783 let rcpd_args = rcpd_config.to_args();
785 tracing::debug!("rcpd arguments: {:?}", rcpd_args);
786 let mut cmd = ssh_session.arc_command(&rcpd_path);
787 cmd.arg("--role").arg(role.to_string()).args(rcpd_args);
788 if let Some(ip) = bind_ip {
790 tracing::debug!("passing --bind-ip {} to rcpd", ip);
791 cmd.arg("--bind-ip").arg(ip);
792 }
793 cmd.stdin(openssh::Stdio::piped());
796 cmd.stdout(openssh::Stdio::piped());
797 cmd.stderr(openssh::Stdio::piped());
798 tracing::info!("Will run remotely: {cmd:?}");
799 let mut child = cmd.spawn().await.context("Failed to spawn rcpd command")?;
800 let stderr = child.stderr().take().context("rcpd stderr not available")?;
805 let mut stderr_reader = tokio::io::BufReader::new(stderr);
806 let mut line = String::new();
807 use tokio::io::AsyncBufReadExt;
808 stderr_reader
809 .read_line(&mut line)
810 .await
811 .context("failed to read connection info from rcpd")?;
812 let line = line.trim();
813 let host_stderr = session.host.clone();
816 let stderr_drain = tokio::spawn(async move {
817 let mut line = String::new();
818 loop {
819 line.clear();
820 match stderr_reader.read_line(&mut line).await {
821 Ok(0) => break, Ok(_) => {
823 let trimmed = line.trim();
824 if !trimmed.is_empty() {
825 tracing::debug!(host = %host_stderr, "rcpd stderr: {}", trimmed);
826 }
827 }
828 Err(e) => {
829 tracing::debug!(host = %host_stderr, "rcpd stderr read error: {:#}", e);
830 break;
831 }
832 }
833 }
834 });
835 let stdout_drain = if let Some(stdout) = child.stdout().take() {
838 let host_stdout = session.host.clone();
839 let mut stdout_reader = tokio::io::BufReader::new(stdout);
840 Some(tokio::spawn(async move {
841 let mut line = String::new();
842 loop {
843 line.clear();
844 match stdout_reader.read_line(&mut line).await {
845 Ok(0) => break, Ok(_) => {
847 let trimmed = line.trim();
848 if !trimmed.is_empty() {
849 tracing::debug!(host = %host_stdout, "rcpd stdout: {}", trimmed);
850 }
851 }
852 Err(e) => {
853 tracing::debug!(host = %host_stdout, "rcpd stdout read error: {:#}", e);
854 break;
855 }
856 }
857 }
858 }))
859 } else {
860 None
861 };
862 tracing::debug!("rcpd connection line: {}", line);
863 let conn_info = if let Some(rest) = line.strip_prefix("RCP_TLS ") {
864 let parts: Vec<&str> = rest.split_whitespace().collect();
866 if parts.len() != 2 {
867 anyhow::bail!("invalid RCP_TLS line from rcpd: {}", line);
868 }
869 let addr = parts[0]
870 .parse()
871 .with_context(|| format!("invalid address in RCP_TLS line: {}", parts[0]))?;
872 let fingerprint = tls::fingerprint_from_hex(parts[1])
873 .with_context(|| format!("invalid fingerprint in RCP_TLS line: {}", parts[1]))?;
874 RcpdConnectionInfo {
875 addr,
876 fingerprint: Some(fingerprint),
877 }
878 } else if let Some(rest) = line.strip_prefix("RCP_TCP ") {
879 let addr = rest
881 .trim()
882 .parse()
883 .with_context(|| format!("invalid address in RCP_TCP line: {}", rest))?;
884 RcpdConnectionInfo {
885 addr,
886 fingerprint: None,
887 }
888 } else {
889 anyhow::bail!(
890 "unexpected output from rcpd (expected RCP_TLS or RCP_TCP): {}",
891 line
892 );
893 };
894 tracing::info!(
895 "rcpd listening on {} (encryption={})",
896 conn_info.addr,
897 conn_info.fingerprint.is_some()
898 );
899 Ok(RcpdProcess {
900 child,
901 conn_info,
902 _stderr_drain: stderr_drain,
903 _stdout_drain: stdout_drain,
904 })
905}
906
907fn get_local_ip(explicit_bind_ip: Option<&str>) -> anyhow::Result<std::net::IpAddr> {
912 if let Some(ip_str) = explicit_bind_ip {
914 let ip = ip_str
915 .parse::<std::net::IpAddr>()
916 .with_context(|| format!("invalid IP address: {}", ip_str))?;
917 match ip {
918 std::net::IpAddr::V4(ipv4) => {
919 tracing::debug!("using explicit bind IP: {}", ipv4);
920 return Ok(std::net::IpAddr::V4(ipv4));
921 }
922 std::net::IpAddr::V6(_) => {
923 anyhow::bail!(
924 "IPv6 address not supported for binding (got {}). \
925 TCP endpoints bind to 0.0.0.0 (IPv4 only)",
926 ip
927 );
928 }
929 }
930 }
931 if let Some(ipv4) = try_ipv4_via_kernel_routing()? {
933 return Ok(std::net::IpAddr::V4(ipv4));
934 }
935 tracing::debug!("routing-based detection failed, falling back to interface enumeration");
937 let interfaces = collect_ipv4_interfaces().context("Failed to enumerate network interfaces")?;
938 if let Some(ipv4) = choose_best_ipv4(&interfaces) {
939 tracing::debug!("using IPv4 address from interface scan: {}", ipv4);
940 return Ok(std::net::IpAddr::V4(ipv4));
941 }
942 anyhow::bail!("No IPv4 interfaces found (TCP endpoints require IPv4 as they bind to 0.0.0.0)")
943}
944
945fn try_ipv4_via_kernel_routing() -> anyhow::Result<Option<std::net::Ipv4Addr>> {
946 let private_ips = ["10.0.0.1:80", "172.16.0.1:80", "192.168.1.1:80"];
949 for addr_str in &private_ips {
950 let addr = addr_str
951 .parse::<std::net::SocketAddr>()
952 .expect("hardcoded socket addresses are valid");
953 let socket = match std::net::UdpSocket::bind("0.0.0.0:0") {
954 Ok(socket) => socket,
955 Err(err) => {
956 tracing::debug!(?err, "failed to bind UDP socket for routing detection");
957 continue;
958 }
959 };
960 if let Err(err) = socket.connect(addr) {
961 tracing::debug!(?err, "connect() failed for routing target {}", addr);
962 continue;
963 }
964 match socket.local_addr() {
965 Ok(std::net::SocketAddr::V4(local_addr)) => {
966 let ipv4 = *local_addr.ip();
967 if !ipv4.is_loopback() && !ipv4.is_unspecified() {
968 tracing::debug!(
969 "using IPv4 address from kernel routing (via {}): {}",
970 addr,
971 ipv4
972 );
973 return Ok(Some(ipv4));
974 }
975 }
976 Ok(_) => {
977 tracing::debug!("kernel routing returned IPv6 despite IPv4 bind, ignoring");
978 }
979 Err(err) => {
980 tracing::debug!(?err, "local_addr() failed for routing-based detection");
981 }
982 }
983 }
984 Ok(None)
985}
986
987#[derive(Clone, Debug, PartialEq, Eq)]
988struct InterfaceIpv4 {
989 name: String,
990 addr: std::net::Ipv4Addr,
991}
992
993fn collect_ipv4_interfaces() -> anyhow::Result<Vec<InterfaceIpv4>> {
994 use if_addrs::get_if_addrs;
995 let mut interfaces = Vec::new();
996 for iface in get_if_addrs()? {
997 if let std::net::IpAddr::V4(ipv4) = iface.addr.ip() {
998 interfaces.push(InterfaceIpv4 {
999 name: iface.name,
1000 addr: ipv4,
1001 });
1002 }
1003 }
1004 Ok(interfaces)
1005}
1006
1007fn choose_best_ipv4(interfaces: &[InterfaceIpv4]) -> Option<std::net::Ipv4Addr> {
1008 interfaces
1009 .iter()
1010 .filter(|iface| !iface.addr.is_unspecified())
1011 .min_by_key(|iface| interface_priority(&iface.name, &iface.addr))
1012 .map(|iface| iface.addr)
1013}
1014
1015fn interface_priority(
1016 name: &str,
1017 addr: &std::net::Ipv4Addr,
1018) -> (InterfaceCategory, u8, u8, std::net::Ipv4Addr) {
1019 (
1020 classify_interface(name, addr),
1021 if addr.is_link_local() { 1 } else { 0 },
1022 if addr.is_private() { 1 } else { 0 },
1023 *addr,
1024 )
1025}
1026
1027#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd)]
1028enum InterfaceCategory {
1029 Preferred = 0,
1030 Normal = 1,
1031 Virtual = 2,
1032 Loopback = 3,
1033}
1034
1035fn classify_interface(name: &str, addr: &std::net::Ipv4Addr) -> InterfaceCategory {
1036 if addr.is_loopback() {
1037 return InterfaceCategory::Loopback;
1038 }
1039 let normalized = normalize_interface_name(name);
1040 if is_virtual_interface(&normalized) {
1041 return InterfaceCategory::Virtual;
1042 }
1043 if is_preferred_physical_interface(&normalized) {
1044 return InterfaceCategory::Preferred;
1045 }
1046 InterfaceCategory::Normal
1047}
1048
1049fn normalize_interface_name(original: &str) -> String {
1050 let mut normalized = String::with_capacity(original.len());
1051 for ch in original.chars() {
1052 if ch.is_ascii_alphanumeric() {
1053 normalized.push(ch.to_ascii_lowercase());
1054 }
1055 }
1056 normalized
1057}
1058
1059fn is_virtual_interface(name: &str) -> bool {
1060 const VIRTUAL_PREFIXES: &[&str] = &[
1061 "br",
1062 "docker",
1063 "veth",
1064 "virbr",
1065 "vmnet",
1066 "wg",
1067 "tailscale",
1068 "zt",
1069 "zerotier",
1070 "tap",
1071 "tun",
1072 "utun",
1073 "ham",
1074 "vpn",
1075 "lo",
1076 "lxc",
1077 ];
1078 VIRTUAL_PREFIXES
1079 .iter()
1080 .any(|prefix| name.starts_with(prefix))
1081 || name.contains("virtual")
1082}
1083
1084fn is_preferred_physical_interface(name: &str) -> bool {
1085 const PHYSICAL_PREFIXES: &[&str] = &[
1086 "en", "eth", "em", "eno", "ens", "enp", "wl", "ww", "wlan", "ethernet", "lan", "wifi",
1087 ];
1088 PHYSICAL_PREFIXES
1089 .iter()
1090 .any(|prefix| name.starts_with(prefix))
1091}
1092
1093#[instrument]
1095pub fn get_random_server_name() -> String {
1096 rand::random_iter::<u8>()
1097 .filter(|b| b.is_ascii_alphanumeric())
1098 .take(20)
1099 .map(char::from)
1100 .collect()
1101}
1102
1103#[instrument(skip(config))]
1111pub async fn create_tcp_control_listener(
1112 config: &TcpConfig,
1113 bind_ip: Option<&str>,
1114) -> anyhow::Result<tokio::net::TcpListener> {
1115 let bind_addr = if let Some(ip_str) = bind_ip {
1116 let ip = ip_str
1117 .parse::<std::net::IpAddr>()
1118 .with_context(|| format!("invalid IP address: {}", ip_str))?;
1119 std::net::SocketAddr::new(ip, 0)
1120 } else {
1121 std::net::SocketAddr::new(std::net::IpAddr::V4(std::net::Ipv4Addr::UNSPECIFIED), 0)
1122 };
1123 let listener = if let Some(ranges_str) = config.port_ranges.as_deref() {
1124 let ranges = port_ranges::PortRanges::parse(ranges_str)?;
1125 ranges.bind_tcp_listener(bind_addr.ip()).await?
1126 } else {
1127 tokio::net::TcpListener::bind(bind_addr).await?
1128 };
1129 let local_addr = listener.local_addr()?;
1130 tracing::info!("TCP control listener bound to {}", local_addr);
1131 Ok(listener)
1132}
1133
1134#[instrument(skip(config))]
1138pub async fn create_tcp_data_listener(
1139 config: &TcpConfig,
1140 bind_ip: Option<&str>,
1141) -> anyhow::Result<tokio::net::TcpListener> {
1142 let bind_addr = if let Some(ip_str) = bind_ip {
1143 let ip = ip_str
1144 .parse::<std::net::IpAddr>()
1145 .with_context(|| format!("invalid IP address: {}", ip_str))?;
1146 std::net::SocketAddr::new(ip, 0)
1147 } else {
1148 std::net::SocketAddr::new(std::net::IpAddr::V4(std::net::Ipv4Addr::UNSPECIFIED), 0)
1149 };
1150 let listener = if let Some(ranges_str) = config.port_ranges.as_deref() {
1151 let ranges = port_ranges::PortRanges::parse(ranges_str)?;
1152 ranges.bind_tcp_listener(bind_addr.ip()).await?
1153 } else {
1154 tokio::net::TcpListener::bind(bind_addr).await?
1155 };
1156 let local_addr = listener.local_addr()?;
1157 tracing::info!("TCP data listener bound to {}", local_addr);
1158 Ok(listener)
1159}
1160
1161pub fn get_tcp_listener_addr(
1165 listener: &tokio::net::TcpListener,
1166 bind_ip: Option<&str>,
1167) -> anyhow::Result<std::net::SocketAddr> {
1168 let local_addr = listener.local_addr()?;
1169 if local_addr.ip().is_unspecified() {
1170 let local_ip = get_local_ip(bind_ip).context("failed to get local IP address")?;
1171 Ok(std::net::SocketAddr::new(local_ip, local_addr.port()))
1172 } else {
1173 Ok(local_addr)
1174 }
1175}
1176
1177#[instrument]
1179pub async fn connect_tcp_control(
1180 addr: std::net::SocketAddr,
1181 timeout_sec: u64,
1182) -> anyhow::Result<tokio::net::TcpStream> {
1183 let stream = tokio::time::timeout(
1184 std::time::Duration::from_secs(timeout_sec),
1185 tokio::net::TcpStream::connect(addr),
1186 )
1187 .await
1188 .with_context(|| format!("connection to {} timed out after {}s", addr, timeout_sec))?
1189 .with_context(|| format!("failed to connect to {}", addr))?;
1190 stream.set_nodelay(true)?;
1191 tracing::debug!("connected to TCP control server at {}", addr);
1192 Ok(stream)
1193}
1194
1195pub fn configure_tcp_buffers(stream: &tokio::net::TcpStream, profile: NetworkProfile) {
1199 use socket2::SockRef;
1200 let (send_buf, recv_buf) = match profile {
1201 NetworkProfile::Datacenter => (16 * 1024 * 1024, 16 * 1024 * 1024),
1202 NetworkProfile::Internet => (2 * 1024 * 1024, 2 * 1024 * 1024),
1203 };
1204 let sock_ref = SockRef::from(stream);
1205 if let Err(err) = sock_ref.set_send_buffer_size(send_buf) {
1206 tracing::warn!("failed to set TCP send buffer size: {err:#}");
1207 }
1208 if let Err(err) = sock_ref.set_recv_buffer_size(recv_buf) {
1209 tracing::warn!("failed to set TCP receive buffer size: {err:#}");
1210 }
1211 if let (Ok(send), Ok(recv)) = (sock_ref.send_buffer_size(), sock_ref.recv_buffer_size()) {
1212 tracing::debug!(
1213 "TCP socket buffer sizes: send={} recv={}",
1214 bytesize::ByteSize(send as u64),
1215 bytesize::ByteSize(recv as u64),
1216 );
1217 }
1218}
1219
1220#[cfg(test)]
1221mod tests {
1222 use super::*;
1223 use std::collections::HashMap;
1224 use std::path::PathBuf;
1225 use std::sync::Mutex;
1226
1227 struct MockDiscoverySession {
1228 test_responses: HashMap<String, bool>,
1229 which_response: Option<String>,
1230 home_response: Result<String, String>,
1231 calls: Mutex<Vec<String>>,
1232 }
1233
1234 impl Default for MockDiscoverySession {
1235 fn default() -> Self {
1236 Self {
1237 test_responses: HashMap::new(),
1238 which_response: None,
1239 home_response: Err("HOME not set".to_string()),
1240 calls: Mutex::new(Vec::new()),
1241 }
1242 }
1243 }
1244
1245 impl MockDiscoverySession {
1246 fn new() -> Self {
1247 Self::default()
1248 }
1249
1250 fn with_home(mut self, home: Option<&str>) -> Self {
1251 self.home_response = match home {
1252 Some(home) => Ok(home.to_string()),
1253 None => Err("HOME not set".to_string()),
1254 };
1255 self
1256 }
1257 fn with_which(mut self, path: Option<&str>) -> Self {
1258 self.which_response = path.map(|p| p.to_string());
1259 self
1260 }
1261 fn set_test_response(&mut self, path: &str, exists: bool) {
1262 self.test_responses.insert(path.to_string(), exists);
1263 }
1264 fn calls(&self) -> Vec<String> {
1265 self.calls.lock().unwrap().clone()
1266 }
1267 }
1268
1269 impl DiscoverySession for MockDiscoverySession {
1270 fn test_executable<'a>(
1271 &'a self,
1272 path: &'a str,
1273 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = anyhow::Result<bool>> + Send + 'a>>
1274 {
1275 self.calls.lock().unwrap().push(format!("test:{}", path));
1276 let exists = self.test_responses.get(path).copied().unwrap_or(false);
1277 Box::pin(async move { Ok(exists) })
1278 }
1279 fn which<'a>(
1280 &'a self,
1281 binary: &'a str,
1282 ) -> std::pin::Pin<
1283 Box<dyn std::future::Future<Output = anyhow::Result<Option<String>>> + Send + 'a>,
1284 > {
1285 self.calls.lock().unwrap().push(format!("which:{}", binary));
1286 let result = self.which_response.clone();
1287 Box::pin(async move { Ok(result) })
1288 }
1289 fn remote_home<'a>(
1290 &'a self,
1291 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = anyhow::Result<String>> + Send + 'a>>
1292 {
1293 self.calls.lock().unwrap().push("home".to_string());
1294 let result = self.home_response.clone();
1295 Box::pin(async move {
1296 match result {
1297 Ok(home) => Ok(home),
1298 Err(e) => Err(anyhow::anyhow!(e)),
1299 }
1300 })
1301 }
1302 }
1303
1304 #[tokio::test]
1305 async fn discover_rcpd_prefers_explicit_path() {
1306 let mut session = MockDiscoverySession::new();
1307 session.set_test_response("/opt/rcpd", true);
1308 let path = discover_rcpd_path_internal(&session, Some("/opt/rcpd"), None)
1309 .await
1310 .expect("should return explicit path");
1311 assert_eq!(path, "/opt/rcpd");
1312 assert_eq!(session.calls(), vec!["test:/opt/rcpd"]);
1313 }
1314
1315 #[tokio::test]
1316 async fn discover_rcpd_explicit_path_errors_without_fallbacks() {
1317 let session = MockDiscoverySession::new();
1318 let err = discover_rcpd_path_internal(&session, Some("/missing/rcpd"), None)
1319 .await
1320 .expect_err("should fail when explicit path is missing");
1321 assert!(
1322 err.to_string()
1323 .contains("rcpd binary not found or not executable"),
1324 "unexpected error: {err}"
1325 );
1326 assert_eq!(session.calls(), vec!["test:/missing/rcpd"]);
1327 }
1328
1329 #[tokio::test]
1330 async fn discover_rcpd_uses_same_dir_first() {
1331 let mut session = MockDiscoverySession::new();
1332 session.set_test_response("/custom/bin/rcpd", true);
1333 let path =
1334 discover_rcpd_path_internal(&session, None, Some(PathBuf::from("/custom/bin/rcp")))
1335 .await
1336 .expect("should find in same directory");
1337 assert_eq!(path, "/custom/bin/rcpd");
1338 assert_eq!(session.calls(), vec!["test:/custom/bin/rcpd"]);
1339 }
1340
1341 #[tokio::test]
1342 async fn discover_rcpd_falls_back_to_path_after_same_dir() {
1343 let mut session = MockDiscoverySession::new().with_which(Some("/usr/bin/rcpd"));
1344 session.set_test_response("/custom/bin/rcpd", false);
1345 let path =
1346 discover_rcpd_path_internal(&session, None, Some(PathBuf::from("/custom/bin/rcp")))
1347 .await
1348 .expect("should find in PATH after same dir miss");
1349 assert_eq!(path, "/usr/bin/rcpd");
1350 assert_eq!(session.calls(), vec!["test:/custom/bin/rcpd", "which:rcpd"]);
1351 }
1352
1353 #[tokio::test]
1354 async fn discover_rcpd_uses_cache_last() {
1355 let mut session = MockDiscoverySession::new()
1356 .with_home(Some("/home/rcp"))
1357 .with_which(None);
1358 session.set_test_response("/custom/bin/rcpd", false);
1359 let local_version = common::version::ProtocolVersion::current();
1360 let cache_path = format!("/home/rcp/.cache/rcp/bin/rcpd-{}", local_version.semantic);
1361 session.set_test_response(&cache_path, true);
1362 let path =
1363 discover_rcpd_path_internal(&session, None, Some(PathBuf::from("/custom/bin/rcp")))
1364 .await
1365 .expect("should fall back to cache");
1366 assert_eq!(path, cache_path);
1367 assert_eq!(
1368 session.calls(),
1369 vec![
1370 "test:/custom/bin/rcpd".to_string(),
1371 "which:rcpd".to_string(),
1372 "home".to_string(),
1373 format!("test:{cache_path}")
1374 ]
1375 );
1376 }
1377
1378 #[tokio::test]
1379 async fn discover_rcpd_reports_home_missing_in_error() {
1380 let mut session = MockDiscoverySession::new().with_which(None);
1381 session.set_test_response("/custom/bin/rcpd", false);
1382 let err =
1383 discover_rcpd_path_internal(&session, None, Some(PathBuf::from("/custom/bin/rcp")))
1384 .await
1385 .expect_err("should fail when nothing is found");
1386 let msg = err.to_string();
1387 assert!(
1388 msg.contains("Deployed cache: (skipped, HOME not set)"),
1389 "expected searched list to mention skipped cache, got: {msg}"
1390 );
1391 assert_eq!(
1392 session.calls(),
1393 vec!["test:/custom/bin/rcpd", "which:rcpd", "home"]
1394 );
1395 }
1396
1397 #[test]
1406 fn test_tokio_unstable_enabled() {
1407 #[cfg(not(tokio_unstable))]
1409 {
1410 panic!(
1411 "tokio_unstable cfg flag is not enabled! \
1412 This is required for console-subscriber support. \
1413 Check .cargo/config.toml"
1414 );
1415 }
1416
1417 #[cfg(tokio_unstable)]
1419 {
1420 let _join_set: tokio::task::JoinSet<()> = tokio::task::JoinSet::new();
1423 }
1424 }
1425
1426 fn iface(name: &str, addr: [u8; 4]) -> InterfaceIpv4 {
1427 InterfaceIpv4 {
1428 name: name.to_string(),
1429 addr: std::net::Ipv4Addr::new(addr[0], addr[1], addr[2], addr[3]),
1430 }
1431 }
1432
1433 #[test]
1434 fn choose_best_ipv4_prefers_physical_interfaces() {
1435 let interfaces = vec![
1436 iface("docker0", [172, 17, 0, 1]),
1437 iface("enp3s0", [192, 168, 1, 44]),
1438 iface("tailscale0", [100, 115, 92, 5]),
1439 ];
1440 assert_eq!(
1441 choose_best_ipv4(&interfaces),
1442 Some(std::net::Ipv4Addr::new(192, 168, 1, 44))
1443 );
1444 }
1445
1446 #[test]
1447 fn choose_best_ipv4_deprioritizes_link_local() {
1448 let interfaces = vec![
1449 iface("enp0s8", [169, 254, 10, 2]),
1450 iface("wlan0", [10, 0, 0, 23]),
1451 ];
1452 assert_eq!(
1453 choose_best_ipv4(&interfaces),
1454 Some(std::net::Ipv4Addr::new(10, 0, 0, 23))
1455 );
1456 }
1457
1458 #[test]
1459 fn choose_best_ipv4_falls_back_to_loopback() {
1460 let interfaces = vec![iface("lo", [127, 0, 0, 1]), iface("docker0", [0, 0, 0, 0])];
1461 assert_eq!(
1462 choose_best_ipv4(&interfaces),
1463 Some(std::net::Ipv4Addr::new(127, 0, 0, 1))
1464 );
1465 }
1466
1467 #[test]
1468 fn test_get_local_ip_with_explicit_ipv4() {
1469 let result = get_local_ip(Some("192.168.1.100"));
1471 assert!(result.is_ok(), "should accept valid IPv4 address");
1472 let ip = result.unwrap();
1473 assert_eq!(
1474 ip,
1475 std::net::IpAddr::V4(std::net::Ipv4Addr::new(192, 168, 1, 100))
1476 );
1477 }
1478
1479 #[test]
1480 fn test_get_local_ip_with_explicit_loopback() {
1481 let result = get_local_ip(Some("127.0.0.1"));
1483 assert!(result.is_ok(), "should accept loopback address");
1484 let ip = result.unwrap();
1485 assert_eq!(
1486 ip,
1487 std::net::IpAddr::V4(std::net::Ipv4Addr::new(127, 0, 0, 1))
1488 );
1489 }
1490
1491 #[test]
1492 fn test_get_local_ip_rejects_ipv6() {
1493 let result = get_local_ip(Some("::1"));
1495 assert!(result.is_err(), "should reject IPv6 address");
1496 let err = result.unwrap_err();
1497 let err_msg = format!("{err:#}");
1498 assert!(
1499 err_msg.contains("IPv6 address not supported"),
1500 "error should mention IPv6 not supported, got: {err_msg}"
1501 );
1502 assert!(
1503 err_msg.contains("0.0.0.0"),
1504 "error should mention IPv4-only binding, got: {err_msg}"
1505 );
1506 }
1507
1508 #[test]
1509 fn test_get_local_ip_rejects_ipv6_full() {
1510 let result = get_local_ip(Some("2001:db8::1"));
1512 assert!(result.is_err(), "should reject IPv6 address");
1513 let err = result.unwrap_err();
1514 let err_msg = format!("{err:#}");
1515 assert!(
1516 err_msg.contains("IPv6 address not supported"),
1517 "error should mention IPv6 not supported, got: {err_msg}"
1518 );
1519 }
1520
1521 #[test]
1522 fn test_get_local_ip_rejects_invalid_ip() {
1523 let result = get_local_ip(Some("not-an-ip"));
1525 assert!(result.is_err(), "should reject invalid IP format");
1526 let err = result.unwrap_err();
1527 let err_msg = format!("{err:#}");
1528 assert!(
1529 err_msg.contains("invalid IP address"),
1530 "error should mention invalid IP address, got: {err_msg}"
1531 );
1532 }
1533
1534 #[test]
1535 fn test_get_local_ip_rejects_invalid_ipv4() {
1536 let result = get_local_ip(Some("999.999.999.999"));
1538 assert!(result.is_err(), "should reject invalid IPv4 address");
1539 let err = result.unwrap_err();
1540 let err_msg = format!("{err:#}");
1541 assert!(
1542 err_msg.contains("invalid IP address"),
1543 "error should mention invalid IP address, got: {err_msg}"
1544 );
1545 }
1546}