1use crate::address::parse_utils::try_to_parse_str_to_ip;
2use rama_core::error::{ErrorContext, OpaqueError};
3#[cfg(feature = "http")]
4use rama_http_types::HeaderValue;
5use std::fmt;
6use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
7use std::str::FromStr;
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
11pub struct SocketAddress {
12 ip_addr: IpAddr,
13 port: u16,
14}
15
16impl SocketAddress {
17 pub const fn new(ip_addr: IpAddr, port: u16) -> Self {
19 SocketAddress { ip_addr, port }
20 }
21
22 pub const fn local_ipv4(port: u16) -> Self {
33 SocketAddress {
34 ip_addr: IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
35 port,
36 }
37 }
38
39 pub const fn local_ipv6(port: u16) -> Self {
50 SocketAddress {
51 ip_addr: IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)),
52 port,
53 }
54 }
55
56 pub const fn default_ipv4(port: u16) -> Self {
67 SocketAddress {
68 ip_addr: IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)),
69 port,
70 }
71 }
72
73 pub const fn default_ipv6(port: u16) -> Self {
84 SocketAddress {
85 ip_addr: IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0)),
86 port,
87 }
88 }
89
90 pub const fn broadcast_ipv4(port: u16) -> Self {
101 SocketAddress {
102 ip_addr: IpAddr::V4(Ipv4Addr::new(255, 255, 255, 255)),
103 port,
104 }
105 }
106
107 pub fn ip_addr(&self) -> &IpAddr {
109 &self.ip_addr
110 }
111
112 pub fn into_ip_addr(self) -> IpAddr {
114 self.ip_addr
115 }
116
117 pub fn port(&self) -> u16 {
119 self.port
120 }
121
122 pub fn into_parts(self) -> (IpAddr, u16) {
124 (self.ip_addr, self.port)
125 }
126}
127
128impl From<SocketAddr> for SocketAddress {
129 fn from(addr: SocketAddr) -> Self {
130 SocketAddress {
131 ip_addr: addr.ip(),
132 port: addr.port(),
133 }
134 }
135}
136
137impl From<&SocketAddr> for SocketAddress {
138 fn from(addr: &SocketAddr) -> Self {
139 SocketAddress {
140 ip_addr: addr.ip(),
141 port: addr.port(),
142 }
143 }
144}
145
146impl From<SocketAddrV4> for SocketAddress {
147 fn from(value: SocketAddrV4) -> Self {
148 SocketAddress {
149 ip_addr: (*value.ip()).into(),
150 port: value.port(),
151 }
152 }
153}
154
155impl From<SocketAddrV6> for SocketAddress {
156 fn from(value: SocketAddrV6) -> Self {
157 SocketAddress {
158 ip_addr: (*value.ip()).into(),
159 port: value.port(),
160 }
161 }
162}
163
164impl From<SocketAddress> for SocketAddr {
165 fn from(addr: SocketAddress) -> Self {
166 SocketAddr::new(addr.ip_addr, addr.port)
167 }
168}
169
170impl From<(IpAddr, u16)> for SocketAddress {
171 #[inline]
172 fn from((ip_addr, port): (IpAddr, u16)) -> Self {
173 Self { ip_addr, port }
174 }
175}
176
177impl From<(Ipv4Addr, u16)> for SocketAddress {
178 #[inline]
179 fn from((ip, port): (Ipv4Addr, u16)) -> Self {
180 Self {
181 ip_addr: ip.into(),
182 port,
183 }
184 }
185}
186
187impl From<([u8; 4], u16)> for SocketAddress {
188 #[inline]
189 fn from((ip, port): ([u8; 4], u16)) -> Self {
190 let ip: IpAddr = ip.into();
191 (ip, port).into()
192 }
193}
194
195impl From<(Ipv6Addr, u16)> for SocketAddress {
196 #[inline]
197 fn from((ip, port): (Ipv6Addr, u16)) -> Self {
198 Self {
199 ip_addr: ip.into(),
200 port,
201 }
202 }
203}
204
205impl From<([u8; 16], u16)> for SocketAddress {
206 #[inline]
207 fn from((ip, port): ([u8; 16], u16)) -> Self {
208 let ip: IpAddr = ip.into();
209 (ip, port).into()
210 }
211}
212
213impl fmt::Display for SocketAddress {
214 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
215 match &self.ip_addr {
216 IpAddr::V4(ip) => write!(f, "{}:{}", ip, self.port),
217 IpAddr::V6(ip) => write!(f, "[{}]:{}", ip, self.port),
218 }
219 }
220}
221
222impl FromStr for SocketAddress {
223 type Err = OpaqueError;
224
225 fn from_str(s: &str) -> Result<Self, Self::Err> {
226 SocketAddress::try_from(s)
227 }
228}
229
230impl TryFrom<String> for SocketAddress {
231 type Error = OpaqueError;
232
233 fn try_from(s: String) -> Result<Self, Self::Error> {
234 s.as_str().try_into()
235 }
236}
237
238impl TryFrom<&String> for SocketAddress {
239 type Error = OpaqueError;
240
241 fn try_from(value: &String) -> Result<Self, Self::Error> {
242 value.as_str().try_into()
243 }
244}
245
246impl TryFrom<&str> for SocketAddress {
247 type Error = OpaqueError;
248
249 fn try_from(s: &str) -> Result<Self, Self::Error> {
250 let (ip_addr, port) = crate::address::parse_utils::split_port_from_str(s)?;
251 let ip_addr =
252 try_to_parse_str_to_ip(ip_addr).context("parse ip address from socket address")?;
253 match ip_addr {
254 IpAddr::V6(_) if !s.starts_with('[') => Err(OpaqueError::from_display(
255 "missing brackets for IPv6 address with port",
256 )),
257 _ => Ok(SocketAddress { ip_addr, port }),
258 }
259 }
260}
261
262#[cfg(feature = "http")]
263impl TryFrom<HeaderValue> for SocketAddress {
264 type Error = OpaqueError;
265
266 fn try_from(header: HeaderValue) -> Result<Self, Self::Error> {
267 Self::try_from(&header)
268 }
269}
270
271#[cfg(feature = "http")]
272impl TryFrom<&HeaderValue> for SocketAddress {
273 type Error = OpaqueError;
274
275 fn try_from(header: &HeaderValue) -> Result<Self, Self::Error> {
276 header.to_str().context("convert header to str")?.try_into()
277 }
278}
279
280impl TryFrom<Vec<u8>> for SocketAddress {
281 type Error = OpaqueError;
282
283 fn try_from(bytes: Vec<u8>) -> Result<Self, Self::Error> {
284 Self::try_from(bytes.as_slice())
285 }
286}
287
288impl TryFrom<&[u8]> for SocketAddress {
289 type Error = OpaqueError;
290
291 fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
292 let s = std::str::from_utf8(bytes).context("parse sock address from bytes")?;
293 s.try_into()
294 }
295}
296
297impl serde::Serialize for SocketAddress {
298 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
299 where
300 S: serde::Serializer,
301 {
302 let address = self.to_string();
303 address.serialize(serializer)
304 }
305}
306
307impl<'de> serde::Deserialize<'de> for SocketAddress {
308 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
309 where
310 D: serde::Deserializer<'de>,
311 {
312 let s = <std::borrow::Cow<'de, str>>::deserialize(deserializer)?;
313 s.parse().map_err(serde::de::Error::custom)
314 }
315}
316
317#[cfg(test)]
318mod tests {
319 use super::*;
320
321 fn assert_eq(s: &str, sock_address: SocketAddress, ip_addr: &str, port: u16) {
322 assert_eq!(
323 sock_address.ip_addr().to_string(),
324 ip_addr,
325 "parsing: {}",
326 s
327 );
328 assert_eq!(sock_address.port(), port, "parsing: {}", s);
329 }
330
331 #[test]
332 fn test_parse_valid() {
333 for (s, (expected_ip_addr, expected_port)) in [
334 ("[::1]:80", ("::1", 80)),
335 ("127.0.0.1:80", ("127.0.0.1", 80)),
336 (
337 "[2001:db8:3333:4444:5555:6666:7777:8888]:80",
338 ("2001:db8:3333:4444:5555:6666:7777:8888", 80),
339 ),
340 ] {
341 let msg = format!("parsing '{}'", s);
342
343 assert_eq(s, s.parse().expect(&msg), expected_ip_addr, expected_port);
344 assert_eq(
345 s,
346 s.try_into().expect(&msg),
347 expected_ip_addr,
348 expected_port,
349 );
350 assert_eq(
351 s,
352 s.to_owned().try_into().expect(&msg),
353 expected_ip_addr,
354 expected_port,
355 );
356 assert_eq(
357 s,
358 s.as_bytes().try_into().expect(&msg),
359 expected_ip_addr,
360 expected_port,
361 );
362 assert_eq(
363 s,
364 s.as_bytes().to_vec().try_into().expect(&msg),
365 expected_ip_addr,
366 expected_port,
367 );
368 }
369 }
370
371 #[test]
372 fn test_parse_invalid() {
373 for s in [
374 "",
375 "-",
376 ".",
377 ":",
378 ":80",
379 "-.",
380 ".-",
381 "::1",
382 "127.0.0.1",
383 "[::1]",
384 "2001:db8:3333:4444:5555:6666:7777:8888",
385 "[2001:db8:3333:4444:5555:6666:7777:8888]",
386 "example.com",
387 "example.com:",
388 "example.com:-1",
389 "example.com:999999",
390 "example.com:80",
391 "example:com",
392 "[127.0.0.1]:80",
393 "2001:db8:3333:4444:5555:6666:7777:8888:80",
394 ] {
395 let msg = format!("parsing '{}'", s);
396 assert!(s.parse::<SocketAddress>().is_err(), "{}", msg);
397 assert!(SocketAddress::try_from(s).is_err(), "{}", msg);
398 assert!(SocketAddress::try_from(s.to_owned()).is_err(), "{}", msg);
399 assert!(SocketAddress::try_from(s.as_bytes()).is_err(), "{}", msg);
400 assert!(
401 SocketAddress::try_from(s.as_bytes().to_vec()).is_err(),
402 "{}",
403 msg
404 );
405 }
406 }
407
408 #[test]
409 fn test_parse_display() {
410 for (s, expected) in [("[::1]:80", "[::1]:80"), ("127.0.0.1:80", "127.0.0.1:80")] {
411 let msg = format!("parsing '{}'", s);
412 let socket_address: SocketAddress = s.parse().expect(&msg);
413 assert_eq!(socket_address.to_string(), expected, "{}", msg);
414 }
415 }
416}