1use std::sync::{
13 Condvar,
14 Mutex,
15 MutexGuard,
16};
17use std::time::Duration;
18
19#[cfg(feature = "async")]
20use tokio::sync::{
21 Notify,
22 watch,
23};
24
25#[cfg(feature = "async")]
26use super::{
27 AsyncConditionWaiter,
28 AsyncMonitorFuture,
29 AsyncNotificationWaiter,
30 AsyncTimeoutConditionWaiter,
31 AsyncTimeoutNotificationWaiter,
32};
33use super::{
34 ConditionWaiter,
35 NotificationWaiter,
36 Notifier,
37 TimeoutConditionWaiter,
38 TimeoutNotificationWaiter,
39 WaitTimeoutResult,
40 WaitTimeoutStatus,
41};
42
43pub struct MockMonitor<T> {
49 state: Mutex<MockMonitorState<T>>,
51 changed: Condvar,
53 #[cfg(feature = "async")]
55 async_notification: Notify,
56 #[cfg(feature = "async")]
58 async_change_sender: watch::Sender<u64>,
59}
60
61struct MockMonitorState<T> {
63 value: T,
65 elapsed: Duration,
67 notification_epoch: u64,
69 change_epoch: u64,
71}
72
73impl<T> MockMonitor<T> {
74 pub fn new(state: T) -> Self {
84 #[cfg(feature = "async")]
85 let (async_change_sender, _) = watch::channel(0);
86 Self {
87 state: Mutex::new(MockMonitorState {
88 value: state,
89 elapsed: Duration::ZERO,
90 notification_epoch: 0,
91 change_epoch: 0,
92 }),
93 changed: Condvar::new(),
94 #[cfg(feature = "async")]
95 async_notification: Notify::new(),
96 #[cfg(feature = "async")]
97 async_change_sender,
98 }
99 }
100
101 pub fn elapsed(&self) -> Duration {
107 self.lock_state().elapsed
108 }
109
110 pub fn set_elapsed(&self, elapsed: Duration) {
118 let change_epoch = {
119 let mut state = self.lock_state();
120 state.elapsed = elapsed;
121 Self::advance_change_epoch(&mut state)
122 };
123 self.changed.notify_all();
124 self.notify_async_change(change_epoch);
125 }
126
127 pub fn advance(&self, duration: Duration) {
133 let change_epoch = {
134 let mut state = self.lock_state();
135 state.elapsed = state.elapsed.saturating_add(duration);
136 Self::advance_change_epoch(&mut state)
137 };
138 self.changed.notify_all();
139 self.notify_async_change(change_epoch);
140 }
141
142 pub fn reset_elapsed(&self) {
144 self.set_elapsed(Duration::ZERO);
145 }
146
147 pub fn read<R, F>(&self, f: F) -> R
157 where
158 F: FnOnce(&T) -> R,
159 {
160 let state = self.lock_state();
161 f(&state.value)
162 }
163
164 pub fn write<R, F>(&self, f: F) -> R
176 where
177 F: FnOnce(&mut T) -> R,
178 {
179 let mut state = self.lock_state();
180 f(&mut state.value)
181 }
182
183 pub fn write_notify_one<R, F>(&self, f: F) -> R
193 where
194 F: FnOnce(&mut T) -> R,
195 {
196 let result = self.write(f);
197 self.notify_one();
198 result
199 }
200
201 pub fn write_notify_all<R, F>(&self, f: F) -> R
211 where
212 F: FnOnce(&mut T) -> R,
213 {
214 let result = self.write(f);
215 self.notify_all();
216 result
217 }
218
219 pub fn notify_one(&self) {
221 let change_epoch = self.advance_notification_epoch();
222 self.changed.notify_one();
223 #[cfg(feature = "async")]
224 self.async_notification.notify_one();
225 self.notify_async_change(change_epoch);
226 }
227
228 pub fn notify_all(&self) {
230 let change_epoch = self.advance_notification_epoch();
231 self.changed.notify_all();
232 #[cfg(feature = "async")]
233 self.async_notification.notify_waiters();
234 self.notify_async_change(change_epoch);
235 }
236
237 fn lock_state(&self) -> MutexGuard<'_, MockMonitorState<T>> {
243 self.state
244 .lock()
245 .unwrap_or_else(std::sync::PoisonError::into_inner)
246 }
247
248 fn advance_change_epoch(state: &mut MockMonitorState<T>) -> u64 {
258 state.change_epoch = state.change_epoch.wrapping_add(1);
259 state.change_epoch
260 }
261
262 fn advance_notification_epoch(&self) -> u64 {
268 let mut state = self.lock_state();
269 state.notification_epoch = state.notification_epoch.wrapping_add(1);
270 Self::advance_change_epoch(&mut state)
271 }
272
273 #[cfg(feature = "async")]
279 fn notify_async_change(&self, change_epoch: u64) {
280 let _ = self.async_change_sender.send(change_epoch);
281 }
282
283 #[cfg(not(feature = "async"))]
285 fn notify_async_change(&self, _change_epoch: u64) {}
286}
287
288impl<T> Notifier for MockMonitor<T> {
289 fn notify_one(&self) {
291 Self::notify_one(self);
292 }
293
294 fn notify_all(&self) {
296 Self::notify_all(self);
297 }
298}
299
300impl<T> NotificationWaiter for MockMonitor<T> {
301 fn wait(&self) {
303 let mut state = self.lock_state();
304 let observed_epoch = state.notification_epoch;
305 while state.notification_epoch == observed_epoch {
306 state = self
307 .changed
308 .wait(state)
309 .unwrap_or_else(std::sync::PoisonError::into_inner);
310 }
311 }
312}
313
314impl<T> TimeoutNotificationWaiter for MockMonitor<T> {
315 fn wait_for(&self, timeout: Duration) -> WaitTimeoutStatus {
317 let mut state = self.lock_state();
318 let observed_epoch = state.notification_epoch;
319 let target_elapsed = state.elapsed.saturating_add(timeout);
320 loop {
321 if state.notification_epoch != observed_epoch {
322 return WaitTimeoutStatus::Woken;
323 }
324 if state.elapsed >= target_elapsed {
325 return WaitTimeoutStatus::TimedOut;
326 }
327 state = self
328 .changed
329 .wait(state)
330 .unwrap_or_else(std::sync::PoisonError::into_inner);
331 }
332 }
333}
334
335impl<T> ConditionWaiter for MockMonitor<T> {
336 type State = T;
337
338 fn wait_until<R, P, F>(&self, mut predicate: P, action: F) -> R
340 where
341 P: FnMut(&Self::State) -> bool,
342 F: FnOnce(&mut Self::State) -> R,
343 {
344 self.wait_while(|state| !predicate(state), action)
345 }
346
347 fn wait_while<R, P, F>(&self, mut predicate: P, action: F) -> R
349 where
350 P: FnMut(&Self::State) -> bool,
351 F: FnOnce(&mut Self::State) -> R,
352 {
353 let mut state = self.lock_state();
354 while predicate(&state.value) {
355 state = self
356 .changed
357 .wait(state)
358 .unwrap_or_else(std::sync::PoisonError::into_inner);
359 }
360 action(&mut state.value)
361 }
362}
363
364impl<T> TimeoutConditionWaiter for MockMonitor<T> {
365 fn wait_until_for<R, P, F>(
367 &self,
368 timeout: Duration,
369 mut predicate: P,
370 action: F,
371 ) -> WaitTimeoutResult<R>
372 where
373 P: FnMut(&Self::State) -> bool,
374 F: FnOnce(&mut Self::State) -> R,
375 {
376 self.wait_while_for(timeout, |state| !predicate(state), action)
377 }
378
379 fn wait_while_for<R, P, F>(
381 &self,
382 timeout: Duration,
383 mut predicate: P,
384 action: F,
385 ) -> WaitTimeoutResult<R>
386 where
387 P: FnMut(&Self::State) -> bool,
388 F: FnOnce(&mut Self::State) -> R,
389 {
390 let mut state = self.lock_state();
391 let target_elapsed = state.elapsed.saturating_add(timeout);
392 loop {
393 if !predicate(&state.value) {
394 return WaitTimeoutResult::Ready(action(&mut state.value));
395 }
396 if state.elapsed >= target_elapsed {
397 return WaitTimeoutResult::TimedOut;
398 }
399 state = self
400 .changed
401 .wait(state)
402 .unwrap_or_else(std::sync::PoisonError::into_inner);
403 }
404 }
405}
406
407#[cfg(feature = "async")]
408impl<T: Send> AsyncNotificationWaiter for MockMonitor<T> {
409 fn async_wait<'a>(&'a self) -> AsyncMonitorFuture<'a, ()> {
411 let notified = self.async_notification.notified();
412 Box::pin(notified)
413 }
414}
415
416#[cfg(feature = "async")]
417impl<T: Send> AsyncTimeoutNotificationWaiter for MockMonitor<T> {
418 fn async_wait_for<'a>(
420 &'a self,
421 timeout: Duration,
422 ) -> AsyncMonitorFuture<'a, WaitTimeoutStatus> {
423 let target_elapsed = self.elapsed().saturating_add(timeout);
424 let mut change_receiver = self.async_change_sender.subscribe();
425 Box::pin(async move {
426 loop {
427 if self.elapsed() >= target_elapsed {
428 return WaitTimeoutStatus::TimedOut;
429 }
430 let notified = self.async_notification.notified();
431 tokio::select! {
432 () = notified => return WaitTimeoutStatus::Woken,
433 changed = change_receiver.changed() => {
434 changed.expect("mock monitor sender should live while the monitor is borrowed");
435 }
436 }
437 }
438 })
439 }
440}
441
442#[cfg(feature = "async")]
443impl<T: Send> AsyncConditionWaiter for MockMonitor<T> {
444 type State = T;
445
446 fn async_wait_until<'a, R, P, F>(
448 &'a self,
449 mut predicate: P,
450 action: F,
451 ) -> AsyncMonitorFuture<'a, R>
452 where
453 R: Send + 'a,
454 P: FnMut(&Self::State) -> bool + Send + 'a,
455 F: FnOnce(&mut Self::State) -> R + Send + 'a,
456 {
457 self.async_wait_while(move |state| !predicate(state), action)
458 }
459
460 fn async_wait_while<'a, R, P, F>(
462 &'a self,
463 mut predicate: P,
464 action: F,
465 ) -> AsyncMonitorFuture<'a, R>
466 where
467 R: Send + 'a,
468 P: FnMut(&Self::State) -> bool + Send + 'a,
469 F: FnOnce(&mut Self::State) -> R + Send + 'a,
470 {
471 Box::pin(async move {
472 loop {
473 let notified = {
474 let mut state = self.lock_state();
475 if !predicate(&state.value) {
476 return action(&mut state.value);
477 }
478 self.async_notification.notified()
479 };
480 notified.await;
481 }
482 })
483 }
484}
485
486#[cfg(feature = "async")]
487impl<T: Send> AsyncTimeoutConditionWaiter for MockMonitor<T> {
488 fn async_wait_until_for<'a, R, P, F>(
490 &'a self,
491 timeout: Duration,
492 mut predicate: P,
493 action: F,
494 ) -> AsyncMonitorFuture<'a, WaitTimeoutResult<R>>
495 where
496 R: Send + 'a,
497 P: FnMut(&Self::State) -> bool + Send + 'a,
498 F: FnOnce(&mut Self::State) -> R + Send + 'a,
499 {
500 self.async_wait_while_for(timeout, move |state| !predicate(state), action)
501 }
502
503 fn async_wait_while_for<'a, R, P, F>(
505 &'a self,
506 timeout: Duration,
507 mut predicate: P,
508 action: F,
509 ) -> AsyncMonitorFuture<'a, WaitTimeoutResult<R>>
510 where
511 R: Send + 'a,
512 P: FnMut(&Self::State) -> bool + Send + 'a,
513 F: FnOnce(&mut Self::State) -> R + Send + 'a,
514 {
515 let target_elapsed = self.elapsed().saturating_add(timeout);
516 let mut change_receiver = self.async_change_sender.subscribe();
517 Box::pin(async move {
518 loop {
519 {
520 let mut state = self.lock_state();
521 if !predicate(&state.value) {
522 return WaitTimeoutResult::Ready(action(&mut state.value));
523 }
524 if state.elapsed >= target_elapsed {
525 return WaitTimeoutResult::TimedOut;
526 }
527 }
528 change_receiver
529 .changed()
530 .await
531 .expect("mock monitor sender should live while the monitor is borrowed");
532 }
533 })
534 }
535}
536
537impl<T> From<T> for MockMonitor<T> {
538 fn from(value: T) -> Self {
540 Self::new(value)
541 }
542}
543
544impl<T: Default> Default for MockMonitor<T> {
545 fn default() -> Self {
547 Self::new(T::default())
548 }
549}