pingora_core/protocols/l4/
socket.rs

1// Copyright 2025 Cloudflare, Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Generic socket type
16
17use crate::{Error, OrErr};
18use log::warn;
19#[cfg(unix)]
20use nix::sys::socket::{getpeername, getsockname, SockaddrStorage};
21use std::cmp::Ordering;
22use std::hash::{Hash, Hasher};
23use std::net::SocketAddr as StdSockAddr;
24#[cfg(unix)]
25use std::os::unix::net::SocketAddr as StdUnixSockAddr;
26#[cfg(unix)]
27use tokio::net::unix::SocketAddr as TokioUnixSockAddr;
28
29/// [`SocketAddr`] is a storage type that contains either a Internet (IP address)
30/// socket address or a Unix domain socket address.
31#[derive(Debug, Clone)]
32pub enum SocketAddr {
33    Inet(StdSockAddr),
34    #[cfg(unix)]
35    Unix(StdUnixSockAddr),
36}
37
38impl SocketAddr {
39    /// Get a reference to the IP socket if it is one
40    pub fn as_inet(&self) -> Option<&StdSockAddr> {
41        if let SocketAddr::Inet(addr) = self {
42            Some(addr)
43        } else {
44            None
45        }
46    }
47
48    /// Get a reference to the Unix domain socket if it is one
49    #[cfg(unix)]
50    pub fn as_unix(&self) -> Option<&StdUnixSockAddr> {
51        if let SocketAddr::Unix(addr) = self {
52            Some(addr)
53        } else {
54            None
55        }
56    }
57
58    /// Set the port if the address is an IP socket.
59    pub fn set_port(&mut self, port: u16) {
60        if let SocketAddr::Inet(addr) = self {
61            addr.set_port(port)
62        }
63    }
64
65    #[cfg(unix)]
66    fn from_sockaddr_storage(sock: &SockaddrStorage) -> Option<SocketAddr> {
67        if let Some(v4) = sock.as_sockaddr_in() {
68            return Some(SocketAddr::Inet(StdSockAddr::V4(
69                std::net::SocketAddrV4::new(v4.ip().into(), v4.port()),
70            )));
71        } else if let Some(v6) = sock.as_sockaddr_in6() {
72            return Some(SocketAddr::Inet(StdSockAddr::V6(
73                std::net::SocketAddrV6::new(v6.ip(), v6.port(), v6.flowinfo(), v6.scope_id()),
74            )));
75        }
76
77        // TODO: don't set abstract / unnamed for now,
78        // for parity with how we treat these types in TryFrom<TokioUnixSockAddr>
79        Some(SocketAddr::Unix(
80            sock.as_unix_addr()
81                .map(|addr| addr.path().map(StdUnixSockAddr::from_pathname))??
82                .ok()?,
83        ))
84    }
85
86    #[cfg(unix)]
87    pub fn from_raw_fd(fd: std::os::unix::io::RawFd, peer_addr: bool) -> Option<SocketAddr> {
88        let sockaddr_storage = if peer_addr {
89            getpeername(fd)
90        } else {
91            getsockname(fd)
92        };
93        match sockaddr_storage {
94            Ok(sockaddr) => Self::from_sockaddr_storage(&sockaddr),
95            // could be errors such as EBADF, i.e. fd is no longer a valid socket
96            // fail open in this case
97            Err(_e) => None,
98        }
99    }
100
101    #[cfg(windows)]
102    pub fn from_raw_socket(
103        sock: std::os::windows::io::RawSocket,
104        is_peer_addr: bool,
105    ) -> Option<SocketAddr> {
106        use crate::protocols::windows::{local_addr, peer_addr};
107        if is_peer_addr {
108            peer_addr(sock)
109        } else {
110            local_addr(sock)
111        }
112        .map(|s| s.into())
113        .ok()
114    }
115}
116
117impl std::fmt::Display for SocketAddr {
118    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
119        match self {
120            SocketAddr::Inet(addr) => write!(f, "{addr}"),
121            #[cfg(unix)]
122            SocketAddr::Unix(addr) => {
123                if let Some(path) = addr.as_pathname() {
124                    write!(f, "{}", path.display())
125                } else {
126                    write!(f, "{addr:?}")
127                }
128            }
129        }
130    }
131}
132
133impl Hash for SocketAddr {
134    fn hash<H: Hasher>(&self, state: &mut H) {
135        match self {
136            Self::Inet(sockaddr) => sockaddr.hash(state),
137            #[cfg(unix)]
138            Self::Unix(sockaddr) => {
139                if let Some(path) = sockaddr.as_pathname() {
140                    // use the underlying path as the hash
141                    path.hash(state);
142                } else {
143                    // unnamed or abstract UDS
144                    // abstract UDS name not yet exposed by std API
145                    // panic for now, we can decide on the right way to hash them later
146                    panic!("Unnamed and abstract UDS types not yet supported for hashing")
147                }
148            }
149        }
150    }
151}
152
153impl PartialEq for SocketAddr {
154    fn eq(&self, other: &Self) -> bool {
155        match self {
156            Self::Inet(addr) => Some(addr) == other.as_inet(),
157            #[cfg(unix)]
158            Self::Unix(addr) => {
159                let path = addr.as_pathname();
160                // can only compare UDS with path, assume false on all unnamed UDS
161                path.is_some() && path == other.as_unix().and_then(|addr| addr.as_pathname())
162            }
163        }
164    }
165}
166
167impl PartialOrd for SocketAddr {
168    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
169        Some(self.cmp(other))
170    }
171}
172
173impl Ord for SocketAddr {
174    fn cmp(&self, other: &Self) -> Ordering {
175        match self {
176            Self::Inet(addr) => {
177                if let Some(o) = other.as_inet() {
178                    addr.cmp(o)
179                } else {
180                    // always make Inet < Unix "smallest for variants at the top"
181                    Ordering::Less
182                }
183            }
184            #[cfg(unix)]
185            Self::Unix(addr) => {
186                if let Some(o) = other.as_unix() {
187                    // NOTE: unnamed UDS are consider the same
188                    addr.as_pathname().cmp(&o.as_pathname())
189                } else {
190                    // always make Inet < Unix "smallest for variants at the top"
191                    Ordering::Greater
192                }
193            }
194        }
195    }
196}
197
198impl Eq for SocketAddr {}
199
200impl std::str::FromStr for SocketAddr {
201    type Err = Box<Error>;
202
203    // This is very basic parsing logic, it might treat invalid IP:PORT str as UDS path
204    #[cfg(unix)]
205    fn from_str(s: &str) -> Result<Self, Self::Err> {
206        if s.starts_with("unix:") {
207            // format unix:/tmp/server.socket
208            let path = s.trim_start_matches("unix:");
209            let uds_socket = StdUnixSockAddr::from_pathname(path)
210                .or_err(crate::BindError, "invalid UDS path")?;
211            Ok(SocketAddr::Unix(uds_socket))
212        } else {
213            match StdSockAddr::from_str(s) {
214                Ok(addr) => Ok(SocketAddr::Inet(addr)),
215                Err(_) => {
216                    // Try to parse as UDS for backward compatibility
217                    let uds_socket = StdUnixSockAddr::from_pathname(s)
218                        .or_err(crate::BindError, "invalid UDS path")?;
219                    warn!("Raw Unix domain socket path support will be deprecated, add 'unix:' prefix instead");
220                    Ok(SocketAddr::Unix(uds_socket))
221                }
222            }
223        }
224    }
225
226    #[cfg(windows)]
227    fn from_str(s: &str) -> Result<Self, Self::Err> {
228        let addr = StdSockAddr::from_str(s).or_err(crate::BindError, "invalid socket addr")?;
229        Ok(SocketAddr::Inet(addr))
230    }
231}
232
233impl std::net::ToSocketAddrs for SocketAddr {
234    type Iter = std::iter::Once<StdSockAddr>;
235
236    // Error if UDS addr
237    fn to_socket_addrs(&self) -> std::io::Result<Self::Iter> {
238        if let Some(inet) = self.as_inet() {
239            Ok(std::iter::once(*inet))
240        } else {
241            Err(std::io::Error::new(
242                std::io::ErrorKind::Other,
243                "UDS socket cannot be used as inet socket",
244            ))
245        }
246    }
247}
248
249impl From<StdSockAddr> for SocketAddr {
250    fn from(sockaddr: StdSockAddr) -> Self {
251        SocketAddr::Inet(sockaddr)
252    }
253}
254
255#[cfg(unix)]
256impl From<StdUnixSockAddr> for SocketAddr {
257    fn from(sockaddr: StdUnixSockAddr) -> Self {
258        SocketAddr::Unix(sockaddr)
259    }
260}
261
262// TODO: ideally mio/tokio will start using the std version of the unix `SocketAddr`
263// so we can avoid a fallible conversion
264// https://github.com/tokio-rs/mio/issues/1527
265#[cfg(unix)]
266impl TryFrom<TokioUnixSockAddr> for SocketAddr {
267    type Error = String;
268
269    fn try_from(value: TokioUnixSockAddr) -> Result<Self, Self::Error> {
270        if let Some(Ok(addr)) = value.as_pathname().map(StdUnixSockAddr::from_pathname) {
271            Ok(addr.into())
272        } else {
273            // may be unnamed/abstract UDS
274            Err(format!("could not convert {value:?} to SocketAddr"))
275        }
276    }
277}
278
279#[cfg(test)]
280mod test {
281    use super::*;
282
283    #[test]
284    fn parse_ip() {
285        let ip: SocketAddr = "127.0.0.1:80".parse().unwrap();
286        assert!(ip.as_inet().is_some());
287    }
288
289    #[cfg(unix)]
290    #[test]
291    fn parse_uds() {
292        let uds: SocketAddr = "/tmp/my.sock".parse().unwrap();
293        assert!(uds.as_unix().is_some());
294    }
295
296    #[cfg(unix)]
297    #[test]
298    fn parse_uds_with_prefix() {
299        let uds: SocketAddr = "unix:/tmp/my.sock".parse().unwrap();
300        assert!(uds.as_unix().is_some());
301    }
302}