use std::sync::atomic::{AtomicU8, Ordering};
use std::sync::Arc;
use std::time::Instant;
use tokio::sync::{Notify, RwLock};
use tracing::{info, warn};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u8)]
pub enum ConnState {
Connected = 0,
Disconnected = 1,
Reconnecting = 2,
}
impl ConnState {
fn from_u8(v: u8) -> Self {
match v {
0 => ConnState::Connected,
1 => ConnState::Disconnected,
2 => ConnState::Reconnecting,
_ => ConnState::Disconnected,
}
}
}
#[derive(Clone)]
pub struct ConnectionMonitor {
state: Arc<AtomicU8>,
reconnected: Arc<Notify>,
disconnected_since: Arc<RwLock<Option<Instant>>>,
}
unsafe impl Send for ConnectionMonitor {}
unsafe impl Sync for ConnectionMonitor {}
impl ConnectionMonitor {
pub fn new() -> Self {
Self {
state: Arc::new(AtomicU8::new(ConnState::Connected as u8)),
reconnected: Arc::new(Notify::new()),
disconnected_since: Arc::new(RwLock::new(None)),
}
}
pub fn state(&self) -> ConnState {
ConnState::from_u8(self.state.load(Ordering::SeqCst))
}
pub fn is_connected(&self) -> bool {
self.state() == ConnState::Connected
}
pub async fn on_disconnected(&self) {
self.state
.store(ConnState::Disconnected as u8, Ordering::SeqCst);
let mut guard = self.disconnected_since.write().await;
if guard.is_none() {
*guard = Some(Instant::now());
}
warn!("connection monitor: disconnected");
}
pub async fn on_reconnected(&self) {
let prev = ConnState::from_u8(
self.state
.swap(ConnState::Connected as u8, Ordering::SeqCst),
);
{
let mut guard = self.disconnected_since.write().await;
*guard = None;
}
if prev == ConnState::Disconnected || prev == ConnState::Reconnecting {
info!("connection monitor: reconnected");
self.reconnected.notify_waiters();
}
}
pub async fn disconnected_duration(&self) -> Option<std::time::Duration> {
let guard = self.disconnected_since.read().await;
guard.map(|since| since.elapsed())
}
pub async fn wait_for_reconnect(&self) {
self.reconnected.notified().await;
}
}
impl Default for ConnectionMonitor {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
#[tokio::test]
async fn initial_state_is_connected() {
let mon = ConnectionMonitor::new();
assert_eq!(mon.state(), ConnState::Connected);
assert!(mon.is_connected());
assert!(mon.disconnected_duration().await.is_none());
}
#[tokio::test]
async fn tracks_disconnect() {
let mon = ConnectionMonitor::new();
mon.on_disconnected().await;
assert_eq!(mon.state(), ConnState::Disconnected);
assert!(!mon.is_connected());
assert!(mon.disconnected_duration().await.is_some());
}
#[tokio::test]
async fn tracks_reconnect() {
let mon = ConnectionMonitor::new();
mon.on_disconnected().await;
mon.on_reconnected().await;
assert_eq!(mon.state(), ConnState::Connected);
assert!(mon.is_connected());
assert!(mon.disconnected_duration().await.is_none());
}
#[tokio::test]
async fn notify_fires_on_reconnect() {
let mon = ConnectionMonitor::new();
let mon2 = mon.clone();
let handle = tokio::spawn(async move {
mon2.wait_for_reconnect().await;
true
});
tokio::time::sleep(Duration::from_millis(20)).await;
mon.on_disconnected().await;
mon.on_reconnected().await;
let result = tokio::time::timeout(Duration::from_secs(2), handle)
.await
.expect("timed out waiting for reconnect notify")
.expect("task panicked");
assert!(result);
}
#[tokio::test]
async fn disconnect_duration_increases() {
let mon = ConnectionMonitor::new();
mon.on_disconnected().await;
tokio::time::sleep(Duration::from_millis(100)).await;
let dur = mon
.disconnected_duration()
.await
.expect("should have duration");
assert!(
dur >= Duration::from_millis(90),
"duration was {:?}, expected >= 90ms",
dur
);
}
#[tokio::test]
async fn multiple_disconnects_keep_first_timestamp() {
let mon = ConnectionMonitor::new();
mon.on_disconnected().await;
let first = mon
.disconnected_duration()
.await
.expect("should have duration");
tokio::time::sleep(Duration::from_millis(50)).await;
mon.on_disconnected().await;
let second = mon
.disconnected_duration()
.await
.expect("should have duration");
assert!(
second >= first,
"second duration {:?} should be >= first {:?}",
second,
first
);
}
}