Skip to main content

koi_common/
net.rs

1//! Networking utilities - Happy Eyeballs endpoint resolution.
2//!
3//! When a URL contains `localhost`, the OS may resolve it to both `[::1]` (IPv6)
4//! and `127.0.0.1` (IPv4). If the server only listens on one protocol, the client
5//! stalls for ~2 s on every request while the first SYN times out. This module
6//! races a TCP connect to both addresses in parallel and rewrites the URL to use
7//! whichever responds first - eliminating that per-request penalty.
8
9use std::net::{SocketAddr, TcpStream};
10use std::time::Duration;
11
12use tracing::{debug, trace};
13
14/// Timeout for the parallel TCP race.  Kept short - we're probing loopback.
15const RACE_TIMEOUT: Duration = Duration::from_millis(300);
16
17/// If `endpoint` contains `localhost`, race IPv4 vs IPv6 TCP connects and
18/// return a rewritten URL using the winning literal address.  If both fail (or
19/// the URL doesn't use `localhost`), the original endpoint is returned unchanged.
20///
21/// # Examples
22/// ```
23/// use koi_common::net::resolve_localhost;
24///
25/// // Non-localhost URLs pass through unchanged.
26/// assert_eq!(
27///     resolve_localhost("http://192.168.1.5:5641"),
28///     "http://192.168.1.5:5641"
29/// );
30/// ```
31pub fn resolve_localhost(endpoint: &str) -> String {
32    // Only race when the host is literally "localhost".
33    let lower = endpoint.to_ascii_lowercase();
34    if !lower.contains("://localhost:") && !lower.ends_with("://localhost") {
35        return endpoint.to_string();
36    }
37
38    // Extract port - default to 80 if not present.
39    let port = extract_port(endpoint).unwrap_or(80);
40
41    let v4_addr: SocketAddr = ([127, 0, 0, 1], port).into();
42    let v6_addr: SocketAddr = ([0, 0, 0, 0, 0, 0, 0, 1], port).into();
43
44    debug!(port, "racing IPv4 vs IPv6 on localhost:{port}");
45
46    // Spawn two threads; first successful connect wins.
47    let (tx, rx) = std::sync::mpsc::channel::<&str>();
48
49    let tx4 = tx.clone();
50    std::thread::spawn(move || {
51        trace!(%v4_addr, "probing IPv4");
52        if TcpStream::connect_timeout(&v4_addr, RACE_TIMEOUT).is_ok() {
53            let _ = tx4.send("127.0.0.1");
54        }
55    });
56
57    let tx6 = tx;
58    std::thread::spawn(move || {
59        trace!(%v6_addr, "probing IPv6");
60        if TcpStream::connect_timeout(&v6_addr, RACE_TIMEOUT).is_ok() {
61            let _ = tx6.send("[::1]");
62        }
63    });
64
65    // Wait for the first winner or timeout.
66    match rx.recv_timeout(RACE_TIMEOUT + Duration::from_millis(50)) {
67        Ok(winner) => {
68            let resolved = replace_localhost(endpoint, winner, port);
69            debug!(winner, %resolved, "localhost resolved via Happy Eyeballs");
70            resolved
71        }
72        Err(_) => {
73            debug!("neither IPv4 nor IPv6 responded - keeping original endpoint");
74            endpoint.to_string()
75        }
76    }
77}
78
79/// Extract the port number from an HTTP endpoint URL.
80fn extract_port(endpoint: &str) -> Option<u16> {
81    // Strip scheme
82    let after_scheme = endpoint
83        .find("://")
84        .map(|i| &endpoint[i + 3..])
85        .unwrap_or(endpoint);
86
87    // Strip path
88    let host_port = after_scheme.split('/').next().unwrap_or(after_scheme);
89
90    // Port is after the last colon (handles IPv6 bracket notation)
91    host_port.rsplit(':').next()?.parse().ok()
92}
93
94/// Replace "localhost" in the endpoint with the winning address literal.
95fn replace_localhost(endpoint: &str, winner: &str, port: u16) -> String {
96    // Build a case-insensitive replacement.  The URL could say "Localhost",
97    // "LOCALHOST", etc., so we locate it by position.
98    let lower = endpoint.to_ascii_lowercase();
99    if let Some(pos) = lower.find("localhost") {
100        let before = &endpoint[..pos];
101        let after_host = &endpoint[pos + "localhost".len()..];
102        // If winner is IPv6 literal and we're writing a URL, use bracket form.
103        format!("{before}{winner}{after_host}")
104    } else {
105        // Shouldn't happen given the guard in resolve_localhost, but be safe.
106        endpoint.replace("localhost", &format!("{winner}:{port}"))
107    }
108}
109
110#[cfg(test)]
111mod tests {
112    use super::*;
113
114    #[test]
115    fn non_localhost_passthrough() {
116        let ep = "http://192.168.1.5:5641";
117        assert_eq!(resolve_localhost(ep), ep);
118    }
119
120    #[test]
121    fn extract_port_simple() {
122        assert_eq!(extract_port("http://localhost:5641"), Some(5641));
123    }
124
125    #[test]
126    fn extract_port_with_path() {
127        assert_eq!(extract_port("http://localhost:8080/foo"), Some(8080));
128    }
129
130    #[test]
131    fn extract_port_none() {
132        assert_eq!(extract_port("http://localhost"), None);
133    }
134
135    #[test]
136    fn replace_localhost_ipv4() {
137        assert_eq!(
138            replace_localhost("http://localhost:5641", "127.0.0.1", 5641),
139            "http://127.0.0.1:5641"
140        );
141    }
142
143    #[test]
144    fn replace_localhost_ipv6() {
145        assert_eq!(
146            replace_localhost("http://localhost:5641", "[::1]", 5641),
147            "http://[::1]:5641"
148        );
149    }
150
151    #[test]
152    fn replace_localhost_with_path() {
153        assert_eq!(
154            replace_localhost("http://localhost:5641/v1/foo", "127.0.0.1", 5641),
155            "http://127.0.0.1:5641/v1/foo"
156        );
157    }
158
159    #[test]
160    fn replace_localhost_case_insensitive() {
161        assert_eq!(
162            replace_localhost("http://Localhost:5641", "127.0.0.1", 5641),
163            "http://127.0.0.1:5641"
164        );
165    }
166}