use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::time::Duration;
use pushwire_client::{ChannelKind, ClientConfig, Frame, PushClient, ReconnectPolicy};
use pushwire_server::PushServer;
use serde::{Deserialize, Serialize};
use tokio::net::TcpListener;
#[derive(Debug, Clone, Copy, Hash, Eq, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
enum Ch {
Data,
System,
}
impl ChannelKind for Ch {
fn priority(&self) -> u8 {
match self {
Ch::System => 0,
Ch::Data => 1,
}
}
fn wire_id(&self) -> u8 {
match self {
Ch::Data => 0x01,
Ch::System => 0x05,
}
}
fn from_wire_id(id: u8) -> Option<Self> {
match id {
0x01 => Some(Ch::Data),
0x05 => Some(Ch::System),
_ => None,
}
}
fn from_name(s: &str) -> Option<Self> {
match s {
"data" => Some(Ch::Data),
"system" => Some(Ch::System),
_ => None,
}
}
fn name(&self) -> &'static str {
match self {
Ch::Data => "data",
Ch::System => "system",
}
}
fn is_system(&self) -> bool {
matches!(self, Ch::System)
}
fn all() -> &'static [Self] {
&[Self::Data, Self::System]
}
}
async fn start_echo_server() -> (String, Arc<PushServer<Ch>>) {
let server: Arc<PushServer<Ch>> = Arc::new(PushServer::new());
server.register_handler(Ch::Data, |client_id, frame, srv| {
let reply = Frame::new(Ch::Data, frame.payload);
let _ = srv.send(client_id, reply);
});
let rps = server.clone();
let app = rps.router().with_state(server.clone());
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let url = format!("http://127.0.0.1:{}", addr.port());
tokio::spawn(async move {
axum::serve(listener, app).await.unwrap();
});
tokio::time::sleep(Duration::from_millis(50)).await;
(url, server)
}
#[tokio::test]
async fn connect_send_receive_disconnect() {
let (url, _server) = start_echo_server().await;
let mut config = ClientConfig::new(&url);
config.reconnect = ReconnectPolicy::disabled();
let mut client = PushClient::<Ch>::new(config);
let received = Arc::new(AtomicBool::new(false));
let received_payload = Arc::new(std::sync::Mutex::new(serde_json::Value::Null));
let r = received.clone();
let rp = received_payload.clone();
client.on(Ch::Data, move |frame: Frame<Ch>| {
*rp.lock().unwrap() = frame.payload;
r.store(true, Ordering::SeqCst);
});
client.connect().await.expect("connect");
assert_eq!(client.state(), pushwire_client::ConnectionState::Connected);
let msg = Frame::new(Ch::Data, serde_json::json!({"echo": "test"}));
client.send(msg).await.expect("send");
wait_for(&received, Duration::from_secs(2)).await;
assert!(received.load(Ordering::SeqCst), "should have received echo");
let payload = received_payload.lock().unwrap().clone();
assert_eq!(payload, serde_json::json!({"echo": "test"}));
client.disconnect().await.expect("disconnect");
assert_eq!(
client.state(),
pushwire_client::ConnectionState::Disconnected
);
}
#[tokio::test]
async fn cursor_tracking() {
let (url, server) = start_echo_server().await;
let mut config = ClientConfig::new(&url);
config.reconnect = ReconnectPolicy::disabled();
let client_id = config.client_id;
let mut client = PushClient::<Ch>::new(config);
let cursor_seen = Arc::new(AtomicU64::new(0));
let cs = cursor_seen.clone();
let count = Arc::new(AtomicU64::new(0));
let cnt = count.clone();
client.on(Ch::Data, move |frame: Frame<Ch>| {
if let Some(c) = frame.cursor {
cs.store(c, Ordering::SeqCst);
}
cnt.fetch_add(1, Ordering::SeqCst);
});
client.connect().await.expect("connect");
for i in 0..3 {
let frame = Frame::new(Ch::Data, serde_json::json!({"seq": i}));
server.send(client_id, frame).expect("server send");
}
wait_for_count(&count, 3, Duration::from_secs(2)).await;
assert_eq!(count.load(Ordering::SeqCst), 3);
let cursors = client.cursors();
assert!(
cursors.get(&Ch::Data).copied().unwrap_or(0) >= 3,
"cursor should be >= 3, got {:?}",
cursors
);
client.disconnect().await.expect("disconnect");
}
#[tokio::test]
async fn reconnect_with_cursor_resume() {
let (url, server) = start_echo_server().await;
let mut config = ClientConfig::new(&url);
let mut reconnect = ReconnectPolicy::default();
reconnect.initial_delay = Duration::from_millis(100);
reconnect.max_delay = Duration::from_secs(1);
reconnect.backoff_factor = 1.5;
reconnect.max_retries = Some(3);
reconnect.jitter = false;
config.reconnect = reconnect;
let client_id = config.client_id;
let mut client = PushClient::<Ch>::new(config);
let count = Arc::new(AtomicU64::new(0));
let cnt = count.clone();
client.on(Ch::Data, move |frame: Frame<Ch>| {
let _ = frame;
cnt.fetch_add(1, Ordering::SeqCst);
});
client.connect().await.expect("connect");
let frame = Frame::new(Ch::Data, serde_json::json!({"pre": true}));
server.send(client_id, frame).expect("send");
wait_for_count(&count, 1, Duration::from_secs(2)).await;
let cursors_before = client.cursors();
assert!(!cursors_before.is_empty(), "should have cursor state");
client.disconnect().await.expect("disconnect");
}
#[tokio::test]
async fn send_before_connect_fails() {
let config = ClientConfig::new("http://127.0.0.1:1");
let client = PushClient::<Ch>::new(config);
let result = client
.send(Frame::new(Ch::Data, serde_json::json!({})))
.await;
assert!(result.is_err());
}
async fn wait_for(flag: &AtomicBool, timeout: Duration) {
let deadline = tokio::time::Instant::now() + timeout;
while !flag.load(Ordering::SeqCst) {
if tokio::time::Instant::now() > deadline {
break;
}
tokio::time::sleep(Duration::from_millis(20)).await;
}
}
async fn wait_for_count(count: &AtomicU64, target: u64, timeout: Duration) {
let deadline = tokio::time::Instant::now() + timeout;
while count.load(Ordering::SeqCst) < target {
if tokio::time::Instant::now() > deadline {
break;
}
tokio::time::sleep(Duration::from_millis(20)).await;
}
}