Skip to main content

forest/utils/
mod.rs

1// Copyright 2019-2026 ChainSafe Systems
2// SPDX-License-Identifier: Apache-2.0, MIT
3
4pub mod broadcast;
5pub mod cache;
6pub mod cid;
7pub mod db;
8pub mod encoding;
9pub mod flume;
10pub mod get_size;
11pub mod hash;
12pub mod io;
13pub mod misc;
14pub mod multihash;
15pub mod net;
16pub mod p2p;
17pub mod proofs_api;
18pub mod rand;
19pub mod reqwest_resume;
20mod shallow_clone;
21pub use shallow_clone::ShallowClone;
22#[cfg(feature = "sqlite")]
23pub mod sqlite;
24pub mod stats;
25pub mod stream;
26pub mod version;
27
28use anyhow::{Context as _, bail};
29use futures::Future;
30use multiaddr::{Multiaddr, Protocol};
31use std::{str::FromStr, time::Duration};
32use tokio::time::sleep;
33use tracing::error;
34use url::Url;
35
36/// `"hunter2:/ip4/127.0.0.1/wss" -> "wss://:hunter2@127.0.0.1/"`
37#[derive(Clone, Debug)]
38pub struct UrlFromMultiAddr(pub Url);
39
40impl FromStr for UrlFromMultiAddr {
41    type Err = anyhow::Error;
42
43    fn from_str(s: &str) -> Result<Self, Self::Err> {
44        let (p, s) = match s.split_once(':') {
45            Some((first, rest)) => (Some(first), rest),
46            None => (None, s),
47        };
48        let m = Multiaddr::from_str(s).context("invalid multiaddr")?;
49        let mut u = multiaddr2url(&m).context("unsupported multiaddr")?;
50        if u.set_password(p).is_err() {
51            bail!("unsupported password")
52        }
53        Ok(Self(u))
54    }
55}
56
57/// `"/dns/example.com/tcp/8080/http" -> "http://example.com:8080/"`
58///
59/// Returns [`None`] on unsupported formats, or if there is a URL parsing error.
60///
61/// Note that [`Multiaddr`]s do NOT support a (URL) `path`, so that must be handled
62/// out-of-band.
63fn multiaddr2url(m: &Multiaddr) -> Option<Url> {
64    let mut components = m.iter().peekable();
65    let host = match components.next()? {
66        Protocol::Dns(it) | Protocol::Dns4(it) | Protocol::Dns6(it) | Protocol::Dnsaddr(it) => {
67            it.to_string()
68        }
69        Protocol::Ip4(it) => it.to_string(),
70        Protocol::Ip6(it) => it.to_string(),
71        _ => return None,
72    };
73    let port = components
74        .next_if(|it| matches!(it, Protocol::Tcp(_)))
75        .map(|it| match it {
76            Protocol::Tcp(port) => port,
77            _ => unreachable!(),
78        });
79    // ENHANCEMENT: could recognise `Tcp/443/Tls` as `https`
80    let scheme = match components.next()? {
81        Protocol::Http => "http",
82        Protocol::Https => "https",
83        Protocol::Ws(it) if it == "/" => "ws",
84        Protocol::Wss(it) if it == "/" => "wss",
85        _ => return None,
86    };
87    let None = components.next() else { return None };
88    let parse_me = match port {
89        Some(port) => format!("{scheme}://{host}:{port}"),
90        None => format!("{scheme}://{host}"),
91    };
92    parse_me.parse().ok()
93}
94
95#[test]
96fn test_url_from_multiaddr() {
97    #[track_caller]
98    fn do_test(input: &str, expected: &str) {
99        let UrlFromMultiAddr(url) = input.parse().unwrap();
100        assert_eq!(url.as_str(), expected, "input: {input}");
101    }
102    do_test("/dns/example.com/http", "http://example.com/");
103    do_test("/dns/example.com/tcp/8080/http", "http://example.com:8080/");
104    do_test("/dns/example.com/tcp/8081/ws", "ws://example.com:8081/");
105    do_test("/ip4/127.0.0.1/wss", "wss://127.0.0.1/");
106
107    // with password
108    do_test(
109        "hunter2:/dns/example.com/http",
110        "http://:hunter2@example.com/",
111    );
112    do_test(
113        "hunter2:/dns/example.com/tcp/8080/http",
114        "http://:hunter2@example.com:8080/",
115    );
116    do_test("hunter2:/ip4/127.0.0.1/wss", "wss://:hunter2@127.0.0.1/");
117}
118
119/// Keep running the future created by `make_fut` until the timeout or retry
120/// limit in `args` is reached.
121/// `F` _must_ be cancel safe.
122#[tracing::instrument(skip_all)]
123pub async fn retry<F, T, E>(
124    args: RetryArgs,
125    mut make_fut: impl FnMut() -> F,
126) -> Result<T, RetryError>
127where
128    F: Future<Output = Result<T, E>>,
129    E: std::fmt::Debug,
130{
131    let max_retries = args.max_retries.unwrap_or(usize::MAX);
132    let task = async {
133        for _ in 0..max_retries {
134            match make_fut().await {
135                Ok(ok) => return Ok(ok),
136                Err(err) => error!("retrying operation after {err:?}"),
137            }
138            if let Some(delay) = args.delay {
139                sleep(delay).await;
140            }
141        }
142        Err(RetryError::RetriesExceeded)
143    };
144
145    if let Some(timeout) = args.timeout {
146        tokio::time::timeout(timeout, task)
147            .await
148            .map_err(|_| RetryError::TimeoutExceeded)?
149    } else {
150        task.await
151    }
152}
153
154#[derive(Debug, Clone, Copy, smart_default::SmartDefault)]
155pub struct RetryArgs {
156    #[default(Some(Duration::from_secs(1)))]
157    pub timeout: Option<Duration>,
158    #[default(Some(5))]
159    pub max_retries: Option<usize>,
160    #[default(Some(Duration::from_millis(200)))]
161    pub delay: Option<Duration>,
162}
163
164#[derive(Debug, Clone, Copy, PartialEq, Eq, thiserror::Error)]
165pub enum RetryError {
166    #[error("operation timed out")]
167    TimeoutExceeded,
168    #[error("retry limit exceeded")]
169    RetriesExceeded,
170}
171
172#[allow(dead_code)]
173#[cfg(test)]
174pub fn is_debug_build() -> bool {
175    cfg!(debug_assertions)
176}
177
178#[allow(dead_code)]
179#[cfg(test)]
180pub fn is_ci() -> bool {
181    // https://docs.github.com/en/actions/writing-workflows/choosing-what-your-workflow-does/store-information-in-variables#default-environment-variables
182    misc::env::is_env_truthy("CI")
183}
184
185#[cfg(test)]
186mod tests {
187    mod files;
188
189    use RetryError::{RetriesExceeded, TimeoutExceeded};
190    use futures::future::pending;
191    use std::{future::ready, sync::atomic::AtomicUsize};
192
193    use super::*;
194
195    impl RetryArgs {
196        fn new_ms(
197            timeout: impl Into<Option<u64>>,
198            max_retries: impl Into<Option<usize>>,
199            delay: impl Into<Option<u64>>,
200        ) -> Self {
201            Self {
202                timeout: timeout.into().map(Duration::from_millis),
203                max_retries: max_retries.into(),
204                delay: delay.into().map(Duration::from_millis),
205            }
206        }
207    }
208
209    #[tokio::test]
210    async fn timeout() {
211        let res = retry(RetryArgs::new_ms(1, None, None), pending::<Result<(), ()>>).await;
212        assert_eq!(Err(TimeoutExceeded), res);
213    }
214
215    #[tokio::test]
216    async fn retries() {
217        let res = retry(RetryArgs::new_ms(None, 1, None), || ready(Err::<(), _>(()))).await;
218        assert_eq!(Err(RetriesExceeded), res);
219    }
220
221    #[tokio::test]
222    async fn ok() {
223        let res = retry(RetryArgs::default(), || ready(Ok::<_, ()>(()))).await;
224        assert_eq!(Ok(()), res);
225    }
226
227    #[tokio::test]
228    async fn needs_retry() {
229        use std::sync::atomic::Ordering::SeqCst;
230        let count = AtomicUsize::new(0);
231        let res = retry(RetryArgs::new_ms(None, None, None), || async {
232            match count.fetch_add(1, SeqCst) > 5 {
233                true => Ok(()),
234                false => Err(()),
235            }
236        })
237        .await;
238        assert_eq!(Ok(()), res);
239        assert!(count.load(SeqCst) > 5);
240    }
241}