1pub mod cache;
5pub mod cid;
6pub mod db;
7pub mod encoding;
8pub mod flume;
9pub mod get_size;
10pub mod io;
11pub mod misc;
12pub mod multihash;
13pub mod net;
14pub mod p2p;
15pub mod proofs_api;
16pub mod rand;
17pub mod reqwest_resume;
18#[cfg(feature = "sqlite")]
19pub mod sqlite;
20pub mod stats;
21pub mod stream;
22pub mod version;
23
24use anyhow::{Context as _, bail};
25use futures::Future;
26use multiaddr::{Multiaddr, Protocol};
27use std::{str::FromStr, time::Duration};
28use tokio::time::sleep;
29use tracing::error;
30use url::Url;
31
32#[derive(Clone, Debug)]
34pub struct UrlFromMultiAddr(pub Url);
35
36impl FromStr for UrlFromMultiAddr {
37 type Err = anyhow::Error;
38
39 fn from_str(s: &str) -> Result<Self, Self::Err> {
40 let (p, s) = match s.split_once(':') {
41 Some((first, rest)) => (Some(first), rest),
42 None => (None, s),
43 };
44 let m = Multiaddr::from_str(s).context("invalid multiaddr")?;
45 let mut u = multiaddr2url(&m).context("unsupported multiaddr")?;
46 if u.set_password(p).is_err() {
47 bail!("unsupported password")
48 }
49 Ok(Self(u))
50 }
51}
52
53fn multiaddr2url(m: &Multiaddr) -> Option<Url> {
60 let mut components = m.iter().peekable();
61 let host = match components.next()? {
62 Protocol::Dns(it) | Protocol::Dns4(it) | Protocol::Dns6(it) | Protocol::Dnsaddr(it) => {
63 it.to_string()
64 }
65 Protocol::Ip4(it) => it.to_string(),
66 Protocol::Ip6(it) => it.to_string(),
67 _ => return None,
68 };
69 let port = components
70 .next_if(|it| matches!(it, Protocol::Tcp(_)))
71 .map(|it| match it {
72 Protocol::Tcp(port) => port,
73 _ => unreachable!(),
74 });
75 let scheme = match components.next()? {
77 Protocol::Http => "http",
78 Protocol::Https => "https",
79 Protocol::Ws(it) if it == "/" => "ws",
80 Protocol::Wss(it) if it == "/" => "wss",
81 _ => return None,
82 };
83 let None = components.next() else { return None };
84 let parse_me = match port {
85 Some(port) => format!("{scheme}://{host}:{port}"),
86 None => format!("{scheme}://{host}"),
87 };
88 parse_me.parse().ok()
89}
90
91#[test]
92fn test_url_from_multiaddr() {
93 #[track_caller]
94 fn do_test(input: &str, expected: &str) {
95 let UrlFromMultiAddr(url) = input.parse().unwrap();
96 assert_eq!(url.as_str(), expected, "input: {input}");
97 }
98 do_test("/dns/example.com/http", "http://example.com/");
99 do_test("/dns/example.com/tcp/8080/http", "http://example.com:8080/");
100 do_test("/dns/example.com/tcp/8081/ws", "ws://example.com:8081/");
101 do_test("/ip4/127.0.0.1/wss", "wss://127.0.0.1/");
102
103 do_test(
105 "hunter2:/dns/example.com/http",
106 "http://:hunter2@example.com/",
107 );
108 do_test(
109 "hunter2:/dns/example.com/tcp/8080/http",
110 "http://:hunter2@example.com:8080/",
111 );
112 do_test("hunter2:/ip4/127.0.0.1/wss", "wss://:hunter2@127.0.0.1/");
113}
114
115#[tracing::instrument(skip_all)]
119pub async fn retry<F, T, E>(
120 args: RetryArgs,
121 mut make_fut: impl FnMut() -> F,
122) -> Result<T, RetryError>
123where
124 F: Future<Output = Result<T, E>>,
125 E: std::fmt::Debug,
126{
127 let max_retries = args.max_retries.unwrap_or(usize::MAX);
128 let task = async {
129 for _ in 0..max_retries {
130 match make_fut().await {
131 Ok(ok) => return Ok(ok),
132 Err(err) => error!("retrying operation after {err:?}"),
133 }
134 if let Some(delay) = args.delay {
135 sleep(delay).await;
136 }
137 }
138 Err(RetryError::RetriesExceeded)
139 };
140
141 if let Some(timeout) = args.timeout {
142 tokio::time::timeout(timeout, task)
143 .await
144 .map_err(|_| RetryError::TimeoutExceeded)?
145 } else {
146 task.await
147 }
148}
149
150#[derive(Debug, Clone, Copy, smart_default::SmartDefault)]
151pub struct RetryArgs {
152 #[default(Some(Duration::from_secs(1)))]
153 pub timeout: Option<Duration>,
154 #[default(Some(5))]
155 pub max_retries: Option<usize>,
156 #[default(Some(Duration::from_millis(200)))]
157 pub delay: Option<Duration>,
158}
159
160#[derive(Debug, Clone, Copy, PartialEq, Eq, thiserror::Error)]
161pub enum RetryError {
162 #[error("operation timed out")]
163 TimeoutExceeded,
164 #[error("retry limit exceeded")]
165 RetriesExceeded,
166}
167
168#[allow(dead_code)]
169#[cfg(test)]
170pub fn is_debug_build() -> bool {
171 cfg!(debug_assertions)
172}
173
174#[allow(dead_code)]
175#[cfg(test)]
176pub fn is_ci() -> bool {
177 misc::env::is_env_truthy("CI")
179}
180
181#[cfg(test)]
182mod tests {
183 mod files;
184
185 use RetryError::{RetriesExceeded, TimeoutExceeded};
186 use futures::future::pending;
187 use std::{future::ready, sync::atomic::AtomicUsize};
188
189 use super::*;
190
191 impl RetryArgs {
192 fn new_ms(
193 timeout: impl Into<Option<u64>>,
194 max_retries: impl Into<Option<usize>>,
195 delay: impl Into<Option<u64>>,
196 ) -> Self {
197 Self {
198 timeout: timeout.into().map(Duration::from_millis),
199 max_retries: max_retries.into(),
200 delay: delay.into().map(Duration::from_millis),
201 }
202 }
203 }
204
205 #[tokio::test]
206 async fn timeout() {
207 let res = retry(RetryArgs::new_ms(1, None, None), pending::<Result<(), ()>>).await;
208 assert_eq!(Err(TimeoutExceeded), res);
209 }
210
211 #[tokio::test]
212 async fn retries() {
213 let res = retry(RetryArgs::new_ms(None, 1, None), || ready(Err::<(), _>(()))).await;
214 assert_eq!(Err(RetriesExceeded), res);
215 }
216
217 #[tokio::test]
218 async fn ok() {
219 let res = retry(RetryArgs::default(), || ready(Ok::<_, ()>(()))).await;
220 assert_eq!(Ok(()), res);
221 }
222
223 #[tokio::test]
224 async fn needs_retry() {
225 use std::sync::atomic::Ordering::SeqCst;
226 let count = AtomicUsize::new(0);
227 let res = retry(RetryArgs::new_ms(None, None, None), || async {
228 match count.fetch_add(1, SeqCst) > 5 {
229 true => Ok(()),
230 false => Err(()),
231 }
232 })
233 .await;
234 assert_eq!(Ok(()), res);
235 assert!(count.load(SeqCst) > 5);
236 }
237}