hocuspocus_rs_ws/
doc_connection.rs

1// Portions of this module are adapted from the Hocuspocus JavaScript server
2// (https://github.com/ueberdosis/hocuspocus) and y-sweet
3// (https://github.com/y-sweet/y-sweet), both distributed under the MIT license.
4// Adapted code retains the original license terms.
5
6use crate::api_types::Authorization;
7use crate::client_connection::DocServer;
8use crate::sync::Message;
9use crate::sync::{
10    self, DefaultProtocol, MSG_SYNC, MSG_SYNC_UPDATE, Protocol, SyncMessage, awareness::Awareness,
11};
12use anyhow::Result;
13use std::sync::{Arc, OnceLock, RwLock};
14use tokio::sync::mpsc;
15use tracing::debug;
16use yrs::updates::decoder::DecoderV1;
17use yrs::{
18    Subscription, Update,
19    block::ClientID,
20    encoding::write::Write,
21    updates::{
22        decoder::Decode,
23        encoder::{Encode, Encoder, EncoderV1},
24    },
25};
26
27// TODO: this is an implementation detail and should not be exposed.
28pub const DOC_NAME: &str = "doc";
29const SYNC_STATUS_MESSAGE: u8 = 102;
30
31pub struct DocConnection<T: Protocol = DefaultProtocol> {
32    doc_name: String,
33    doc_server: Arc<dyn DocServer>,
34    authorization: Arc<RwLock<Authorization>>,
35    awareness: Arc<RwLock<Awareness>>,
36    sender: mpsc::Sender<Vec<u8>>,
37    protocol: T,
38    closed: Arc<OnceLock<()>>,
39
40    /// If the client sends an awareness state, this will be set to its client ID.
41    /// It is used to clear the awareness state when a client disconnects.
42    client_id: OnceLock<ClientID>,
43
44    #[allow(unused)] // acts as RAII guard
45    doc_subscription: Subscription,
46    #[allow(unused)] // acts as RAII guard
47    awareness_subscription: Subscription,
48}
49
50impl DocConnection {
51    pub fn new(
52        doc_name: String,
53        doc_server: Arc<dyn DocServer>,
54        awareness: Arc<RwLock<Awareness>>,
55        callback: mpsc::Sender<Vec<u8>>,
56    ) -> Self {
57        Self::new_inner(doc_name, doc_server, awareness, callback)
58    }
59
60    pub fn new_inner(
61        doc_name: String,
62        doc_server: Arc<dyn DocServer>,
63        awareness: Arc<RwLock<Awareness>>,
64        callback: mpsc::Sender<Vec<u8>>,
65    ) -> Self {
66        let closed = Arc::new(OnceLock::new());
67
68        let (doc_subscription, awareness_subscription) = {
69            let mut awareness = awareness.write().unwrap();
70
71            let doc_subscription = {
72                let doc = awareness.doc();
73                let callback = callback.clone();
74                let doc_name = doc_name.clone();
75                let closed = closed.clone();
76                doc.observe_update_v1(move |_, event| {
77                    if closed.get().is_some() {
78                        return;
79                    }
80                    // https://github.com/y-crdt/y-sync/blob/56958e83acfd1f3c09f5dd67cf23c9c72f000707/src/net/broadcast.rs#L47-L52
81                    let mut encoder = EncoderV1::new();
82                    encoder.write_string(doc_name.as_str());
83                    encoder.write_var(MSG_SYNC);
84                    encoder.write_var(MSG_SYNC_UPDATE);
85                    encoder.write_buf(&event.update);
86                    let msg = encoder.to_vec();
87                    callback.try_send(msg).expect("todo err handling");
88                })
89                .unwrap()
90            };
91
92            let callback = callback.clone();
93            let closed = closed.clone();
94            let doc_name = doc_name.clone();
95            let awareness_subscription = awareness.on_update(move |awareness, e| {
96                if closed.get().is_some() {
97                    return;
98                }
99
100                debug!("awareneess update observed, sending to client");
101
102                // https://github.com/y-crdt/y-sync/blob/56958e83acfd1f3c09f5dd67cf23c9c72f000707/src/net/broadcast.rs#L59
103                let added = e.added();
104                let updated = e.updated();
105                let removed = e.removed();
106                let mut changed = Vec::with_capacity(added.len() + updated.len() + removed.len());
107                changed.extend_from_slice(added);
108                changed.extend_from_slice(updated);
109                changed.extend_from_slice(removed);
110
111                if let Ok(u) = awareness.update_with_clients(changed) {
112                    let mut encoder = EncoderV1::new();
113                    encoder.write_string(doc_name.as_str());
114                    Message::Awareness(u).encode(&mut encoder);
115                    let msg = encoder.to_vec();
116                    callback.try_send(msg).expect("todo err handling");
117                }
118            });
119
120            (doc_subscription, awareness_subscription)
121        };
122
123        let protocol = DefaultProtocol;
124        // Initial handshake is based on this:
125        // https://github.com/y-crdt/y-sync/blob/56958e83acfd1f3c09f5dd67cf23c9c72f000707/src/sync.rs#L45-L54
126
127        Self {
128            doc_name,
129            doc_server,
130            awareness,
131            doc_subscription,
132            awareness_subscription,
133            authorization: Arc::new(RwLock::new(Authorization::None)),
134            protocol,
135            sender: callback,
136            client_id: OnceLock::new(),
137            closed,
138        }
139    }
140
141    pub async fn send(&self, mut update: DecoderV1<'_>) -> Result<(), anyhow::Error> {
142        let msg = Message::decode(&mut update)?;
143        self.send_message(msg).await?;
144
145        Ok(())
146    }
147
148    pub async fn send_message(&self, msg: Message) -> Result<(), anyhow::Error> {
149        let mut encoder = EncoderV1::new();
150        encoder.write_string(self.doc_name.as_str());
151        msg.encode(&mut encoder);
152        self.send_raw(encoder.to_vec()).await
153    }
154
155    pub async fn send_raw(&self, msg: Vec<u8>) -> Result<(), anyhow::Error> {
156        self.sender.send(msg).await?;
157
158        Ok(())
159    }
160
161    // Adapted from:
162    // https://github.com/y-crdt/y-sync/blob/56958e83acfd1f3c09f5dd67cf23c9c72f000707/src/net/conn.rs#L184C1-L222C1
163    #[tracing::instrument(skip(self, msg), fields(doc_name = self.doc_name))]
164    pub async fn handle_msg(&self, msg: Message) -> Result<Option<Message>, sync::Error> {
165        debug!("Handling message for document: {:?}", msg);
166        let protocol = &self.protocol;
167        let awareness = &self.awareness;
168
169        let can_write = matches!(*self.authorization.read().unwrap(), Authorization::Full);
170
171        match msg {
172            Message::Sync(msg) => match msg {
173                SyncMessage::SyncStep1(sv) => {
174                    let awareness = awareness.read().unwrap();
175                    protocol.handle_sync_step1(&awareness, sv)
176                }
177                SyncMessage::SyncStep2(update) => {
178                    if can_write {
179                        let mut awareness = awareness.write().unwrap();
180                        protocol.handle_sync_step2(&mut awareness, Update::decode_v1(&update)?)
181                    } else {
182                        Err(sync::Error::PermissionDenied {
183                            reason: "Token does not have write access".to_string(),
184                        })
185                    }
186                }
187                SyncMessage::Update(update) => {
188                    if can_write {
189                        let mut awareness = awareness.write().unwrap();
190                        protocol.handle_update(&mut awareness, Update::decode_v1(&update)?)
191                    } else {
192                        Err(sync::Error::PermissionDenied {
193                            reason: "Token does not have write access".to_string(),
194                        })
195                    }
196                }
197            },
198            Message::Auth(token, _) => {
199                let token = token.unwrap_or_default();
200                let config = self
201                    .doc_server
202                    .authenticate(&self.doc_name, token.as_str())
203                    .await;
204
205                let mut auth_failed = false;
206                if let Ok(config) = config {
207                    if config.is_authenticated {
208                        if config.read_only {
209                            *self.authorization.write().unwrap() = Authorization::ReadOnly;
210                        } else {
211                            *self.authorization.write().unwrap() = Authorization::Full;
212                        }
213                    } else {
214                        *self.authorization.write().unwrap() = Authorization::None;
215                        auth_failed = true;
216                    }
217                }
218
219                if auth_failed {
220                    tracing::warn!(
221                        "Authentication failed for document: {}",
222                        self.doc_name
223                    );
224
225                    let handle_auth_message =
226                        protocol.handle_auth_fail(&self.awareness.read().unwrap());
227                    self.send_message(handle_auth_message).await?;
228                    return Err(sync::Error::PermissionDenied {
229                        reason: "Authentication failed".to_string(),
230                    });
231                }
232
233                let handle_auth_message =
234                    protocol.handle_auth_success(&self.awareness.read().unwrap(), true);
235                self.send_message(handle_auth_message).await?;
236
237                if !self.awareness.read().unwrap().clients().is_empty() {
238                    let awareness = protocol.awareness(&self.awareness.read().unwrap())?;
239                    self.send_message(awareness).await?;
240                } else {
241                    debug!("No existing awareness states to send to client");
242                }
243
244                let sync1_message = protocol.sync_step1(&self.awareness.read().unwrap())?;
245                self.send_message(sync1_message).await?;
246
247                Ok(None)
248            }
249            Message::AwarenessQuery => {
250                let awareness = awareness.read().unwrap();
251                protocol.handle_awareness_query(&awareness)
252            }
253            Message::Awareness(update) => {
254                if update.clients.len() == 1 {
255                    let client_id = update.clients.keys().next().unwrap();
256                    self.client_id.get_or_init(|| *client_id);
257                } else {
258                    tracing::warn!(
259                        "Received awareness update with more than one client {:?}",
260                        update.clients
261                    );
262                }
263                let mut awareness = awareness.write().unwrap();
264                protocol.handle_awareness_update(&mut awareness, update)
265            }
266            Message::SyncStatus(synced) => {
267                debug!("Client sync status changed: synced={}", synced);
268
269                Ok(None)
270            }
271            Message::Custom(SYNC_STATUS_MESSAGE, data) => {
272                // Respond to the client with the same payload it sent.
273                Ok(Some(Message::Custom(SYNC_STATUS_MESSAGE, data)))
274            }
275            Message::Custom(tag, data) => {
276                let mut awareness = awareness.write().unwrap();
277                protocol.missing_handle(&mut awareness, tag, data)
278            }
279        }
280    }
281}
282
283impl<T: Protocol> Drop for DocConnection<T> {
284    fn drop(&mut self) {
285        self.closed.set(()).unwrap();
286
287        // If this client had an awareness state, remove it.
288        if let Some(client_id) = self.client_id.get() {
289            let mut awareness = self.awareness.write().unwrap();
290            awareness.remove_state(*client_id);
291        }
292    }
293}