geph4_aioutils/
lib.rs

1use std::{pin::Pin, time::Duration};
2
3use serde::{de::DeserializeOwned, Serialize};
4use smol::{channel::Receiver, prelude::*};
5
6mod dns;
7pub use dns::*;
8
9/// Race two different futures, returning the first non-Err, or an Err if both branches error.
10pub async fn try_race<T, E, F1, F2>(future1: F1, future2: F2) -> Result<T, E>
11where
12    F1: Future<Output = Result<T, E>>,
13    F2: Future<Output = Result<T, E>>,
14{
15    let (send_err, recv_err) = smol::channel::bounded(2);
16    // success future, always returns a success.
17    let success = smol::future::race(
18        async {
19            match future1.await {
20                Ok(v) => v,
21                Err(e) => {
22                    drop(send_err.try_send(e));
23                    smol::future::pending().await
24                }
25            }
26        },
27        async {
28            match future2.await {
29                Ok(v) => v,
30                Err(e) => {
31                    drop(send_err.try_send(e));
32                    smol::future::pending().await
33                }
34            }
35        },
36    );
37    // fail future. waits for two failures.
38    let fail = async {
39        if recv_err.recv().await.is_ok() {
40            if let Ok(err) = recv_err.recv().await {
41                err
42            } else {
43                smol::future::pending().await
44            }
45        } else {
46            smol::future::pending().await
47        }
48    };
49    // race success and future
50    async { Ok(success.await) }
51        .or(async { Err(fail.await) })
52        .await
53}
54
55/// Reads a bincode-deserializable value with a 16bbe length
56pub async fn read_pascalish<T: DeserializeOwned>(
57    reader: &mut (impl AsyncRead + Unpin),
58) -> anyhow::Result<T> {
59    // first read 2 bytes as length
60    let mut len_bts = [0u8; 2];
61    reader.read_exact(&mut len_bts).await?;
62    let len = u16::from_be_bytes(len_bts);
63    // then read len
64    let mut true_buf = vec![0u8; len as usize];
65    reader.read_exact(&mut true_buf).await?;
66    // then deserialize
67    Ok(bincode::deserialize(&true_buf)?)
68}
69
70/// Writes a bincode-serializable value with a 16bbe length
71pub async fn write_pascalish<T: Serialize>(
72    writer: &mut (impl AsyncWrite + Unpin),
73    value: &T,
74) -> anyhow::Result<()> {
75    let serialized = bincode::serialize(value).unwrap();
76    assert!(serialized.len() <= 65535);
77    // write bytes
78    writer
79        .write_all(&(serialized.len() as u16).to_be_bytes())
80        .await?;
81    writer.write_all(&serialized).await?;
82    Ok(())
83}
84
85const IDLE_TIMEOUT: Duration = Duration::from_secs(3600);
86
87/// Copies an AsyncRead to an AsyncWrite, with a callback for every write.
88#[inline]
89pub async fn copy_with_stats(
90    reader: impl AsyncRead + Unpin,
91    writer: impl AsyncWrite + Unpin,
92    mut on_write: impl FnMut(usize),
93) -> std::io::Result<()> {
94    copy_with_stats_async(reader, writer, move |n| {
95        on_write(n);
96        async {}
97    })
98    .await
99}
100
101/// Copies an AsyncRead to an AsyncWrite, with an async callback for every write.
102#[inline]
103pub async fn copy_with_stats_async<F: Future<Output = ()>>(
104    mut reader: impl AsyncRead + Unpin,
105    mut writer: impl AsyncWrite + Unpin,
106    mut on_write: impl FnMut(usize) -> F,
107) -> std::io::Result<()> {
108    let mut buffer = [0u8; 32768];
109    let mut timeout = smol::Timer::after(IDLE_TIMEOUT);
110    loop {
111        // first read into the small buffer
112        let n = reader
113            .read(&mut buffer)
114            .or(async {
115                (&mut timeout).await;
116                Err(std::io::Error::new(
117                    std::io::ErrorKind::TimedOut,
118                    "copy_with_stats timeout",
119                ))
120            })
121            .await?;
122        if n == 0 {
123            return Ok(());
124        }
125        timeout.set_after(IDLE_TIMEOUT);
126        writer
127            .write_all(&buffer[..n])
128            .or(async {
129                (&mut timeout).await;
130                Err(std::io::Error::new(
131                    std::io::ErrorKind::TimedOut,
132                    "copy_with_stats timeout",
133                ))
134            })
135            .await?;
136        on_write(n).await;
137    }
138}
139
140// /// Copies an Read to an Write, with a callback for every write.
141// pub fn copy_with_stats_sync(
142//     mut reader: impl std::io::Read,
143//     mut writer: impl std::io::Write,
144//     mut on_write: impl FnMut(usize),
145// ) -> std::io::Result<()> {
146//     let mut buffer = [0u8; 32768];
147//     loop {
148//         // first read into the small buffer
149//         let n = reader.read(&mut buffer)?;
150//         if n == 0 {
151//             return Ok(());
152//         }
153//         on_write(n);
154//         writer.write_all(&buffer[..n])?;
155//     }
156// }
157
158pub trait AsyncReadWrite: AsyncRead + AsyncWrite {}
159
160impl<T: AsyncRead + AsyncWrite> AsyncReadWrite for T {}
161
162pub type ConnLike = async_dup::Arc<async_dup::Mutex<Pin<Box<dyn AsyncReadWrite + 'static + Send>>>>;
163
164pub fn connify<T: AsyncRead + AsyncWrite + 'static + Send>(conn: T) -> ConnLike {
165    async_dup::Arc::new(async_dup::Mutex::new(Box::pin(conn)))
166}
167
168pub fn to_ioerror<T: Into<Box<dyn std::error::Error + Send + Sync>>>(e: T) -> std::io::Error {
169    std::io::Error::new(std::io::ErrorKind::Other, e)
170}
171
172/// Reads from an async_channel::Receiver, but returns a vector of all available items instead of just one to save on context-switching.
173pub async fn recv_chan_many<T>(ch: Receiver<T>) -> Result<Vec<T>, smol::channel::RecvError> {
174    let mut toret = vec![ch.recv().await?];
175    // push as many as possible
176    while let Ok(val) = ch.try_recv() {
177        toret.push(val);
178    }
179    Ok(toret)
180}