cloudpub_common/
utils.rs

1use anyhow::{anyhow, Context, Result};
2use backoff::backoff::Backoff;
3use backoff::Notify;
4use std::future::Future;
5use std::io;
6use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
7use tokio::net::{lookup_host, TcpListener, TcpSocket, ToSocketAddrs, UdpSocket};
8use tokio::sync::watch;
9use tracing::{debug, trace};
10
11use crate::protocol::message::Message as ProtocolMessage;
12use futures::future::{BoxFuture, FutureExt};
13
14pub fn box_future<F, T>(future: F) -> BoxFuture<'static, T>
15where
16    F: Future<Output = T> + Send + 'static,
17    T: 'static,
18{
19    future.boxed()
20}
21
22pub async fn to_socket_addr<A: ToSocketAddrs>(addr: A) -> Result<std::net::SocketAddr> {
23    lookup_host(addr)
24        .await?
25        .next()
26        .ok_or_else(|| anyhow!("Failed to lookup the host"))
27}
28
29pub fn host_port_pair(s: &str) -> Result<(&str, u16)> {
30    let semi = s.rfind(':').expect("missing semicolon");
31    Ok((&s[..semi], s[semi + 1..].parse()?))
32}
33
34/// Create a UDP socket and connect to `addr`
35pub async fn udp_connect<A: ToSocketAddrs>(addr: A) -> Result<UdpSocket> {
36    let addr = to_socket_addr(addr).await?;
37
38    let bind_addr = match addr {
39        std::net::SocketAddr::V4(_) => "0.0.0.0:0",
40        std::net::SocketAddr::V6(_) => ":::0",
41    };
42
43    let s = UdpSocket::bind(bind_addr).await?;
44    s.connect(addr).await?;
45    Ok(s)
46}
47
48// Wrapper of retry_notify
49pub async fn retry_notify_with_deadline<I, E, Fn, Fut, B, N>(
50    backoff: B,
51    operation: Fn,
52    notify: N,
53    deadline: &mut watch::Receiver<bool>,
54) -> Result<I>
55where
56    E: std::error::Error + Send + Sync + 'static,
57    B: Backoff,
58    Fn: FnMut() -> Fut,
59    Fut: Future<Output = std::result::Result<I, backoff::Error<E>>>,
60    N: Notify<E>,
61{
62    tokio::select! {
63        v = backoff::future::retry_notify(backoff, operation, notify) => {
64            v.map_err(anyhow::Error::new)
65        }
66        _ = deadline.changed() => {
67            Err(anyhow!("shutdown"))
68        }
69    }
70}
71
72pub async fn find_free_tcp_port() -> Result<u16> {
73    let tcp_listener = TcpListener::bind("0.0.0.0:0").await?;
74    let port = tcp_listener.local_addr()?.port();
75    Ok(port)
76}
77
78pub async fn find_free_udp_port() -> Result<u16> {
79    let udp_listener = UdpSocket::bind("0.0.0.0:0").await?;
80    let port = udp_listener.local_addr()?.port();
81    Ok(port)
82}
83
84pub async fn is_udp_port_available(bind_addr: &str, port: u16) -> Result<bool> {
85    match UdpSocket::bind((bind_addr, port)).await {
86        Ok(_) => Ok(true),
87        Err(ref e) if e.kind() == io::ErrorKind::AddrInUse => Ok(false),
88        Err(e) => Err(e).context("Failed to check UDP port")?,
89    }
90}
91
92pub async fn is_tcp_port_available(bind_addr: &str, port: u16) -> Result<bool> {
93    // Parse as IpAddr to correctly support IPv4/IPv6 and avoid "[addr]:port" formatting issues
94    let ip: std::net::IpAddr = bind_addr
95        .parse()
96        .with_context(|| format!("Invalid bind address: {}", bind_addr))?;
97    let addr = SocketAddr::new(ip, port);
98
99    let tcp_socket = match addr {
100        SocketAddr::V4(_) => TcpSocket::new_v4()?,
101        SocketAddr::V6(_) => TcpSocket::new_v6()?,
102    };
103
104    debug!("Check port: {}", addr);
105    match tcp_socket.bind(addr) {
106        Ok(_) => Ok(true),
107        Err(ref e) if e.kind() == io::ErrorKind::AddrInUse => Ok(false),
108        Err(ref e) if e.kind() == io::ErrorKind::PermissionDenied => Ok(false),
109        Err(e) => Err(e).context("Failed to check TCP port")?,
110    }
111}
112
113pub fn get_version_number(version: &str) -> i64 {
114    let mut n = 0;
115    for x in version.split(".") {
116        n = n * 10000 + x.parse::<i64>().unwrap_or(0);
117    }
118    n
119}
120
121pub fn get_platform() -> String {
122    #[cfg(all(target_os = "linux", target_arch = "x86_64"))]
123    let platform = "linux-x86_64".to_string();
124    #[cfg(all(target_os = "linux", target_arch = "arm"))]
125    let platform = "linux-armv7".to_string();
126    #[cfg(all(target_os = "linux", target_arch = "aarch64"))]
127    let platform = "linux-aarch64".to_string();
128    #[cfg(all(target_os = "linux", target_arch = "x86"))]
129    let platform = "linux-i686".to_string();
130    #[cfg(all(target_os = "linux", target_arch = "mips", target_endian = "big"))]
131    let platform = "linux-mips".to_string();
132    #[cfg(all(target_os = "linux", target_arch = "mips", target_endian = "little"))]
133    let platform = "linux-mipsel".to_string();
134    #[cfg(all(target_os = "windows", target_arch = "x86_64"))]
135    let platform = "windows-x86_64".to_string();
136    #[cfg(all(target_os = "windows", target_arch = "x86"))]
137    let platform = "windows-i686".to_string();
138    #[cfg(all(target_os = "macos", target_arch = "x86_64"))]
139    let platform = "macos-x86_64".to_string();
140    #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
141    let platform = "macos-aarch64".to_string();
142    platform
143}
144
145pub fn split_host_port(host_and_port: &str, default_port: u16) -> (String, u16) {
146    let parts = host_and_port.split(':');
147    let parts: Vec<&str> = parts.collect();
148    let host = parts[0].to_string();
149    let port = if parts.len() > 1 {
150        parts[1].parse::<u16>().unwrap_or(default_port)
151    } else {
152        default_port
153    };
154    (host, port)
155}
156
157pub fn socket_addr_to_proto(addr: &SocketAddr) -> crate::protocol::SocketAddr {
158    match addr {
159        SocketAddr::V4(addr_v4) => crate::protocol::SocketAddr {
160            addr: Some(crate::protocol::socket_addr::Addr::V4(
161                crate::protocol::SocketAddrV4 {
162                    ip: u32::from(*addr_v4.ip()),
163                    port: addr_v4.port() as u32,
164                },
165            )),
166        },
167        SocketAddr::V6(addr_v6) => crate::protocol::SocketAddr {
168            addr: Some(crate::protocol::socket_addr::Addr::V6(
169                crate::protocol::SocketAddrV6 {
170                    ip: addr_v6.ip().octets().to_vec(),
171                    port: addr_v6.port() as u32,
172                    flowinfo: addr_v6.flowinfo(),
173                    scope_id: addr_v6.scope_id(),
174                },
175            )),
176        },
177    }
178}
179
180pub fn proto_to_socket_addr(proto_addr: &crate::protocol::SocketAddr) -> Result<SocketAddr> {
181    match &proto_addr.addr {
182        Some(crate::protocol::socket_addr::Addr::V4(v4)) => Ok(SocketAddr::V4(SocketAddrV4::new(
183            Ipv4Addr::from(v4.ip),
184            v4.port as u16,
185        ))),
186        Some(crate::protocol::socket_addr::Addr::V6(v6)) => {
187            if v6.ip.len() == 16 {
188                let mut ip_bytes = [0u8; 16];
189                ip_bytes.copy_from_slice(&v6.ip);
190                Ok(SocketAddr::V6(SocketAddrV6::new(
191                    Ipv6Addr::from(ip_bytes),
192                    v6.port as u16,
193                    v6.flowinfo,
194                    v6.scope_id,
195                )))
196            } else {
197                Err(anyhow!(
198                    "Invalid IPv6 address length: expected 16 bytes, got {}",
199                    v6.ip.len()
200                ))
201            }
202        }
203        None => Err(anyhow!("Missing socket address")),
204    }
205}
206
207pub fn trace_message(label: &str, msg: &ProtocolMessage) {
208    match msg {
209        ProtocolMessage::CreateDataChannelWithId(data) => {
210            trace!(
211                "{}: CreateDataChannelWithId {{ channel_id: {}, {:?} }}",
212                label,
213                data.channel_id,
214                data.endpoint
215            );
216        }
217        ProtocolMessage::DataChannelData(data) => {
218            //let data_str = String::from_utf8_lossy(&data.data);
219            trace!(
220                "{}: DataChannelData {{ channel_id: {}, data_size: {} bytes }}",
221                label,
222                data.channel_id,
223                data.data.len(),
224            );
225        }
226        ProtocolMessage::DataChannelDataUdp(data) => {
227            trace!(
228                "{}: DataChannelDataUdp {{ channel_id: {}, data_size: {} bytes, socket_addr: {:?} }}",
229                label,
230                data.channel_id,
231                data.data.len(),
232                data.socket_addr
233            );
234        }
235        ProtocolMessage::DataChannelAck(data) => {
236            trace!(
237                "{}: DataChannelAck {{ channel_id: {}, consumed: {} bytes }}",
238                label,
239                data.channel_id,
240                data.consumed
241            );
242        }
243        ProtocolMessage::DataChannelEof(data) => {
244            trace!(
245                "{}: DataChannelEof {{ channel_id: {}, error: {} }}",
246                label,
247                data.channel_id,
248                data.error
249            );
250        }
251        ProtocolMessage::Progress(data) => {
252            trace!("{}: Progress {:?}", label, data);
253        }
254        ProtocolMessage::HeartBeat(_) => {
255            trace!("{}: HeartBeat", label);
256        }
257        _ => {
258            debug!("{}: {:?}", label, msg);
259        }
260    }
261}