use super::{write_file, Scaffold, ScaffoldArgs, ScaffoldResult};
use anyhow::Result;
pub struct WebsocketScaffold;
impl Scaffold for WebsocketScaffold {
fn name(&self) -> &'static str {
"websocket"
}
fn description(&self) -> &'static str {
"WebSocket: upgrade handler, channel auth, pub/sub, presence tracking, room management"
}
fn generate(&self, args: &ScaffoldArgs) -> Result<ScaffoldResult> {
let mut r = ScaffoldResult::default();
let d = args.dry_run;
write_file(
&mut r,
"src/app/controllers/ws_controller.rs",
CONTROLLER,
d,
)?;
write_file(&mut r, "src/app/ws/room_manager.rs", ROOM_MANAGER, d)?;
r.warnings
.push("Register GET /ws route in src/app/routes.rs".into());
r.warnings
.push("rok-websocket must be in Cargo.toml dependencies".into());
Ok(r)
}
}
const CONTROLLER: &str = r#"use axum::{extract::{State, WebSocketUpgrade}, response::IntoResponse};
use rok_websocket::{Channel, PubSub};
use rok_auth::axum::Ctx;
pub async fn handle(ws: WebSocketUpgrade, ctx: Ctx) -> impl IntoResponse {
ws.on_upgrade(move |socket| async move {
// TODO: authenticate, join channels, broadcast/receive messages
tracing::info!(user_id = ctx.user_id(), "WebSocket connected");
})
}
pub async fn broadcast(State(pubsub): State<PubSub>, channel: &str, msg: serde_json::Value) {
pubsub.publish(channel, msg).await;
}
"#;
const ROOM_MANAGER: &str = r#"use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use tokio::sync::RwLock;
#[derive(Default, Clone)]
pub struct RoomManager {
rooms: Arc<RwLock<HashMap<String, HashSet<i64>>>>,
}
impl RoomManager {
pub async fn join(&self, room: &str, user_id: i64) {
self.rooms.write().await
.entry(room.to_string())
.or_default()
.insert(user_id);
}
pub async fn leave(&self, room: &str, user_id: i64) {
let mut rooms = self.rooms.write().await;
if let Some(members) = rooms.get_mut(room) {
members.remove(&user_id);
}
}
pub async fn members(&self, room: &str) -> Vec<i64> {
self.rooms.read().await
.get(room)
.map(|s| s.iter().cloned().collect())
.unwrap_or_default()
}
}
"#;