use std::{
sync::{
Arc, Mutex,
atomic::{AtomicBool, Ordering},
},
time::Duration,
};
pub type AuthResultSender = tokio::sync::oneshot::Sender<Result<(), String>>;
pub type AuthResultReceiver = tokio::sync::oneshot::Receiver<Result<(), String>>;
#[derive(Clone, Debug)]
pub struct AuthTracker {
tx: Arc<Mutex<Option<AuthResultSender>>>,
authenticated: Arc<AtomicBool>,
}
impl AuthTracker {
pub fn new() -> Self {
Self {
tx: Arc::new(Mutex::new(None)),
authenticated: Arc::new(AtomicBool::new(false)),
}
}
#[must_use]
pub fn is_authenticated(&self) -> bool {
self.authenticated.load(Ordering::Acquire)
}
pub fn invalidate(&self) {
self.authenticated.store(false, Ordering::Release);
}
pub fn begin(&self) -> AuthResultReceiver {
let (sender, receiver) = tokio::sync::oneshot::channel();
self.authenticated.store(false, Ordering::Release);
if let Ok(mut guard) = self.tx.lock() {
if let Some(old) = guard.take() {
log::warn!("New authentication request superseding previous pending request");
let _ = old.send(Err("Authentication attempt superseded".to_string()));
} else {
log::debug!("Starting new authentication request");
}
*guard = Some(sender);
}
receiver
}
pub fn succeed(&self) {
self.authenticated.store(true, Ordering::Release);
if let Ok(mut guard) = self.tx.lock()
&& let Some(sender) = guard.take()
{
let _ = sender.send(Ok(()));
}
}
pub fn fail(&self, error: impl Into<String>) {
self.authenticated.store(false, Ordering::Release);
let message = error.into();
if let Ok(mut guard) = self.tx.lock()
&& let Some(sender) = guard.take()
{
let _ = sender.send(Err(message));
}
}
pub async fn wait_for_result<E>(
&self,
timeout: Duration,
receiver: AuthResultReceiver,
) -> Result<(), E>
where
E: From<String>,
{
match tokio::time::timeout(timeout, receiver).await {
Ok(Ok(Ok(()))) => Ok(()),
Ok(Ok(Err(msg))) => Err(E::from(msg)),
Ok(Err(_)) => Err(E::from("Authentication channel closed".to_string())),
Err(_) => {
Err(E::from("Authentication timed out".to_string()))
}
}
}
}
impl Default for AuthTracker {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use std::{
sync::atomic::{AtomicBool, Ordering},
time::Duration,
};
use rstest::rstest;
use super::*;
#[derive(Debug, PartialEq)]
struct TestError(String);
impl From<String> for TestError {
fn from(msg: String) -> Self {
Self(msg)
}
}
#[rstest]
#[tokio::test]
async fn test_successful_authentication() {
let tracker = AuthTracker::new();
let rx = tracker.begin();
tracker.succeed();
let result: Result<(), TestError> =
tracker.wait_for_result(Duration::from_secs(1), rx).await;
assert!(result.is_ok());
}
#[rstest]
#[tokio::test]
async fn test_failed_authentication() {
let tracker = AuthTracker::new();
let rx = tracker.begin();
tracker.fail("Invalid credentials");
let result: Result<(), TestError> =
tracker.wait_for_result(Duration::from_secs(1), rx).await;
assert_eq!(
result.unwrap_err(),
TestError("Invalid credentials".to_string())
);
}
#[rstest]
#[tokio::test]
async fn test_authentication_timeout() {
let tracker = AuthTracker::new();
let rx = tracker.begin();
let result: Result<(), TestError> =
tracker.wait_for_result(Duration::from_millis(50), rx).await;
assert_eq!(
result.unwrap_err(),
TestError("Authentication timed out".to_string())
);
}
#[rstest]
#[tokio::test]
async fn test_begin_supersedes_previous_sender() {
let tracker = AuthTracker::new();
let first = tracker.begin();
let second = tracker.begin();
let result = first.await.expect("oneshot closed unexpectedly");
assert_eq!(result, Err("Authentication attempt superseded".to_string()));
tracker.succeed();
let result: Result<(), TestError> = tracker
.wait_for_result(Duration::from_secs(1), second)
.await;
assert!(result.is_ok());
}
#[rstest]
#[tokio::test]
async fn test_succeed_without_pending_auth() {
let tracker = AuthTracker::new();
tracker.succeed();
}
#[rstest]
#[tokio::test]
async fn test_fail_without_pending_auth() {
let tracker = AuthTracker::new();
tracker.fail("Some error");
}
#[rstest]
#[tokio::test]
async fn test_multiple_sequential_authentications() {
let tracker = AuthTracker::new();
let rx1 = tracker.begin();
tracker.succeed();
let result1: Result<(), TestError> =
tracker.wait_for_result(Duration::from_secs(1), rx1).await;
assert!(result1.is_ok());
let rx2 = tracker.begin();
tracker.fail("Credentials expired");
let result2: Result<(), TestError> =
tracker.wait_for_result(Duration::from_secs(1), rx2).await;
assert_eq!(
result2.unwrap_err(),
TestError("Credentials expired".to_string())
);
let rx3 = tracker.begin();
tracker.succeed();
let result3: Result<(), TestError> =
tracker.wait_for_result(Duration::from_secs(1), rx3).await;
assert!(result3.is_ok());
}
#[rstest]
#[tokio::test]
async fn test_channel_closed_before_result() {
let tracker = AuthTracker::new();
let rx = tracker.begin();
tracker.begin();
let result: Result<(), TestError> =
tracker.wait_for_result(Duration::from_secs(1), rx).await;
assert_eq!(
result.unwrap_err(),
TestError("Authentication attempt superseded".to_string())
);
}
#[rstest]
#[tokio::test]
async fn test_concurrent_auth_attempts() {
let tracker = Arc::new(AuthTracker::new());
let mut handles = vec![];
for i in 0..10 {
let tracker_clone = Arc::clone(&tracker);
let handle = tokio::spawn(async move {
let rx = tracker_clone.begin();
if i == 9 {
tokio::time::sleep(Duration::from_millis(10)).await;
tracker_clone.succeed();
}
let result: Result<(), TestError> = tracker_clone
.wait_for_result(Duration::from_secs(1), rx)
.await;
(i, result)
});
handles.push(handle);
}
let mut successes = 0;
let mut superseded = 0;
for handle in handles {
let (i, result) = handle.await.unwrap();
match result {
Ok(()) => {
assert_eq!(i, 9);
successes += 1;
}
Err(TestError(msg)) if msg.contains("superseded") => {
superseded += 1;
}
Err(e) => panic!("Unexpected error: {e:?}"),
}
}
assert_eq!(successes, 1);
assert_eq!(superseded, 9);
}
#[rstest]
fn test_default_trait() {
let _tracker = AuthTracker::default();
}
#[rstest]
#[tokio::test]
async fn test_clone_trait() {
let tracker = AuthTracker::new();
let cloned = tracker.clone();
let rx = tracker.begin();
cloned.succeed(); let result: Result<(), TestError> =
tracker.wait_for_result(Duration::from_secs(1), rx).await;
assert!(result.is_ok());
}
#[rstest]
fn test_debug_trait() {
let tracker = AuthTracker::new();
let debug_str = format!("{tracker:?}");
assert!(debug_str.contains("AuthTracker"));
}
#[rstest]
#[tokio::test]
async fn test_timeout_clears_sender() {
let tracker = AuthTracker::new();
let rx1 = tracker.begin();
let result1: Result<(), TestError> = tracker
.wait_for_result(Duration::from_millis(50), rx1)
.await;
assert_eq!(
result1.unwrap_err(),
TestError("Authentication timed out".to_string())
);
let rx2 = tracker.begin();
tracker.succeed();
let result2: Result<(), TestError> =
tracker.wait_for_result(Duration::from_secs(1), rx2).await;
assert!(result2.is_ok());
}
#[rstest]
#[tokio::test]
async fn test_fail_clears_sender() {
let tracker = AuthTracker::new();
let rx1 = tracker.begin();
tracker.fail("Bad credentials");
let result1: Result<(), TestError> =
tracker.wait_for_result(Duration::from_secs(1), rx1).await;
assert!(result1.is_err());
let rx2 = tracker.begin();
tracker.succeed();
let result2: Result<(), TestError> =
tracker.wait_for_result(Duration::from_secs(1), rx2).await;
assert!(result2.is_ok());
}
#[rstest]
#[tokio::test]
async fn test_succeed_clears_sender() {
let tracker = AuthTracker::new();
let rx1 = tracker.begin();
tracker.succeed();
let result1: Result<(), TestError> =
tracker.wait_for_result(Duration::from_secs(1), rx1).await;
assert!(result1.is_ok());
let rx2 = tracker.begin();
tracker.succeed();
let result2: Result<(), TestError> =
tracker.wait_for_result(Duration::from_secs(1), rx2).await;
assert!(result2.is_ok());
}
#[rstest]
#[tokio::test]
async fn test_rapid_begin_succeed_cycles() {
let tracker = AuthTracker::new();
for _ in 0..100 {
let rx = tracker.begin();
tracker.succeed();
let result: Result<(), TestError> =
tracker.wait_for_result(Duration::from_secs(1), rx).await;
assert!(result.is_ok());
}
}
#[rstest]
#[tokio::test]
async fn test_double_succeed_is_safe() {
let tracker = AuthTracker::new();
let rx = tracker.begin();
tracker.succeed();
tracker.succeed();
let result: Result<(), TestError> =
tracker.wait_for_result(Duration::from_secs(1), rx).await;
assert!(result.is_ok());
}
#[rstest]
#[tokio::test]
async fn test_double_fail_is_safe() {
let tracker = AuthTracker::new();
let rx = tracker.begin();
tracker.fail("Error 1");
tracker.fail("Error 2");
let result: Result<(), TestError> =
tracker.wait_for_result(Duration::from_secs(1), rx).await;
assert_eq!(
result.unwrap_err(),
TestError("Error 1".to_string()) );
}
#[rstest]
#[tokio::test]
async fn test_succeed_after_fail_is_ignored() {
let tracker = AuthTracker::new();
let rx = tracker.begin();
tracker.fail("Auth failed");
tracker.succeed();
let result: Result<(), TestError> =
tracker.wait_for_result(Duration::from_secs(1), rx).await;
assert!(result.is_err()); }
#[rstest]
#[tokio::test]
async fn test_fail_after_succeed_is_ignored() {
let tracker = AuthTracker::new();
let rx = tracker.begin();
tracker.succeed();
tracker.fail("Auth failed");
let result: Result<(), TestError> =
tracker.wait_for_result(Duration::from_secs(1), rx).await;
assert!(result.is_ok()); }
#[rstest]
#[tokio::test]
async fn test_reconnect_flow_waits_for_auth() {
let tracker = Arc::new(AuthTracker::new());
let subscribed = Arc::new(tokio::sync::Notify::new());
let auth_completed = Arc::new(tokio::sync::Notify::new());
let tracker_reconnect = Arc::clone(&tracker);
let subscribed_reconnect = Arc::clone(&subscribed);
let auth_completed_reconnect = Arc::clone(&auth_completed);
let reconnect_task = tokio::spawn(async move {
let rx = tracker_reconnect.begin();
let tracker_resub = Arc::clone(&tracker_reconnect);
let subscribed_resub = Arc::clone(&subscribed_reconnect);
let auth_completed_resub = Arc::clone(&auth_completed_reconnect);
let resub_task = tokio::spawn(async move {
let result: Result<(), TestError> = tracker_resub
.wait_for_result(Duration::from_secs(5), rx)
.await;
if result.is_ok() {
auth_completed_resub.notify_one();
tokio::time::sleep(Duration::from_millis(10)).await;
subscribed_resub.notify_one();
}
});
resub_task.await.unwrap();
});
tokio::time::sleep(Duration::from_millis(100)).await;
tracker.succeed();
reconnect_task.await.unwrap();
tokio::select! {
() = auth_completed.notified() => {
}
() = tokio::time::sleep(Duration::from_secs(1)) => {
panic!("Auth never completed");
}
}
tokio::select! {
() = subscribed.notified() => {
}
() = tokio::time::sleep(Duration::from_secs(1)) => {
panic!("Subscription never completed");
}
}
}
#[rstest]
#[tokio::test]
async fn test_reconnect_flow_blocks_on_auth_failure() {
let tracker = Arc::new(AuthTracker::new());
let subscribed = Arc::new(AtomicBool::new(false));
let tracker_reconnect = Arc::clone(&tracker);
let subscribed_reconnect = Arc::clone(&subscribed);
let reconnect_task = tokio::spawn(async move {
let rx = tracker_reconnect.begin();
let tracker_resub = Arc::clone(&tracker_reconnect);
let subscribed_resub = Arc::clone(&subscribed_reconnect);
let resub_task = tokio::spawn(async move {
let result: Result<(), TestError> = tracker_resub
.wait_for_result(Duration::from_secs(5), rx)
.await;
if result.is_ok() {
subscribed_resub.store(true, Ordering::Relaxed);
}
});
resub_task.await.unwrap();
});
tokio::time::sleep(Duration::from_millis(50)).await;
tracker.fail("Invalid credentials");
reconnect_task.await.unwrap();
tokio::time::sleep(Duration::from_millis(100)).await;
assert!(!subscribed.load(Ordering::Relaxed));
}
#[rstest]
#[tokio::test]
async fn test_state_machine_transitions() {
let tracker = AuthTracker::new();
let rx1 = tracker.begin();
tracker.succeed();
let result1: Result<(), TestError> =
tracker.wait_for_result(Duration::from_secs(1), rx1).await;
assert!(result1.is_ok());
let rx2 = tracker.begin();
tracker.fail("Error");
let result2: Result<(), TestError> =
tracker.wait_for_result(Duration::from_secs(1), rx2).await;
assert!(result2.is_err());
let rx3 = tracker.begin();
let result3: Result<(), TestError> = tracker
.wait_for_result(Duration::from_millis(50), rx3)
.await;
assert_eq!(
result3.unwrap_err(),
TestError("Authentication timed out".to_string())
);
let rx4 = tracker.begin();
let rx5 = tracker.begin();
let result4: Result<(), TestError> =
tracker.wait_for_result(Duration::from_secs(1), rx4).await;
assert_eq!(
result4.unwrap_err(),
TestError("Authentication attempt superseded".to_string())
);
tracker.succeed();
let result5: Result<(), TestError> =
tracker.wait_for_result(Duration::from_secs(1), rx5).await;
assert!(result5.is_ok());
}
#[rstest]
#[tokio::test]
async fn test_no_sender_leaks() {
let tracker = AuthTracker::new();
for _ in 0..100 {
let rx = tracker.begin();
let _result: Result<(), TestError> =
tracker.wait_for_result(Duration::from_millis(1), rx).await;
}
let rx = tracker.begin();
tracker.succeed();
let result: Result<(), TestError> =
tracker.wait_for_result(Duration::from_secs(1), rx).await;
assert!(result.is_ok());
}
#[rstest]
#[tokio::test]
async fn test_concurrent_succeed_fail_calls() {
let tracker = Arc::new(AuthTracker::new());
let rx = tracker.begin();
let mut handles = vec![];
for _ in 0..50 {
let tracker_clone = Arc::clone(&tracker);
handles.push(tokio::spawn(async move {
tracker_clone.succeed();
}));
}
for _ in 0..50 {
let tracker_clone = Arc::clone(&tracker);
handles.push(tokio::spawn(async move {
tracker_clone.fail("Error");
}));
}
for handle in handles {
handle.await.unwrap();
}
let result: Result<(), TestError> =
tracker.wait_for_result(Duration::from_secs(1), rx).await;
let _ = result;
}
#[rstest]
fn test_is_authenticated_initial_state() {
let tracker = AuthTracker::new();
assert!(!tracker.is_authenticated());
}
#[rstest]
#[tokio::test]
async fn test_is_authenticated_after_succeed() {
let tracker = AuthTracker::new();
assert!(!tracker.is_authenticated());
let _rx = tracker.begin();
assert!(!tracker.is_authenticated());
tracker.succeed();
assert!(tracker.is_authenticated());
}
#[rstest]
#[tokio::test]
async fn test_is_authenticated_after_fail() {
let tracker = AuthTracker::new();
let _rx = tracker.begin();
tracker.fail("error");
assert!(!tracker.is_authenticated());
}
#[rstest]
#[tokio::test]
async fn test_invalidate_clears_auth_state() {
let tracker = AuthTracker::new();
let _rx = tracker.begin();
tracker.succeed();
assert!(tracker.is_authenticated());
tracker.invalidate();
assert!(!tracker.is_authenticated());
}
#[rstest]
#[tokio::test]
async fn test_begin_clears_auth_state() {
let tracker = AuthTracker::new();
let _rx1 = tracker.begin();
tracker.succeed();
assert!(tracker.is_authenticated());
let _rx2 = tracker.begin();
assert!(!tracker.is_authenticated());
}
#[rstest]
fn test_is_authenticated_shared_across_clones() {
let tracker = AuthTracker::new();
let cloned = tracker.clone();
let _rx = tracker.begin();
tracker.succeed();
assert!(cloned.is_authenticated());
}
#[rstest]
fn test_invalidate_shared_across_clones() {
let tracker = AuthTracker::new();
let cloned = tracker.clone();
let _rx = tracker.begin();
tracker.succeed();
assert!(tracker.is_authenticated());
cloned.invalidate();
assert!(!tracker.is_authenticated());
}
#[rstest]
fn test_succeed_without_begin_still_updates_auth_state() {
let tracker = AuthTracker::new();
assert!(!tracker.is_authenticated());
tracker.succeed();
assert!(tracker.is_authenticated());
}
#[rstest]
fn test_fail_without_begin_still_updates_auth_state() {
let tracker = AuthTracker::new();
tracker.succeed();
assert!(tracker.is_authenticated());
tracker.fail("error");
assert!(!tracker.is_authenticated());
}
#[rstest]
#[tokio::test]
async fn test_auth_state_false_after_timeout_until_late_response() {
let tracker = AuthTracker::new();
let rx = tracker.begin();
assert!(!tracker.is_authenticated());
let result: Result<(), TestError> =
tracker.wait_for_result(Duration::from_millis(10), rx).await;
assert!(result.is_err());
assert!(!tracker.is_authenticated());
tracker.succeed();
assert!(tracker.is_authenticated());
}
}