1use std::net::{IpAddr, Ipv4Addr, SocketAddr, ToSocketAddrs};
2
3use anyhow::{Context, Result};
4
5#[derive(Clone, Debug, PartialEq, Eq)]
7pub struct TcpSocketSpec {
8 pub host: IpAddr,
9 pub port: u16,
10}
11
12impl TcpSocketSpec {
13 #[must_use]
15 pub fn addr(&self, port: u16) -> SocketAddr {
16 SocketAddr::new(self.host, port)
17 }
18}
19
20pub fn parse_tcp_socket(raw: &str) -> Result<TcpSocketSpec> {
27 parse_tcp_socket_with_allowed_hosts(raw, &[])
28}
29
30pub fn parse_tcp_socket_with_allowed_hosts(
35 raw: &str,
36 allowed_hosts: &[String],
37) -> Result<TcpSocketSpec> {
38 let raw = raw.trim();
39 if raw.is_empty() {
40 return Err(anyhow::anyhow!("tcp_socket cannot be empty"));
41 }
42
43 let (host, port_str) = if let Some(rest) = raw.strip_prefix('[') {
44 let (host, port_str) = rest
46 .split_once("]:")
47 .ok_or_else(|| anyhow::anyhow!("tcp_socket must be in [host]:port format"))?;
48 (host, port_str)
49 } else {
50 let (host, port_str) = raw
52 .rsplit_once(':')
53 .ok_or_else(|| anyhow::anyhow!("tcp_socket must be in host:port format"))?;
54 if host.is_empty() {
55 return Err(anyhow::anyhow!("tcp_socket host cannot be empty"));
56 }
57 (host, port_str)
58 };
59
60 let port: u16 = port_str
61 .parse()
62 .with_context(|| format!("invalid tcp_socket port '{port_str}' - must be 1-65535"))?;
63 if port == 0 {
64 return Err(anyhow::anyhow!("tcp_socket port must be > 0"));
65 }
66
67 let is_allowed_host = allowed_hosts
68 .iter()
69 .any(|allowed| allowed.eq_ignore_ascii_case(host));
70
71 let ip = if host == "localhost" {
72 IpAddr::V4(Ipv4Addr::LOCALHOST)
73 } else {
74 match host.parse::<IpAddr>() {
75 Ok(parsed) => {
76 if !parsed.is_loopback() && !is_allowed_host {
77 return Err(anyhow::anyhow!(
78 "tcp_socket host must be loopback (127.0.0.1, ::1, or localhost) \
79for security, or explicitly listed in allowed_hosts - got '{host}'"
80 ));
81 }
82 parsed
83 }
84 Err(_) => {
85 if !is_allowed_host {
86 return Err(anyhow::anyhow!(
87 "tcp_socket hostname '{host}' is not allowed; add it to allowed_hosts"
88 ));
89 }
90 let resolved = (host, port)
91 .to_socket_addrs()
92 .with_context(|| format!("invalid tcp_socket host '{host}'"))?
93 .next()
94 .ok_or_else(|| {
95 anyhow::anyhow!("invalid tcp_socket host '{host}' - no addresses resolved")
96 })?;
97 resolved.ip()
98 }
99 }
100 };
101
102 Ok(TcpSocketSpec { host: ip, port })
103}
104
105#[cfg(test)]
106mod tests {
107 use super::*;
108
109 #[test]
110 fn parse_tcp_socket_ipv4_localhost() {
111 let spec = parse_tcp_socket("127.0.0.1:9000").unwrap();
112 assert_eq!(spec.host, IpAddr::V4(Ipv4Addr::LOCALHOST));
113 assert_eq!(spec.port, 9000);
114 }
115
116 #[test]
117 fn parse_tcp_socket_localhost_hostname() {
118 let spec = parse_tcp_socket("localhost:1234").unwrap();
119 assert_eq!(spec.host, IpAddr::V4(Ipv4Addr::LOCALHOST));
120 assert_eq!(spec.port, 1234);
121 }
122
123 #[test]
124 fn parse_tcp_socket_ipv6_loopback() {
125 let spec = parse_tcp_socket("[::1]:8080").unwrap();
126 assert!(spec.host.is_loopback());
127 assert_eq!(spec.port, 8080);
128 }
129
130 #[test]
131 fn parse_tcp_socket_rejects_non_loopback() {
132 let err = parse_tcp_socket("10.0.0.1:1234").unwrap_err();
133 assert!(err.to_string().contains("loopback"));
134 }
135
136 #[test]
137 fn parse_tcp_socket_allows_non_loopback_ip_when_allowlisted() {
138 let spec = parse_tcp_socket_with_allowed_hosts("10.0.0.1:1234", &["10.0.0.1".to_string()])
139 .unwrap();
140 assert_eq!(spec.host, IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)));
141 assert_eq!(spec.port, 1234);
142 }
143
144 #[test]
145 fn parse_tcp_socket_rejects_hostname_without_allowlist() {
146 let err = parse_tcp_socket("docker-runner:9000").unwrap_err();
147 assert!(err.to_string().contains("allowed_hosts"));
148 }
149
150 #[test]
151 fn parse_tcp_socket_rejects_zero_port() {
152 let err = parse_tcp_socket("127.0.0.1:0").unwrap_err();
153 assert!(err.to_string().contains("port must be > 0"));
154 }
155
156 #[test]
157 fn parse_tcp_socket_rejects_empty() {
158 let err = parse_tcp_socket("").unwrap_err();
159 assert!(err.to_string().contains("cannot be empty"));
160 }
161
162 #[test]
163 fn parse_tcp_socket_rejects_missing_port() {
164 let err = parse_tcp_socket("127.0.0.1").unwrap_err();
165 assert!(err.to_string().contains("host:port"));
166 }
167
168 #[test]
169 fn parse_tcp_socket_rejects_invalid_port() {
170 let err = parse_tcp_socket("127.0.0.1:abc").unwrap_err();
171 assert!(err.to_string().contains("port"));
172 }
173}