awaitility/
least.rs

1use super::backend::Backend;
2use std::{future::Future, time::{Duration, Instant}};
3
4pub struct LeastWait<'a> {
5    duration: Duration,
6    backend: Backend<'a>,
7}
8
9pub fn at_least(duration: Duration) -> LeastWait<'static> {
10    LeastWait {
11        duration,
12        backend: Backend::default(),
13    }
14}
15
16pub fn at_least_backend<'a>(duration: Duration, backend: Backend<'a>) -> LeastWait<'a> {
17    LeastWait {
18        duration,
19        backend,
20    }
21}
22
23impl<'a> LeastWait<'a> {
24    pub fn poll_interval(&mut self, interval: Duration) -> &mut Self {
25        self.backend.set_interval(interval);
26        self
27    }
28
29    pub fn describe<'b: 'a>(&mut self, desc: &'b str) -> &mut Self {
30        self.backend.set_description(desc);
31        self
32    }
33
34    pub fn always(&mut self, f: impl Fn() -> bool) -> &Self {
35        let now = Instant::now();
36        loop {
37            let elapsed = now.elapsed();
38            if elapsed > self.duration {
39                break;
40            }
41            if !f() {
42                let desc = format!("Condition failed before duration {:?} elapsed.", elapsed);
43                self.backend.fail(&desc);
44                break;
45            }
46            std::thread::sleep(self.backend.interval);
47        }
48        self
49    }
50
51    pub async fn always_async<Fut>(&mut self, f: impl Fn() -> Fut) -> &Self where Fut: Future<Output = bool> {
52        let now = Instant::now();
53        loop {
54            let elapsed = now.elapsed();
55            if elapsed > self.duration {
56                break;
57            }
58            if !f().await {
59                let desc = format!("Condition failed before duration {:?} elapsed.", elapsed);
60                self.backend.fail(&desc);
61                break;
62            }
63            std::thread::sleep(self.backend.interval);
64        }
65        self
66    }
67
68    pub fn once(&mut self, f: impl Fn() -> bool) -> &Self {
69        let now = Instant::now();
70        loop {
71            let elapsed = now.elapsed();
72            if elapsed > self.duration {
73                let desc = format!("Condition failed before duration {:?} elapsed.", elapsed);
74                self.backend.fail(&desc);
75                break;
76            }
77            if f() {
78                break;
79            }
80            std::thread::sleep(self.backend.interval);
81        }
82        self
83    }
84
85    pub async fn once_async<Fut>(&mut self, f: impl Fn() -> Fut) -> &Self where Fut: Future<Output = bool> {
86        let now = Instant::now();
87        loop {
88            let elapsed = now.elapsed();
89            if elapsed > self.duration {
90                let desc = format!("Condition failed before duration {:?} elapsed.", elapsed);
91                self.backend.fail(&desc);
92                break;
93            }
94            if f().await {
95                break;
96            }
97            std::thread::sleep(self.backend.interval);
98        }
99        self
100    }
101}
102
103#[cfg(test)]
104mod least_test {
105    use std::sync::atomic::{AtomicUsize, Ordering};
106    use std::sync::Arc;
107    use std::time::Duration;
108
109    #[test]
110    fn at_least_test() {
111        let counter = Arc::new(AtomicUsize::new(5));
112        let tcounter = counter.clone();
113        std::thread::spawn(move || {
114            std::thread::sleep(Duration::from_millis(150));
115            while tcounter.load(Ordering::SeqCst) < 15 {
116                tcounter.fetch_add(1, Ordering::SeqCst);
117            }
118        });
119        super::at_least(Duration::from_millis(100)).always(|| counter.load(Ordering::SeqCst) < 10);
120    }
121
122    #[test]
123    #[should_panic]
124    fn at_least_panic() {
125        let counter = Arc::new(AtomicUsize::new(5));
126        let tcounter = counter.clone();
127        std::thread::spawn(move || {
128            while tcounter.load(Ordering::SeqCst) < 15 {
129                tcounter.fetch_add(1, Ordering::SeqCst);
130            }
131        });
132        super::at_least(Duration::from_millis(100)).always(|| counter.load(Ordering::SeqCst) < 10);
133    }
134
135    #[tokio::test]
136    async fn at_least_async_fn() {
137        let counter = Arc::new(AtomicUsize::new(5));
138        let tcounter = counter.clone();
139        std::thread::spawn(move || {
140            std::thread::sleep(Duration::from_millis(150));
141            while tcounter.load(Ordering::SeqCst) < 15 {
142                tcounter.fetch_add(1, Ordering::SeqCst);
143            }
144        });
145        super::at_least(Duration::from_millis(100)).always_async(|| async {
146            counter.load(Ordering::SeqCst) < 10
147        }).await;
148    }
149
150    #[test]
151    fn once_test() {
152        let counter = Arc::new(AtomicUsize::new(5));
153        let tcounter = counter.clone();
154        std::thread::spawn(move || {
155            while tcounter.load(Ordering::SeqCst) < 15 {
156                tcounter.fetch_add(1, Ordering::SeqCst);
157            }
158        });
159        super::at_least(Duration::from_millis(100)).once(|| counter.load(Ordering::SeqCst) < 10);
160    }
161
162    #[tokio::test]
163    async fn once_async_fn() {
164        let counter = Arc::new(AtomicUsize::new(5));
165        let tcounter = counter.clone();
166        std::thread::spawn(move || {
167            while tcounter.load(Ordering::SeqCst) < 15 {
168                tcounter.fetch_add(1, Ordering::SeqCst);
169            }
170        });
171        super::at_least(Duration::from_millis(100)).once_async(|| async {
172            counter.load(Ordering::SeqCst) < 10
173        }).await;    
174    }
175
176    #[tokio::test]
177    #[should_panic]
178    async fn once_async_panic() {
179        let counter = Arc::new(AtomicUsize::new(5));
180        let tcounter = counter.clone();
181        std::thread::spawn(move || {
182            while tcounter.load(Ordering::SeqCst) < 15 {
183                tcounter.fetch_add(1, Ordering::SeqCst);
184            }
185        });
186        super::at_least(Duration::from_millis(100)).once_async(|| async {
187            counter.load(Ordering::SeqCst) < 5
188        }).await;    
189    }
190}