Skip to main content

forest/utils/
mod.rs

1// Copyright 2019-2026 ChainSafe Systems
2// SPDX-License-Identifier: Apache-2.0, MIT
3
4pub mod cache;
5pub mod cid;
6pub mod db;
7pub mod encoding;
8pub mod flume;
9pub mod get_size;
10pub mod io;
11pub mod misc;
12pub mod multihash;
13pub mod net;
14pub mod p2p;
15pub mod proofs_api;
16pub mod rand;
17pub mod reqwest_resume;
18#[cfg(feature = "sqlite")]
19pub mod sqlite;
20pub mod stats;
21pub mod stream;
22pub mod version;
23
24use anyhow::{Context as _, bail};
25use futures::Future;
26use multiaddr::{Multiaddr, Protocol};
27use std::{str::FromStr, time::Duration};
28use tokio::time::sleep;
29use tracing::error;
30use url::Url;
31
32/// `"hunter2:/ip4/127.0.0.1/wss" -> "wss://:hunter2@127.0.0.1/"`
33#[derive(Clone, Debug)]
34pub struct UrlFromMultiAddr(pub Url);
35
36impl FromStr for UrlFromMultiAddr {
37    type Err = anyhow::Error;
38
39    fn from_str(s: &str) -> Result<Self, Self::Err> {
40        let (p, s) = match s.split_once(':') {
41            Some((first, rest)) => (Some(first), rest),
42            None => (None, s),
43        };
44        let m = Multiaddr::from_str(s).context("invalid multiaddr")?;
45        let mut u = multiaddr2url(&m).context("unsupported multiaddr")?;
46        if u.set_password(p).is_err() {
47            bail!("unsupported password")
48        }
49        Ok(Self(u))
50    }
51}
52
53/// `"/dns/example.com/tcp/8080/http" -> "http://example.com:8080/"`
54///
55/// Returns [`None`] on unsupported formats, or if there is a URL parsing error.
56///
57/// Note that [`Multiaddr`]s do NOT support a (URL) `path`, so that must be handled
58/// out-of-band.
59fn multiaddr2url(m: &Multiaddr) -> Option<Url> {
60    let mut components = m.iter().peekable();
61    let host = match components.next()? {
62        Protocol::Dns(it) | Protocol::Dns4(it) | Protocol::Dns6(it) | Protocol::Dnsaddr(it) => {
63            it.to_string()
64        }
65        Protocol::Ip4(it) => it.to_string(),
66        Protocol::Ip6(it) => it.to_string(),
67        _ => return None,
68    };
69    let port = components
70        .next_if(|it| matches!(it, Protocol::Tcp(_)))
71        .map(|it| match it {
72            Protocol::Tcp(port) => port,
73            _ => unreachable!(),
74        });
75    // ENHANCEMENT: could recognise `Tcp/443/Tls` as `https`
76    let scheme = match components.next()? {
77        Protocol::Http => "http",
78        Protocol::Https => "https",
79        Protocol::Ws(it) if it == "/" => "ws",
80        Protocol::Wss(it) if it == "/" => "wss",
81        _ => return None,
82    };
83    let None = components.next() else { return None };
84    let parse_me = match port {
85        Some(port) => format!("{scheme}://{host}:{port}"),
86        None => format!("{scheme}://{host}"),
87    };
88    parse_me.parse().ok()
89}
90
91#[test]
92fn test_url_from_multiaddr() {
93    #[track_caller]
94    fn do_test(input: &str, expected: &str) {
95        let UrlFromMultiAddr(url) = input.parse().unwrap();
96        assert_eq!(url.as_str(), expected, "input: {input}");
97    }
98    do_test("/dns/example.com/http", "http://example.com/");
99    do_test("/dns/example.com/tcp/8080/http", "http://example.com:8080/");
100    do_test("/dns/example.com/tcp/8081/ws", "ws://example.com:8081/");
101    do_test("/ip4/127.0.0.1/wss", "wss://127.0.0.1/");
102
103    // with password
104    do_test(
105        "hunter2:/dns/example.com/http",
106        "http://:hunter2@example.com/",
107    );
108    do_test(
109        "hunter2:/dns/example.com/tcp/8080/http",
110        "http://:hunter2@example.com:8080/",
111    );
112    do_test("hunter2:/ip4/127.0.0.1/wss", "wss://:hunter2@127.0.0.1/");
113}
114
115/// Keep running the future created by `make_fut` until the timeout or retry
116/// limit in `args` is reached.
117/// `F` _must_ be cancel safe.
118#[tracing::instrument(skip_all)]
119pub async fn retry<F, T, E>(
120    args: RetryArgs,
121    mut make_fut: impl FnMut() -> F,
122) -> Result<T, RetryError>
123where
124    F: Future<Output = Result<T, E>>,
125    E: std::fmt::Debug,
126{
127    let max_retries = args.max_retries.unwrap_or(usize::MAX);
128    let task = async {
129        for _ in 0..max_retries {
130            match make_fut().await {
131                Ok(ok) => return Ok(ok),
132                Err(err) => error!("retrying operation after {err:?}"),
133            }
134            if let Some(delay) = args.delay {
135                sleep(delay).await;
136            }
137        }
138        Err(RetryError::RetriesExceeded)
139    };
140
141    if let Some(timeout) = args.timeout {
142        tokio::time::timeout(timeout, task)
143            .await
144            .map_err(|_| RetryError::TimeoutExceeded)?
145    } else {
146        task.await
147    }
148}
149
150#[derive(Debug, Clone, Copy, smart_default::SmartDefault)]
151pub struct RetryArgs {
152    #[default(Some(Duration::from_secs(1)))]
153    pub timeout: Option<Duration>,
154    #[default(Some(5))]
155    pub max_retries: Option<usize>,
156    #[default(Some(Duration::from_millis(200)))]
157    pub delay: Option<Duration>,
158}
159
160#[derive(Debug, Clone, Copy, PartialEq, Eq, thiserror::Error)]
161pub enum RetryError {
162    #[error("operation timed out")]
163    TimeoutExceeded,
164    #[error("retry limit exceeded")]
165    RetriesExceeded,
166}
167
168#[allow(dead_code)]
169#[cfg(test)]
170pub fn is_debug_build() -> bool {
171    cfg!(debug_assertions)
172}
173
174#[allow(dead_code)]
175#[cfg(test)]
176pub fn is_ci() -> bool {
177    // https://docs.github.com/en/actions/writing-workflows/choosing-what-your-workflow-does/store-information-in-variables#default-environment-variables
178    misc::env::is_env_truthy("CI")
179}
180
181#[cfg(test)]
182mod tests {
183    mod files;
184
185    use RetryError::{RetriesExceeded, TimeoutExceeded};
186    use futures::future::pending;
187    use std::{future::ready, sync::atomic::AtomicUsize};
188
189    use super::*;
190
191    impl RetryArgs {
192        fn new_ms(
193            timeout: impl Into<Option<u64>>,
194            max_retries: impl Into<Option<usize>>,
195            delay: impl Into<Option<u64>>,
196        ) -> Self {
197            Self {
198                timeout: timeout.into().map(Duration::from_millis),
199                max_retries: max_retries.into(),
200                delay: delay.into().map(Duration::from_millis),
201            }
202        }
203    }
204
205    #[tokio::test]
206    async fn timeout() {
207        let res = retry(RetryArgs::new_ms(1, None, None), pending::<Result<(), ()>>).await;
208        assert_eq!(Err(TimeoutExceeded), res);
209    }
210
211    #[tokio::test]
212    async fn retries() {
213        let res = retry(RetryArgs::new_ms(None, 1, None), || ready(Err::<(), _>(()))).await;
214        assert_eq!(Err(RetriesExceeded), res);
215    }
216
217    #[tokio::test]
218    async fn ok() {
219        let res = retry(RetryArgs::default(), || ready(Ok::<_, ()>(()))).await;
220        assert_eq!(Ok(()), res);
221    }
222
223    #[tokio::test]
224    async fn needs_retry() {
225        use std::sync::atomic::Ordering::SeqCst;
226        let count = AtomicUsize::new(0);
227        let res = retry(RetryArgs::new_ms(None, None, None), || async {
228            match count.fetch_add(1, SeqCst) > 5 {
229                true => Ok(()),
230                false => Err(()),
231            }
232        })
233        .await;
234        assert_eq!(Ok(()), res);
235        assert!(count.load(SeqCst) > 5);
236    }
237}