1#[cfg(not(tokio_unstable))]
119compile_error!("tokio_unstable cfg must be enabled; see .cargo/config.toml");
120
121use anyhow::{anyhow, Context};
122use tracing::instrument;
123
124pub mod deploy;
125pub mod port_ranges;
126pub mod protocol;
127pub mod streams;
128pub mod tls;
129pub mod tracelog;
130
131#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, serde::Serialize, serde::Deserialize)]
137pub enum NetworkProfile {
138 #[default]
141 Datacenter,
142 Internet,
145}
146
147impl std::fmt::Display for NetworkProfile {
148 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
149 match self {
150 Self::Datacenter => write!(f, "datacenter"),
151 Self::Internet => write!(f, "internet"),
152 }
153 }
154}
155
156impl std::str::FromStr for NetworkProfile {
157 type Err = String;
158 fn from_str(s: &str) -> Result<Self, Self::Err> {
159 match s.to_lowercase().as_str() {
160 "datacenter" => Ok(Self::Datacenter),
161 "internet" => Ok(Self::Internet),
162 _ => Err(format!(
163 "invalid network profile '{}', expected 'datacenter' or 'internet'",
164 s
165 )),
166 }
167 }
168}
169
170pub const DATACENTER_REMOTE_COPY_BUFFER_SIZE: usize = 16 * 1024 * 1024;
172
173pub const INTERNET_REMOTE_COPY_BUFFER_SIZE: usize = 2 * 1024 * 1024;
175
176impl NetworkProfile {
177 pub fn default_remote_copy_buffer_size(&self) -> usize {
183 match self {
184 Self::Datacenter => DATACENTER_REMOTE_COPY_BUFFER_SIZE,
185 Self::Internet => INTERNET_REMOTE_COPY_BUFFER_SIZE,
186 }
187 }
188}
189
190#[derive(Debug, Clone)]
194pub struct TcpConfig {
195 pub port_ranges: Option<String>,
197 pub conn_timeout_sec: u64,
199 pub network_profile: NetworkProfile,
201 pub buffer_size: Option<usize>,
203 pub max_connections: usize,
205 pub pending_writes_multiplier: usize,
207}
208
209pub const DEFAULT_PENDING_WRITES_MULTIPLIER: usize = 4;
211
212impl Default for TcpConfig {
213 fn default() -> Self {
214 Self {
215 port_ranges: None,
216 conn_timeout_sec: 15,
217 network_profile: NetworkProfile::default(),
218 buffer_size: None,
219 max_connections: 100,
220 pending_writes_multiplier: DEFAULT_PENDING_WRITES_MULTIPLIER,
221 }
222 }
223}
224
225impl TcpConfig {
226 pub fn with_timeout(conn_timeout_sec: u64) -> Self {
228 Self {
229 port_ranges: None,
230 conn_timeout_sec,
231 network_profile: NetworkProfile::default(),
232 buffer_size: None,
233 max_connections: 100,
234 pending_writes_multiplier: DEFAULT_PENDING_WRITES_MULTIPLIER,
235 }
236 }
237 pub fn with_port_ranges(mut self, ranges: impl Into<String>) -> Self {
239 self.port_ranges = Some(ranges.into());
240 self
241 }
242 pub fn with_network_profile(mut self, profile: NetworkProfile) -> Self {
244 self.network_profile = profile;
245 self
246 }
247 pub fn with_buffer_size(mut self, size: usize) -> Self {
249 self.buffer_size = Some(size);
250 self
251 }
252 pub fn with_max_connections(mut self, max: usize) -> Self {
254 self.max_connections = max;
255 self
256 }
257 pub fn with_pending_writes_multiplier(mut self, multiplier: usize) -> Self {
259 self.pending_writes_multiplier = multiplier;
260 self
261 }
262 pub fn effective_buffer_size(&self) -> usize {
264 self.buffer_size
265 .unwrap_or_else(|| self.network_profile.default_remote_copy_buffer_size())
266 }
267}
268
269#[derive(Debug, PartialEq, Eq, Clone, Hash)]
270pub struct SshSession {
271 pub user: Option<String>,
272 pub host: String,
273 pub port: Option<u16>,
274}
275
276impl SshSession {
277 pub fn local() -> Self {
278 Self {
279 user: None,
280 host: "localhost".to_string(),
281 port: None,
282 }
283 }
284}
285
286pub use common::is_localhost;
288
289async fn setup_ssh_session(
290 session: &SshSession,
291) -> anyhow::Result<std::sync::Arc<openssh::Session>> {
292 let host = session.host.as_str();
293 let destination = match (session.user.as_deref(), session.port) {
294 (Some(user), Some(port)) => format!("ssh://{user}@{host}:{port}"),
295 (None, Some(port)) => format!("ssh://{}:{}", session.host, port),
296 (Some(user), None) => format!("ssh://{user}@{host}"),
297 (None, None) => format!("ssh://{host}"),
298 };
299 tracing::debug!("Connecting to SSH destination: {}", destination);
300 let session = std::sync::Arc::new(
301 openssh::Session::connect(destination, openssh::KnownHosts::Accept)
302 .await
303 .context("Failed to establish SSH connection")?,
304 );
305 Ok(session)
306}
307
308#[instrument]
309pub async fn get_remote_home_for_session(
310 session: &SshSession,
311) -> anyhow::Result<std::path::PathBuf> {
312 let ssh_session = setup_ssh_session(session).await?;
313 let home = get_remote_home(&ssh_session).await?;
314 Ok(std::path::PathBuf::from(home))
315}
316
317#[instrument]
318pub async fn wait_for_rcpd_process(
319 process: openssh::Child<std::sync::Arc<openssh::Session>>,
320) -> anyhow::Result<()> {
321 tracing::info!("Waiting on rcpd server on: {:?}", process);
322 let output = tokio::time::timeout(
324 std::time::Duration::from_secs(10),
325 process.wait_with_output(),
326 )
327 .await
328 .context("Timeout waiting for rcpd process to exit")?
329 .context("Failed to wait for rcpd process")?;
330 if !output.status.success() {
331 let stdout = String::from_utf8_lossy(&output.stdout);
332 let stderr = String::from_utf8_lossy(&output.stderr);
333 tracing::error!(
334 "rcpd command failed on remote host, status code: {:?}\nstdout:\n{}\nstderr:\n{}",
335 output.status.code(),
336 stdout,
337 stderr
338 );
339 return Err(anyhow!(
340 "rcpd command failed on remote host, status code: {:?}",
341 output.status.code(),
342 ));
343 }
344 if !output.stderr.is_empty() {
346 let stderr = String::from_utf8_lossy(&output.stderr);
347 tracing::debug!("rcpd stderr output:\n{}", stderr);
348 }
349 Ok(())
350}
351
352pub(crate) fn shell_escape(s: &str) -> String {
356 format!("'{}'", s.replace('\'', r"'\''"))
357}
358
359pub async fn get_remote_home(session: &std::sync::Arc<openssh::Session>) -> anyhow::Result<String> {
377 if let Ok(home_override) = std::env::var("RCP_REMOTE_HOME_OVERRIDE") {
378 if !home_override.is_empty() {
379 return Ok(home_override);
380 }
381 }
382 let output = session
383 .command("sh")
384 .arg("-c")
385 .arg("echo \"${HOME:?HOME not set}\"")
386 .output()
387 .await
388 .context("failed to check HOME environment variable on remote host")?;
389
390 if !output.status.success() {
391 let stderr = String::from_utf8_lossy(&output.stderr);
392 anyhow::bail!(
393 "HOME environment variable is not set on remote host\n\
394 \n\
395 stderr: {}\n\
396 \n\
397 The HOME environment variable is required for rcpd deployment and discovery.\n\
398 Please ensure your SSH configuration preserves environment variables.",
399 stderr
400 );
401 }
402
403 let home = String::from_utf8_lossy(&output.stdout).trim().to_string();
404
405 if home.is_empty() {
406 anyhow::bail!(
407 "HOME environment variable is empty on remote host\n\
408 \n\
409 The HOME environment variable is required for rcpd deployment and discovery.\n\
410 Please ensure your SSH configuration sets HOME correctly."
411 );
412 }
413
414 Ok(home)
415}
416
417#[cfg(test)]
418mod shell_escape_tests {
419 use super::*;
420
421 #[test]
422 fn test_shell_escape_simple() {
423 assert_eq!(shell_escape("simple"), "'simple'");
424 }
425
426 #[test]
427 fn test_shell_escape_with_spaces() {
428 assert_eq!(shell_escape("path with spaces"), "'path with spaces'");
429 }
430
431 #[test]
432 fn test_shell_escape_with_single_quote() {
433 assert_eq!(
435 shell_escape("path'with'quotes"),
436 r"'path'\''with'\''quotes'"
437 );
438 }
439
440 #[test]
441 fn test_shell_escape_injection_attempt() {
442 assert_eq!(shell_escape("foo; rm -rf /"), "'foo; rm -rf /'");
444 }
446
447 #[test]
448 fn test_shell_escape_special_chars() {
449 assert_eq!(shell_escape("$PATH && echo pwned"), "'$PATH && echo pwned'");
450 }
452}
453
454trait DiscoverySession {
455 fn test_executable<'a>(
456 &'a self,
457 path: &'a str,
458 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = anyhow::Result<bool>> + Send + 'a>>;
459 fn which<'a>(
460 &'a self,
461 binary: &'a str,
462 ) -> std::pin::Pin<
463 Box<dyn std::future::Future<Output = anyhow::Result<Option<String>>> + Send + 'a>,
464 >;
465 fn remote_home<'a>(
466 &'a self,
467 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = anyhow::Result<String>> + Send + 'a>>;
468}
469
470struct RealDiscoverySession<'a> {
471 session: &'a std::sync::Arc<openssh::Session>,
472}
473
474impl<'a> DiscoverySession for RealDiscoverySession<'a> {
475 fn test_executable<'b>(
476 &'b self,
477 path: &'b str,
478 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = anyhow::Result<bool>> + Send + 'b>>
479 {
480 Box::pin(async move {
481 let output = self
482 .session
483 .command("sh")
484 .arg("-c")
485 .arg(format!("test -x {}", shell_escape(path)))
486 .output()
487 .await?;
488 Ok(output.status.success())
489 })
490 }
491 fn which<'b>(
492 &'b self,
493 binary: &'b str,
494 ) -> std::pin::Pin<
495 Box<dyn std::future::Future<Output = anyhow::Result<Option<String>>> + Send + 'b>,
496 > {
497 Box::pin(async move {
498 let output = self.session.command("which").arg(binary).output().await?;
499 if output.status.success() {
500 let path = String::from_utf8_lossy(&output.stdout).trim().to_string();
501 if !path.is_empty() {
502 return Ok(Some(path));
503 }
504 }
505 Ok(None)
506 })
507 }
508 fn remote_home<'b>(
509 &'b self,
510 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = anyhow::Result<String>> + Send + 'b>>
511 {
512 Box::pin(get_remote_home(self.session))
513 }
514}
515
516async fn discover_rcpd_path(
529 session: &std::sync::Arc<openssh::Session>,
530 explicit_path: Option<&str>,
531) -> anyhow::Result<String> {
532 let real_session = RealDiscoverySession { session };
533 discover_rcpd_path_internal(&real_session, explicit_path, None).await
534}
535
536async fn discover_rcpd_path_internal<S: DiscoverySession + ?Sized>(
537 session: &S,
538 explicit_path: Option<&str>,
539 current_exe_override: Option<std::path::PathBuf>,
540) -> anyhow::Result<String> {
541 let local_version = common::version::ProtocolVersion::current();
542 if let Some(path) = explicit_path {
544 tracing::debug!("Trying explicit rcpd path: {}", path);
545 if session.test_executable(path).await? {
546 tracing::info!("Found rcpd at explicit path: {}", path);
547 return Ok(path.to_string());
548 }
549 return Err(anyhow::anyhow!(
552 "rcpd binary not found or not executable at explicit path: {}",
553 path
554 ));
555 }
556 if let Ok(current_exe) = current_exe_override
558 .map(Ok)
559 .unwrap_or_else(std::env::current_exe)
560 {
561 if let Some(bin_dir) = current_exe.parent() {
562 let path = bin_dir.join("rcpd").display().to_string();
563 tracing::debug!("Trying same directory as rcp: {}", path);
564 if session.test_executable(&path).await? {
565 tracing::info!("Found rcpd in same directory as rcp: {}", path);
566 return Ok(path);
567 }
568 }
569 }
570 tracing::debug!("Trying to find rcpd in PATH");
572 if let Some(path) = session.which("rcpd").await? {
573 tracing::info!("Found rcpd in PATH: {}", path);
574 return Ok(path);
575 }
576 let cache_path = match session.remote_home().await {
579 Ok(home) => {
580 let path = format!("{}/.cache/rcp/bin/rcpd-{}", home, local_version.semantic);
581 tracing::debug!("Trying deployed cache path: {}", path);
582 if session.test_executable(&path).await? {
583 tracing::info!("Found rcpd in deployed cache: {}", path);
584 return Ok(path);
585 }
586 Some(path)
587 }
588 Err(e) => {
589 tracing::debug!(
590 "HOME not set on remote host, skipping cache directory check: {:#}",
591 e
592 );
593 None
594 }
595 };
596 let mut searched = vec![];
598 searched.push("- Same directory as local rcp binary".to_string());
599 searched.push("- PATH (via 'which rcpd')".to_string());
600 if let Some(path) = cache_path.as_ref() {
601 searched.push(format!("- Deployed cache: {}", path));
602 } else {
603 searched.push("- Deployed cache: (skipped, HOME not set)".to_string());
604 }
605 if let Some(path) = explicit_path {
606 searched.insert(
607 0,
608 format!("- Explicit path: {} (not found or not executable)", path),
609 );
610 }
611 Err(anyhow::anyhow!(
612 "rcpd binary not found on remote host\n\
613 \n\
614 Searched in:\n\
615 {}\n\
616 \n\
617 Options:\n\
618 - Use automatic deployment: rcp --auto-deploy-rcpd ...\n\
619 - Install rcpd manually: cargo install rcp-tools-rcp --version {}\n\
620 - Specify explicit path: rcp --rcpd-path=/path/to/rcpd ...",
621 searched.join("\n"),
622 local_version.semantic
623 ))
624}
625
626async fn try_discover_and_check_version(
631 session: &std::sync::Arc<openssh::Session>,
632 explicit_path: Option<&str>,
633 remote_host: &str,
634) -> anyhow::Result<String> {
635 let rcpd_path = discover_rcpd_path(session, explicit_path).await?;
637 check_rcpd_version(session, &rcpd_path, remote_host).await?;
639 Ok(rcpd_path)
640}
641
642async fn check_rcpd_version(
646 session: &std::sync::Arc<openssh::Session>,
647 rcpd_path: &str,
648 remote_host: &str,
649) -> anyhow::Result<()> {
650 let local_version = common::version::ProtocolVersion::current();
651
652 tracing::debug!("Checking rcpd version on remote host: {}", remote_host);
653
654 let output = session
656 .command(rcpd_path)
657 .arg("--protocol-version")
658 .output()
659 .await
660 .context("Failed to execute rcpd --protocol-version on remote host")?;
661
662 if !output.status.success() {
663 let stderr = String::from_utf8_lossy(&output.stderr);
664 return Err(anyhow::anyhow!(
665 "rcpd --protocol-version failed on remote host '{}'\n\
666 \n\
667 stderr: {}\n\
668 \n\
669 This may indicate an old version of rcpd that does not support --protocol-version.\n\
670 Please install a matching version of rcpd on the remote host:\n\
671 - cargo install rcp-tools-rcp --version {}",
672 remote_host,
673 stderr,
674 local_version.semantic
675 ));
676 }
677
678 let stdout = String::from_utf8_lossy(&output.stdout);
679 let remote_version = common::version::ProtocolVersion::from_json(stdout.trim())
680 .context("Failed to parse rcpd version JSON from remote host")?;
681
682 tracing::info!(
683 "Local version: {}, Remote version: {}",
684 local_version,
685 remote_version
686 );
687
688 if !local_version.is_compatible_with(&remote_version) {
689 return Err(anyhow::anyhow!(
690 "rcpd version mismatch\n\
691 \n\
692 Local: rcp {}\n\
693 Remote: rcpd {} on host '{}'\n\
694 \n\
695 The rcpd version on the remote host must exactly match the rcp version.\n\
696 \n\
697 To fix this, install the matching version on the remote host:\n\
698 - ssh {} 'cargo install rcp-tools-rcp --version {}'",
699 local_version,
700 remote_version,
701 remote_host,
702 shell_escape(remote_host),
703 local_version.semantic
704 ));
705 }
706
707 Ok(())
708}
709
710#[derive(Debug, Clone)]
712pub struct RcpdConnectionInfo {
713 pub addr: std::net::SocketAddr,
715 pub fingerprint: Option<tls::Fingerprint>,
717}
718
719pub struct RcpdProcess {
721 pub child: openssh::Child<std::sync::Arc<openssh::Session>>,
723 pub conn_info: RcpdConnectionInfo,
725 _stderr_drain: tokio::task::JoinHandle<()>,
727 _stdout_drain: Option<tokio::task::JoinHandle<()>>,
729}
730
731#[allow(clippy::too_many_arguments)]
732#[instrument]
733pub async fn start_rcpd(
734 rcpd_config: &protocol::RcpdConfig,
735 session: &SshSession,
736 explicit_rcpd_path: Option<&str>,
737 auto_deploy_rcpd: bool,
738 bind_ip: Option<&str>,
739 role: protocol::RcpdRole,
740) -> anyhow::Result<RcpdProcess> {
741 tracing::info!("Starting rcpd server on: {:?}", session);
742 let remote_host = &session.host;
743 let ssh_session = setup_ssh_session(session).await?;
744 let rcpd_path =
746 match try_discover_and_check_version(&ssh_session, explicit_rcpd_path, remote_host).await {
747 Ok(path) => {
748 path
750 }
751 Err(e) => {
752 if auto_deploy_rcpd {
754 tracing::info!(
755 "rcpd not found or version mismatch, attempting auto-deployment"
756 );
757 let local_rcpd = deploy::find_local_rcpd_binary()
759 .context("failed to find local rcpd binary for deployment")?;
760 tracing::info!("Found local rcpd binary at {}", local_rcpd.display());
761 let local_version = common::version::ProtocolVersion::current();
763 let deployed_path = deploy::deploy_rcpd(
765 &ssh_session,
766 &local_rcpd,
767 &local_version.semantic,
768 remote_host,
769 )
770 .await
771 .context("failed to deploy rcpd to remote host")?;
772 tracing::info!("Successfully deployed rcpd to {}", deployed_path);
773 if let Err(e) = deploy::cleanup_old_versions(&ssh_session, 3).await {
775 tracing::warn!("failed to cleanup old versions (non-fatal): {:#}", e);
776 }
777 deployed_path
778 } else {
779 return Err(e);
781 }
782 }
783 };
784 let rcpd_args = rcpd_config.to_args();
786 tracing::debug!("rcpd arguments: {:?}", rcpd_args);
787 let mut cmd = ssh_session.arc_command(&rcpd_path);
788 cmd.arg("--role").arg(role.to_string()).args(rcpd_args);
789 if let Some(ip) = bind_ip {
791 tracing::debug!("passing --bind-ip {} to rcpd", ip);
792 cmd.arg("--bind-ip").arg(ip);
793 }
794 cmd.stdin(openssh::Stdio::piped());
797 cmd.stdout(openssh::Stdio::piped());
798 cmd.stderr(openssh::Stdio::piped());
799 tracing::info!("Will run remotely: {cmd:?}");
800 let mut child = cmd.spawn().await.context("Failed to spawn rcpd command")?;
801 let stderr = child.stderr().take().context("rcpd stderr not available")?;
806 let mut stderr_reader = tokio::io::BufReader::new(stderr);
807 let mut line = String::new();
808 use tokio::io::AsyncBufReadExt;
809 stderr_reader
810 .read_line(&mut line)
811 .await
812 .context("failed to read connection info from rcpd")?;
813 let line = line.trim();
814 let host_stderr = session.host.clone();
817 let stderr_drain = tokio::spawn(async move {
818 let mut line = String::new();
819 loop {
820 line.clear();
821 match stderr_reader.read_line(&mut line).await {
822 Ok(0) => break, Ok(_) => {
824 let trimmed = line.trim();
825 if !trimmed.is_empty() {
826 tracing::debug!(host = %host_stderr, "rcpd stderr: {}", trimmed);
827 }
828 }
829 Err(e) => {
830 tracing::debug!(host = %host_stderr, "rcpd stderr read error: {:#}", e);
831 break;
832 }
833 }
834 }
835 });
836 let stdout_drain = if let Some(stdout) = child.stdout().take() {
839 let host_stdout = session.host.clone();
840 let mut stdout_reader = tokio::io::BufReader::new(stdout);
841 Some(tokio::spawn(async move {
842 let mut line = String::new();
843 loop {
844 line.clear();
845 match stdout_reader.read_line(&mut line).await {
846 Ok(0) => break, Ok(_) => {
848 let trimmed = line.trim();
849 if !trimmed.is_empty() {
850 tracing::debug!(host = %host_stdout, "rcpd stdout: {}", trimmed);
851 }
852 }
853 Err(e) => {
854 tracing::debug!(host = %host_stdout, "rcpd stdout read error: {:#}", e);
855 break;
856 }
857 }
858 }
859 }))
860 } else {
861 None
862 };
863 tracing::debug!("rcpd connection line: {}", line);
864 let conn_info = if let Some(rest) = line.strip_prefix("RCP_TLS ") {
865 let parts: Vec<&str> = rest.split_whitespace().collect();
867 if parts.len() != 2 {
868 anyhow::bail!("invalid RCP_TLS line from rcpd: {}", line);
869 }
870 let addr = parts[0]
871 .parse()
872 .with_context(|| format!("invalid address in RCP_TLS line: {}", parts[0]))?;
873 let fingerprint = tls::fingerprint_from_hex(parts[1])
874 .with_context(|| format!("invalid fingerprint in RCP_TLS line: {}", parts[1]))?;
875 RcpdConnectionInfo {
876 addr,
877 fingerprint: Some(fingerprint),
878 }
879 } else if let Some(rest) = line.strip_prefix("RCP_TCP ") {
880 let addr = rest
882 .trim()
883 .parse()
884 .with_context(|| format!("invalid address in RCP_TCP line: {}", rest))?;
885 RcpdConnectionInfo {
886 addr,
887 fingerprint: None,
888 }
889 } else {
890 anyhow::bail!(
891 "unexpected output from rcpd (expected RCP_TLS or RCP_TCP): {}",
892 line
893 );
894 };
895 tracing::info!(
896 "rcpd listening on {} (encryption={})",
897 conn_info.addr,
898 conn_info.fingerprint.is_some()
899 );
900 Ok(RcpdProcess {
901 child,
902 conn_info,
903 _stderr_drain: stderr_drain,
904 _stdout_drain: stdout_drain,
905 })
906}
907
908fn get_local_ip(explicit_bind_ip: Option<&str>) -> anyhow::Result<std::net::IpAddr> {
913 if let Some(ip_str) = explicit_bind_ip {
915 let ip = ip_str
916 .parse::<std::net::IpAddr>()
917 .with_context(|| format!("invalid IP address: {}", ip_str))?;
918 match ip {
919 std::net::IpAddr::V4(ipv4) => {
920 tracing::debug!("using explicit bind IP: {}", ipv4);
921 return Ok(std::net::IpAddr::V4(ipv4));
922 }
923 std::net::IpAddr::V6(_) => {
924 anyhow::bail!(
925 "IPv6 address not supported for binding (got {}). \
926 TCP endpoints bind to 0.0.0.0 (IPv4 only)",
927 ip
928 );
929 }
930 }
931 }
932 if let Some(ipv4) = try_ipv4_via_kernel_routing()? {
934 return Ok(std::net::IpAddr::V4(ipv4));
935 }
936 tracing::debug!("routing-based detection failed, falling back to interface enumeration");
938 let interfaces = collect_ipv4_interfaces().context("Failed to enumerate network interfaces")?;
939 if let Some(ipv4) = choose_best_ipv4(&interfaces) {
940 tracing::debug!("using IPv4 address from interface scan: {}", ipv4);
941 return Ok(std::net::IpAddr::V4(ipv4));
942 }
943 anyhow::bail!("No IPv4 interfaces found (TCP endpoints require IPv4 as they bind to 0.0.0.0)")
944}
945
946fn try_ipv4_via_kernel_routing() -> anyhow::Result<Option<std::net::Ipv4Addr>> {
947 let private_ips = ["10.0.0.1:80", "172.16.0.1:80", "192.168.1.1:80"];
950 for addr_str in &private_ips {
951 let addr = addr_str
952 .parse::<std::net::SocketAddr>()
953 .expect("hardcoded socket addresses are valid");
954 let socket = match std::net::UdpSocket::bind("0.0.0.0:0") {
955 Ok(socket) => socket,
956 Err(err) => {
957 tracing::debug!(?err, "failed to bind UDP socket for routing detection");
958 continue;
959 }
960 };
961 if let Err(err) = socket.connect(addr) {
962 tracing::debug!(?err, "connect() failed for routing target {}", addr);
963 continue;
964 }
965 match socket.local_addr() {
966 Ok(std::net::SocketAddr::V4(local_addr)) => {
967 let ipv4 = *local_addr.ip();
968 if !ipv4.is_loopback() && !ipv4.is_unspecified() {
969 tracing::debug!(
970 "using IPv4 address from kernel routing (via {}): {}",
971 addr,
972 ipv4
973 );
974 return Ok(Some(ipv4));
975 }
976 }
977 Ok(_) => {
978 tracing::debug!("kernel routing returned IPv6 despite IPv4 bind, ignoring");
979 }
980 Err(err) => {
981 tracing::debug!(?err, "local_addr() failed for routing-based detection");
982 }
983 }
984 }
985 Ok(None)
986}
987
988#[derive(Clone, Debug, PartialEq, Eq)]
989struct InterfaceIpv4 {
990 name: String,
991 addr: std::net::Ipv4Addr,
992}
993
994fn collect_ipv4_interfaces() -> anyhow::Result<Vec<InterfaceIpv4>> {
995 use if_addrs::get_if_addrs;
996 let mut interfaces = Vec::new();
997 for iface in get_if_addrs()? {
998 if let std::net::IpAddr::V4(ipv4) = iface.addr.ip() {
999 interfaces.push(InterfaceIpv4 {
1000 name: iface.name,
1001 addr: ipv4,
1002 });
1003 }
1004 }
1005 Ok(interfaces)
1006}
1007
1008fn choose_best_ipv4(interfaces: &[InterfaceIpv4]) -> Option<std::net::Ipv4Addr> {
1009 interfaces
1010 .iter()
1011 .filter(|iface| !iface.addr.is_unspecified())
1012 .min_by_key(|iface| interface_priority(&iface.name, &iface.addr))
1013 .map(|iface| iface.addr)
1014}
1015
1016fn interface_priority(
1017 name: &str,
1018 addr: &std::net::Ipv4Addr,
1019) -> (InterfaceCategory, u8, u8, std::net::Ipv4Addr) {
1020 (
1021 classify_interface(name, addr),
1022 if addr.is_link_local() { 1 } else { 0 },
1023 if addr.is_private() { 1 } else { 0 },
1024 *addr,
1025 )
1026}
1027
1028#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd)]
1029enum InterfaceCategory {
1030 Preferred = 0,
1031 Normal = 1,
1032 Virtual = 2,
1033 Loopback = 3,
1034}
1035
1036fn classify_interface(name: &str, addr: &std::net::Ipv4Addr) -> InterfaceCategory {
1037 if addr.is_loopback() {
1038 return InterfaceCategory::Loopback;
1039 }
1040 let normalized = normalize_interface_name(name);
1041 if is_virtual_interface(&normalized) {
1042 return InterfaceCategory::Virtual;
1043 }
1044 if is_preferred_physical_interface(&normalized) {
1045 return InterfaceCategory::Preferred;
1046 }
1047 InterfaceCategory::Normal
1048}
1049
1050fn normalize_interface_name(original: &str) -> String {
1051 let mut normalized = String::with_capacity(original.len());
1052 for ch in original.chars() {
1053 if ch.is_ascii_alphanumeric() {
1054 normalized.push(ch.to_ascii_lowercase());
1055 }
1056 }
1057 normalized
1058}
1059
1060fn is_virtual_interface(name: &str) -> bool {
1061 const VIRTUAL_PREFIXES: &[&str] = &[
1062 "br",
1063 "docker",
1064 "veth",
1065 "virbr",
1066 "vmnet",
1067 "wg",
1068 "tailscale",
1069 "zt",
1070 "zerotier",
1071 "tap",
1072 "tun",
1073 "utun",
1074 "ham",
1075 "vpn",
1076 "lo",
1077 "lxc",
1078 ];
1079 VIRTUAL_PREFIXES
1080 .iter()
1081 .any(|prefix| name.starts_with(prefix))
1082 || name.contains("virtual")
1083}
1084
1085fn is_preferred_physical_interface(name: &str) -> bool {
1086 const PHYSICAL_PREFIXES: &[&str] = &[
1087 "en", "eth", "em", "eno", "ens", "enp", "wl", "ww", "wlan", "ethernet", "lan", "wifi",
1088 ];
1089 PHYSICAL_PREFIXES
1090 .iter()
1091 .any(|prefix| name.starts_with(prefix))
1092}
1093
1094#[instrument]
1096pub fn get_random_server_name() -> String {
1097 use rand::Rng;
1098 rand::thread_rng()
1099 .sample_iter(&rand::distributions::Alphanumeric)
1100 .take(20)
1101 .map(char::from)
1102 .collect()
1103}
1104
1105#[instrument(skip(config))]
1113pub async fn create_tcp_control_listener(
1114 config: &TcpConfig,
1115 bind_ip: Option<&str>,
1116) -> anyhow::Result<tokio::net::TcpListener> {
1117 let bind_addr = if let Some(ip_str) = bind_ip {
1118 let ip = ip_str
1119 .parse::<std::net::IpAddr>()
1120 .with_context(|| format!("invalid IP address: {}", ip_str))?;
1121 std::net::SocketAddr::new(ip, 0)
1122 } else {
1123 std::net::SocketAddr::new(std::net::IpAddr::V4(std::net::Ipv4Addr::UNSPECIFIED), 0)
1124 };
1125 let listener = if let Some(ranges_str) = config.port_ranges.as_deref() {
1126 let ranges = port_ranges::PortRanges::parse(ranges_str)?;
1127 ranges.bind_tcp_listener(bind_addr.ip()).await?
1128 } else {
1129 tokio::net::TcpListener::bind(bind_addr).await?
1130 };
1131 let local_addr = listener.local_addr()?;
1132 tracing::info!("TCP control listener bound to {}", local_addr);
1133 Ok(listener)
1134}
1135
1136#[instrument(skip(config))]
1140pub async fn create_tcp_data_listener(
1141 config: &TcpConfig,
1142 bind_ip: Option<&str>,
1143) -> anyhow::Result<tokio::net::TcpListener> {
1144 let bind_addr = if let Some(ip_str) = bind_ip {
1145 let ip = ip_str
1146 .parse::<std::net::IpAddr>()
1147 .with_context(|| format!("invalid IP address: {}", ip_str))?;
1148 std::net::SocketAddr::new(ip, 0)
1149 } else {
1150 std::net::SocketAddr::new(std::net::IpAddr::V4(std::net::Ipv4Addr::UNSPECIFIED), 0)
1151 };
1152 let listener = if let Some(ranges_str) = config.port_ranges.as_deref() {
1153 let ranges = port_ranges::PortRanges::parse(ranges_str)?;
1154 ranges.bind_tcp_listener(bind_addr.ip()).await?
1155 } else {
1156 tokio::net::TcpListener::bind(bind_addr).await?
1157 };
1158 let local_addr = listener.local_addr()?;
1159 tracing::info!("TCP data listener bound to {}", local_addr);
1160 Ok(listener)
1161}
1162
1163pub fn get_tcp_listener_addr(
1167 listener: &tokio::net::TcpListener,
1168 bind_ip: Option<&str>,
1169) -> anyhow::Result<std::net::SocketAddr> {
1170 let local_addr = listener.local_addr()?;
1171 if local_addr.ip().is_unspecified() {
1172 let local_ip = get_local_ip(bind_ip).context("failed to get local IP address")?;
1173 Ok(std::net::SocketAddr::new(local_ip, local_addr.port()))
1174 } else {
1175 Ok(local_addr)
1176 }
1177}
1178
1179#[instrument]
1181pub async fn connect_tcp_control(
1182 addr: std::net::SocketAddr,
1183 timeout_sec: u64,
1184) -> anyhow::Result<tokio::net::TcpStream> {
1185 let stream = tokio::time::timeout(
1186 std::time::Duration::from_secs(timeout_sec),
1187 tokio::net::TcpStream::connect(addr),
1188 )
1189 .await
1190 .with_context(|| format!("connection to {} timed out after {}s", addr, timeout_sec))?
1191 .with_context(|| format!("failed to connect to {}", addr))?;
1192 stream.set_nodelay(true)?;
1193 tracing::debug!("connected to TCP control server at {}", addr);
1194 Ok(stream)
1195}
1196
1197pub fn configure_tcp_buffers(stream: &tokio::net::TcpStream, profile: NetworkProfile) {
1201 use socket2::SockRef;
1202 let (send_buf, recv_buf) = match profile {
1203 NetworkProfile::Datacenter => (16 * 1024 * 1024, 16 * 1024 * 1024),
1204 NetworkProfile::Internet => (2 * 1024 * 1024, 2 * 1024 * 1024),
1205 };
1206 let sock_ref = SockRef::from(stream);
1207 if let Err(err) = sock_ref.set_send_buffer_size(send_buf) {
1208 tracing::warn!("failed to set TCP send buffer size: {err:#}");
1209 }
1210 if let Err(err) = sock_ref.set_recv_buffer_size(recv_buf) {
1211 tracing::warn!("failed to set TCP receive buffer size: {err:#}");
1212 }
1213 if let (Ok(send), Ok(recv)) = (sock_ref.send_buffer_size(), sock_ref.recv_buffer_size()) {
1214 tracing::debug!(
1215 "TCP socket buffer sizes: send={} recv={}",
1216 bytesize::ByteSize(send as u64),
1217 bytesize::ByteSize(recv as u64),
1218 );
1219 }
1220}
1221
1222#[cfg(test)]
1223mod tests {
1224 use super::*;
1225 use std::collections::HashMap;
1226 use std::path::PathBuf;
1227 use std::sync::Mutex;
1228
1229 struct MockDiscoverySession {
1230 test_responses: HashMap<String, bool>,
1231 which_response: Option<String>,
1232 home_response: Result<String, String>,
1233 calls: Mutex<Vec<String>>,
1234 }
1235
1236 impl Default for MockDiscoverySession {
1237 fn default() -> Self {
1238 Self {
1239 test_responses: HashMap::new(),
1240 which_response: None,
1241 home_response: Err("HOME not set".to_string()),
1242 calls: Mutex::new(Vec::new()),
1243 }
1244 }
1245 }
1246
1247 impl MockDiscoverySession {
1248 fn new() -> Self {
1249 Self::default()
1250 }
1251
1252 fn with_home(mut self, home: Option<&str>) -> Self {
1253 self.home_response = match home {
1254 Some(home) => Ok(home.to_string()),
1255 None => Err("HOME not set".to_string()),
1256 };
1257 self
1258 }
1259 fn with_which(mut self, path: Option<&str>) -> Self {
1260 self.which_response = path.map(|p| p.to_string());
1261 self
1262 }
1263 fn set_test_response(&mut self, path: &str, exists: bool) {
1264 self.test_responses.insert(path.to_string(), exists);
1265 }
1266 fn calls(&self) -> Vec<String> {
1267 self.calls.lock().unwrap().clone()
1268 }
1269 }
1270
1271 impl DiscoverySession for MockDiscoverySession {
1272 fn test_executable<'a>(
1273 &'a self,
1274 path: &'a str,
1275 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = anyhow::Result<bool>> + Send + 'a>>
1276 {
1277 self.calls.lock().unwrap().push(format!("test:{}", path));
1278 let exists = self.test_responses.get(path).copied().unwrap_or(false);
1279 Box::pin(async move { Ok(exists) })
1280 }
1281 fn which<'a>(
1282 &'a self,
1283 binary: &'a str,
1284 ) -> std::pin::Pin<
1285 Box<dyn std::future::Future<Output = anyhow::Result<Option<String>>> + Send + 'a>,
1286 > {
1287 self.calls.lock().unwrap().push(format!("which:{}", binary));
1288 let result = self.which_response.clone();
1289 Box::pin(async move { Ok(result) })
1290 }
1291 fn remote_home<'a>(
1292 &'a self,
1293 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = anyhow::Result<String>> + Send + 'a>>
1294 {
1295 self.calls.lock().unwrap().push("home".to_string());
1296 let result = self.home_response.clone();
1297 Box::pin(async move {
1298 match result {
1299 Ok(home) => Ok(home),
1300 Err(e) => Err(anyhow::anyhow!(e)),
1301 }
1302 })
1303 }
1304 }
1305
1306 #[tokio::test]
1307 async fn discover_rcpd_prefers_explicit_path() {
1308 let mut session = MockDiscoverySession::new();
1309 session.set_test_response("/opt/rcpd", true);
1310 let path = discover_rcpd_path_internal(&session, Some("/opt/rcpd"), None)
1311 .await
1312 .expect("should return explicit path");
1313 assert_eq!(path, "/opt/rcpd");
1314 assert_eq!(session.calls(), vec!["test:/opt/rcpd"]);
1315 }
1316
1317 #[tokio::test]
1318 async fn discover_rcpd_explicit_path_errors_without_fallbacks() {
1319 let session = MockDiscoverySession::new();
1320 let err = discover_rcpd_path_internal(&session, Some("/missing/rcpd"), None)
1321 .await
1322 .expect_err("should fail when explicit path is missing");
1323 assert!(
1324 err.to_string()
1325 .contains("rcpd binary not found or not executable"),
1326 "unexpected error: {err}"
1327 );
1328 assert_eq!(session.calls(), vec!["test:/missing/rcpd"]);
1329 }
1330
1331 #[tokio::test]
1332 async fn discover_rcpd_uses_same_dir_first() {
1333 let mut session = MockDiscoverySession::new();
1334 session.set_test_response("/custom/bin/rcpd", true);
1335 let path =
1336 discover_rcpd_path_internal(&session, None, Some(PathBuf::from("/custom/bin/rcp")))
1337 .await
1338 .expect("should find in same directory");
1339 assert_eq!(path, "/custom/bin/rcpd");
1340 assert_eq!(session.calls(), vec!["test:/custom/bin/rcpd"]);
1341 }
1342
1343 #[tokio::test]
1344 async fn discover_rcpd_falls_back_to_path_after_same_dir() {
1345 let mut session = MockDiscoverySession::new().with_which(Some("/usr/bin/rcpd"));
1346 session.set_test_response("/custom/bin/rcpd", false);
1347 let path =
1348 discover_rcpd_path_internal(&session, None, Some(PathBuf::from("/custom/bin/rcp")))
1349 .await
1350 .expect("should find in PATH after same dir miss");
1351 assert_eq!(path, "/usr/bin/rcpd");
1352 assert_eq!(session.calls(), vec!["test:/custom/bin/rcpd", "which:rcpd"]);
1353 }
1354
1355 #[tokio::test]
1356 async fn discover_rcpd_uses_cache_last() {
1357 let mut session = MockDiscoverySession::new()
1358 .with_home(Some("/home/rcp"))
1359 .with_which(None);
1360 session.set_test_response("/custom/bin/rcpd", false);
1361 let local_version = common::version::ProtocolVersion::current();
1362 let cache_path = format!("/home/rcp/.cache/rcp/bin/rcpd-{}", local_version.semantic);
1363 session.set_test_response(&cache_path, true);
1364 let path =
1365 discover_rcpd_path_internal(&session, None, Some(PathBuf::from("/custom/bin/rcp")))
1366 .await
1367 .expect("should fall back to cache");
1368 assert_eq!(path, cache_path);
1369 assert_eq!(
1370 session.calls(),
1371 vec![
1372 "test:/custom/bin/rcpd".to_string(),
1373 "which:rcpd".to_string(),
1374 "home".to_string(),
1375 format!("test:{cache_path}")
1376 ]
1377 );
1378 }
1379
1380 #[tokio::test]
1381 async fn discover_rcpd_reports_home_missing_in_error() {
1382 let mut session = MockDiscoverySession::new().with_which(None);
1383 session.set_test_response("/custom/bin/rcpd", false);
1384 let err =
1385 discover_rcpd_path_internal(&session, None, Some(PathBuf::from("/custom/bin/rcp")))
1386 .await
1387 .expect_err("should fail when nothing is found");
1388 let msg = err.to_string();
1389 assert!(
1390 msg.contains("Deployed cache: (skipped, HOME not set)"),
1391 "expected searched list to mention skipped cache, got: {msg}"
1392 );
1393 assert_eq!(
1394 session.calls(),
1395 vec!["test:/custom/bin/rcpd", "which:rcpd", "home"]
1396 );
1397 }
1398
1399 #[test]
1408 fn test_tokio_unstable_enabled() {
1409 #[cfg(not(tokio_unstable))]
1411 {
1412 panic!(
1413 "tokio_unstable cfg flag is not enabled! \
1414 This is required for console-subscriber support. \
1415 Check .cargo/config.toml"
1416 );
1417 }
1418
1419 #[cfg(tokio_unstable)]
1421 {
1422 let _join_set: tokio::task::JoinSet<()> = tokio::task::JoinSet::new();
1425 }
1426 }
1427
1428 fn iface(name: &str, addr: [u8; 4]) -> InterfaceIpv4 {
1429 InterfaceIpv4 {
1430 name: name.to_string(),
1431 addr: std::net::Ipv4Addr::new(addr[0], addr[1], addr[2], addr[3]),
1432 }
1433 }
1434
1435 #[test]
1436 fn choose_best_ipv4_prefers_physical_interfaces() {
1437 let interfaces = vec![
1438 iface("docker0", [172, 17, 0, 1]),
1439 iface("enp3s0", [192, 168, 1, 44]),
1440 iface("tailscale0", [100, 115, 92, 5]),
1441 ];
1442 assert_eq!(
1443 choose_best_ipv4(&interfaces),
1444 Some(std::net::Ipv4Addr::new(192, 168, 1, 44))
1445 );
1446 }
1447
1448 #[test]
1449 fn choose_best_ipv4_deprioritizes_link_local() {
1450 let interfaces = vec![
1451 iface("enp0s8", [169, 254, 10, 2]),
1452 iface("wlan0", [10, 0, 0, 23]),
1453 ];
1454 assert_eq!(
1455 choose_best_ipv4(&interfaces),
1456 Some(std::net::Ipv4Addr::new(10, 0, 0, 23))
1457 );
1458 }
1459
1460 #[test]
1461 fn choose_best_ipv4_falls_back_to_loopback() {
1462 let interfaces = vec![iface("lo", [127, 0, 0, 1]), iface("docker0", [0, 0, 0, 0])];
1463 assert_eq!(
1464 choose_best_ipv4(&interfaces),
1465 Some(std::net::Ipv4Addr::new(127, 0, 0, 1))
1466 );
1467 }
1468
1469 #[test]
1470 fn test_get_local_ip_with_explicit_ipv4() {
1471 let result = get_local_ip(Some("192.168.1.100"));
1473 assert!(result.is_ok(), "should accept valid IPv4 address");
1474 let ip = result.unwrap();
1475 assert_eq!(
1476 ip,
1477 std::net::IpAddr::V4(std::net::Ipv4Addr::new(192, 168, 1, 100))
1478 );
1479 }
1480
1481 #[test]
1482 fn test_get_local_ip_with_explicit_loopback() {
1483 let result = get_local_ip(Some("127.0.0.1"));
1485 assert!(result.is_ok(), "should accept loopback address");
1486 let ip = result.unwrap();
1487 assert_eq!(
1488 ip,
1489 std::net::IpAddr::V4(std::net::Ipv4Addr::new(127, 0, 0, 1))
1490 );
1491 }
1492
1493 #[test]
1494 fn test_get_local_ip_rejects_ipv6() {
1495 let result = get_local_ip(Some("::1"));
1497 assert!(result.is_err(), "should reject IPv6 address");
1498 let err = result.unwrap_err();
1499 let err_msg = format!("{err:#}");
1500 assert!(
1501 err_msg.contains("IPv6 address not supported"),
1502 "error should mention IPv6 not supported, got: {err_msg}"
1503 );
1504 assert!(
1505 err_msg.contains("0.0.0.0"),
1506 "error should mention IPv4-only binding, got: {err_msg}"
1507 );
1508 }
1509
1510 #[test]
1511 fn test_get_local_ip_rejects_ipv6_full() {
1512 let result = get_local_ip(Some("2001:db8::1"));
1514 assert!(result.is_err(), "should reject IPv6 address");
1515 let err = result.unwrap_err();
1516 let err_msg = format!("{err:#}");
1517 assert!(
1518 err_msg.contains("IPv6 address not supported"),
1519 "error should mention IPv6 not supported, got: {err_msg}"
1520 );
1521 }
1522
1523 #[test]
1524 fn test_get_local_ip_rejects_invalid_ip() {
1525 let result = get_local_ip(Some("not-an-ip"));
1527 assert!(result.is_err(), "should reject invalid IP format");
1528 let err = result.unwrap_err();
1529 let err_msg = format!("{err:#}");
1530 assert!(
1531 err_msg.contains("invalid IP address"),
1532 "error should mention invalid IP address, got: {err_msg}"
1533 );
1534 }
1535
1536 #[test]
1537 fn test_get_local_ip_rejects_invalid_ipv4() {
1538 let result = get_local_ip(Some("999.999.999.999"));
1540 assert!(result.is_err(), "should reject invalid IPv4 address");
1541 let err = result.unwrap_err();
1542 let err_msg = format!("{err:#}");
1543 assert!(
1544 err_msg.contains("invalid IP address"),
1545 "error should mention invalid IP address, got: {err_msg}"
1546 );
1547 }
1548}