1pub mod sync;
7#[cfg(feature = "sqlite")]
8pub mod db;
9
10pub use sync::{DocHandler, MSG_AUTH, MSG_AWARENESS, MSG_QUERY_AWARENESS, MSG_SYNC};
11#[cfg(feature = "sqlite")]
12pub use db::Database;
13
14#[cfg(feature = "server")]
15use axum::{
16 extract::{
17 ws::{Message, WebSocket, WebSocketUpgrade},
18 State,
19 },
20 response::Response,
21 routing::get,
22 Router,
23};
24#[cfg(feature = "server")]
25use dashmap::DashMap;
26#[cfg(feature = "server")]
27use futures_util::{
28 stream::{SplitSink, SplitStream},
29 SinkExt, StreamExt,
30};
31#[cfg(feature = "server")]
32use std::sync::Arc;
33
34#[cfg(feature = "server")]
36pub struct AppState {
37 pub rooms: DashMap<String, Arc<DocHandler>>,
38 #[cfg(feature = "sqlite")]
39 pub db: Database,
40}
41
42#[cfg(feature = "server")]
43impl AppState {
44 #[cfg(feature = "sqlite")]
45 pub fn new(db: Database) -> Self {
46 Self {
47 rooms: DashMap::new(),
48 db,
49 }
50 }
51
52 #[cfg(not(feature = "sqlite"))]
53 pub fn new() -> Self {
54 Self {
55 rooms: DashMap::new(),
56 }
57 }
58
59 pub fn get_or_create_handler(&self, room_name: &str) -> Arc<DocHandler> {
61 self.rooms
62 .entry(room_name.to_string())
63 .or_insert_with(|| {
64 let name = room_name.to_string();
65 #[cfg(feature = "sqlite")]
66 {
67 let db = self.db.clone();
68 tokio::task::block_in_place(|| {
69 tokio::runtime::Handle::current()
70 .block_on(async { Arc::new(DocHandler::new(name, db).await) })
71 })
72 }
73 #[cfg(not(feature = "sqlite"))]
74 {
75 tokio::task::block_in_place(|| {
76 tokio::runtime::Handle::current()
77 .block_on(async { Arc::new(DocHandler::new(name).await) })
78 })
79 }
80 })
81 .clone()
82 }
83}
84
85#[cfg(feature = "server")]
87pub fn create_router(state: Arc<AppState>) -> Router {
88 Router::new()
89 .route("/sync/:room_name", get(ws_handler))
90 .route("/sync", get(ws_handler_generic))
91 .with_state(state)
92}
93
94#[cfg(feature = "server")]
96async fn ws_handler(
97 ws: WebSocketUpgrade,
98 axum::extract::Path(room_name): axum::extract::Path<String>,
99 State(state): State<Arc<AppState>>,
100) -> Response {
101 ws.on_upgrade(move |socket| handle_socket_with_room(socket, state, room_name))
102}
103
104#[cfg(feature = "server")]
105async fn ws_handler_generic(ws: WebSocketUpgrade, State(state): State<Arc<AppState>>) -> Response {
106 ws.on_upgrade(move |socket| handle_socket_generic(socket, state))
107}
108
109#[cfg(feature = "server")]
110async fn handle_socket_with_room(socket: WebSocket, state: Arc<AppState>, room_name: String) {
111 let handler = state.get_or_create_handler(&room_name);
112 let (sender, receiver) = socket.split();
113 run_connection(sender, receiver, handler, room_name, None).await;
114}
115
116#[cfg(feature = "server")]
117async fn handle_socket_generic(socket: WebSocket, state: Arc<AppState>) {
118 let (mut sender, mut receiver) = socket.split();
119
120 let first_msg = match receiver.next().await {
122 Some(Ok(Message::Binary(data))) => data,
123 _ => return,
124 };
125
126 let (_, room_name) = match DocHandler::read_and_skip_doc_name(&first_msg) {
127 Some(res) => res,
128 None => return,
129 };
130
131 let handler = state.get_or_create_handler(&room_name);
132
133 let responses = handler.handle_message(&first_msg).await;
135 for resp in &responses {
136 if sender
137 .send(Message::Binary(resp.clone().into()))
138 .await
139 .is_err()
140 {
141 return;
142 }
143 }
144
145 run_connection(
146 sender,
147 receiver,
148 handler,
149 room_name,
150 Some(first_msg.to_vec()),
151 )
152 .await;
153}
154
155#[cfg(feature = "server")]
156pub async fn run_connection(
157 mut ws_sender: SplitSink<WebSocket, Message>,
158 mut ws_receiver: SplitStream<WebSocket>,
159 handler: Arc<DocHandler>,
160 room_name: String,
161 _initial_message: Option<Vec<u8>>,
162) {
163 let initial_msgs = handler.generate_initial_sync();
165 for msg in initial_msgs {
166 if ws_sender.send(Message::Binary(msg.into())).await.is_err() {
167 return;
168 }
169 }
170
171 let mut broadcast_rx = handler.subscribe();
172
173 loop {
174 tokio::select! {
175 msg = ws_receiver.next() => {
176 match msg {
177 Some(Ok(Message::Binary(data))) => {
178 let responses = handler.handle_message(&data).await;
179 for resp in responses {
180 if ws_sender.send(Message::Binary(resp.into())).await.is_err() {
181 return;
182 }
183 }
184 }
185 Some(Ok(Message::Ping(data))) => {
186 let _ = ws_sender.send(Message::Pong(data)).await;
187 }
188 Some(Ok(Message::Close(_))) | None => {
189 tracing::debug!("Client disconnected from room '{}'", room_name);
190 return;
191 }
192 _ => {}
193 }
194 }
195 msg = broadcast_rx.recv() => {
196 if let Ok(data) = msg {
197 if ws_sender.send(Message::Binary(data.into())).await.is_err() {
198 return;
199 }
200 }
201 }
202 }
203 }
204}