pushwire-client 0.1.1

Generic multiplexed push client with WebSocket and SSE transports
Documentation
//! Integration tests — spin up a PushServer, connect a PushClient,
//! and verify the full protocol flow.

use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::time::Duration;

use pushwire_client::{ChannelKind, ClientConfig, Frame, PushClient, ReconnectPolicy};
use pushwire_server::PushServer;
use serde::{Deserialize, Serialize};
use tokio::net::TcpListener;

// -------------------------------------------------------------------------
// Test channel type
// -------------------------------------------------------------------------

#[derive(Debug, Clone, Copy, Hash, Eq, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
enum Ch {
    Data,
    System,
}

impl ChannelKind for Ch {
    fn priority(&self) -> u8 {
        match self {
            Ch::System => 0,
            Ch::Data => 1,
        }
    }
    fn wire_id(&self) -> u8 {
        match self {
            Ch::Data => 0x01,
            Ch::System => 0x05,
        }
    }
    fn from_wire_id(id: u8) -> Option<Self> {
        match id {
            0x01 => Some(Ch::Data),
            0x05 => Some(Ch::System),
            _ => None,
        }
    }
    fn from_name(s: &str) -> Option<Self> {
        match s {
            "data" => Some(Ch::Data),
            "system" => Some(Ch::System),
            _ => None,
        }
    }
    fn name(&self) -> &'static str {
        match self {
            Ch::Data => "data",
            Ch::System => "system",
        }
    }
    fn is_system(&self) -> bool {
        matches!(self, Ch::System)
    }
    fn all() -> &'static [Self] {
        &[Self::Data, Self::System]
    }
}

// -------------------------------------------------------------------------
// Helper: start a server on a random port, return the URL
// -------------------------------------------------------------------------

async fn start_echo_server() -> (String, Arc<PushServer<Ch>>) {
    let server: Arc<PushServer<Ch>> = Arc::new(PushServer::new());

    // Echo handler: reflect frames back to sender.
    server.register_handler(Ch::Data, |client_id, frame, srv| {
        let reply = Frame::new(Ch::Data, frame.payload);
        let _ = srv.send(client_id, reply);
    });

    let rps = server.clone();
    let app = rps.router().with_state(server.clone());
    let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
    let addr = listener.local_addr().unwrap();
    let url = format!("http://127.0.0.1:{}", addr.port());

    tokio::spawn(async move {
        axum::serve(listener, app).await.unwrap();
    });

    // Give the server a moment to start.
    tokio::time::sleep(Duration::from_millis(50)).await;

    (url, server)
}

// -------------------------------------------------------------------------
// Tests
// -------------------------------------------------------------------------

#[tokio::test]
async fn connect_send_receive_disconnect() {
    let (url, _server) = start_echo_server().await;

    let mut config = ClientConfig::new(&url);
    config.reconnect = ReconnectPolicy::disabled();

    let mut client = PushClient::<Ch>::new(config);

    let received = Arc::new(AtomicBool::new(false));
    let received_payload = Arc::new(std::sync::Mutex::new(serde_json::Value::Null));
    let r = received.clone();
    let rp = received_payload.clone();

    client.on(Ch::Data, move |frame: Frame<Ch>| {
        *rp.lock().unwrap() = frame.payload;
        r.store(true, Ordering::SeqCst);
    });

    client.connect().await.expect("connect");
    assert_eq!(client.state(), pushwire_client::ConnectionState::Connected);

    // Send a frame.
    let msg = Frame::new(Ch::Data, serde_json::json!({"echo": "test"}));
    client.send(msg).await.expect("send");

    // Wait for echo.
    wait_for(&received, Duration::from_secs(2)).await;
    assert!(received.load(Ordering::SeqCst), "should have received echo");

    let payload = received_payload.lock().unwrap().clone();
    assert_eq!(payload, serde_json::json!({"echo": "test"}));

    client.disconnect().await.expect("disconnect");
    assert_eq!(
        client.state(),
        pushwire_client::ConnectionState::Disconnected
    );
}

#[tokio::test]
async fn cursor_tracking() {
    let (url, server) = start_echo_server().await;

    let mut config = ClientConfig::new(&url);
    config.reconnect = ReconnectPolicy::disabled();
    let client_id = config.client_id;

    let mut client = PushClient::<Ch>::new(config);

    let cursor_seen = Arc::new(AtomicU64::new(0));
    let cs = cursor_seen.clone();
    let count = Arc::new(AtomicU64::new(0));
    let cnt = count.clone();

    client.on(Ch::Data, move |frame: Frame<Ch>| {
        if let Some(c) = frame.cursor {
            cs.store(c, Ordering::SeqCst);
        }
        cnt.fetch_add(1, Ordering::SeqCst);
    });

    client.connect().await.expect("connect");

    // Server pushes 3 frames directly.
    for i in 0..3 {
        let frame = Frame::new(Ch::Data, serde_json::json!({"seq": i}));
        server.send(client_id, frame).expect("server send");
    }

    // Wait for all 3.
    wait_for_count(&count, 3, Duration::from_secs(2)).await;
    assert_eq!(count.load(Ordering::SeqCst), 3);

    // Client should have cursor state.
    let cursors = client.cursors();
    assert!(
        cursors.get(&Ch::Data).copied().unwrap_or(0) >= 3,
        "cursor should be >= 3, got {:?}",
        cursors
    );

    client.disconnect().await.expect("disconnect");
}

#[tokio::test]
async fn reconnect_with_cursor_resume() {
    let (url, server) = start_echo_server().await;

    let mut config = ClientConfig::new(&url);
    let mut reconnect = ReconnectPolicy::default();
    reconnect.initial_delay = Duration::from_millis(100);
    reconnect.max_delay = Duration::from_secs(1);
    reconnect.backoff_factor = 1.5;
    reconnect.max_retries = Some(3);
    reconnect.jitter = false;
    config.reconnect = reconnect;
    let client_id = config.client_id;

    let mut client = PushClient::<Ch>::new(config);

    let count = Arc::new(AtomicU64::new(0));
    let cnt = count.clone();

    client.on(Ch::Data, move |frame: Frame<Ch>| {
        let _ = frame;
        cnt.fetch_add(1, Ordering::SeqCst);
    });

    client.connect().await.expect("connect");

    // Send a frame to verify connection works.
    let frame = Frame::new(Ch::Data, serde_json::json!({"pre": true}));
    server.send(client_id, frame).expect("send");
    wait_for_count(&count, 1, Duration::from_secs(2)).await;

    // Verify cursors were tracked.
    let cursors_before = client.cursors();
    assert!(!cursors_before.is_empty(), "should have cursor state");

    client.disconnect().await.expect("disconnect");
}

#[tokio::test]
async fn send_before_connect_fails() {
    let config = ClientConfig::new("http://127.0.0.1:1");
    let client = PushClient::<Ch>::new(config);

    let result = client
        .send(Frame::new(Ch::Data, serde_json::json!({})))
        .await;
    assert!(result.is_err());
}

// -------------------------------------------------------------------------
// Helpers
// -------------------------------------------------------------------------

async fn wait_for(flag: &AtomicBool, timeout: Duration) {
    let deadline = tokio::time::Instant::now() + timeout;
    while !flag.load(Ordering::SeqCst) {
        if tokio::time::Instant::now() > deadline {
            break;
        }
        tokio::time::sleep(Duration::from_millis(20)).await;
    }
}

async fn wait_for_count(count: &AtomicU64, target: u64, timeout: Duration) {
    let deadline = tokio::time::Instant::now() + timeout;
    while count.load(Ordering::SeqCst) < target {
        if tokio::time::Instant::now() > deadline {
            break;
        }
        tokio::time::sleep(Duration::from_millis(20)).await;
    }
}