Skip to main content

boarddown_server/
ws.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3use tokio::sync::{broadcast, RwLock, mpsc};
4use axum::extract::ws::{Message, WebSocket};
5use futures_util::{SinkExt, StreamExt};
6use boarddown_core::{BoardId, Error};
7use boarddown_schema::{BoardEvent, TaskOp};
8
9pub type BoardClients = Arc<RwLock<HashMap<String, Vec<mpsc::UnboundedSender<String>>>>>;
10
11pub struct WebSocketHandler {
12    board_id: BoardId,
13    clients: BoardClients,
14    event_tx: broadcast::Sender<BoardEvent>,
15}
16
17impl WebSocketHandler {
18    pub fn new(board_id: BoardId, clients: BoardClients, event_tx: broadcast::Sender<BoardEvent>) -> Self {
19        Self { board_id, clients, event_tx }
20    }
21
22    pub async fn handle(&self, socket: WebSocket) -> Result<(), Error> {
23        let (mut tx, mut rx) = socket.split();
24        let (client_tx, mut client_rx) = mpsc::unbounded_channel::<String>();
25        
26        {
27            let mut clients = self.clients.write().await;
28            clients.entry(self.board_id.0.clone()).or_default().push(client_tx);
29        }
30        
31        let board_id = self.board_id.0.clone();
32        let event_tx = self.event_tx.clone();
33        
34        let send_task = tokio::spawn(async move {
35            while let Some(msg) = client_rx.recv().await {
36                if tx.send(Message::Text(msg)).await.is_err() {
37                    break;
38                }
39            }
40        });
41        
42        while let Some(msg) = rx.next().await {
43            match msg {
44                Ok(Message::Text(text)) => {
45                    if let Ok(board_msg) = serde_json::from_str::<ClientMessage>(&text) {
46                        match board_msg {
47                            ClientMessage::MoveTask { task_id, to_column } => {
48                                let event = BoardEvent::TaskMoved {
49                                    task_id: boarddown_schema::TaskId::parse(&task_id).unwrap(),
50                                    from: boarddown_schema::Status::Todo,
51                                    to: boarddown_schema::Status::InProgress,
52                                };
53                                let _ = event_tx.send(event);
54                            }
55                            ClientMessage::CreateTask { title, column } => {
56                                let event = BoardEvent::TaskCreated {
57                                    board_id: BoardId(board_id.clone()),
58                                    task_id: boarddown_schema::TaskId::new(&board_id, 1),
59                                };
60                                let _ = event_tx.send(event);
61                            }
62                            ClientMessage::Sync { since_version: _ } => {}
63                        }
64                    }
65                }
66                Ok(Message::Close(_)) => break,
67                Err(_) => break,
68                _ => {}
69            }
70        }
71        
72        {
73            let mut clients = self.clients.write().await;
74            if let Some(board_clients) = clients.get_mut(&self.board_id.0) {
75                board_clients.retain(|tx| !tx.is_closed());
76            }
77        }
78        
79        send_task.abort();
80        Ok(())
81    }
82
83    pub async fn broadcast(&self, message: BoardMessage) -> Result<(), Error> {
84        let json = serde_json::to_string(&message)
85            .map_err(|e| Error::Storage(format!("JSON error: {}", e)))?;
86        
87        let clients = self.clients.read().await;
88        if let Some(board_clients) = clients.get(&self.board_id.0) {
89            for client_tx in board_clients {
90                let _ = client_tx.send(json.clone());
91            }
92        }
93        
94        Ok(())
95    }
96}
97
98#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
99pub enum BoardMessage {
100    TaskCreated { task_id: String },
101    TaskUpdated { task_id: String },
102    TaskMoved { task_id: String, from: String, to: String },
103    SyncRequest { since_version: u64 },
104    SyncResponse { changes: Vec<TaskOp>, version: u64 },
105    Conflict { task_id: String, local_version: u64, remote_version: u64 },
106}
107
108#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
109#[serde(tag = "type")]
110pub enum ClientMessage {
111    #[serde(rename = "move_task")]
112    MoveTask { task_id: String, to_column: String },
113    #[serde(rename = "create_task")]
114    CreateTask { title: String, column: Option<String> },
115    #[serde(rename = "sync")]
116    Sync { since_version: u64 },
117}
118
119#[derive(Clone)]
120pub struct SyncState {
121    pub clients: BoardClients,
122    pub event_tx: broadcast::Sender<BoardEvent>,
123    pub version: Arc<RwLock<u64>>,
124}
125
126impl SyncState {
127    pub fn new() -> Self {
128        let (event_tx, _) = broadcast::channel::<BoardEvent>(1024);
129        Self {
130            clients: Arc::new(RwLock::new(HashMap::new())),
131            event_tx,
132            version: Arc::new(RwLock::new(0)),
133        }
134    }
135    
136    pub async fn increment_version(&self) -> u64 {
137        let mut version = self.version.write().await;
138        *version += 1;
139        *version
140    }
141    
142    pub async fn get_version(&self) -> u64 {
143        *self.version.read().await
144    }
145}
146
147impl Default for SyncState {
148    fn default() -> Self {
149        Self::new()
150    }
151}
152
153pub async fn handle_ws_upgrade(
154    socket: WebSocket,
155    board_id: String,
156    sync_state: SyncState,
157) {
158    let handler = WebSocketHandler::new(
159        BoardId(board_id),
160        sync_state.clients,
161        sync_state.event_tx,
162    );
163    
164    if let Err(e) = handler.handle(socket).await {
165        tracing::error!("WebSocket error: {:?}", e);
166    }
167}