hocuspocus_rs_ws/
doc_connection.rs1use crate::api_types::Authorization;
7use crate::client_connection::DocServer;
8use crate::sync::Message;
9use crate::sync::{
10 self, DefaultProtocol, MSG_SYNC, MSG_SYNC_UPDATE, Protocol, SyncMessage, awareness::Awareness,
11};
12use anyhow::Result;
13use std::sync::{Arc, OnceLock, RwLock};
14use tokio::sync::mpsc;
15use tracing::debug;
16use yrs::updates::decoder::DecoderV1;
17use yrs::{
18 Subscription, Update,
19 block::ClientID,
20 encoding::write::Write,
21 updates::{
22 decoder::Decode,
23 encoder::{Encode, Encoder, EncoderV1},
24 },
25};
26
27pub const DOC_NAME: &str = "doc";
29const SYNC_STATUS_MESSAGE: u8 = 102;
30
31pub struct DocConnection<T: Protocol = DefaultProtocol> {
32 doc_name: String,
33 doc_server: Arc<dyn DocServer>,
34 authorization: Arc<RwLock<Authorization>>,
35 awareness: Arc<RwLock<Awareness>>,
36 sender: mpsc::Sender<Vec<u8>>,
37 protocol: T,
38 closed: Arc<OnceLock<()>>,
39
40 client_id: OnceLock<ClientID>,
43
44 #[allow(unused)] doc_subscription: Subscription,
46 #[allow(unused)] awareness_subscription: Subscription,
48}
49
50impl DocConnection {
51 pub fn new(
52 doc_name: String,
53 doc_server: Arc<dyn DocServer>,
54 awareness: Arc<RwLock<Awareness>>,
55 callback: mpsc::Sender<Vec<u8>>,
56 ) -> Self {
57 Self::new_inner(doc_name, doc_server, awareness, callback)
58 }
59
60 pub fn new_inner(
61 doc_name: String,
62 doc_server: Arc<dyn DocServer>,
63 awareness: Arc<RwLock<Awareness>>,
64 callback: mpsc::Sender<Vec<u8>>,
65 ) -> Self {
66 let closed = Arc::new(OnceLock::new());
67
68 let (doc_subscription, awareness_subscription) = {
69 let mut awareness = awareness.write().unwrap();
70
71 let doc_subscription = {
72 let doc = awareness.doc();
73 let callback = callback.clone();
74 let doc_name = doc_name.clone();
75 let closed = closed.clone();
76 doc.observe_update_v1(move |_, event| {
77 if closed.get().is_some() {
78 return;
79 }
80 let mut encoder = EncoderV1::new();
82 encoder.write_string(doc_name.as_str());
83 encoder.write_var(MSG_SYNC);
84 encoder.write_var(MSG_SYNC_UPDATE);
85 encoder.write_buf(&event.update);
86 let msg = encoder.to_vec();
87 callback.try_send(msg).expect("todo err handling");
88 })
89 .unwrap()
90 };
91
92 let callback = callback.clone();
93 let closed = closed.clone();
94 let doc_name = doc_name.clone();
95 let awareness_subscription = awareness.on_update(move |awareness, e| {
96 if closed.get().is_some() {
97 return;
98 }
99
100 debug!("awareneess update observed, sending to client");
101
102 let added = e.added();
104 let updated = e.updated();
105 let removed = e.removed();
106 let mut changed = Vec::with_capacity(added.len() + updated.len() + removed.len());
107 changed.extend_from_slice(added);
108 changed.extend_from_slice(updated);
109 changed.extend_from_slice(removed);
110
111 if let Ok(u) = awareness.update_with_clients(changed) {
112 let mut encoder = EncoderV1::new();
113 encoder.write_string(doc_name.as_str());
114 Message::Awareness(u).encode(&mut encoder);
115 let msg = encoder.to_vec();
116 callback.try_send(msg).expect("todo err handling");
117 }
118 });
119
120 (doc_subscription, awareness_subscription)
121 };
122
123 let protocol = DefaultProtocol;
124 Self {
128 doc_name,
129 doc_server,
130 awareness,
131 doc_subscription,
132 awareness_subscription,
133 authorization: Arc::new(RwLock::new(Authorization::None)),
134 protocol,
135 sender: callback,
136 client_id: OnceLock::new(),
137 closed,
138 }
139 }
140
141 pub async fn send(&self, mut update: DecoderV1<'_>) -> Result<(), anyhow::Error> {
142 let msg = Message::decode(&mut update)?;
143 self.send_message(msg).await?;
144
145 Ok(())
146 }
147
148 pub async fn send_message(&self, msg: Message) -> Result<(), anyhow::Error> {
149 let mut encoder = EncoderV1::new();
150 encoder.write_string(self.doc_name.as_str());
151 msg.encode(&mut encoder);
152 self.send_raw(encoder.to_vec()).await
153 }
154
155 pub async fn send_raw(&self, msg: Vec<u8>) -> Result<(), anyhow::Error> {
156 self.sender.send(msg).await?;
157
158 Ok(())
159 }
160
161 #[tracing::instrument(skip(self, msg), fields(doc_name = self.doc_name))]
164 pub async fn handle_msg(&self, msg: Message) -> Result<Option<Message>, sync::Error> {
165 debug!("Handling message for document: {:?}", msg);
166 let protocol = &self.protocol;
167 let awareness = &self.awareness;
168
169 let can_write = matches!(*self.authorization.read().unwrap(), Authorization::Full);
170
171 match msg {
172 Message::Sync(msg) => match msg {
173 SyncMessage::SyncStep1(sv) => {
174 let awareness = awareness.read().unwrap();
175 protocol.handle_sync_step1(&awareness, sv)
176 }
177 SyncMessage::SyncStep2(update) => {
178 if can_write {
179 let mut awareness = awareness.write().unwrap();
180 protocol.handle_sync_step2(&mut awareness, Update::decode_v1(&update)?)
181 } else {
182 Err(sync::Error::PermissionDenied {
183 reason: "Token does not have write access".to_string(),
184 })
185 }
186 }
187 SyncMessage::Update(update) => {
188 if can_write {
189 let mut awareness = awareness.write().unwrap();
190 protocol.handle_update(&mut awareness, Update::decode_v1(&update)?)
191 } else {
192 Err(sync::Error::PermissionDenied {
193 reason: "Token does not have write access".to_string(),
194 })
195 }
196 }
197 },
198 Message::Auth(token, _) => {
199 let token = token.unwrap_or_default();
200 let config = self
201 .doc_server
202 .authenticate(&self.doc_name, token.as_str())
203 .await;
204
205 let mut auth_failed = false;
206 if let Ok(config) = config {
207 if config.is_authenticated {
208 if config.read_only {
209 *self.authorization.write().unwrap() = Authorization::ReadOnly;
210 } else {
211 *self.authorization.write().unwrap() = Authorization::Full;
212 }
213 } else {
214 *self.authorization.write().unwrap() = Authorization::None;
215 auth_failed = true;
216 }
217 }
218
219 if auth_failed {
220 tracing::warn!(
221 "Authentication failed for document: {}",
222 self.doc_name
223 );
224
225 let handle_auth_message =
226 protocol.handle_auth_fail(&self.awareness.read().unwrap());
227 self.send_message(handle_auth_message).await?;
228 return Err(sync::Error::PermissionDenied {
229 reason: "Authentication failed".to_string(),
230 });
231 }
232
233 let handle_auth_message =
234 protocol.handle_auth_success(&self.awareness.read().unwrap(), true);
235 self.send_message(handle_auth_message).await?;
236
237 if !self.awareness.read().unwrap().clients().is_empty() {
238 let awareness = protocol.awareness(&self.awareness.read().unwrap())?;
239 self.send_message(awareness).await?;
240 } else {
241 debug!("No existing awareness states to send to client");
242 }
243
244 let sync1_message = protocol.sync_step1(&self.awareness.read().unwrap())?;
245 self.send_message(sync1_message).await?;
246
247 Ok(None)
248 }
249 Message::AwarenessQuery => {
250 let awareness = awareness.read().unwrap();
251 protocol.handle_awareness_query(&awareness)
252 }
253 Message::Awareness(update) => {
254 if update.clients.len() == 1 {
255 let client_id = update.clients.keys().next().unwrap();
256 self.client_id.get_or_init(|| *client_id);
257 } else {
258 tracing::warn!(
259 "Received awareness update with more than one client {:?}",
260 update.clients
261 );
262 }
263 let mut awareness = awareness.write().unwrap();
264 protocol.handle_awareness_update(&mut awareness, update)
265 }
266 Message::SyncStatus(synced) => {
267 debug!("Client sync status changed: synced={}", synced);
268
269 Ok(None)
270 }
271 Message::Custom(SYNC_STATUS_MESSAGE, data) => {
272 Ok(Some(Message::Custom(SYNC_STATUS_MESSAGE, data)))
274 }
275 Message::Custom(tag, data) => {
276 let mut awareness = awareness.write().unwrap();
277 protocol.missing_handle(&mut awareness, tag, data)
278 }
279 }
280 }
281}
282
283impl<T: Protocol> Drop for DocConnection<T> {
284 fn drop(&mut self) {
285 self.closed.set(()).unwrap();
286
287 if let Some(client_id) = self.client_id.get() {
289 let mut awareness = self.awareness.write().unwrap();
290 awareness.remove_state(*client_id);
291 }
292 }
293}