use std::{
sync::{
Arc, Mutex,
atomic::{AtomicU8, Ordering},
},
time::Duration,
};
pub type AuthResultSender = tokio::sync::oneshot::Sender<Result<(), String>>;
pub type AuthResultReceiver = tokio::sync::oneshot::Receiver<Result<(), String>>;
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
#[repr(u8)]
pub enum AuthState {
#[default]
Unauthenticated = 0,
Authenticated = 1,
Failed = 2,
}
impl AuthState {
#[inline]
#[must_use]
#[expect(
clippy::match_same_arms,
reason = "explicit variant listing is clearer than collapsing 0 with wildcard"
)]
fn from_u8(value: u8) -> Self {
match value {
0 => Self::Unauthenticated,
1 => Self::Authenticated,
2 => Self::Failed,
_ => Self::Unauthenticated,
}
}
#[inline]
#[must_use]
const fn as_u8(self) -> u8 {
self as u8
}
}
#[derive(Clone, Debug)]
pub struct AuthTracker {
tx: Arc<Mutex<Option<AuthResultSender>>>,
state: Arc<AtomicU8>,
state_notify: Arc<tokio::sync::Notify>,
}
impl AuthTracker {
#[must_use]
pub fn new() -> Self {
Self {
tx: Arc::new(Mutex::new(None)),
state: Arc::new(AtomicU8::new(AuthState::Unauthenticated.as_u8())),
state_notify: Arc::new(tokio::sync::Notify::new()),
}
}
#[must_use]
pub fn auth_state(&self) -> AuthState {
AuthState::from_u8(self.state.load(Ordering::Acquire))
}
#[must_use]
pub fn is_authenticated(&self) -> bool {
self.auth_state() == AuthState::Authenticated
}
pub fn invalidate(&self) {
self.state
.store(AuthState::Unauthenticated.as_u8(), Ordering::Release);
self.state_notify.notify_waiters();
}
#[allow(
clippy::must_use_candidate,
reason = "callers use this for side effects"
)]
pub fn begin(&self) -> AuthResultReceiver {
let (sender, receiver) = tokio::sync::oneshot::channel();
self.state
.store(AuthState::Unauthenticated.as_u8(), Ordering::Release);
if let Ok(mut guard) = self.tx.lock() {
if let Some(old) = guard.take() {
log::warn!("New authentication request superseding previous pending request");
let _ = old.send(Err("Authentication attempt superseded".to_string()));
} else {
log::debug!("Starting new authentication request");
}
*guard = Some(sender);
}
receiver
}
pub fn succeed(&self) {
self.state
.store(AuthState::Authenticated.as_u8(), Ordering::Release);
self.state_notify.notify_waiters();
if let Ok(mut guard) = self.tx.lock()
&& let Some(sender) = guard.take()
{
let _ = sender.send(Ok(()));
}
}
pub fn fail(&self, error: impl Into<String>) {
self.state
.store(AuthState::Failed.as_u8(), Ordering::Release);
self.state_notify.notify_waiters();
let message = error.into();
if let Ok(mut guard) = self.tx.lock()
&& let Some(sender) = guard.take()
{
let _ = sender.send(Err(message));
}
}
pub async fn wait_for_result<E>(
&self,
timeout: Duration,
receiver: AuthResultReceiver,
) -> Result<(), E>
where
E: From<String>,
{
match tokio::time::timeout(timeout, receiver).await {
Ok(Ok(Ok(()))) => Ok(()),
Ok(Ok(Err(msg))) => Err(E::from(msg)),
Ok(Err(_)) => Err(E::from("Authentication channel closed".to_string())),
Err(_) => {
Err(E::from("Authentication timed out".to_string()))
}
}
}
pub async fn wait_for_authenticated(&self, timeout: Duration) -> bool {
if self.is_authenticated() {
return true;
}
tokio::time::timeout(timeout, async {
loop {
let notified = self.state_notify.notified();
match self.auth_state() {
AuthState::Authenticated => return true,
AuthState::Failed => return false,
AuthState::Unauthenticated => notified.await,
}
}
})
.await
.unwrap_or(false)
}
}
impl Default for AuthTracker {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use std::{
sync::atomic::{AtomicBool, Ordering},
time::Duration,
};
use rstest::rstest;
use super::*;
#[derive(Debug, PartialEq)]
struct TestError(String);
impl From<String> for TestError {
fn from(msg: String) -> Self {
Self(msg)
}
}
#[rstest]
#[tokio::test]
async fn test_successful_authentication() {
let tracker = AuthTracker::new();
let rx = tracker.begin();
tracker.succeed();
let result: Result<(), TestError> =
tracker.wait_for_result(Duration::from_secs(1), rx).await;
assert!(result.is_ok());
}
#[rstest]
#[tokio::test]
async fn test_failed_authentication() {
let tracker = AuthTracker::new();
let rx = tracker.begin();
tracker.fail("Invalid credentials");
let result: Result<(), TestError> =
tracker.wait_for_result(Duration::from_secs(1), rx).await;
assert_eq!(
result.unwrap_err(),
TestError("Invalid credentials".to_string())
);
}
#[rstest]
#[tokio::test]
async fn test_authentication_timeout() {
let tracker = AuthTracker::new();
let rx = tracker.begin();
let result: Result<(), TestError> =
tracker.wait_for_result(Duration::from_millis(50), rx).await;
assert_eq!(
result.unwrap_err(),
TestError("Authentication timed out".to_string())
);
}
#[rstest]
#[tokio::test]
async fn test_begin_supersedes_previous_sender() {
let tracker = AuthTracker::new();
let first = tracker.begin();
let second = tracker.begin();
let result = first.await.expect("oneshot closed unexpectedly");
assert_eq!(result, Err("Authentication attempt superseded".to_string()));
tracker.succeed();
let result: Result<(), TestError> = tracker
.wait_for_result(Duration::from_secs(1), second)
.await;
assert!(result.is_ok());
}
#[rstest]
#[tokio::test]
async fn test_succeed_without_pending_auth() {
let tracker = AuthTracker::new();
tracker.succeed();
}
#[rstest]
#[tokio::test]
async fn test_fail_without_pending_auth() {
let tracker = AuthTracker::new();
tracker.fail("Some error");
}
#[rstest]
#[tokio::test]
async fn test_multiple_sequential_authentications() {
let tracker = AuthTracker::new();
let rx1 = tracker.begin();
tracker.succeed();
let result1: Result<(), TestError> =
tracker.wait_for_result(Duration::from_secs(1), rx1).await;
assert!(result1.is_ok());
let rx2 = tracker.begin();
tracker.fail("Credentials expired");
let result2: Result<(), TestError> =
tracker.wait_for_result(Duration::from_secs(1), rx2).await;
assert_eq!(
result2.unwrap_err(),
TestError("Credentials expired".to_string())
);
let rx3 = tracker.begin();
tracker.succeed();
let result3: Result<(), TestError> =
tracker.wait_for_result(Duration::from_secs(1), rx3).await;
assert!(result3.is_ok());
}
#[rstest]
#[tokio::test]
async fn test_channel_closed_before_result() {
let tracker = AuthTracker::new();
let rx = tracker.begin();
tracker.begin();
let result: Result<(), TestError> =
tracker.wait_for_result(Duration::from_secs(1), rx).await;
assert_eq!(
result.unwrap_err(),
TestError("Authentication attempt superseded".to_string())
);
}
#[rstest]
#[tokio::test]
async fn test_concurrent_auth_attempts() {
let tracker = Arc::new(AuthTracker::new());
let mut handles = vec![];
for i in 0..10 {
let tracker_clone = Arc::clone(&tracker);
let handle = tokio::spawn(async move {
let rx = tracker_clone.begin();
if i == 9 {
tokio::time::sleep(Duration::from_millis(10)).await;
tracker_clone.succeed();
}
let result: Result<(), TestError> = tracker_clone
.wait_for_result(Duration::from_secs(1), rx)
.await;
(i, result)
});
handles.push(handle);
}
let mut successes = 0;
let mut superseded = 0;
for handle in handles {
let (i, result) = handle.await.unwrap();
match result {
Ok(()) => {
assert_eq!(i, 9);
successes += 1;
}
Err(TestError(msg)) if msg.contains("superseded") => {
superseded += 1;
}
Err(e) => panic!("Unexpected error: {e:?}"),
}
}
assert_eq!(successes, 1);
assert_eq!(superseded, 9);
}
#[rstest]
fn test_default_trait() {
let _tracker = AuthTracker::default();
}
#[rstest]
#[tokio::test]
async fn test_clone_trait() {
let tracker = AuthTracker::new();
let cloned = tracker.clone();
let rx = tracker.begin();
cloned.succeed(); let result: Result<(), TestError> =
tracker.wait_for_result(Duration::from_secs(1), rx).await;
assert!(result.is_ok());
}
#[rstest]
fn test_debug_trait() {
let tracker = AuthTracker::new();
let debug_str = format!("{tracker:?}");
assert!(debug_str.contains("AuthTracker"));
}
#[rstest]
#[tokio::test]
async fn test_timeout_clears_sender() {
let tracker = AuthTracker::new();
let rx1 = tracker.begin();
let result1: Result<(), TestError> = tracker
.wait_for_result(Duration::from_millis(50), rx1)
.await;
assert_eq!(
result1.unwrap_err(),
TestError("Authentication timed out".to_string())
);
let rx2 = tracker.begin();
tracker.succeed();
let result2: Result<(), TestError> =
tracker.wait_for_result(Duration::from_secs(1), rx2).await;
assert!(result2.is_ok());
}
#[rstest]
#[tokio::test]
async fn test_fail_clears_sender() {
let tracker = AuthTracker::new();
let rx1 = tracker.begin();
tracker.fail("Bad credentials");
let result1: Result<(), TestError> =
tracker.wait_for_result(Duration::from_secs(1), rx1).await;
assert!(result1.is_err());
let rx2 = tracker.begin();
tracker.succeed();
let result2: Result<(), TestError> =
tracker.wait_for_result(Duration::from_secs(1), rx2).await;
assert!(result2.is_ok());
}
#[rstest]
#[tokio::test]
async fn test_succeed_clears_sender() {
let tracker = AuthTracker::new();
let rx1 = tracker.begin();
tracker.succeed();
let result1: Result<(), TestError> =
tracker.wait_for_result(Duration::from_secs(1), rx1).await;
assert!(result1.is_ok());
let rx2 = tracker.begin();
tracker.succeed();
let result2: Result<(), TestError> =
tracker.wait_for_result(Duration::from_secs(1), rx2).await;
assert!(result2.is_ok());
}
#[rstest]
#[tokio::test]
async fn test_rapid_begin_succeed_cycles() {
let tracker = AuthTracker::new();
for _ in 0..100 {
let rx = tracker.begin();
tracker.succeed();
let result: Result<(), TestError> =
tracker.wait_for_result(Duration::from_secs(1), rx).await;
assert!(result.is_ok());
}
}
#[rstest]
#[tokio::test]
async fn test_double_succeed_is_safe() {
let tracker = AuthTracker::new();
let rx = tracker.begin();
tracker.succeed();
tracker.succeed();
let result: Result<(), TestError> =
tracker.wait_for_result(Duration::from_secs(1), rx).await;
assert!(result.is_ok());
}
#[rstest]
#[tokio::test]
async fn test_double_fail_is_safe() {
let tracker = AuthTracker::new();
let rx = tracker.begin();
tracker.fail("Error 1");
tracker.fail("Error 2");
let result: Result<(), TestError> =
tracker.wait_for_result(Duration::from_secs(1), rx).await;
assert_eq!(
result.unwrap_err(),
TestError("Error 1".to_string()) );
}
#[rstest]
#[tokio::test]
async fn test_succeed_after_fail_is_ignored() {
let tracker = AuthTracker::new();
let rx = tracker.begin();
tracker.fail("Auth failed");
tracker.succeed();
let result: Result<(), TestError> =
tracker.wait_for_result(Duration::from_secs(1), rx).await;
assert!(result.is_err()); }
#[rstest]
#[tokio::test]
async fn test_fail_after_succeed_is_ignored() {
let tracker = AuthTracker::new();
let rx = tracker.begin();
tracker.succeed();
tracker.fail("Auth failed");
let result: Result<(), TestError> =
tracker.wait_for_result(Duration::from_secs(1), rx).await;
assert!(result.is_ok()); }
#[rstest]
#[tokio::test]
async fn test_reconnect_flow_waits_for_auth() {
let tracker = Arc::new(AuthTracker::new());
let subscribed = Arc::new(tokio::sync::Notify::new());
let auth_completed = Arc::new(tokio::sync::Notify::new());
let tracker_reconnect = Arc::clone(&tracker);
let subscribed_reconnect = Arc::clone(&subscribed);
let auth_completed_reconnect = Arc::clone(&auth_completed);
let reconnect_task = tokio::spawn(async move {
let rx = tracker_reconnect.begin();
let tracker_resub = Arc::clone(&tracker_reconnect);
let subscribed_resub = Arc::clone(&subscribed_reconnect);
let auth_completed_resub = Arc::clone(&auth_completed_reconnect);
let resub_task = tokio::spawn(async move {
let result: Result<(), TestError> = tracker_resub
.wait_for_result(Duration::from_secs(5), rx)
.await;
if result.is_ok() {
auth_completed_resub.notify_one();
tokio::time::sleep(Duration::from_millis(10)).await;
subscribed_resub.notify_one();
}
});
resub_task.await.unwrap();
});
tokio::time::sleep(Duration::from_millis(100)).await;
tracker.succeed();
reconnect_task.await.unwrap();
tokio::select! {
() = auth_completed.notified() => {
}
() = tokio::time::sleep(Duration::from_secs(1)) => {
panic!("Auth never completed");
}
}
tokio::select! {
() = subscribed.notified() => {
}
() = tokio::time::sleep(Duration::from_secs(1)) => {
panic!("Subscription never completed");
}
}
}
#[rstest]
#[tokio::test]
async fn test_reconnect_flow_blocks_on_auth_failure() {
let tracker = Arc::new(AuthTracker::new());
let subscribed = Arc::new(AtomicBool::new(false));
let tracker_reconnect = Arc::clone(&tracker);
let subscribed_reconnect = Arc::clone(&subscribed);
let reconnect_task = tokio::spawn(async move {
let rx = tracker_reconnect.begin();
let tracker_resub = Arc::clone(&tracker_reconnect);
let subscribed_resub = Arc::clone(&subscribed_reconnect);
let resub_task = tokio::spawn(async move {
let result: Result<(), TestError> = tracker_resub
.wait_for_result(Duration::from_secs(5), rx)
.await;
if result.is_ok() {
subscribed_resub.store(true, Ordering::Relaxed);
}
});
resub_task.await.unwrap();
});
tokio::time::sleep(Duration::from_millis(50)).await;
tracker.fail("Invalid credentials");
reconnect_task.await.unwrap();
tokio::time::sleep(Duration::from_millis(100)).await;
assert!(!subscribed.load(Ordering::Relaxed));
}
#[rstest]
#[tokio::test]
async fn test_state_machine_transitions() {
let tracker = AuthTracker::new();
let rx1 = tracker.begin();
tracker.succeed();
let result1: Result<(), TestError> =
tracker.wait_for_result(Duration::from_secs(1), rx1).await;
assert!(result1.is_ok());
let rx2 = tracker.begin();
tracker.fail("Error");
let result2: Result<(), TestError> =
tracker.wait_for_result(Duration::from_secs(1), rx2).await;
assert!(result2.is_err());
let rx3 = tracker.begin();
let result3: Result<(), TestError> = tracker
.wait_for_result(Duration::from_millis(50), rx3)
.await;
assert_eq!(
result3.unwrap_err(),
TestError("Authentication timed out".to_string())
);
let rx4 = tracker.begin();
let rx5 = tracker.begin();
let result4: Result<(), TestError> =
tracker.wait_for_result(Duration::from_secs(1), rx4).await;
assert_eq!(
result4.unwrap_err(),
TestError("Authentication attempt superseded".to_string())
);
tracker.succeed();
let result5: Result<(), TestError> =
tracker.wait_for_result(Duration::from_secs(1), rx5).await;
assert!(result5.is_ok());
}
#[rstest]
#[tokio::test]
async fn test_no_sender_leaks() {
let tracker = AuthTracker::new();
for _ in 0..100 {
let rx = tracker.begin();
let _result: Result<(), TestError> =
tracker.wait_for_result(Duration::from_millis(1), rx).await;
}
let rx = tracker.begin();
tracker.succeed();
let result: Result<(), TestError> =
tracker.wait_for_result(Duration::from_secs(1), rx).await;
assert!(result.is_ok());
}
#[rstest]
#[tokio::test]
async fn test_concurrent_succeed_fail_calls() {
let tracker = Arc::new(AuthTracker::new());
let rx = tracker.begin();
let mut handles = vec![];
for _ in 0..50 {
let tracker_clone = Arc::clone(&tracker);
handles.push(tokio::spawn(async move {
tracker_clone.succeed();
}));
}
for _ in 0..50 {
let tracker_clone = Arc::clone(&tracker);
handles.push(tokio::spawn(async move {
tracker_clone.fail("Error");
}));
}
for handle in handles {
handle.await.unwrap();
}
let result: Result<(), TestError> =
tracker.wait_for_result(Duration::from_secs(1), rx).await;
let _ = result;
}
#[rstest]
fn test_is_authenticated_initial_state() {
let tracker = AuthTracker::new();
assert!(!tracker.is_authenticated());
}
#[rstest]
#[tokio::test]
async fn test_is_authenticated_after_succeed() {
let tracker = AuthTracker::new();
assert!(!tracker.is_authenticated());
let _rx = tracker.begin();
assert!(!tracker.is_authenticated());
tracker.succeed();
assert!(tracker.is_authenticated());
}
#[rstest]
#[tokio::test]
async fn test_is_authenticated_after_fail() {
let tracker = AuthTracker::new();
let _rx = tracker.begin();
tracker.fail("error");
assert!(!tracker.is_authenticated());
}
#[rstest]
#[tokio::test]
async fn test_invalidate_clears_auth_state() {
let tracker = AuthTracker::new();
let _rx = tracker.begin();
tracker.succeed();
assert!(tracker.is_authenticated());
tracker.invalidate();
assert!(!tracker.is_authenticated());
}
#[rstest]
#[tokio::test]
async fn test_begin_clears_auth_state() {
let tracker = AuthTracker::new();
let _rx1 = tracker.begin();
tracker.succeed();
assert!(tracker.is_authenticated());
let _rx2 = tracker.begin();
assert!(!tracker.is_authenticated());
}
#[rstest]
fn test_is_authenticated_shared_across_clones() {
let tracker = AuthTracker::new();
let cloned = tracker.clone();
let _rx = tracker.begin();
tracker.succeed();
assert!(cloned.is_authenticated());
}
#[rstest]
fn test_invalidate_shared_across_clones() {
let tracker = AuthTracker::new();
let cloned = tracker.clone();
let _rx = tracker.begin();
tracker.succeed();
assert!(tracker.is_authenticated());
cloned.invalidate();
assert!(!tracker.is_authenticated());
}
#[rstest]
fn test_succeed_without_begin_still_updates_auth_state() {
let tracker = AuthTracker::new();
assert!(!tracker.is_authenticated());
tracker.succeed();
assert!(tracker.is_authenticated());
}
#[rstest]
fn test_fail_without_begin_still_updates_auth_state() {
let tracker = AuthTracker::new();
tracker.succeed();
assert!(tracker.is_authenticated());
tracker.fail("error");
assert!(!tracker.is_authenticated());
}
#[rstest]
#[tokio::test]
async fn test_auth_state_false_after_timeout_until_late_response() {
let tracker = AuthTracker::new();
let rx = tracker.begin();
assert!(!tracker.is_authenticated());
let result: Result<(), TestError> =
tracker.wait_for_result(Duration::from_millis(10), rx).await;
assert!(result.is_err());
assert!(!tracker.is_authenticated());
tracker.succeed();
assert!(tracker.is_authenticated());
}
#[rstest]
#[tokio::test]
async fn test_wait_for_authenticated_already_authenticated() {
let tracker = AuthTracker::new();
let _rx = tracker.begin();
tracker.succeed();
assert!(
tracker
.wait_for_authenticated(Duration::from_millis(50))
.await
);
}
#[rstest]
#[tokio::test]
async fn test_wait_for_authenticated_succeeds_after_delay() {
let tracker = AuthTracker::new();
let _rx = tracker.begin();
let tracker_clone = tracker.clone();
tokio::spawn(async move {
tokio::time::sleep(Duration::from_millis(50)).await;
tracker_clone.succeed();
});
assert!(tracker.wait_for_authenticated(Duration::from_secs(1)).await);
}
#[rstest]
#[tokio::test]
async fn test_wait_for_authenticated_returns_false_on_failure() {
let tracker = AuthTracker::new();
let _rx = tracker.begin();
let tracker_clone = tracker.clone();
tokio::spawn(async move {
tokio::time::sleep(Duration::from_millis(50)).await;
tracker_clone.fail("rejected");
});
let start = tokio::time::Instant::now();
let result = tracker.wait_for_authenticated(Duration::from_secs(5)).await;
let elapsed = start.elapsed();
assert!(!result);
assert!(elapsed < Duration::from_secs(1));
}
#[rstest]
#[tokio::test]
async fn test_wait_for_authenticated_times_out() {
let tracker = AuthTracker::new();
let _rx = tracker.begin();
assert!(
!tracker
.wait_for_authenticated(Duration::from_millis(50))
.await
);
}
#[rstest]
#[tokio::test]
async fn test_wait_for_authenticated_begin_clears_failed() {
let tracker = AuthTracker::new();
let _rx = tracker.begin();
tracker.fail("first attempt");
assert!(
!tracker
.wait_for_authenticated(Duration::from_millis(10))
.await
);
let _rx = tracker.begin();
let tracker_clone = tracker.clone();
tokio::spawn(async move {
tokio::time::sleep(Duration::from_millis(50)).await;
tracker_clone.succeed();
});
assert!(tracker.wait_for_authenticated(Duration::from_secs(1)).await);
}
#[rstest]
#[tokio::test]
async fn test_wait_for_authenticated_invalidate_does_not_return_false() {
let tracker = AuthTracker::new();
let _rx = tracker.begin();
let tracker_clone = tracker.clone();
tokio::spawn(async move {
tokio::time::sleep(Duration::from_millis(20)).await;
tracker_clone.invalidate();
tokio::time::sleep(Duration::from_millis(20)).await;
tracker_clone.succeed();
});
assert!(tracker.wait_for_authenticated(Duration::from_secs(1)).await);
}
#[rstest]
#[tokio::test]
async fn test_wait_for_authenticated_concurrent_waiters() {
let tracker = Arc::new(AuthTracker::new());
let _rx = tracker.begin();
let mut handles = vec![];
for _ in 0..10 {
let t = Arc::clone(&tracker);
handles.push(tokio::spawn(async move {
t.wait_for_authenticated(Duration::from_secs(1)).await
}));
}
tokio::time::sleep(Duration::from_millis(50)).await;
tracker.succeed();
for handle in handles {
assert!(handle.await.unwrap());
}
}
#[rstest]
#[tokio::test]
async fn test_wait_for_authenticated_not_authenticated_initially() {
let tracker = AuthTracker::new();
assert!(
!tracker
.wait_for_authenticated(Duration::from_millis(50))
.await
);
}
}
#[cfg(test)]
mod proptest_tests {
use std::{sync::Arc, time::Duration};
use proptest::prelude::*;
use rstest::rstest;
use super::*;
proptest! {
#[rstest]
fn test_state_consistency_after_random_operations(
ops in proptest::collection::vec(0u8..4, 1..50)
) {
let tracker = AuthTracker::new();
let mut expected_auth = false;
for op in &ops {
match op {
0 => {
let _rx = tracker.begin();
expected_auth = false;
}
1 => {
tracker.succeed();
expected_auth = true;
}
2 => {
tracker.fail("test");
expected_auth = false;
}
3 => {
tracker.invalidate();
expected_auth = false;
}
_ => unreachable!(),
}
}
prop_assert_eq!(tracker.is_authenticated(), expected_auth);
}
#[rstest]
fn test_begin_always_clears_failed(
prior_ops in proptest::collection::vec(0u8..4, 0..20)
) {
let tracker = AuthTracker::new();
for op in &prior_ops {
match op {
0 => { let _rx = tracker.begin(); }
1 => tracker.succeed(),
2 => tracker.fail("test"),
3 => tracker.invalidate(),
_ => unreachable!(),
}
}
let _rx = tracker.begin();
prop_assert_eq!(tracker.auth_state(), AuthState::Unauthenticated);
}
#[rstest]
fn test_succeed_always_sets_authenticated(
prior_ops in proptest::collection::vec(0u8..4, 0..20)
) {
let tracker = AuthTracker::new();
for op in &prior_ops {
match op {
0 => { let _rx = tracker.begin(); }
1 => tracker.succeed(),
2 => tracker.fail("test"),
3 => tracker.invalidate(),
_ => unreachable!(),
}
}
tracker.succeed();
prop_assert_eq!(tracker.auth_state(), AuthState::Authenticated);
}
}
#[rstest]
#[tokio::test]
async fn test_wait_responds_within_bounded_time() {
for auth_result in [true, false] {
let tracker = Arc::new(AuthTracker::new());
let _rx = tracker.begin();
let tracker_clone = Arc::clone(&tracker);
tokio::spawn(async move {
tokio::time::sleep(Duration::from_millis(30)).await;
if auth_result {
tracker_clone.succeed();
} else {
tracker_clone.fail("rejected");
}
});
let start = tokio::time::Instant::now();
let result = tracker
.wait_for_authenticated(Duration::from_secs(10))
.await;
let elapsed = start.elapsed();
assert_eq!(result, auth_result);
assert!(
elapsed < Duration::from_millis(500),
"wait_for_authenticated took {elapsed:?} for auth_result={auth_result}"
);
}
}
}