1#[cfg(not(tokio_unstable))]
119compile_error!("tokio_unstable cfg must be enabled; see .cargo/config.toml");
120
121use anyhow::{anyhow, Context};
122use tracing::instrument;
123
124pub mod deploy;
125pub mod port_ranges;
126pub mod protocol;
127pub mod streams;
128pub mod tls;
129pub mod tracelog;
130
131#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, serde::Serialize, serde::Deserialize)]
137pub enum NetworkProfile {
138 #[default]
141 Datacenter,
142 Internet,
145}
146
147impl std::fmt::Display for NetworkProfile {
148 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
149 match self {
150 Self::Datacenter => write!(f, "datacenter"),
151 Self::Internet => write!(f, "internet"),
152 }
153 }
154}
155
156impl std::str::FromStr for NetworkProfile {
157 type Err = String;
158 fn from_str(s: &str) -> Result<Self, Self::Err> {
159 match s.to_lowercase().as_str() {
160 "datacenter" => Ok(Self::Datacenter),
161 "internet" => Ok(Self::Internet),
162 _ => Err(format!(
163 "invalid network profile '{}', expected 'datacenter' or 'internet'",
164 s
165 )),
166 }
167 }
168}
169
170pub const DATACENTER_REMOTE_COPY_BUFFER_SIZE: usize = 16 * 1024 * 1024;
172
173pub const INTERNET_REMOTE_COPY_BUFFER_SIZE: usize = 2 * 1024 * 1024;
175
176impl NetworkProfile {
177 pub fn default_remote_copy_buffer_size(&self) -> usize {
183 match self {
184 Self::Datacenter => DATACENTER_REMOTE_COPY_BUFFER_SIZE,
185 Self::Internet => INTERNET_REMOTE_COPY_BUFFER_SIZE,
186 }
187 }
188}
189
190#[derive(Debug, Clone)]
194pub struct TcpConfig {
195 pub port_ranges: Option<String>,
197 pub conn_timeout_sec: u64,
199 pub network_profile: NetworkProfile,
201 pub buffer_size: Option<usize>,
203 pub max_connections: usize,
205 pub pending_writes_multiplier: usize,
207}
208
209pub const DEFAULT_PENDING_WRITES_MULTIPLIER: usize = 4;
211
212impl Default for TcpConfig {
213 fn default() -> Self {
214 Self {
215 port_ranges: None,
216 conn_timeout_sec: 15,
217 network_profile: NetworkProfile::default(),
218 buffer_size: None,
219 max_connections: 100,
220 pending_writes_multiplier: DEFAULT_PENDING_WRITES_MULTIPLIER,
221 }
222 }
223}
224
225impl TcpConfig {
226 pub fn with_timeout(conn_timeout_sec: u64) -> Self {
228 Self {
229 port_ranges: None,
230 conn_timeout_sec,
231 network_profile: NetworkProfile::default(),
232 buffer_size: None,
233 max_connections: 100,
234 pending_writes_multiplier: DEFAULT_PENDING_WRITES_MULTIPLIER,
235 }
236 }
237 pub fn with_port_ranges(mut self, ranges: impl Into<String>) -> Self {
239 self.port_ranges = Some(ranges.into());
240 self
241 }
242 pub fn with_network_profile(mut self, profile: NetworkProfile) -> Self {
244 self.network_profile = profile;
245 self
246 }
247 pub fn with_buffer_size(mut self, size: usize) -> Self {
249 self.buffer_size = Some(size);
250 self
251 }
252 pub fn with_max_connections(mut self, max: usize) -> Self {
254 self.max_connections = max;
255 self
256 }
257 pub fn with_pending_writes_multiplier(mut self, multiplier: usize) -> Self {
259 self.pending_writes_multiplier = multiplier;
260 self
261 }
262 pub fn effective_buffer_size(&self) -> usize {
264 self.buffer_size
265 .unwrap_or_else(|| self.network_profile.default_remote_copy_buffer_size())
266 }
267}
268
269#[derive(Debug, PartialEq, Eq, Clone, Hash)]
270pub struct SshSession {
271 pub user: Option<String>,
272 pub host: String,
273 pub port: Option<u16>,
274}
275
276impl SshSession {
277 pub fn local() -> Self {
278 Self {
279 user: None,
280 host: "localhost".to_string(),
281 port: None,
282 }
283 }
284}
285
286pub use common::is_localhost;
288
289async fn setup_ssh_session(
290 session: &SshSession,
291) -> anyhow::Result<std::sync::Arc<openssh::Session>> {
292 let host = session.host.as_str();
293 let destination = match (session.user.as_deref(), session.port) {
294 (Some(user), Some(port)) => format!("ssh://{user}@{host}:{port}"),
295 (None, Some(port)) => format!("ssh://{}:{}", session.host, port),
296 (Some(user), None) => format!("ssh://{user}@{host}"),
297 (None, None) => format!("ssh://{host}"),
298 };
299 tracing::debug!("Connecting to SSH destination: {}", destination);
300 let session = std::sync::Arc::new(
301 openssh::Session::connect(destination, openssh::KnownHosts::Accept)
302 .await
303 .context("Failed to establish SSH connection")?,
304 );
305 Ok(session)
306}
307
308#[instrument]
309pub async fn get_remote_home_for_session(
310 session: &SshSession,
311) -> anyhow::Result<std::path::PathBuf> {
312 let ssh_session = setup_ssh_session(session).await?;
313 let home = get_remote_home(&ssh_session).await?;
314 Ok(std::path::PathBuf::from(home))
315}
316
317#[instrument]
318pub async fn wait_for_rcpd_process(
319 process: openssh::Child<std::sync::Arc<openssh::Session>>,
320) -> anyhow::Result<()> {
321 tracing::info!("Waiting on rcpd server on: {:?}", process);
322 let output = tokio::time::timeout(
324 std::time::Duration::from_secs(10),
325 process.wait_with_output(),
326 )
327 .await
328 .context("Timeout waiting for rcpd process to exit")?
329 .context("Failed to wait for rcpd process")?;
330 if !output.status.success() {
331 let stdout = String::from_utf8_lossy(&output.stdout);
332 let stderr = String::from_utf8_lossy(&output.stderr);
333 tracing::error!(
334 "rcpd command failed on remote host, status code: {:?}\nstdout:\n{}\nstderr:\n{}",
335 output.status.code(),
336 stdout,
337 stderr
338 );
339 return Err(anyhow!(
340 "rcpd command failed on remote host, status code: {:?}",
341 output.status.code(),
342 ));
343 }
344 if !output.stderr.is_empty() {
346 let stderr = String::from_utf8_lossy(&output.stderr);
347 tracing::debug!("rcpd stderr output:\n{}", stderr);
348 }
349 Ok(())
350}
351
352pub(crate) fn shell_escape(s: &str) -> String {
356 format!("'{}'", s.replace('\'', r"'\''"))
357}
358
359pub async fn get_remote_home(session: &std::sync::Arc<openssh::Session>) -> anyhow::Result<String> {
377 if let Ok(home_override) = std::env::var("RCP_REMOTE_HOME_OVERRIDE") {
378 if !home_override.is_empty() {
379 return Ok(home_override);
380 }
381 }
382 let output = session
383 .command("sh")
384 .arg("-c")
385 .arg("echo \"${HOME:?HOME not set}\"")
386 .output()
387 .await
388 .context("failed to check HOME environment variable on remote host")?;
389
390 if !output.status.success() {
391 let stderr = String::from_utf8_lossy(&output.stderr);
392 anyhow::bail!(
393 "HOME environment variable is not set on remote host\n\
394 \n\
395 stderr: {}\n\
396 \n\
397 The HOME environment variable is required for rcpd deployment and discovery.\n\
398 Please ensure your SSH configuration preserves environment variables.",
399 stderr
400 );
401 }
402
403 let home = String::from_utf8_lossy(&output.stdout).trim().to_string();
404
405 if home.is_empty() {
406 anyhow::bail!(
407 "HOME environment variable is empty on remote host\n\
408 \n\
409 The HOME environment variable is required for rcpd deployment and discovery.\n\
410 Please ensure your SSH configuration sets HOME correctly."
411 );
412 }
413
414 Ok(home)
415}
416
417#[cfg(test)]
418mod shell_escape_tests {
419 use super::*;
420
421 #[test]
422 fn test_shell_escape_simple() {
423 assert_eq!(shell_escape("simple"), "'simple'");
424 }
425
426 #[test]
427 fn test_shell_escape_with_spaces() {
428 assert_eq!(shell_escape("path with spaces"), "'path with spaces'");
429 }
430
431 #[test]
432 fn test_shell_escape_with_single_quote() {
433 assert_eq!(
435 shell_escape("path'with'quotes"),
436 r"'path'\''with'\''quotes'"
437 );
438 }
439
440 #[test]
441 fn test_shell_escape_injection_attempt() {
442 assert_eq!(shell_escape("foo; rm -rf /"), "'foo; rm -rf /'");
444 }
446
447 #[test]
448 fn test_shell_escape_special_chars() {
449 assert_eq!(shell_escape("$PATH && echo pwned"), "'$PATH && echo pwned'");
450 }
452}
453
454trait DiscoverySession {
455 fn test_executable<'a>(
456 &'a self,
457 path: &'a str,
458 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = anyhow::Result<bool>> + Send + 'a>>;
459 fn which<'a>(
460 &'a self,
461 binary: &'a str,
462 ) -> std::pin::Pin<
463 Box<dyn std::future::Future<Output = anyhow::Result<Option<String>>> + Send + 'a>,
464 >;
465 fn remote_home<'a>(
466 &'a self,
467 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = anyhow::Result<String>> + Send + 'a>>;
468}
469
470struct RealDiscoverySession<'a> {
471 session: &'a std::sync::Arc<openssh::Session>,
472}
473
474impl<'a> DiscoverySession for RealDiscoverySession<'a> {
475 fn test_executable<'b>(
476 &'b self,
477 path: &'b str,
478 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = anyhow::Result<bool>> + Send + 'b>>
479 {
480 Box::pin(async move {
481 let output = self
482 .session
483 .command("sh")
484 .arg("-c")
485 .arg(format!("test -x {}", shell_escape(path)))
486 .output()
487 .await?;
488 Ok(output.status.success())
489 })
490 }
491 fn which<'b>(
492 &'b self,
493 binary: &'b str,
494 ) -> std::pin::Pin<
495 Box<dyn std::future::Future<Output = anyhow::Result<Option<String>>> + Send + 'b>,
496 > {
497 Box::pin(async move {
498 let output = self.session.command("which").arg(binary).output().await?;
499 if output.status.success() {
500 let path = String::from_utf8_lossy(&output.stdout).trim().to_string();
501 if !path.is_empty() {
502 return Ok(Some(path));
503 }
504 }
505 Ok(None)
506 })
507 }
508 fn remote_home<'b>(
509 &'b self,
510 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = anyhow::Result<String>> + Send + 'b>>
511 {
512 Box::pin(get_remote_home(self.session))
513 }
514}
515
516async fn discover_rcpd_path(
529 session: &std::sync::Arc<openssh::Session>,
530 explicit_path: Option<&str>,
531) -> anyhow::Result<String> {
532 let real_session = RealDiscoverySession { session };
533 discover_rcpd_path_internal(&real_session, explicit_path, None).await
534}
535
536async fn discover_rcpd_path_internal<S: DiscoverySession + ?Sized>(
537 session: &S,
538 explicit_path: Option<&str>,
539 current_exe_override: Option<std::path::PathBuf>,
540) -> anyhow::Result<String> {
541 let local_version = common::version::ProtocolVersion::current();
542 if let Some(path) = explicit_path {
544 tracing::debug!("Trying explicit rcpd path: {}", path);
545 if session.test_executable(path).await? {
546 tracing::info!("Found rcpd at explicit path: {}", path);
547 return Ok(path.to_string());
548 }
549 return Err(anyhow::anyhow!(
552 "rcpd binary not found or not executable at explicit path: {}",
553 path
554 ));
555 }
556 if let Ok(current_exe) = current_exe_override
558 .map(Ok)
559 .unwrap_or_else(std::env::current_exe)
560 {
561 if let Some(bin_dir) = current_exe.parent() {
562 let path = bin_dir.join("rcpd").display().to_string();
563 tracing::debug!("Trying same directory as rcp: {}", path);
564 if session.test_executable(&path).await? {
565 tracing::info!("Found rcpd in same directory as rcp: {}", path);
566 return Ok(path);
567 }
568 }
569 }
570 tracing::debug!("Trying to find rcpd in PATH");
572 if let Some(path) = session.which("rcpd").await? {
573 tracing::info!("Found rcpd in PATH: {}", path);
574 return Ok(path);
575 }
576 let cache_path = match session.remote_home().await {
579 Ok(home) => {
580 let path = format!("{}/.cache/rcp/bin/rcpd-{}", home, local_version.semantic);
581 tracing::debug!("Trying deployed cache path: {}", path);
582 if session.test_executable(&path).await? {
583 tracing::info!("Found rcpd in deployed cache: {}", path);
584 return Ok(path);
585 }
586 Some(path)
587 }
588 Err(e) => {
589 tracing::debug!(
590 "HOME not set on remote host, skipping cache directory check: {:#}",
591 e
592 );
593 None
594 }
595 };
596 let mut searched = vec![];
598 searched.push("- Same directory as local rcp binary".to_string());
599 searched.push("- PATH (via 'which rcpd')".to_string());
600 if let Some(path) = cache_path.as_ref() {
601 searched.push(format!("- Deployed cache: {}", path));
602 } else {
603 searched.push("- Deployed cache: (skipped, HOME not set)".to_string());
604 }
605 if let Some(path) = explicit_path {
606 searched.insert(
607 0,
608 format!("- Explicit path: {} (not found or not executable)", path),
609 );
610 }
611 Err(anyhow::anyhow!(
612 "rcpd binary not found on remote host\n\
613 \n\
614 Searched in:\n\
615 {}\n\
616 \n\
617 Options:\n\
618 - Use automatic deployment: rcp --auto-deploy-rcpd ...\n\
619 - Install rcpd manually: cargo install rcp-tools-rcp --version {}\n\
620 - Specify explicit path: rcp --rcpd-path=/path/to/rcpd ...",
621 searched.join("\n"),
622 local_version.semantic
623 ))
624}
625
626async fn try_discover_and_check_version(
631 session: &std::sync::Arc<openssh::Session>,
632 explicit_path: Option<&str>,
633 remote_host: &str,
634) -> anyhow::Result<String> {
635 let rcpd_path = discover_rcpd_path(session, explicit_path).await?;
637 check_rcpd_version(session, &rcpd_path, remote_host).await?;
639 Ok(rcpd_path)
640}
641
642async fn check_rcpd_version(
646 session: &std::sync::Arc<openssh::Session>,
647 rcpd_path: &str,
648 remote_host: &str,
649) -> anyhow::Result<()> {
650 let local_version = common::version::ProtocolVersion::current();
651
652 tracing::debug!("Checking rcpd version on remote host: {}", remote_host);
653
654 let output = session
656 .command(rcpd_path)
657 .arg("--protocol-version")
658 .output()
659 .await
660 .context("Failed to execute rcpd --protocol-version on remote host")?;
661
662 if !output.status.success() {
663 let stderr = String::from_utf8_lossy(&output.stderr);
664 return Err(anyhow::anyhow!(
665 "rcpd --protocol-version failed on remote host '{}'\n\
666 \n\
667 stderr: {}\n\
668 \n\
669 This may indicate an old version of rcpd that does not support --protocol-version.\n\
670 Please install a matching version of rcpd on the remote host:\n\
671 - cargo install rcp-tools-rcp --version {}",
672 remote_host,
673 stderr,
674 local_version.semantic
675 ));
676 }
677
678 let stdout = String::from_utf8_lossy(&output.stdout);
679 let remote_version = common::version::ProtocolVersion::from_json(stdout.trim())
680 .context("Failed to parse rcpd version JSON from remote host")?;
681
682 tracing::info!(
683 "Local version: {}, Remote version: {}",
684 local_version,
685 remote_version
686 );
687
688 if !local_version.is_compatible_with(&remote_version) {
689 return Err(anyhow::anyhow!(
690 "rcpd version mismatch\n\
691 \n\
692 Local: rcp {}\n\
693 Remote: rcpd {} on host '{}'\n\
694 \n\
695 The rcpd version on the remote host must exactly match the rcp version.\n\
696 \n\
697 To fix this, install the matching version on the remote host:\n\
698 - ssh {} 'cargo install rcp-tools-rcp --version {}'",
699 local_version,
700 remote_version,
701 remote_host,
702 shell_escape(remote_host),
703 local_version.semantic
704 ));
705 }
706
707 Ok(())
708}
709
710#[derive(Debug, Clone)]
712pub struct RcpdConnectionInfo {
713 pub addr: std::net::SocketAddr,
715 pub fingerprint: Option<tls::Fingerprint>,
717}
718
719pub struct RcpdProcess {
721 pub child: openssh::Child<std::sync::Arc<openssh::Session>>,
723 pub conn_info: RcpdConnectionInfo,
725 _stderr_drain: tokio::task::JoinHandle<()>,
727 _stdout_drain: Option<tokio::task::JoinHandle<()>>,
729}
730
731#[allow(clippy::too_many_arguments)]
732#[instrument]
733pub async fn start_rcpd(
734 rcpd_config: &protocol::RcpdConfig,
735 session: &SshSession,
736 explicit_rcpd_path: Option<&str>,
737 auto_deploy_rcpd: bool,
738 bind_ip: Option<&str>,
739 role: protocol::RcpdRole,
740) -> anyhow::Result<RcpdProcess> {
741 tracing::info!("Starting rcpd server on: {:?}", session);
742 let remote_host = &session.host;
743 let ssh_session = setup_ssh_session(session).await?;
744 let rcpd_path =
746 match try_discover_and_check_version(&ssh_session, explicit_rcpd_path, remote_host).await {
747 Ok(path) => {
748 path
750 }
751 Err(e) => {
752 if auto_deploy_rcpd {
754 tracing::info!(
755 "rcpd not found or version mismatch, attempting auto-deployment"
756 );
757 let local_rcpd = deploy::find_local_rcpd_binary()
759 .context("failed to find local rcpd binary for deployment")?;
760 tracing::info!("Found local rcpd binary at {}", local_rcpd.display());
761 let local_version = common::version::ProtocolVersion::current();
763 let deployed_path = deploy::deploy_rcpd(
765 &ssh_session,
766 &local_rcpd,
767 &local_version.semantic,
768 remote_host,
769 )
770 .await
771 .context("failed to deploy rcpd to remote host")?;
772 tracing::info!("Successfully deployed rcpd to {}", deployed_path);
773 if let Err(e) = deploy::cleanup_old_versions(&ssh_session, 3).await {
775 tracing::warn!("failed to cleanup old versions (non-fatal): {:#}", e);
776 }
777 deployed_path
778 } else {
779 return Err(e);
781 }
782 }
783 };
784 let rcpd_args = rcpd_config.to_args();
786 tracing::debug!("rcpd arguments: {:?}", rcpd_args);
787 let mut cmd = ssh_session.arc_command(&rcpd_path);
788 cmd.arg("--role").arg(role.to_string()).args(rcpd_args);
789 if let Some(ip) = bind_ip {
791 tracing::debug!("passing --bind-ip {} to rcpd", ip);
792 cmd.arg("--bind-ip").arg(ip);
793 }
794 cmd.stdin(openssh::Stdio::piped());
797 cmd.stdout(openssh::Stdio::piped());
798 cmd.stderr(openssh::Stdio::piped());
799 tracing::info!("Will run remotely: {cmd:?}");
800 let mut child = cmd.spawn().await.context("Failed to spawn rcpd command")?;
801 let stderr = child.stderr().take().context("rcpd stderr not available")?;
806 let mut stderr_reader = tokio::io::BufReader::new(stderr);
807 let mut line = String::new();
808 use tokio::io::AsyncBufReadExt;
809 stderr_reader
810 .read_line(&mut line)
811 .await
812 .context("failed to read connection info from rcpd")?;
813 let line = line.trim();
814 let host_stderr = session.host.clone();
817 let stderr_drain = tokio::spawn(async move {
818 let mut line = String::new();
819 loop {
820 line.clear();
821 match stderr_reader.read_line(&mut line).await {
822 Ok(0) => break, Ok(_) => {
824 let trimmed = line.trim();
825 if !trimmed.is_empty() {
826 tracing::debug!(host = %host_stderr, "rcpd stderr: {}", trimmed);
827 }
828 }
829 Err(e) => {
830 tracing::debug!(host = %host_stderr, "rcpd stderr read error: {:#}", e);
831 break;
832 }
833 }
834 }
835 });
836 let stdout_drain = if let Some(stdout) = child.stdout().take() {
839 let host_stdout = session.host.clone();
840 let mut stdout_reader = tokio::io::BufReader::new(stdout);
841 Some(tokio::spawn(async move {
842 let mut line = String::new();
843 loop {
844 line.clear();
845 match stdout_reader.read_line(&mut line).await {
846 Ok(0) => break, Ok(_) => {
848 let trimmed = line.trim();
849 if !trimmed.is_empty() {
850 tracing::debug!(host = %host_stdout, "rcpd stdout: {}", trimmed);
851 }
852 }
853 Err(e) => {
854 tracing::debug!(host = %host_stdout, "rcpd stdout read error: {:#}", e);
855 break;
856 }
857 }
858 }
859 }))
860 } else {
861 None
862 };
863 tracing::debug!("rcpd connection line: {}", line);
864 let conn_info = if let Some(rest) = line.strip_prefix("RCP_TLS ") {
865 let parts: Vec<&str> = rest.split_whitespace().collect();
867 if parts.len() != 2 {
868 anyhow::bail!("invalid RCP_TLS line from rcpd: {}", line);
869 }
870 let addr = parts[0]
871 .parse()
872 .with_context(|| format!("invalid address in RCP_TLS line: {}", parts[0]))?;
873 let fingerprint = tls::fingerprint_from_hex(parts[1])
874 .with_context(|| format!("invalid fingerprint in RCP_TLS line: {}", parts[1]))?;
875 RcpdConnectionInfo {
876 addr,
877 fingerprint: Some(fingerprint),
878 }
879 } else if let Some(rest) = line.strip_prefix("RCP_TCP ") {
880 let addr = rest
882 .trim()
883 .parse()
884 .with_context(|| format!("invalid address in RCP_TCP line: {}", rest))?;
885 RcpdConnectionInfo {
886 addr,
887 fingerprint: None,
888 }
889 } else {
890 anyhow::bail!(
891 "unexpected output from rcpd (expected RCP_TLS or RCP_TCP): {}",
892 line
893 );
894 };
895 tracing::info!(
896 "rcpd listening on {} (encryption={})",
897 conn_info.addr,
898 conn_info.fingerprint.is_some()
899 );
900 Ok(RcpdProcess {
901 child,
902 conn_info,
903 _stderr_drain: stderr_drain,
904 _stdout_drain: stdout_drain,
905 })
906}
907
908fn get_local_ip(explicit_bind_ip: Option<&str>) -> anyhow::Result<std::net::IpAddr> {
913 if let Some(ip_str) = explicit_bind_ip {
915 let ip = ip_str
916 .parse::<std::net::IpAddr>()
917 .with_context(|| format!("invalid IP address: {}", ip_str))?;
918 match ip {
919 std::net::IpAddr::V4(ipv4) => {
920 tracing::debug!("using explicit bind IP: {}", ipv4);
921 return Ok(std::net::IpAddr::V4(ipv4));
922 }
923 std::net::IpAddr::V6(_) => {
924 anyhow::bail!(
925 "IPv6 address not supported for binding (got {}). \
926 TCP endpoints bind to 0.0.0.0 (IPv4 only)",
927 ip
928 );
929 }
930 }
931 }
932 if let Some(ipv4) = try_ipv4_via_kernel_routing()? {
934 return Ok(std::net::IpAddr::V4(ipv4));
935 }
936 tracing::debug!("routing-based detection failed, falling back to interface enumeration");
938 let interfaces = collect_ipv4_interfaces().context("Failed to enumerate network interfaces")?;
939 if let Some(ipv4) = choose_best_ipv4(&interfaces) {
940 tracing::debug!("using IPv4 address from interface scan: {}", ipv4);
941 return Ok(std::net::IpAddr::V4(ipv4));
942 }
943 anyhow::bail!("No IPv4 interfaces found (TCP endpoints require IPv4 as they bind to 0.0.0.0)")
944}
945
946fn try_ipv4_via_kernel_routing() -> anyhow::Result<Option<std::net::Ipv4Addr>> {
947 let private_ips = ["10.0.0.1:80", "172.16.0.1:80", "192.168.1.1:80"];
950 for addr_str in &private_ips {
951 let addr = addr_str
952 .parse::<std::net::SocketAddr>()
953 .expect("hardcoded socket addresses are valid");
954 let socket = match std::net::UdpSocket::bind("0.0.0.0:0") {
955 Ok(socket) => socket,
956 Err(err) => {
957 tracing::debug!(?err, "failed to bind UDP socket for routing detection");
958 continue;
959 }
960 };
961 if let Err(err) = socket.connect(addr) {
962 tracing::debug!(?err, "connect() failed for routing target {}", addr);
963 continue;
964 }
965 match socket.local_addr() {
966 Ok(std::net::SocketAddr::V4(local_addr)) => {
967 let ipv4 = *local_addr.ip();
968 if !ipv4.is_loopback() && !ipv4.is_unspecified() {
969 tracing::debug!(
970 "using IPv4 address from kernel routing (via {}): {}",
971 addr,
972 ipv4
973 );
974 return Ok(Some(ipv4));
975 }
976 }
977 Ok(_) => {
978 tracing::debug!("kernel routing returned IPv6 despite IPv4 bind, ignoring");
979 }
980 Err(err) => {
981 tracing::debug!(?err, "local_addr() failed for routing-based detection");
982 }
983 }
984 }
985 Ok(None)
986}
987
988#[derive(Clone, Debug, PartialEq, Eq)]
989struct InterfaceIpv4 {
990 name: String,
991 addr: std::net::Ipv4Addr,
992}
993
994fn collect_ipv4_interfaces() -> anyhow::Result<Vec<InterfaceIpv4>> {
995 use if_addrs::get_if_addrs;
996 let mut interfaces = Vec::new();
997 for iface in get_if_addrs()? {
998 if let std::net::IpAddr::V4(ipv4) = iface.addr.ip() {
999 interfaces.push(InterfaceIpv4 {
1000 name: iface.name,
1001 addr: ipv4,
1002 });
1003 }
1004 }
1005 Ok(interfaces)
1006}
1007
1008fn choose_best_ipv4(interfaces: &[InterfaceIpv4]) -> Option<std::net::Ipv4Addr> {
1009 interfaces
1010 .iter()
1011 .filter(|iface| !iface.addr.is_unspecified())
1012 .min_by_key(|iface| interface_priority(&iface.name, &iface.addr))
1013 .map(|iface| iface.addr)
1014}
1015
1016fn interface_priority(
1017 name: &str,
1018 addr: &std::net::Ipv4Addr,
1019) -> (InterfaceCategory, u8, u8, std::net::Ipv4Addr) {
1020 (
1021 classify_interface(name, addr),
1022 if addr.is_link_local() { 1 } else { 0 },
1023 if addr.is_private() { 1 } else { 0 },
1024 *addr,
1025 )
1026}
1027
1028#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd)]
1029enum InterfaceCategory {
1030 Preferred = 0,
1031 Normal = 1,
1032 Virtual = 2,
1033 Loopback = 3,
1034}
1035
1036fn classify_interface(name: &str, addr: &std::net::Ipv4Addr) -> InterfaceCategory {
1037 if addr.is_loopback() {
1038 return InterfaceCategory::Loopback;
1039 }
1040 let normalized = normalize_interface_name(name);
1041 if is_virtual_interface(&normalized) {
1042 return InterfaceCategory::Virtual;
1043 }
1044 if is_preferred_physical_interface(&normalized) {
1045 return InterfaceCategory::Preferred;
1046 }
1047 InterfaceCategory::Normal
1048}
1049
1050fn normalize_interface_name(original: &str) -> String {
1051 let mut normalized = String::with_capacity(original.len());
1052 for ch in original.chars() {
1053 if ch.is_ascii_alphanumeric() {
1054 normalized.push(ch.to_ascii_lowercase());
1055 }
1056 }
1057 normalized
1058}
1059
1060fn is_virtual_interface(name: &str) -> bool {
1061 const VIRTUAL_PREFIXES: &[&str] = &[
1062 "br",
1063 "docker",
1064 "veth",
1065 "virbr",
1066 "vmnet",
1067 "wg",
1068 "tailscale",
1069 "zt",
1070 "zerotier",
1071 "tap",
1072 "tun",
1073 "utun",
1074 "ham",
1075 "vpn",
1076 "lo",
1077 "lxc",
1078 ];
1079 VIRTUAL_PREFIXES
1080 .iter()
1081 .any(|prefix| name.starts_with(prefix))
1082 || name.contains("virtual")
1083}
1084
1085fn is_preferred_physical_interface(name: &str) -> bool {
1086 const PHYSICAL_PREFIXES: &[&str] = &[
1087 "en", "eth", "em", "eno", "ens", "enp", "wl", "ww", "wlan", "ethernet", "lan", "wifi",
1088 ];
1089 PHYSICAL_PREFIXES
1090 .iter()
1091 .any(|prefix| name.starts_with(prefix))
1092}
1093
1094#[instrument]
1096pub fn get_random_server_name() -> String {
1097 rand::random_iter::<u8>()
1098 .filter(|b| b.is_ascii_alphanumeric())
1099 .take(20)
1100 .map(char::from)
1101 .collect()
1102}
1103
1104#[instrument(skip(config))]
1112pub async fn create_tcp_control_listener(
1113 config: &TcpConfig,
1114 bind_ip: Option<&str>,
1115) -> anyhow::Result<tokio::net::TcpListener> {
1116 let bind_addr = if let Some(ip_str) = bind_ip {
1117 let ip = ip_str
1118 .parse::<std::net::IpAddr>()
1119 .with_context(|| format!("invalid IP address: {}", ip_str))?;
1120 std::net::SocketAddr::new(ip, 0)
1121 } else {
1122 std::net::SocketAddr::new(std::net::IpAddr::V4(std::net::Ipv4Addr::UNSPECIFIED), 0)
1123 };
1124 let listener = if let Some(ranges_str) = config.port_ranges.as_deref() {
1125 let ranges = port_ranges::PortRanges::parse(ranges_str)?;
1126 ranges.bind_tcp_listener(bind_addr.ip()).await?
1127 } else {
1128 tokio::net::TcpListener::bind(bind_addr).await?
1129 };
1130 let local_addr = listener.local_addr()?;
1131 tracing::info!("TCP control listener bound to {}", local_addr);
1132 Ok(listener)
1133}
1134
1135#[instrument(skip(config))]
1139pub async fn create_tcp_data_listener(
1140 config: &TcpConfig,
1141 bind_ip: Option<&str>,
1142) -> anyhow::Result<tokio::net::TcpListener> {
1143 let bind_addr = if let Some(ip_str) = bind_ip {
1144 let ip = ip_str
1145 .parse::<std::net::IpAddr>()
1146 .with_context(|| format!("invalid IP address: {}", ip_str))?;
1147 std::net::SocketAddr::new(ip, 0)
1148 } else {
1149 std::net::SocketAddr::new(std::net::IpAddr::V4(std::net::Ipv4Addr::UNSPECIFIED), 0)
1150 };
1151 let listener = if let Some(ranges_str) = config.port_ranges.as_deref() {
1152 let ranges = port_ranges::PortRanges::parse(ranges_str)?;
1153 ranges.bind_tcp_listener(bind_addr.ip()).await?
1154 } else {
1155 tokio::net::TcpListener::bind(bind_addr).await?
1156 };
1157 let local_addr = listener.local_addr()?;
1158 tracing::info!("TCP data listener bound to {}", local_addr);
1159 Ok(listener)
1160}
1161
1162pub fn get_tcp_listener_addr(
1166 listener: &tokio::net::TcpListener,
1167 bind_ip: Option<&str>,
1168) -> anyhow::Result<std::net::SocketAddr> {
1169 let local_addr = listener.local_addr()?;
1170 if local_addr.ip().is_unspecified() {
1171 let local_ip = get_local_ip(bind_ip).context("failed to get local IP address")?;
1172 Ok(std::net::SocketAddr::new(local_ip, local_addr.port()))
1173 } else {
1174 Ok(local_addr)
1175 }
1176}
1177
1178#[instrument]
1180pub async fn connect_tcp_control(
1181 addr: std::net::SocketAddr,
1182 timeout_sec: u64,
1183) -> anyhow::Result<tokio::net::TcpStream> {
1184 let stream = tokio::time::timeout(
1185 std::time::Duration::from_secs(timeout_sec),
1186 tokio::net::TcpStream::connect(addr),
1187 )
1188 .await
1189 .with_context(|| format!("connection to {} timed out after {}s", addr, timeout_sec))?
1190 .with_context(|| format!("failed to connect to {}", addr))?;
1191 stream.set_nodelay(true)?;
1192 tracing::debug!("connected to TCP control server at {}", addr);
1193 Ok(stream)
1194}
1195
1196pub fn configure_tcp_buffers(stream: &tokio::net::TcpStream, profile: NetworkProfile) {
1200 use socket2::SockRef;
1201 let (send_buf, recv_buf) = match profile {
1202 NetworkProfile::Datacenter => (16 * 1024 * 1024, 16 * 1024 * 1024),
1203 NetworkProfile::Internet => (2 * 1024 * 1024, 2 * 1024 * 1024),
1204 };
1205 let sock_ref = SockRef::from(stream);
1206 if let Err(err) = sock_ref.set_send_buffer_size(send_buf) {
1207 tracing::warn!("failed to set TCP send buffer size: {err:#}");
1208 }
1209 if let Err(err) = sock_ref.set_recv_buffer_size(recv_buf) {
1210 tracing::warn!("failed to set TCP receive buffer size: {err:#}");
1211 }
1212 if let (Ok(send), Ok(recv)) = (sock_ref.send_buffer_size(), sock_ref.recv_buffer_size()) {
1213 tracing::debug!(
1214 "TCP socket buffer sizes: send={} recv={}",
1215 bytesize::ByteSize(send as u64),
1216 bytesize::ByteSize(recv as u64),
1217 );
1218 }
1219}
1220
1221#[cfg(test)]
1222mod tests {
1223 use super::*;
1224 use std::collections::HashMap;
1225 use std::path::PathBuf;
1226 use std::sync::Mutex;
1227
1228 struct MockDiscoverySession {
1229 test_responses: HashMap<String, bool>,
1230 which_response: Option<String>,
1231 home_response: Result<String, String>,
1232 calls: Mutex<Vec<String>>,
1233 }
1234
1235 impl Default for MockDiscoverySession {
1236 fn default() -> Self {
1237 Self {
1238 test_responses: HashMap::new(),
1239 which_response: None,
1240 home_response: Err("HOME not set".to_string()),
1241 calls: Mutex::new(Vec::new()),
1242 }
1243 }
1244 }
1245
1246 impl MockDiscoverySession {
1247 fn new() -> Self {
1248 Self::default()
1249 }
1250
1251 fn with_home(mut self, home: Option<&str>) -> Self {
1252 self.home_response = match home {
1253 Some(home) => Ok(home.to_string()),
1254 None => Err("HOME not set".to_string()),
1255 };
1256 self
1257 }
1258 fn with_which(mut self, path: Option<&str>) -> Self {
1259 self.which_response = path.map(|p| p.to_string());
1260 self
1261 }
1262 fn set_test_response(&mut self, path: &str, exists: bool) {
1263 self.test_responses.insert(path.to_string(), exists);
1264 }
1265 fn calls(&self) -> Vec<String> {
1266 self.calls.lock().unwrap().clone()
1267 }
1268 }
1269
1270 impl DiscoverySession for MockDiscoverySession {
1271 fn test_executable<'a>(
1272 &'a self,
1273 path: &'a str,
1274 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = anyhow::Result<bool>> + Send + 'a>>
1275 {
1276 self.calls.lock().unwrap().push(format!("test:{}", path));
1277 let exists = self.test_responses.get(path).copied().unwrap_or(false);
1278 Box::pin(async move { Ok(exists) })
1279 }
1280 fn which<'a>(
1281 &'a self,
1282 binary: &'a str,
1283 ) -> std::pin::Pin<
1284 Box<dyn std::future::Future<Output = anyhow::Result<Option<String>>> + Send + 'a>,
1285 > {
1286 self.calls.lock().unwrap().push(format!("which:{}", binary));
1287 let result = self.which_response.clone();
1288 Box::pin(async move { Ok(result) })
1289 }
1290 fn remote_home<'a>(
1291 &'a self,
1292 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = anyhow::Result<String>> + Send + 'a>>
1293 {
1294 self.calls.lock().unwrap().push("home".to_string());
1295 let result = self.home_response.clone();
1296 Box::pin(async move {
1297 match result {
1298 Ok(home) => Ok(home),
1299 Err(e) => Err(anyhow::anyhow!(e)),
1300 }
1301 })
1302 }
1303 }
1304
1305 #[tokio::test]
1306 async fn discover_rcpd_prefers_explicit_path() {
1307 let mut session = MockDiscoverySession::new();
1308 session.set_test_response("/opt/rcpd", true);
1309 let path = discover_rcpd_path_internal(&session, Some("/opt/rcpd"), None)
1310 .await
1311 .expect("should return explicit path");
1312 assert_eq!(path, "/opt/rcpd");
1313 assert_eq!(session.calls(), vec!["test:/opt/rcpd"]);
1314 }
1315
1316 #[tokio::test]
1317 async fn discover_rcpd_explicit_path_errors_without_fallbacks() {
1318 let session = MockDiscoverySession::new();
1319 let err = discover_rcpd_path_internal(&session, Some("/missing/rcpd"), None)
1320 .await
1321 .expect_err("should fail when explicit path is missing");
1322 assert!(
1323 err.to_string()
1324 .contains("rcpd binary not found or not executable"),
1325 "unexpected error: {err}"
1326 );
1327 assert_eq!(session.calls(), vec!["test:/missing/rcpd"]);
1328 }
1329
1330 #[tokio::test]
1331 async fn discover_rcpd_uses_same_dir_first() {
1332 let mut session = MockDiscoverySession::new();
1333 session.set_test_response("/custom/bin/rcpd", true);
1334 let path =
1335 discover_rcpd_path_internal(&session, None, Some(PathBuf::from("/custom/bin/rcp")))
1336 .await
1337 .expect("should find in same directory");
1338 assert_eq!(path, "/custom/bin/rcpd");
1339 assert_eq!(session.calls(), vec!["test:/custom/bin/rcpd"]);
1340 }
1341
1342 #[tokio::test]
1343 async fn discover_rcpd_falls_back_to_path_after_same_dir() {
1344 let mut session = MockDiscoverySession::new().with_which(Some("/usr/bin/rcpd"));
1345 session.set_test_response("/custom/bin/rcpd", false);
1346 let path =
1347 discover_rcpd_path_internal(&session, None, Some(PathBuf::from("/custom/bin/rcp")))
1348 .await
1349 .expect("should find in PATH after same dir miss");
1350 assert_eq!(path, "/usr/bin/rcpd");
1351 assert_eq!(session.calls(), vec!["test:/custom/bin/rcpd", "which:rcpd"]);
1352 }
1353
1354 #[tokio::test]
1355 async fn discover_rcpd_uses_cache_last() {
1356 let mut session = MockDiscoverySession::new()
1357 .with_home(Some("/home/rcp"))
1358 .with_which(None);
1359 session.set_test_response("/custom/bin/rcpd", false);
1360 let local_version = common::version::ProtocolVersion::current();
1361 let cache_path = format!("/home/rcp/.cache/rcp/bin/rcpd-{}", local_version.semantic);
1362 session.set_test_response(&cache_path, true);
1363 let path =
1364 discover_rcpd_path_internal(&session, None, Some(PathBuf::from("/custom/bin/rcp")))
1365 .await
1366 .expect("should fall back to cache");
1367 assert_eq!(path, cache_path);
1368 assert_eq!(
1369 session.calls(),
1370 vec![
1371 "test:/custom/bin/rcpd".to_string(),
1372 "which:rcpd".to_string(),
1373 "home".to_string(),
1374 format!("test:{cache_path}")
1375 ]
1376 );
1377 }
1378
1379 #[tokio::test]
1380 async fn discover_rcpd_reports_home_missing_in_error() {
1381 let mut session = MockDiscoverySession::new().with_which(None);
1382 session.set_test_response("/custom/bin/rcpd", false);
1383 let err =
1384 discover_rcpd_path_internal(&session, None, Some(PathBuf::from("/custom/bin/rcp")))
1385 .await
1386 .expect_err("should fail when nothing is found");
1387 let msg = err.to_string();
1388 assert!(
1389 msg.contains("Deployed cache: (skipped, HOME not set)"),
1390 "expected searched list to mention skipped cache, got: {msg}"
1391 );
1392 assert_eq!(
1393 session.calls(),
1394 vec!["test:/custom/bin/rcpd", "which:rcpd", "home"]
1395 );
1396 }
1397
1398 #[test]
1407 fn test_tokio_unstable_enabled() {
1408 #[cfg(not(tokio_unstable))]
1410 {
1411 panic!(
1412 "tokio_unstable cfg flag is not enabled! \
1413 This is required for console-subscriber support. \
1414 Check .cargo/config.toml"
1415 );
1416 }
1417
1418 #[cfg(tokio_unstable)]
1420 {
1421 let _join_set: tokio::task::JoinSet<()> = tokio::task::JoinSet::new();
1424 }
1425 }
1426
1427 fn iface(name: &str, addr: [u8; 4]) -> InterfaceIpv4 {
1428 InterfaceIpv4 {
1429 name: name.to_string(),
1430 addr: std::net::Ipv4Addr::new(addr[0], addr[1], addr[2], addr[3]),
1431 }
1432 }
1433
1434 #[test]
1435 fn choose_best_ipv4_prefers_physical_interfaces() {
1436 let interfaces = vec![
1437 iface("docker0", [172, 17, 0, 1]),
1438 iface("enp3s0", [192, 168, 1, 44]),
1439 iface("tailscale0", [100, 115, 92, 5]),
1440 ];
1441 assert_eq!(
1442 choose_best_ipv4(&interfaces),
1443 Some(std::net::Ipv4Addr::new(192, 168, 1, 44))
1444 );
1445 }
1446
1447 #[test]
1448 fn choose_best_ipv4_deprioritizes_link_local() {
1449 let interfaces = vec![
1450 iface("enp0s8", [169, 254, 10, 2]),
1451 iface("wlan0", [10, 0, 0, 23]),
1452 ];
1453 assert_eq!(
1454 choose_best_ipv4(&interfaces),
1455 Some(std::net::Ipv4Addr::new(10, 0, 0, 23))
1456 );
1457 }
1458
1459 #[test]
1460 fn choose_best_ipv4_falls_back_to_loopback() {
1461 let interfaces = vec![iface("lo", [127, 0, 0, 1]), iface("docker0", [0, 0, 0, 0])];
1462 assert_eq!(
1463 choose_best_ipv4(&interfaces),
1464 Some(std::net::Ipv4Addr::new(127, 0, 0, 1))
1465 );
1466 }
1467
1468 #[test]
1469 fn test_get_local_ip_with_explicit_ipv4() {
1470 let result = get_local_ip(Some("192.168.1.100"));
1472 assert!(result.is_ok(), "should accept valid IPv4 address");
1473 let ip = result.unwrap();
1474 assert_eq!(
1475 ip,
1476 std::net::IpAddr::V4(std::net::Ipv4Addr::new(192, 168, 1, 100))
1477 );
1478 }
1479
1480 #[test]
1481 fn test_get_local_ip_with_explicit_loopback() {
1482 let result = get_local_ip(Some("127.0.0.1"));
1484 assert!(result.is_ok(), "should accept loopback address");
1485 let ip = result.unwrap();
1486 assert_eq!(
1487 ip,
1488 std::net::IpAddr::V4(std::net::Ipv4Addr::new(127, 0, 0, 1))
1489 );
1490 }
1491
1492 #[test]
1493 fn test_get_local_ip_rejects_ipv6() {
1494 let result = get_local_ip(Some("::1"));
1496 assert!(result.is_err(), "should reject IPv6 address");
1497 let err = result.unwrap_err();
1498 let err_msg = format!("{err:#}");
1499 assert!(
1500 err_msg.contains("IPv6 address not supported"),
1501 "error should mention IPv6 not supported, got: {err_msg}"
1502 );
1503 assert!(
1504 err_msg.contains("0.0.0.0"),
1505 "error should mention IPv4-only binding, got: {err_msg}"
1506 );
1507 }
1508
1509 #[test]
1510 fn test_get_local_ip_rejects_ipv6_full() {
1511 let result = get_local_ip(Some("2001:db8::1"));
1513 assert!(result.is_err(), "should reject IPv6 address");
1514 let err = result.unwrap_err();
1515 let err_msg = format!("{err:#}");
1516 assert!(
1517 err_msg.contains("IPv6 address not supported"),
1518 "error should mention IPv6 not supported, got: {err_msg}"
1519 );
1520 }
1521
1522 #[test]
1523 fn test_get_local_ip_rejects_invalid_ip() {
1524 let result = get_local_ip(Some("not-an-ip"));
1526 assert!(result.is_err(), "should reject invalid IP format");
1527 let err = result.unwrap_err();
1528 let err_msg = format!("{err:#}");
1529 assert!(
1530 err_msg.contains("invalid IP address"),
1531 "error should mention invalid IP address, got: {err_msg}"
1532 );
1533 }
1534
1535 #[test]
1536 fn test_get_local_ip_rejects_invalid_ipv4() {
1537 let result = get_local_ip(Some("999.999.999.999"));
1539 assert!(result.is_err(), "should reject invalid IPv4 address");
1540 let err = result.unwrap_err();
1541 let err_msg = format!("{err:#}");
1542 assert!(
1543 err_msg.contains("invalid IP address"),
1544 "error should mention invalid IP address, got: {err_msg}"
1545 );
1546 }
1547}