1use anyhow::{anyhow, Context, Result};
2use backoff::backoff::Backoff;
3use backoff::Notify;
4use std::future::Future;
5use std::io;
6use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
7use tokio::net::{lookup_host, TcpListener, TcpSocket, ToSocketAddrs, UdpSocket};
8use tokio::sync::watch;
9use tracing::{debug, trace};
10
11use crate::protocol::message::Message as ProtocolMessage;
12use futures::future::{BoxFuture, FutureExt};
13
14pub fn box_future<F, T>(future: F) -> BoxFuture<'static, T>
15where
16 F: Future<Output = T> + Send + 'static,
17 T: 'static,
18{
19 future.boxed()
20}
21
22pub async fn to_socket_addr<A: ToSocketAddrs>(addr: A) -> Result<std::net::SocketAddr> {
23 lookup_host(addr)
24 .await?
25 .next()
26 .ok_or_else(|| anyhow!("Failed to lookup the host"))
27}
28
29pub fn host_port_pair(s: &str) -> Result<(&str, u16)> {
30 let semi = s.rfind(':').expect("missing semicolon");
31 Ok((&s[..semi], s[semi + 1..].parse()?))
32}
33
34pub async fn udp_connect<A: ToSocketAddrs>(addr: A) -> Result<UdpSocket> {
36 let addr = to_socket_addr(addr).await?;
37
38 let bind_addr = match addr {
39 std::net::SocketAddr::V4(_) => "0.0.0.0:0",
40 std::net::SocketAddr::V6(_) => ":::0",
41 };
42
43 let s = UdpSocket::bind(bind_addr).await?;
44 s.connect(addr).await?;
45 Ok(s)
46}
47
48pub async fn retry_notify_with_deadline<I, E, Fn, Fut, B, N>(
50 backoff: B,
51 operation: Fn,
52 notify: N,
53 deadline: &mut watch::Receiver<bool>,
54) -> Result<I>
55where
56 E: std::error::Error + Send + Sync + 'static,
57 B: Backoff,
58 Fn: FnMut() -> Fut,
59 Fut: Future<Output = std::result::Result<I, backoff::Error<E>>>,
60 N: Notify<E>,
61{
62 tokio::select! {
63 v = backoff::future::retry_notify(backoff, operation, notify) => {
64 v.map_err(anyhow::Error::new)
65 }
66 _ = deadline.changed() => {
67 Err(anyhow!("shutdown"))
68 }
69 }
70}
71
72pub async fn find_free_tcp_port() -> Result<u16> {
73 let tcp_listener = TcpListener::bind("0.0.0.0:0").await?;
74 let port = tcp_listener.local_addr()?.port();
75 Ok(port)
76}
77
78pub async fn find_free_udp_port() -> Result<u16> {
79 let udp_listener = UdpSocket::bind("0.0.0.0:0").await?;
80 let port = udp_listener.local_addr()?.port();
81 Ok(port)
82}
83
84pub async fn is_udp_port_available(bind_addr: &str, port: u16) -> Result<bool> {
85 match UdpSocket::bind((bind_addr, port)).await {
86 Ok(_) => Ok(true),
87 Err(ref e) if e.kind() == io::ErrorKind::AddrInUse => Ok(false),
88 Err(e) => Err(e).context("Failed to check UDP port")?,
89 }
90}
91
92pub async fn is_tcp_port_available(bind_addr: &str, port: u16) -> Result<bool> {
93 let ip: std::net::IpAddr = bind_addr
95 .parse()
96 .with_context(|| format!("Invalid bind address: {}", bind_addr))?;
97 let addr = SocketAddr::new(ip, port);
98
99 let tcp_socket = match addr {
100 SocketAddr::V4(_) => TcpSocket::new_v4()?,
101 SocketAddr::V6(_) => TcpSocket::new_v6()?,
102 };
103
104 debug!("Check port: {}", addr);
105 match tcp_socket.bind(addr) {
106 Ok(_) => Ok(true),
107 Err(ref e) if e.kind() == io::ErrorKind::AddrInUse => Ok(false),
108 Err(ref e) if e.kind() == io::ErrorKind::PermissionDenied => Ok(false),
109 Err(e) => Err(e).context("Failed to check TCP port")?,
110 }
111}
112
113pub fn get_version_number(version: &str) -> i64 {
114 let mut n = 0;
115 for x in version.split(".") {
116 n = n * 10000 + x.parse::<i64>().unwrap_or(0);
117 }
118 n
119}
120
121pub fn get_platform() -> String {
122 #[cfg(all(target_os = "linux", target_arch = "x86_64"))]
123 let platform = "linux-x86_64".to_string();
124 #[cfg(all(target_os = "linux", target_arch = "arm"))]
125 let platform = "linux-armv7".to_string();
126 #[cfg(all(target_os = "linux", target_arch = "aarch64"))]
127 let platform = "linux-aarch64".to_string();
128 #[cfg(all(target_os = "linux", target_arch = "x86"))]
129 let platform = "linux-i686".to_string();
130 #[cfg(all(target_os = "linux", target_arch = "mips", target_endian = "big"))]
131 let platform = "linux-mips".to_string();
132 #[cfg(all(target_os = "linux", target_arch = "mips", target_endian = "little"))]
133 let platform = "linux-mipsel".to_string();
134 #[cfg(all(target_os = "windows", target_arch = "x86_64"))]
135 let platform = "windows-x86_64".to_string();
136 #[cfg(all(target_os = "windows", target_arch = "x86"))]
137 let platform = "windows-i686".to_string();
138 #[cfg(all(target_os = "macos", target_arch = "x86_64"))]
139 let platform = "macos-x86_64".to_string();
140 #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
141 let platform = "macos-aarch64".to_string();
142 #[cfg(all(target_os = "android", target_arch = "aarch64"))]
143 let platform = "android-aarch64".to_string();
144 platform
145}
146
147pub fn split_host_port(host_and_port: &str, default_port: u16) -> (String, u16) {
148 let parts = host_and_port.split(':');
149 let parts: Vec<&str> = parts.collect();
150 let host = parts[0].to_string();
151 let port = if parts.len() > 1 {
152 parts[1].parse::<u16>().unwrap_or(default_port)
153 } else {
154 default_port
155 };
156 (host, port)
157}
158
159pub fn socket_addr_to_proto(addr: &SocketAddr) -> crate::protocol::SocketAddr {
160 match addr {
161 SocketAddr::V4(addr_v4) => crate::protocol::SocketAddr {
162 addr: Some(crate::protocol::socket_addr::Addr::V4(
163 crate::protocol::SocketAddrV4 {
164 ip: u32::from(*addr_v4.ip()),
165 port: addr_v4.port() as u32,
166 },
167 )),
168 },
169 SocketAddr::V6(addr_v6) => crate::protocol::SocketAddr {
170 addr: Some(crate::protocol::socket_addr::Addr::V6(
171 crate::protocol::SocketAddrV6 {
172 ip: addr_v6.ip().octets().to_vec(),
173 port: addr_v6.port() as u32,
174 flowinfo: addr_v6.flowinfo(),
175 scope_id: addr_v6.scope_id(),
176 },
177 )),
178 },
179 }
180}
181
182pub fn proto_to_socket_addr(proto_addr: &crate::protocol::SocketAddr) -> Result<SocketAddr> {
183 match &proto_addr.addr {
184 Some(crate::protocol::socket_addr::Addr::V4(v4)) => Ok(SocketAddr::V4(SocketAddrV4::new(
185 Ipv4Addr::from(v4.ip),
186 v4.port as u16,
187 ))),
188 Some(crate::protocol::socket_addr::Addr::V6(v6)) => {
189 if v6.ip.len() == 16 {
190 let mut ip_bytes = [0u8; 16];
191 ip_bytes.copy_from_slice(&v6.ip);
192 Ok(SocketAddr::V6(SocketAddrV6::new(
193 Ipv6Addr::from(ip_bytes),
194 v6.port as u16,
195 v6.flowinfo,
196 v6.scope_id,
197 )))
198 } else {
199 Err(anyhow!(
200 "Invalid IPv6 address length: expected 16 bytes, got {}",
201 v6.ip.len()
202 ))
203 }
204 }
205 None => Err(anyhow!("Missing socket address")),
206 }
207}
208
209pub fn trace_message(label: &str, msg: &ProtocolMessage) {
210 match msg {
211 ProtocolMessage::CreateDataChannelWithId(data) => {
212 trace!(
213 "{}: CreateDataChannelWithId {{ channel_id: {}, {:?} }}",
214 label,
215 data.channel_id,
216 data.endpoint
217 );
218 }
219 ProtocolMessage::DataChannelData(data) => {
220 trace!(
222 "{}: DataChannelData {{ channel_id: {}, data_size: {} bytes }}",
223 label,
224 data.channel_id,
225 data.data.len(),
226 );
227 }
228 ProtocolMessage::DataChannelDataUdp(data) => {
229 trace!(
230 "{}: DataChannelDataUdp {{ channel_id: {}, data_size: {} bytes, socket_addr: {:?} }}",
231 label,
232 data.channel_id,
233 data.data.len(),
234 data.socket_addr
235 );
236 }
237 ProtocolMessage::DataChannelAck(data) => {
238 trace!(
239 "{}: DataChannelAck {{ channel_id: {}, consumed: {} bytes }}",
240 label,
241 data.channel_id,
242 data.consumed
243 );
244 }
245 ProtocolMessage::DataChannelEof(data) => {
246 trace!(
247 "{}: DataChannelEof {{ channel_id: {}, error: {} }}",
248 label,
249 data.channel_id,
250 data.error
251 );
252 }
253 ProtocolMessage::Progress(data) => {
254 trace!("{}: Progress {:?}", label, data);
255 }
256 ProtocolMessage::HeartBeat(_) => {
257 trace!("{}: HeartBeat", label);
258 }
259 _ => {
260 debug!("{}: {:?}", label, msg);
261 }
262 }
263}