ferroid/runtime/
tokio.rs

1use crate::{Result, SleepProvider, Snowflake, SnowflakeGenerator, TimeSource};
2
3/// Extension trait for asynchronously generating Snowflake IDs using the
4/// [`tokio`](https://docs.rs/tokio) async runtime.
5///
6/// This trait provides a convenience method for using a [`SleepProvider`]
7/// backed by the `tokio` runtime, allowing you to call `.try_next_id_async()`
8/// without specifying the sleep strategy manually.
9pub trait SnowflakeGeneratorAsyncTokioExt<ID, T>
10where
11    ID: Snowflake,
12    T: TimeSource<ID::Ty>,
13{
14    /// Returns a future that resolves to the next available Snowflake ID using
15    /// the [`TokioSleep`] provider.
16    ///
17    /// Internally delegates to
18    /// [`SnowflakeGeneratorAsyncExt::try_next_id_async`] method with
19    /// [`TokioSleep`] as the sleep strategy.
20    ///
21    /// # Errors
22    ///
23    /// This future may return an error if the underlying generator does.
24    ///
25    /// [`SnowflakeGeneratorAsyncExt::try_next_id_async`]:
26    ///     crate::SnowflakeGeneratorAsyncExt::try_next_id_async
27    fn try_next_id_async(&self) -> impl Future<Output = Result<ID>>;
28}
29
30impl<G, ID, T> SnowflakeGeneratorAsyncTokioExt<ID, T> for G
31where
32    G: SnowflakeGenerator<ID, T>,
33    ID: Snowflake,
34    T: TimeSource<ID::Ty>,
35{
36    fn try_next_id_async(&self) -> impl Future<Output = Result<ID>> {
37        <Self as crate::SnowflakeGeneratorAsyncExt<ID, T>>::try_next_id_async::<TokioSleep>(self)
38    }
39}
40
41/// An implementation of [`SleepProvider`] using Tokio's timer.
42///
43/// This is the default provider for use in async applications built on Tokio.
44pub struct TokioSleep;
45impl SleepProvider for TokioSleep {
46    type Sleep = tokio::time::Sleep;
47
48    fn sleep_for(dur: tokio::time::Duration) -> Self::Sleep {
49        tokio::time::sleep(dur)
50    }
51}
52
53#[cfg(test)]
54mod tests {
55    use super::*;
56    use crate::{
57        AtomicSnowflakeGenerator, LockSnowflakeGenerator, MonotonicClock, Result, Snowflake,
58        SnowflakeGenerator, SnowflakeTwitterId, TimeSource,
59    };
60    use core::fmt;
61    use futures::future::try_join_all;
62    use std::collections::HashSet;
63
64    const TOTAL_IDS: usize = 4096;
65    const NUM_GENERATORS: u64 = 32;
66    const IDS_PER_GENERATOR: usize = TOTAL_IDS * 32; // Enough to simulate at least 32 Pending cycles
67
68    #[tokio::test(flavor = "multi_thread", worker_threads = 8)]
69    async fn generates_many_unique_ids_lock() -> Result<()> {
70        test_many_unique_ids::<_, SnowflakeTwitterId, MonotonicClock>(
71            LockSnowflakeGenerator::new,
72            MonotonicClock::default,
73        )
74        .await
75    }
76
77    #[tokio::test(flavor = "multi_thread", worker_threads = 8)]
78    async fn generates_many_unique_ids_atomic() -> Result<()> {
79        test_many_unique_ids::<_, SnowflakeTwitterId, MonotonicClock>(
80            AtomicSnowflakeGenerator::new,
81            MonotonicClock::default,
82        )
83        .await
84    }
85
86    async fn test_many_unique_ids<G, ID, T>(
87        generator_fn: impl Fn(u64, T) -> G,
88        clock_factory: impl Fn() -> T,
89    ) -> Result<()>
90    where
91        G: SnowflakeGenerator<ID, T> + Send + Sync + 'static,
92        ID: Snowflake + fmt::Debug + Send + 'static,
93        T: TimeSource<ID::Ty> + Clone + Send,
94    {
95        let clock = clock_factory();
96        let generators: Vec<_> = (0..NUM_GENERATORS)
97            .map(|machine_id| generator_fn(machine_id, clock.clone()))
98            .collect();
99
100        // Spawn one future per generator, each producing N IDs
101        let tasks: Vec<tokio::task::JoinHandle<Result<_>>> = generators
102            .into_iter()
103            .map(|g| {
104                tokio::spawn(async move {
105                    let mut ids = Vec::with_capacity(IDS_PER_GENERATOR);
106                    for _ in 0..IDS_PER_GENERATOR {
107                        let id = g.try_next_id_async().await?;
108                        ids.push(id);
109                    }
110                    Ok(ids)
111                })
112            })
113            .collect();
114
115        let all_ids: Vec<_> = try_join_all(tasks)
116            .await?
117            .into_iter()
118            .flat_map(Result::unwrap)
119            .collect();
120
121        let expected_total = NUM_GENERATORS as usize * IDS_PER_GENERATOR;
122        assert_eq!(
123            all_ids.len(),
124            expected_total,
125            "Expected {} IDs but got {}",
126            expected_total,
127            all_ids.len()
128        );
129
130        let mut seen = HashSet::with_capacity(all_ids.len());
131        for id in &all_ids {
132            assert!(seen.insert(id), "Duplicate ID found: {:?}", id);
133        }
134
135        Ok(())
136    }
137}