1use std::net::{SocketAddr, TcpStream};
10use std::time::Duration;
11
12use tracing::{debug, trace};
13
14const RACE_TIMEOUT: Duration = Duration::from_millis(300);
16
17pub fn resolve_localhost(endpoint: &str) -> String {
32 let lower = endpoint.to_ascii_lowercase();
34 if !lower.contains("://localhost:") && !lower.ends_with("://localhost") {
35 return endpoint.to_string();
36 }
37
38 let port = extract_port(endpoint).unwrap_or(80);
40
41 let v4_addr: SocketAddr = ([127, 0, 0, 1], port).into();
42 let v6_addr: SocketAddr = ([0, 0, 0, 0, 0, 0, 0, 1], port).into();
43
44 debug!(port, "racing IPv4 vs IPv6 on localhost:{port}");
45
46 let (tx, rx) = std::sync::mpsc::channel::<&str>();
48
49 let tx4 = tx.clone();
50 std::thread::spawn(move || {
51 trace!(%v4_addr, "probing IPv4");
52 if TcpStream::connect_timeout(&v4_addr, RACE_TIMEOUT).is_ok() {
53 let _ = tx4.send("127.0.0.1");
54 }
55 });
56
57 let tx6 = tx;
58 std::thread::spawn(move || {
59 trace!(%v6_addr, "probing IPv6");
60 if TcpStream::connect_timeout(&v6_addr, RACE_TIMEOUT).is_ok() {
61 let _ = tx6.send("[::1]");
62 }
63 });
64
65 match rx.recv_timeout(RACE_TIMEOUT + Duration::from_millis(50)) {
67 Ok(winner) => {
68 let resolved = replace_localhost(endpoint, winner, port);
69 debug!(winner, %resolved, "localhost resolved via Happy Eyeballs");
70 resolved
71 }
72 Err(_) => {
73 debug!("neither IPv4 nor IPv6 responded - keeping original endpoint");
74 endpoint.to_string()
75 }
76 }
77}
78
79fn extract_port(endpoint: &str) -> Option<u16> {
81 let after_scheme = endpoint
83 .find("://")
84 .map(|i| &endpoint[i + 3..])
85 .unwrap_or(endpoint);
86
87 let host_port = after_scheme.split('/').next().unwrap_or(after_scheme);
89
90 host_port.rsplit(':').next()?.parse().ok()
92}
93
94fn replace_localhost(endpoint: &str, winner: &str, port: u16) -> String {
96 let lower = endpoint.to_ascii_lowercase();
99 if let Some(pos) = lower.find("localhost") {
100 let before = &endpoint[..pos];
101 let after_host = &endpoint[pos + "localhost".len()..];
102 format!("{before}{winner}{after_host}")
104 } else {
105 endpoint.replace("localhost", &format!("{winner}:{port}"))
107 }
108}
109
110#[cfg(test)]
111mod tests {
112 use super::*;
113
114 #[test]
115 fn non_localhost_passthrough() {
116 let ep = "http://192.168.1.5:5641";
117 assert_eq!(resolve_localhost(ep), ep);
118 }
119
120 #[test]
121 fn extract_port_simple() {
122 assert_eq!(extract_port("http://localhost:5641"), Some(5641));
123 }
124
125 #[test]
126 fn extract_port_with_path() {
127 assert_eq!(extract_port("http://localhost:8080/foo"), Some(8080));
128 }
129
130 #[test]
131 fn extract_port_none() {
132 assert_eq!(extract_port("http://localhost"), None);
133 }
134
135 #[test]
136 fn replace_localhost_ipv4() {
137 assert_eq!(
138 replace_localhost("http://localhost:5641", "127.0.0.1", 5641),
139 "http://127.0.0.1:5641"
140 );
141 }
142
143 #[test]
144 fn replace_localhost_ipv6() {
145 assert_eq!(
146 replace_localhost("http://localhost:5641", "[::1]", 5641),
147 "http://[::1]:5641"
148 );
149 }
150
151 #[test]
152 fn replace_localhost_with_path() {
153 assert_eq!(
154 replace_localhost("http://localhost:5641/v1/foo", "127.0.0.1", 5641),
155 "http://127.0.0.1:5641/v1/foo"
156 );
157 }
158
159 #[test]
160 fn replace_localhost_case_insensitive() {
161 assert_eq!(
162 replace_localhost("http://Localhost:5641", "127.0.0.1", 5641),
163 "http://127.0.0.1:5641"
164 );
165 }
166}