1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
mod acknowledged_message;
mod async_read_ext;
mod async_write_ext;
mod http;
#[cfg(feature = "async-std-runtime")]
mod interval;
mod join_handle;
mod resolver;
mod stream;

use std::{future::Future, net::SocketAddr, time::Duration};

pub(crate) use self::{
    acknowledged_message::AcknowledgedMessage,
    async_read_ext::AsyncLittleEndianRead,
    async_write_ext::AsyncLittleEndianWrite,
    join_handle::AsyncJoinHandle,
    resolver::AsyncResolver,
    stream::AsyncStream,
};
use crate::{
    error::{ErrorKind, Result},
    options::StreamAddress,
};
pub(crate) use http::HttpClient;
#[cfg(feature = "async-std-runtime")]
use interval::Interval;
#[cfg(feature = "tokio-runtime")]
use tokio::time::Interval;

/// An abstract handle to the async runtime.
#[derive(Clone, Copy, Debug)]
pub(crate) enum AsyncRuntime {
    /// Represents the `tokio` runtime.
    #[cfg(feature = "tokio-runtime")]
    Tokio,

    /// Represents the `async-std` runtime.
    #[cfg(feature = "async-std-runtime")]
    AsyncStd,
}

impl AsyncRuntime {
    /// Spawn a task in the background to run a future.
    ///
    /// If the runtime is still running, this will return a handle to the background task.
    /// Otherwise, it will return `None`. As a result, this must be called from an async block
    /// or function running on a runtime.
    #[allow(clippy::unnecessary_wraps)]
    pub(crate) fn spawn<F, O>(self, fut: F) -> Option<AsyncJoinHandle<O>>
    where
        F: Future<Output = O> + Send + 'static,
        O: Send + 'static,
    {
        match self {
            #[cfg(feature = "tokio-runtime")]
            Self::Tokio => match TokioCallingContext::current() {
                TokioCallingContext::Async(handle) => {
                    Some(AsyncJoinHandle::Tokio(handle.spawn(fut)))
                }
                TokioCallingContext::Sync => None,
            },

            #[cfg(feature = "async-std-runtime")]
            Self::AsyncStd => Some(AsyncJoinHandle::AsyncStd(async_std::task::spawn(fut))),
        }
    }

    /// Spawn a task in the background to run a future.
    ///
    /// Note: this must only be called from an async block or function running on a runtime.
    pub(crate) fn execute<F, O>(self, fut: F)
    where
        F: Future<Output = O> + Send + 'static,
        O: Send + 'static,
    {
        self.spawn(fut);
    }

    /// Run a future in the foreground, blocking on it completing.
    ///
    /// This will panic if called from a sychronous context when tokio is being used.
    #[cfg(any(feature = "sync", test))]
    pub(crate) fn block_on<F, T>(self, fut: F) -> T
    where
        F: Future<Output = T> + Send,
        T: Send,
    {
        #[cfg(all(feature = "tokio-runtime", not(feature = "async-std-runtime")))]
        {
            match TokioCallingContext::current() {
                TokioCallingContext::Async(_handle) => {
                    tokio::task::block_in_place(|| futures::executor::block_on(fut))
                }
                TokioCallingContext::Sync => {
                    panic!("block_on called from tokio outside of async context")
                }
            }
        }

        #[cfg(feature = "async-std-runtime")]
        {
            async_std::task::block_on(fut)
        }
    }

    /// Run a future in the foreground, blocking on it completing.
    /// This does not notify the runtime that it will be blocking and should only be used for
    /// operations that will immediately (or quickly) succeed.
    pub(crate) fn block_in_place<F, T>(self, fut: F) -> T
    where
        F: Future<Output = T> + Send,
        T: Send,
    {
        futures::executor::block_on(fut)
    }

    /// Delay for the specified duration.
    pub(crate) async fn delay_for(self, delay: Duration) {
        #[cfg(feature = "tokio-runtime")]
        {
            tokio::time::delay_for(delay).await
        }

        #[cfg(feature = "async-std-runtime")]
        {
            async_std::task::sleep(delay).await
        }
    }

    /// Await on a future for a maximum amount of time before returning an error.
    pub(crate) async fn timeout<F: Future>(
        self,
        timeout: Duration,
        future: F,
    ) -> Result<F::Output> {
        #[cfg(feature = "tokio-runtime")]
        {
            tokio::time::timeout(timeout, future)
                .await
                .map_err(|e| ErrorKind::Io(e.into()).into())
        }

        #[cfg(feature = "async-std-runtime")]
        {
            async_std::future::timeout(timeout, future)
                .await
                .map_err(|_| ErrorKind::Io(std::io::ErrorKind::TimedOut.into()).into())
        }
    }

    /// Create a new `Interval` that yields with interval of `duration`.
    /// See: https://docs.rs/tokio/latest/tokio/time/fn.interval.html
    pub(crate) fn interval(self, duration: Duration) -> Interval {
        match self {
            #[cfg(feature = "tokio-runtime")]
            Self::Tokio => tokio::time::interval(duration),

            #[cfg(feature = "async-std-runtime")]
            Self::AsyncStd => Interval::new(duration),
        }
    }

    pub(crate) async fn resolve_address(
        self,
        address: &StreamAddress,
    ) -> Result<impl Iterator<Item = SocketAddr>> {
        match self {
            #[cfg(feature = "tokio-runtime")]
            Self::Tokio => {
                let socket_addrs = tokio::net::lookup_host(format!("{}", address)).await?;
                Ok(socket_addrs)
            }

            #[cfg(feature = "async-std-runtime")]
            Self::AsyncStd => {
                let host = (address.hostname.as_str(), address.port.unwrap_or(27017));
                let socket_addrs = async_std::net::ToSocketAddrs::to_socket_addrs(&host).await?;
                Ok(socket_addrs)
            }
        }
    }
}

/// Represents the context in which a given runtime method is being called from.
#[cfg(feature = "tokio-runtime")]
enum TokioCallingContext {
    /// From a syncronous setting (i.e. not from a runtime thread).
    Sync,

    /// From an asyncronous setting (i.e. from an async block or function being run on a runtime).
    /// Includes a handle to the current runtime.
    Async(tokio::runtime::Handle),
}

#[cfg(feature = "tokio-runtime")]
impl TokioCallingContext {
    /// Get the current calling context.
    fn current() -> Self {
        match tokio::runtime::Handle::try_current() {
            Ok(handle) => TokioCallingContext::Async(handle),
            Err(_) => TokioCallingContext::Sync,
        }
    }
}