1use std::collections::HashMap;
2use std::sync::Arc;
3use tokio::sync::{broadcast, RwLock, mpsc};
4use axum::extract::ws::{Message, WebSocket};
5use futures_util::{SinkExt, StreamExt};
6use boarddown_core::{BoardId, Error};
7use boarddown_schema::{BoardEvent, TaskOp};
8
9pub type BoardClients = Arc<RwLock<HashMap<String, Vec<mpsc::UnboundedSender<String>>>>>;
10
11pub struct WebSocketHandler {
12 board_id: BoardId,
13 clients: BoardClients,
14 event_tx: broadcast::Sender<BoardEvent>,
15}
16
17impl WebSocketHandler {
18 pub fn new(board_id: BoardId, clients: BoardClients, event_tx: broadcast::Sender<BoardEvent>) -> Self {
19 Self { board_id, clients, event_tx }
20 }
21
22 pub async fn handle(&self, socket: WebSocket) -> Result<(), Error> {
23 let (mut tx, mut rx) = socket.split();
24 let (client_tx, mut client_rx) = mpsc::unbounded_channel::<String>();
25
26 {
27 let mut clients = self.clients.write().await;
28 clients.entry(self.board_id.0.clone()).or_default().push(client_tx);
29 }
30
31 let board_id = self.board_id.0.clone();
32 let event_tx = self.event_tx.clone();
33
34 let send_task = tokio::spawn(async move {
35 while let Some(msg) = client_rx.recv().await {
36 if tx.send(Message::Text(msg)).await.is_err() {
37 break;
38 }
39 }
40 });
41
42 while let Some(msg) = rx.next().await {
43 match msg {
44 Ok(Message::Text(text)) => {
45 if let Ok(board_msg) = serde_json::from_str::<ClientMessage>(&text) {
46 match board_msg {
47 ClientMessage::MoveTask { task_id, to_column } => {
48 let event = BoardEvent::TaskMoved {
49 task_id: boarddown_schema::TaskId::parse(&task_id).unwrap(),
50 from: boarddown_schema::Status::Todo,
51 to: boarddown_schema::Status::InProgress,
52 };
53 let _ = event_tx.send(event);
54 }
55 ClientMessage::CreateTask { title, column } => {
56 let event = BoardEvent::TaskCreated {
57 board_id: BoardId(board_id.clone()),
58 task_id: boarddown_schema::TaskId::new(&board_id, 1),
59 };
60 let _ = event_tx.send(event);
61 }
62 ClientMessage::Sync { since_version: _ } => {}
63 }
64 }
65 }
66 Ok(Message::Close(_)) => break,
67 Err(_) => break,
68 _ => {}
69 }
70 }
71
72 {
73 let mut clients = self.clients.write().await;
74 if let Some(board_clients) = clients.get_mut(&self.board_id.0) {
75 board_clients.retain(|tx| !tx.is_closed());
76 }
77 }
78
79 send_task.abort();
80 Ok(())
81 }
82
83 pub async fn broadcast(&self, message: BoardMessage) -> Result<(), Error> {
84 let json = serde_json::to_string(&message)
85 .map_err(|e| Error::Storage(format!("JSON error: {}", e)))?;
86
87 let clients = self.clients.read().await;
88 if let Some(board_clients) = clients.get(&self.board_id.0) {
89 for client_tx in board_clients {
90 let _ = client_tx.send(json.clone());
91 }
92 }
93
94 Ok(())
95 }
96}
97
98#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
99pub enum BoardMessage {
100 TaskCreated { task_id: String },
101 TaskUpdated { task_id: String },
102 TaskMoved { task_id: String, from: String, to: String },
103 SyncRequest { since_version: u64 },
104 SyncResponse { changes: Vec<TaskOp>, version: u64 },
105 Conflict { task_id: String, local_version: u64, remote_version: u64 },
106}
107
108#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
109#[serde(tag = "type")]
110pub enum ClientMessage {
111 #[serde(rename = "move_task")]
112 MoveTask { task_id: String, to_column: String },
113 #[serde(rename = "create_task")]
114 CreateTask { title: String, column: Option<String> },
115 #[serde(rename = "sync")]
116 Sync { since_version: u64 },
117}
118
119#[derive(Clone)]
120pub struct SyncState {
121 pub clients: BoardClients,
122 pub event_tx: broadcast::Sender<BoardEvent>,
123 pub version: Arc<RwLock<u64>>,
124}
125
126impl SyncState {
127 pub fn new() -> Self {
128 let (event_tx, _) = broadcast::channel::<BoardEvent>(1024);
129 Self {
130 clients: Arc::new(RwLock::new(HashMap::new())),
131 event_tx,
132 version: Arc::new(RwLock::new(0)),
133 }
134 }
135
136 pub async fn increment_version(&self) -> u64 {
137 let mut version = self.version.write().await;
138 *version += 1;
139 *version
140 }
141
142 pub async fn get_version(&self) -> u64 {
143 *self.version.read().await
144 }
145}
146
147impl Default for SyncState {
148 fn default() -> Self {
149 Self::new()
150 }
151}
152
153pub async fn handle_ws_upgrade(
154 socket: WebSocket,
155 board_id: String,
156 sync_state: SyncState,
157) {
158 let handler = WebSocketHandler::new(
159 BoardId(board_id),
160 sync_state.clients,
161 sync_state.event_tx,
162 );
163
164 if let Err(e) = handler.handle(socket).await {
165 tracing::error!("WebSocket error: {:?}", e);
166 }
167}