Skip to main content

rrq_config/
tcp_socket.rs

1use std::net::{IpAddr, Ipv4Addr, SocketAddr, ToSocketAddrs};
2
3use anyhow::{Context, Result};
4
5/// Parsed and validated TCP socket specification for a runner.
6#[derive(Clone, Debug, PartialEq, Eq)]
7pub struct TcpSocketSpec {
8    pub host: IpAddr,
9    pub port: u16,
10}
11
12impl TcpSocketSpec {
13    /// Returns a socket address with the given port (used for pool port allocation).
14    #[must_use]
15    pub fn addr(&self, port: u16) -> SocketAddr {
16        SocketAddr::new(self.host, port)
17    }
18}
19
20/// Parses and validates a tcp_socket string (e.g., "127.0.0.1:9000").
21///
22/// Validates:
23/// - Format is `host:port` or `[ipv6]:port`
24/// - Host is localhost/loopback only (127.0.0.1, ::1, localhost)
25/// - Port is > 0 and <= 65535
26pub fn parse_tcp_socket(raw: &str) -> Result<TcpSocketSpec> {
27    parse_tcp_socket_with_allowed_hosts(raw, &[])
28}
29
30/// Parses and validates a tcp_socket string with an explicit non-loopback allowlist.
31///
32/// By default RRQ requires loopback hosts only. When `allowed_hosts` contains the
33/// tcp host exactly (case-insensitive), non-loopback hosts are allowed.
34pub fn parse_tcp_socket_with_allowed_hosts(
35    raw: &str,
36    allowed_hosts: &[String],
37) -> Result<TcpSocketSpec> {
38    let raw = raw.trim();
39    if raw.is_empty() {
40        return Err(anyhow::anyhow!("tcp_socket cannot be empty"));
41    }
42
43    let (host, port_str) = if let Some(rest) = raw.strip_prefix('[') {
44        // IPv6 format: [::1]:port
45        let (host, port_str) = rest
46            .split_once("]:")
47            .ok_or_else(|| anyhow::anyhow!("tcp_socket must be in [host]:port format"))?;
48        (host, port_str)
49    } else {
50        // IPv4 or hostname format: host:port
51        let (host, port_str) = raw
52            .rsplit_once(':')
53            .ok_or_else(|| anyhow::anyhow!("tcp_socket must be in host:port format"))?;
54        if host.is_empty() {
55            return Err(anyhow::anyhow!("tcp_socket host cannot be empty"));
56        }
57        (host, port_str)
58    };
59
60    let port: u16 = port_str
61        .parse()
62        .with_context(|| format!("invalid tcp_socket port '{port_str}' - must be 1-65535"))?;
63    if port == 0 {
64        return Err(anyhow::anyhow!("tcp_socket port must be > 0"));
65    }
66
67    let is_allowed_host = allowed_hosts
68        .iter()
69        .any(|allowed| allowed.eq_ignore_ascii_case(host));
70
71    let ip = if host == "localhost" {
72        IpAddr::V4(Ipv4Addr::LOCALHOST)
73    } else {
74        match host.parse::<IpAddr>() {
75            Ok(parsed) => {
76                if !parsed.is_loopback() && !is_allowed_host {
77                    return Err(anyhow::anyhow!(
78                        "tcp_socket host must be loopback (127.0.0.1, ::1, or localhost) \
79for security, or explicitly listed in allowed_hosts - got '{host}'"
80                    ));
81                }
82                parsed
83            }
84            Err(_) => {
85                if !is_allowed_host {
86                    return Err(anyhow::anyhow!(
87                        "tcp_socket hostname '{host}' is not allowed; add it to allowed_hosts"
88                    ));
89                }
90                let resolved = (host, port)
91                    .to_socket_addrs()
92                    .with_context(|| format!("invalid tcp_socket host '{host}'"))?
93                    .next()
94                    .ok_or_else(|| {
95                        anyhow::anyhow!("invalid tcp_socket host '{host}' - no addresses resolved")
96                    })?;
97                resolved.ip()
98            }
99        }
100    };
101
102    Ok(TcpSocketSpec { host: ip, port })
103}
104
105#[cfg(test)]
106mod tests {
107    use super::*;
108
109    #[test]
110    fn parse_tcp_socket_ipv4_localhost() {
111        let spec = parse_tcp_socket("127.0.0.1:9000").unwrap();
112        assert_eq!(spec.host, IpAddr::V4(Ipv4Addr::LOCALHOST));
113        assert_eq!(spec.port, 9000);
114    }
115
116    #[test]
117    fn parse_tcp_socket_localhost_hostname() {
118        let spec = parse_tcp_socket("localhost:1234").unwrap();
119        assert_eq!(spec.host, IpAddr::V4(Ipv4Addr::LOCALHOST));
120        assert_eq!(spec.port, 1234);
121    }
122
123    #[test]
124    fn parse_tcp_socket_ipv6_loopback() {
125        let spec = parse_tcp_socket("[::1]:8080").unwrap();
126        assert!(spec.host.is_loopback());
127        assert_eq!(spec.port, 8080);
128    }
129
130    #[test]
131    fn parse_tcp_socket_rejects_non_loopback() {
132        let err = parse_tcp_socket("10.0.0.1:1234").unwrap_err();
133        assert!(err.to_string().contains("loopback"));
134    }
135
136    #[test]
137    fn parse_tcp_socket_allows_non_loopback_ip_when_allowlisted() {
138        let spec = parse_tcp_socket_with_allowed_hosts("10.0.0.1:1234", &["10.0.0.1".to_string()])
139            .unwrap();
140        assert_eq!(spec.host, IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)));
141        assert_eq!(spec.port, 1234);
142    }
143
144    #[test]
145    fn parse_tcp_socket_rejects_hostname_without_allowlist() {
146        let err = parse_tcp_socket("docker-runner:9000").unwrap_err();
147        assert!(err.to_string().contains("allowed_hosts"));
148    }
149
150    #[test]
151    fn parse_tcp_socket_rejects_zero_port() {
152        let err = parse_tcp_socket("127.0.0.1:0").unwrap_err();
153        assert!(err.to_string().contains("port must be > 0"));
154    }
155
156    #[test]
157    fn parse_tcp_socket_rejects_empty() {
158        let err = parse_tcp_socket("").unwrap_err();
159        assert!(err.to_string().contains("cannot be empty"));
160    }
161
162    #[test]
163    fn parse_tcp_socket_rejects_missing_port() {
164        let err = parse_tcp_socket("127.0.0.1").unwrap_err();
165        assert!(err.to_string().contains("host:port"));
166    }
167
168    #[test]
169    fn parse_tcp_socket_rejects_invalid_port() {
170        let err = parse_tcp_socket("127.0.0.1:abc").unwrap_err();
171        assert!(err.to_string().contains("port"));
172    }
173}