pingora_core/protocols/l4/
socket.rs1use crate::{Error, OrErr};
18use log::warn;
19#[cfg(unix)]
20use nix::sys::socket::{getpeername, getsockname, SockaddrStorage};
21use std::cmp::Ordering;
22use std::hash::{Hash, Hasher};
23use std::net::SocketAddr as StdSockAddr;
24#[cfg(unix)]
25use std::os::unix::net::SocketAddr as StdUnixSockAddr;
26#[cfg(unix)]
27use tokio::net::unix::SocketAddr as TokioUnixSockAddr;
28
29#[derive(Debug, Clone)]
32pub enum SocketAddr {
33 Inet(StdSockAddr),
34 #[cfg(unix)]
35 Unix(StdUnixSockAddr),
36}
37
38impl SocketAddr {
39 pub fn as_inet(&self) -> Option<&StdSockAddr> {
41 if let SocketAddr::Inet(addr) = self {
42 Some(addr)
43 } else {
44 None
45 }
46 }
47
48 #[cfg(unix)]
50 pub fn as_unix(&self) -> Option<&StdUnixSockAddr> {
51 if let SocketAddr::Unix(addr) = self {
52 Some(addr)
53 } else {
54 None
55 }
56 }
57
58 pub fn set_port(&mut self, port: u16) {
60 if let SocketAddr::Inet(addr) = self {
61 addr.set_port(port)
62 }
63 }
64
65 #[cfg(unix)]
66 fn from_sockaddr_storage(sock: &SockaddrStorage) -> Option<SocketAddr> {
67 if let Some(v4) = sock.as_sockaddr_in() {
68 return Some(SocketAddr::Inet(StdSockAddr::V4(
69 std::net::SocketAddrV4::new(v4.ip().into(), v4.port()),
70 )));
71 } else if let Some(v6) = sock.as_sockaddr_in6() {
72 return Some(SocketAddr::Inet(StdSockAddr::V6(
73 std::net::SocketAddrV6::new(v6.ip(), v6.port(), v6.flowinfo(), v6.scope_id()),
74 )));
75 }
76
77 Some(SocketAddr::Unix(
80 sock.as_unix_addr()
81 .map(|addr| addr.path().map(StdUnixSockAddr::from_pathname))??
82 .ok()?,
83 ))
84 }
85
86 #[cfg(unix)]
87 pub fn from_raw_fd(fd: std::os::unix::io::RawFd, peer_addr: bool) -> Option<SocketAddr> {
88 let sockaddr_storage = if peer_addr {
89 getpeername(fd)
90 } else {
91 getsockname(fd)
92 };
93 match sockaddr_storage {
94 Ok(sockaddr) => Self::from_sockaddr_storage(&sockaddr),
95 Err(_e) => None,
98 }
99 }
100
101 #[cfg(windows)]
102 pub fn from_raw_socket(
103 sock: std::os::windows::io::RawSocket,
104 is_peer_addr: bool,
105 ) -> Option<SocketAddr> {
106 use crate::protocols::windows::{local_addr, peer_addr};
107 if is_peer_addr {
108 peer_addr(sock)
109 } else {
110 local_addr(sock)
111 }
112 .map(|s| s.into())
113 .ok()
114 }
115}
116
117impl std::fmt::Display for SocketAddr {
118 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
119 match self {
120 SocketAddr::Inet(addr) => write!(f, "{addr}"),
121 #[cfg(unix)]
122 SocketAddr::Unix(addr) => {
123 if let Some(path) = addr.as_pathname() {
124 write!(f, "{}", path.display())
125 } else {
126 write!(f, "{addr:?}")
127 }
128 }
129 }
130 }
131}
132
133impl Hash for SocketAddr {
134 fn hash<H: Hasher>(&self, state: &mut H) {
135 match self {
136 Self::Inet(sockaddr) => sockaddr.hash(state),
137 #[cfg(unix)]
138 Self::Unix(sockaddr) => {
139 if let Some(path) = sockaddr.as_pathname() {
140 path.hash(state);
142 } else {
143 panic!("Unnamed and abstract UDS types not yet supported for hashing")
147 }
148 }
149 }
150 }
151}
152
153impl PartialEq for SocketAddr {
154 fn eq(&self, other: &Self) -> bool {
155 match self {
156 Self::Inet(addr) => Some(addr) == other.as_inet(),
157 #[cfg(unix)]
158 Self::Unix(addr) => {
159 let path = addr.as_pathname();
160 path.is_some() && path == other.as_unix().and_then(|addr| addr.as_pathname())
162 }
163 }
164 }
165}
166
167impl PartialOrd for SocketAddr {
168 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
169 Some(self.cmp(other))
170 }
171}
172
173impl Ord for SocketAddr {
174 fn cmp(&self, other: &Self) -> Ordering {
175 match self {
176 Self::Inet(addr) => {
177 if let Some(o) = other.as_inet() {
178 addr.cmp(o)
179 } else {
180 Ordering::Less
182 }
183 }
184 #[cfg(unix)]
185 Self::Unix(addr) => {
186 if let Some(o) = other.as_unix() {
187 addr.as_pathname().cmp(&o.as_pathname())
189 } else {
190 Ordering::Greater
192 }
193 }
194 }
195 }
196}
197
198impl Eq for SocketAddr {}
199
200impl std::str::FromStr for SocketAddr {
201 type Err = Box<Error>;
202
203 #[cfg(unix)]
205 fn from_str(s: &str) -> Result<Self, Self::Err> {
206 if s.starts_with("unix:") {
207 let path = s.trim_start_matches("unix:");
209 let uds_socket = StdUnixSockAddr::from_pathname(path)
210 .or_err(crate::BindError, "invalid UDS path")?;
211 Ok(SocketAddr::Unix(uds_socket))
212 } else {
213 match StdSockAddr::from_str(s) {
214 Ok(addr) => Ok(SocketAddr::Inet(addr)),
215 Err(_) => {
216 let uds_socket = StdUnixSockAddr::from_pathname(s)
218 .or_err(crate::BindError, "invalid UDS path")?;
219 warn!("Raw Unix domain socket path support will be deprecated, add 'unix:' prefix instead");
220 Ok(SocketAddr::Unix(uds_socket))
221 }
222 }
223 }
224 }
225
226 #[cfg(windows)]
227 fn from_str(s: &str) -> Result<Self, Self::Err> {
228 let addr = StdSockAddr::from_str(s).or_err(crate::BindError, "invalid socket addr")?;
229 Ok(SocketAddr::Inet(addr))
230 }
231}
232
233impl std::net::ToSocketAddrs for SocketAddr {
234 type Iter = std::iter::Once<StdSockAddr>;
235
236 fn to_socket_addrs(&self) -> std::io::Result<Self::Iter> {
238 if let Some(inet) = self.as_inet() {
239 Ok(std::iter::once(*inet))
240 } else {
241 Err(std::io::Error::new(
242 std::io::ErrorKind::Other,
243 "UDS socket cannot be used as inet socket",
244 ))
245 }
246 }
247}
248
249impl From<StdSockAddr> for SocketAddr {
250 fn from(sockaddr: StdSockAddr) -> Self {
251 SocketAddr::Inet(sockaddr)
252 }
253}
254
255#[cfg(unix)]
256impl From<StdUnixSockAddr> for SocketAddr {
257 fn from(sockaddr: StdUnixSockAddr) -> Self {
258 SocketAddr::Unix(sockaddr)
259 }
260}
261
262#[cfg(unix)]
266impl TryFrom<TokioUnixSockAddr> for SocketAddr {
267 type Error = String;
268
269 fn try_from(value: TokioUnixSockAddr) -> Result<Self, Self::Error> {
270 if let Some(Ok(addr)) = value.as_pathname().map(StdUnixSockAddr::from_pathname) {
271 Ok(addr.into())
272 } else {
273 Err(format!("could not convert {value:?} to SocketAddr"))
275 }
276 }
277}
278
279#[cfg(test)]
280mod test {
281 use super::*;
282
283 #[test]
284 fn parse_ip() {
285 let ip: SocketAddr = "127.0.0.1:80".parse().unwrap();
286 assert!(ip.as_inet().is_some());
287 }
288
289 #[cfg(unix)]
290 #[test]
291 fn parse_uds() {
292 let uds: SocketAddr = "/tmp/my.sock".parse().unwrap();
293 assert!(uds.as_unix().is_some());
294 }
295
296 #[cfg(unix)]
297 #[test]
298 fn parse_uds_with_prefix() {
299 let uds: SocketAddr = "unix:/tmp/my.sock".parse().unwrap();
300 assert!(uds.as_unix().is_some());
301 }
302}