use async_trait::async_trait;
use axum::{Router, routing::get};
use futures_util::{SinkExt, StreamExt};
use rullst::live::{Live, LiveComponent, live_ws_handler};
use serde_json::Value;
use tokio::net::TcpListener;
use tokio_tungstenite::tungstenite::Message;
#[derive(Default, Clone)]
struct CounterComponent {
count: i32,
mounted: bool,
}
#[async_trait]
impl LiveComponent for CounterComponent {
async fn mount(&mut self) {
self.mounted = true;
self.count = 10;
}
async fn handle_event(&mut self, payload: Value) {
if let Some(action) = payload.get("action").and_then(|v| v.as_str()) {
match action {
"increment" => self.count += 1,
"decrement" => self.count -= 1,
_ => {}
}
}
}
fn render(&self) -> String {
format!("<div id=\"counter\">Count: {}</div>", self.count)
}
}
#[tokio::test]
async fn test_live_mount_html() {
let html = Live::mount::<CounterComponent>("/live/counter").await;
assert!(html.contains("hx-ext=\"ws\""));
assert!(html.contains("ws-connect=\"/live/counter\""));
assert!(html.contains("<div id=\"counter\">Count: 10</div>"));
}
#[tokio::test]
async fn test_live_mount_safe_path() {
let html = Live::mount::<CounterComponent>("/live?param=\"hack\"&other='<>'").await;
assert!(
html.contains("ws-connect=\"/live?param="hack"&other='<>'\"")
);
}
#[tokio::test]
async fn test_live_ws_handler() {
let app = Router::new().route("/ws", get(live_ws_handler::<CounterComponent>));
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
tokio::spawn(async move {
axum::serve(listener, app).await.unwrap();
});
let ws_url = format!("ws://{}/ws", addr);
let (mut ws_stream, _) = tokio_tungstenite::connect_async(&ws_url)
.await
.expect("Failed to connect to WS");
let event = serde_json::json!({
"action": "increment"
});
ws_stream
.send(Message::Text(serde_json::to_string(&event).unwrap().into()))
.await
.expect("Failed to send message");
if let Some(Ok(Message::Text(html_response))) = ws_stream.next().await {
let html_string = html_response.to_string();
assert_eq!(html_string, "<div id=\"counter\">Count: 11</div>");
} else {
panic!("Did not receive text message from server");
}
let event2 = serde_json::json!({
"action": "decrement"
});
ws_stream
.send(Message::Text(
serde_json::to_string(&event2).unwrap().into(),
))
.await
.expect("Failed to send message");
if let Some(Ok(Message::Text(html_response))) = ws_stream.next().await {
let html_string = html_response.to_string();
assert_eq!(html_string, "<div id=\"counter\">Count: 10</div>");
} else {
panic!("Did not receive text message from server");
}
ws_stream.close(None).await.unwrap();
}
#[tokio::test]
async fn test_live_module_exists() {}