dink-sdk 0.3.1

Rust SDK for Dink edge mesh platform — JSON-over-NATS RPC for IoT and edge computing
Documentation
use std::sync::atomic::{AtomicU8, Ordering};
use std::sync::Arc;
use std::time::Instant;
use tokio::sync::{Notify, RwLock};
use tracing::{info, warn};

/// Connection state enum stored as u8 in an atomic.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u8)]
pub enum ConnState {
    Connected = 0,
    Disconnected = 1,
    Reconnecting = 2,
}

impl ConnState {
    fn from_u8(v: u8) -> Self {
        match v {
            0 => ConnState::Connected,
            1 => ConnState::Disconnected,
            2 => ConnState::Reconnecting,
            _ => ConnState::Disconnected,
        }
    }
}

/// Tracks NATS connection state via atomic state + Notify for reconnect events.
///
/// Created once during `EdgeClient::connect()` and shared with the
/// `event_callback` closure. Callers access it via
/// `EdgeClient::connection_monitor()`.
#[derive(Clone)]
pub struct ConnectionMonitor {
    state: Arc<AtomicU8>,
    reconnected: Arc<Notify>,
    disconnected_since: Arc<RwLock<Option<Instant>>>,
}

// Safety: all inner types are Send + Sync.
unsafe impl Send for ConnectionMonitor {}
unsafe impl Sync for ConnectionMonitor {}

impl ConnectionMonitor {
    /// Create a new monitor in the Connected state.
    pub fn new() -> Self {
        Self {
            state: Arc::new(AtomicU8::new(ConnState::Connected as u8)),
            reconnected: Arc::new(Notify::new()),
            disconnected_since: Arc::new(RwLock::new(None)),
        }
    }

    /// Current connection state.
    pub fn state(&self) -> ConnState {
        ConnState::from_u8(self.state.load(Ordering::SeqCst))
    }

    /// Whether the connection is currently alive.
    pub fn is_connected(&self) -> bool {
        self.state() == ConnState::Connected
    }

    /// Called by the event_callback when NATS reports a disconnect.
    pub async fn on_disconnected(&self) {
        self.state
            .store(ConnState::Disconnected as u8, Ordering::SeqCst);
        let mut guard = self.disconnected_since.write().await;
        // Keep the first disconnect timestamp — don't reset on repeated disconnects.
        if guard.is_none() {
            *guard = Some(Instant::now());
        }
        warn!("connection monitor: disconnected");
    }

    /// Called by the event_callback when NATS reports a (re)connect.
    pub async fn on_reconnected(&self) {
        let prev = ConnState::from_u8(
            self.state
                .swap(ConnState::Connected as u8, Ordering::SeqCst),
        );
        {
            let mut guard = self.disconnected_since.write().await;
            *guard = None;
        }
        if prev == ConnState::Disconnected || prev == ConnState::Reconnecting {
            info!("connection monitor: reconnected");
            self.reconnected.notify_waiters();
        }
    }

    /// How long we've been disconnected, or None if connected.
    pub async fn disconnected_duration(&self) -> Option<std::time::Duration> {
        let guard = self.disconnected_since.read().await;
        guard.map(|since| since.elapsed())
    }

    /// Wait until the next reconnect event fires.
    pub async fn wait_for_reconnect(&self) {
        self.reconnected.notified().await;
    }
}

impl Default for ConnectionMonitor {
    fn default() -> Self {
        Self::new()
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use std::time::Duration;

    #[tokio::test]
    async fn initial_state_is_connected() {
        let mon = ConnectionMonitor::new();
        assert_eq!(mon.state(), ConnState::Connected);
        assert!(mon.is_connected());
        assert!(mon.disconnected_duration().await.is_none());
    }

    #[tokio::test]
    async fn tracks_disconnect() {
        let mon = ConnectionMonitor::new();
        mon.on_disconnected().await;
        assert_eq!(mon.state(), ConnState::Disconnected);
        assert!(!mon.is_connected());
        assert!(mon.disconnected_duration().await.is_some());
    }

    #[tokio::test]
    async fn tracks_reconnect() {
        let mon = ConnectionMonitor::new();
        mon.on_disconnected().await;
        mon.on_reconnected().await;
        assert_eq!(mon.state(), ConnState::Connected);
        assert!(mon.is_connected());
        assert!(mon.disconnected_duration().await.is_none());
    }

    #[tokio::test]
    async fn notify_fires_on_reconnect() {
        let mon = ConnectionMonitor::new();
        let mon2 = mon.clone();

        // Must register the notified future BEFORE the event fires.
        let handle = tokio::spawn(async move {
            mon2.wait_for_reconnect().await;
            true
        });

        // Give the spawned task a moment to register the waiter.
        tokio::time::sleep(Duration::from_millis(20)).await;

        mon.on_disconnected().await;
        mon.on_reconnected().await;

        let result = tokio::time::timeout(Duration::from_secs(2), handle)
            .await
            .expect("timed out waiting for reconnect notify")
            .expect("task panicked");
        assert!(result);
    }

    #[tokio::test]
    async fn disconnect_duration_increases() {
        let mon = ConnectionMonitor::new();
        mon.on_disconnected().await;
        tokio::time::sleep(Duration::from_millis(100)).await;
        let dur = mon
            .disconnected_duration()
            .await
            .expect("should have duration");
        assert!(
            dur >= Duration::from_millis(90),
            "duration was {:?}, expected >= 90ms",
            dur
        );
    }

    #[tokio::test]
    async fn multiple_disconnects_keep_first_timestamp() {
        let mon = ConnectionMonitor::new();
        mon.on_disconnected().await;
        let first = mon
            .disconnected_duration()
            .await
            .expect("should have duration");

        tokio::time::sleep(Duration::from_millis(50)).await;
        // Second disconnect should NOT reset the timestamp.
        mon.on_disconnected().await;
        let second = mon
            .disconnected_duration()
            .await
            .expect("should have duration");

        assert!(
            second >= first,
            "second duration {:?} should be >= first {:?}",
            second,
            first
        );
    }
}