hocuspocus_axum_server/
lib.rs

1use std::sync::Arc;
2use std::time::{Duration, SystemTime, UNIX_EPOCH};
3
4use anyhow::Result;
5use axum::extract::ws::{Message, WebSocket, WebSocketUpgrade};
6use axum::extract::State;
7use axum::response::IntoResponse;
8
9use dashmap::DashMap;
10use futures_util::StreamExt;
11use hocuspocus_extension_database::types::{FetchContext, StoreContext};
12use hocuspocus_extension_database::DatabaseExtension;
13use tokio::sync::mpsc as tokio_mpsc;
14use tokio::time::{sleep_until, Instant};
15use yrs::encoding::read::{Cursor as YCursor, Read as YRead};
16use yrs::encoding::write::Write as YWrite;
17use yrs::sync::{Awareness, DefaultProtocol, Message as YMsg, Protocol};
18use yrs::updates::decoder::Decode;
19use yrs::updates::encoder::Encode;
20use yrs::{Doc, ReadTxn, StateVector, Transact, Update};
21
22#[cfg(feature = "redis")]
23use hocuspocus_extension_redis::RedisBroadcaster;
24
25pub struct DocRegistry {
26    doc_counts: DashMap<String, usize>,
27    doc_latest: DashMap<String, Vec<u8>>, // last full state per doc
28}
29
30impl DocRegistry {
31    pub fn new() -> Self {
32        Self {
33            doc_counts: DashMap::new(),
34            doc_latest: DashMap::new(),
35        }
36    }
37
38    pub fn increment(&self, name: &str) -> usize {
39        let entry = self
40            .doc_counts
41            .entry(name.to_string())
42            .and_modify(|c| *c += 1)
43            .or_insert(1);
44        *entry
45    }
46
47    pub fn decrement(&self, name: &str) -> Option<usize> {
48        if let Some(mut entry) = self.doc_counts.get_mut(name) {
49            if *entry > 0 {
50                *entry -= 1;
51            }
52            let remaining = *entry;
53            drop(entry);
54            Some(remaining)
55        } else {
56            None
57        }
58    }
59
60    pub fn set_latest(&self, name: &str, bytes: Vec<u8>) {
61        self.doc_latest.insert(name.to_string(), bytes);
62    }
63
64    pub fn get_latest_cloned(&self, name: &str) -> Option<Vec<u8>> {
65        self.doc_latest.get(name).map(|v| v.clone())
66    }
67
68    pub fn remove(&self, name: &str) {
69        self.doc_counts.remove(name);
70        self.doc_latest.remove(name);
71    }
72}
73
74pub struct AppState<E: DatabaseExtension> {
75    pub db: Arc<E>,
76    pub debounce_ms: u64,
77    pub max_debounce_ms: u64,
78    pub doc_registry: DocRegistry,
79    #[cfg(feature = "redis")]
80    pub redis: Option<Arc<RedisBroadcaster>>, // optional broadcaster
81    // Optional authentication provider; when None, auth is disabled
82    pub auth: Option<Arc<dyn AuthProvider + Send + Sync>>,
83}
84
85pub async fn ws_handler<E: DatabaseExtension + 'static>(
86    ws: WebSocketUpgrade,
87    State(state): State<Arc<AppState<E>>>,
88) -> impl IntoResponse {
89    tracing::debug!("upgrade request");
90    ws.on_upgrade(move |socket| on_ws::<E>(socket, state))
91}
92
93const MSG_SYNC: u32 = 0;
94const MSG_AWARENESS: u32 = 1;
95const MSG_AUTH: u32 = 2;
96const MSG_QUERY_AWARENESS: u32 = 3;
97const MSG_SYNC_STATUS: u32 = 8;
98
99enum WorkerCmd {
100    ApplyState(Vec<u8>),
101    InboundWs(Vec<u8>),
102    SetReadonly(bool),
103    Stop,
104}
105
106enum WorkerEvent {
107    OutgoingWs(Vec<u8>),
108    StoreState(Vec<u8>),
109}
110
111fn worker_thread(
112    cmd_rx: std::sync::mpsc::Receiver<WorkerCmd>,
113    ev_tx: tokio_mpsc::Sender<WorkerEvent>,
114) {
115    // Rationale:
116    // - Use Hocuspocus framing (varstring document name + varuint message type) instead of
117    //   raw yrs protocol. This enables multiplexing multiple documents over a single socket and
118    //   supports Hocuspocus-specific message kinds (e.g. SyncReply, SyncStatus, Stateless) that
119    //   yrs::sync does not define.
120    // - Libraries like yrs-axum speak pure yrs (no outer envelope). That approach is simpler,
121    //   but incompatible with Hocuspocus Provider semantics and our routing/debounce/storage flow.
122    // - We therefore parse the outer envelope here and handle inner y-sync/awareness with yrs
123    //   types. Auth is ignored for the MVP, and we persist full state blobs only (no increments),
124    //   while the server side debounces store operations.
125    // Invariants:
126    // - All inbound/outbound frames are prefixed with the document name.
127    // - Storage events carry the full document state; the database extension remains stateless.
128    let doc = Doc::new();
129    let mut awareness = Awareness::new(doc.clone());
130    let protocol = DefaultProtocol;
131    let mut is_readonly = false;
132
133    // when doc updates, compute state and send StoreState
134    let ev_tx_updates = ev_tx.clone();
135    let doc_for_obs = doc.clone();
136    let doc_for_txn = doc.clone();
137    doc_for_obs
138        .observe_update_v1(move |_txn, _u| {
139            let bytes = {
140                let txn = doc_for_txn.transact();
141                let sv = StateVector::default();
142                txn.encode_state_as_update_v1(&sv)
143            };
144            let _ = ev_tx_updates.try_send(WorkerEvent::StoreState(bytes));
145        })
146        .ok();
147
148    while let Ok(cmd) = cmd_rx.recv() {
149        match cmd {
150            WorkerCmd::ApplyState(bytes) => {
151                if let Ok(update) = Update::decode_v1(&bytes) {
152                    if let Err(e) = doc.transact_mut().apply_update(update) {
153                        tracing::warn!(?e, "failed to apply initial state");
154                    }
155                }
156            }
157            WorkerCmd::InboundWs(data) => {
158                if data.is_empty() {
159                    continue;
160                }
161                // read incoming document name (varstring) and outer message type using yrs encoding
162                let mut cur = YCursor::new(&data);
163                let frame_doc_name = cur.read_string().unwrap_or("").to_string();
164                let t: u32 = cur.read_var().unwrap_or(0);
165                tracing::debug!(doc_name = %frame_doc_name, msg_type = t, len = data.len(), "inbound frame");
166                let body = &data[cur.next..];
167                match t {
168                    MSG_SYNC => {
169                        // Manual parse of y-sync submessage
170                        if body.is_empty() {
171                            tracing::debug!("empty y-sync body");
172                        } else {
173                            let mut bcur = YCursor::new(body);
174                            let subtag: u32 = bcur.read_var().unwrap_or(u32::MAX);
175                            match subtag {
176                                0 => {
177                                    // SyncStep1(sv)
178                                    if let Ok(sv_bytes) = bcur.read_buf() {
179                                        if let Ok(sv) = StateVector::decode_v1(sv_bytes) {
180                                            // reply with SyncStep2 from doc state diff
181                                            let update = {
182                                                let txn = doc.transact();
183                                                txn.encode_state_as_update_v1(&sv)
184                                            };
185                                            let mut out = Vec::new();
186                                            out.write_string(&frame_doc_name);
187                                            out.write_var(MSG_SYNC);
188                                            // subtag 1: SyncStep2(update)
189                                            out.write_var(1u32);
190                                            // writeVarUint8Array(update)
191                                            out.write_buf(&update);
192                                            tracing::debug!(len = out.len(), out_doc = %frame_doc_name, manual = true, "outbound sync step2 reply");
193                                            let _ =
194                                                ev_tx.blocking_send(WorkerEvent::OutgoingWs(out));
195                                        }
196                                    }
197                                }
198                                1 | 2 => {
199                                    // SyncStep2 or Update
200                                    if let Ok(upd_bytes) = bcur.read_buf() {
201                                        if is_readonly {
202                                            // In read-only mode, ignore incoming updates but still acknowledge to avoid client loops.
203                                            let mut ack = Vec::new();
204                                            ack.write_string(&frame_doc_name);
205                                            ack.write_var(MSG_SYNC_STATUS);
206                                            ack.write_var(1u32);
207                                            tracing::debug!(out_doc = %frame_doc_name, manual = true, "ack in readonly mode (no apply)");
208                                            let _ =
209                                                ev_tx.blocking_send(WorkerEvent::OutgoingWs(ack));
210                                        } else {
211                                            // apply incoming update (v1 only)
212                                            if let Ok(update) = Update::decode_v1(upd_bytes) {
213                                                if let Err(e) =
214                                                    doc.transact_mut().apply_update(update)
215                                                {
216                                                    tracing::warn!(?e, "failed to apply update");
217                                                }
218                                                // send SyncStatus applied=1
219                                                let mut ack = Vec::new();
220                                                ack.write_string(&frame_doc_name);
221                                                ack.write_var(MSG_SYNC_STATUS);
222                                                ack.write_var(1u32);
223                                                tracing::debug!(out_doc = %frame_doc_name, manual = true, "outbound sync status ack");
224                                                let _ = ev_tx
225                                                    .blocking_send(WorkerEvent::OutgoingWs(ack));
226
227                                                // also emit StoreState immediately to ensure persistence
228                                                let bytes = {
229                                                    let txn = doc.transact();
230                                                    let sv = StateVector::default();
231                                                    txn.encode_state_as_update_v1(&sv)
232                                                };
233                                                let _ = ev_tx
234                                                    .blocking_send(WorkerEvent::StoreState(bytes));
235                                            } else {
236                                                tracing::debug!(
237                                                    "failed to decode update bytes (v1 only)"
238                                                );
239                                            }
240                                        }
241                                    }
242                                }
243                                _ => {
244                                    tracing::debug!(subtag, "unknown y-sync submessage");
245                                }
246                            }
247                        }
248                    }
249                    2 => {
250                        // Auth messages are handled in the outer task (on_ws). Worker shouldn't receive them.
251                        tracing::debug!("auth message received by worker; ignoring");
252                    }
253                    MSG_AWARENESS => {
254                        if let Ok(inner) = YMsg::decode_v1(body) {
255                            if let YMsg::Awareness(update) = inner {
256                                let reply = protocol
257                                    .handle_awareness_update(&mut awareness, update)
258                                    .ok()
259                                    .flatten();
260                                if let Some(msg) = reply {
261                                    let mut out = Vec::new();
262                                    out.write_string(&frame_doc_name);
263                                    out.write_var(MSG_AWARENESS);
264                                    out.extend(msg.encode_v1());
265                                    tracing::debug!(len = out.len(), out_doc = %frame_doc_name, "outbound awareness echo");
266                                    let _ = ev_tx.blocking_send(WorkerEvent::OutgoingWs(out));
267                                }
268                            }
269                        } else {
270                            tracing::debug!("failed to decode awareness message");
271                        }
272                    }
273                    MSG_QUERY_AWARENESS => {
274                        if let Ok(Some(reply)) = protocol.handle_awareness_query(&awareness) {
275                            let mut out = Vec::new();
276                            out.write_string(&frame_doc_name);
277                            out.write_var(MSG_AWARENESS);
278                            out.extend(reply.encode_v1());
279                            tracing::debug!(len = out.len(), out_doc = %frame_doc_name, "outbound awareness reply");
280                            let _ = ev_tx.blocking_send(WorkerEvent::OutgoingWs(out));
281                        }
282                    }
283                    _ => {}
284                }
285            }
286            WorkerCmd::SetReadonly(ro) => {
287                is_readonly = ro;
288            }
289            WorkerCmd::Stop => break,
290        }
291    }
292}
293
294async fn on_ws<E: DatabaseExtension + 'static>(mut socket: WebSocket, state: Arc<AppState<E>>) {
295    tracing::debug!("connection established");
296
297    // channels for worker communication
298    let (cmd_tx, cmd_rx) = std::sync::mpsc::channel::<WorkerCmd>();
299    let (ev_tx, mut ev_rx) = tokio_mpsc::channel::<WorkerEvent>(64);
300
301    // spawn worker thread
302    std::thread::spawn(move || worker_thread(cmd_rx, ev_tx));
303
304    // debounce state
305    let mut pending = false;
306    let mut first_pending_at: Option<Instant> = None;
307    let mut next_deadline: Option<Instant> = None;
308    let mut latest_state_bytes: Option<Vec<u8>> = None;
309    let mut selected_doc_name: Option<String> = None;
310    let mut is_authenticated: bool = state.auth.is_none();
311    let mut loaded_state: bool = false;
312    #[cfg(feature = "redis")]
313    let mut redis_sub_handle: Option<tokio::task::JoinHandle<()>> = None;
314
315    loop {
316        let sleep_fut = if let Some(deadline) = next_deadline {
317            Some(Box::pin(sleep_until(deadline)))
318        } else {
319            None
320        };
321
322        tokio::select! {
323            // WebSocket input
324            maybe_msg = socket.next() => {
325                match maybe_msg {
326                    Some(Ok(Message::Binary(b))) => {
327                        // detect document name from first frame and load state before forwarding
328                        let mut handled_by_auth = false;
329                        if selected_doc_name.is_none() {
330                            let mut cur = YCursor::new(b.as_ref());
331                            if let Ok(name_str) = cur.read_string() {
332                                let name = name_str.to_string();
333                                tracing::debug!(document_name = %name, "first frame received");
334                                // increment connection count for this doc
335                                let _ = state.doc_registry.increment(&name);
336                                selected_doc_name = Some(name.clone());
337
338                                // If auth is enabled, require auth before loading state or forwarding messages
339                                if state.auth.is_some() && !is_authenticated {
340                                    // Peek message type
341                                    let mtype: u32 = cur.read_var().unwrap_or(u32::MAX);
342                                    if mtype == MSG_AUTH {
343                                        // subtag
344                                        let sub: u32 = cur.read_var().unwrap_or(u32::MAX);
345                                        if sub == 0 {
346                                            // token
347                                            if let Ok(token_str) = cur.read_string() {
348                                                if let Some(provider) = state.auth.as_ref() {
349                                                    match provider.on_authenticate(&name, &token_str) {
350                                                        Ok(scope) => {
351                                                            // send authenticated reply
352                                                            let readonly = matches!(scope, AuthScope::ReadOnly);
353                                                            send_auth_authenticated(&mut socket, &name, readonly).await;
354                                                            is_authenticated = true;
355                                                            let _ = cmd_tx.send(WorkerCmd::SetReadonly(readonly));
356                                                            handled_by_auth = true; // don't forward auth frame
357                                                        }
358                                                        Err(e) => {
359                                                            send_auth_permission_denied(&mut socket, &name, Some(&format!("{}", e))).await;
360                                                            // Keep the connection open; ask for token again.
361                                                            send_auth_token_request(&mut socket, &name).await;
362                                                            handled_by_auth = true;
363                                                        }
364                                                    }
365                                                }
366                                            }
367                                        } else {
368                                            // unexpected auth subtype
369                                            send_auth_permission_denied(&mut socket, &name, Some("invalid-auth-message")).await;
370                                            // Keep the connection open; ask for token again.
371                                            send_auth_token_request(&mut socket, &name).await;
372                                            handled_by_auth = true;
373                                        }
374                                    } else {
375                                        // Not an auth message: request token
376                                        send_auth_token_request(&mut socket, &name).await;
377                                        handled_by_auth = true; // don't forward non-auth
378                                    }
379                                }
380
381                                // subscribe to redis after doc is known
382                                #[cfg(feature = "redis")]
383                                if let Some(bc) = state.redis.as_ref() {
384                                    let (handle, mut rx) = bc.subscribe(name.clone()).await.expect("redis subscribe");
385                                    redis_sub_handle = Some(handle);
386                                    let tx_clone = cmd_tx.clone();
387                                    tokio::spawn(async move {
388                                        while let Some((_is_sync, body)) = rx.recv().await {
389                                            let _ = tx_clone.send(WorkerCmd::InboundWs(body.to_vec()));
390                                        }
391                                    });
392                                }
393                            } else {
394                                tracing::debug!("failed to read document name from first frame");
395                            }
396                        } else if state.auth.is_some() && !is_authenticated {
397                            // Post-selection but pre-auth: only process auth frames
398                            let mut cur = YCursor::new(b.as_ref());
399                            let _ = cur.read_string(); // skip name
400                            let mtype: u32 = cur.read_var().unwrap_or(u32::MAX);
401                            if mtype == MSG_AUTH {
402                                let name = selected_doc_name.as_ref().unwrap().clone();
403                                let sub: u32 = cur.read_var().unwrap_or(u32::MAX);
404                                if sub == 0 {
405                                    if let Ok(token_str) = cur.read_string() {
406                                        if let Some(provider) = state.auth.as_ref() {
407                                            match provider.on_authenticate(&name, &token_str) {
408                                                Ok(scope) => {
409                                                    let readonly = matches!(scope, AuthScope::ReadOnly);
410                                                    send_auth_authenticated(&mut socket, &name, readonly).await;
411                                                    is_authenticated = true;
412                                                    let _ = cmd_tx.send(WorkerCmd::SetReadonly(readonly));
413                                                    handled_by_auth = true;
414                                                }
415                                                Err(e) => {
416                                                    send_auth_permission_denied(&mut socket, &name, Some(&format!("{}", e))).await;
417                                                    // Keep the connection open; ask for token again, same as Hocuspocus.
418                                                    send_auth_token_request(&mut socket, &name).await;
419                                                    handled_by_auth = true;
420                                                }
421                                            }
422                                        }
423                                    }
424                                } else {
425                                    send_auth_permission_denied(&mut socket, &name, Some("invalid-auth-message")).await;
426                                    // Keep the connection open; ask for token again, same as Hocuspocus.
427                                    send_auth_token_request(&mut socket, &name).await;
428                                    handled_by_auth = true;
429                                }
430                            } else {
431                                // request token again and ignore
432                                let name = selected_doc_name.as_ref().unwrap().clone();
433                                send_auth_token_request(&mut socket, &name).await;
434                                handled_by_auth = true;
435                            }
436                        }
437
438                        // If we just authenticated or auth is disabled, ensure state is loaded once
439                        if !loaded_state {
440                            if let Some(name) = selected_doc_name.as_ref() {
441                                if is_authenticated {
442                                    tracing::debug!(document_name = %name, "loading state after auth");
443                                    if let Err(e) = load_and_send_state(&*state.db, name, &cmd_tx).await {
444                                        tracing::warn!(error = %e, document_name = %name, "failed to load/apply state");
445                                    }
446                                    loaded_state = true;
447                                }
448                            }
449                        }
450
451                        if !handled_by_auth {
452                            let _ = cmd_tx.send(WorkerCmd::InboundWs(b.to_vec()));
453                        }
454                    }
455                    Some(Ok(Message::Ping(p))) => {
456                        tracing::debug!(size = %p.len(), "ping received");
457                        let _ = socket.send(Message::Pong(p)).await;
458                    }
459                    Some(Ok(Message::Close(frame))) => {
460                        tracing::debug!(pending = pending, ?frame, "closing connection");
461                        // decrement connection count and if last, force save latest known state
462                        if let Some(name) = selected_doc_name.as_ref() {
463                            if let Some(remaining) = state.doc_registry.decrement(name) {
464                                if remaining == 0 {
465                                    let to_store = latest_state_bytes
466                                        .as_ref()
467                                        .cloned()
468                                        .or_else(|| state.doc_registry.get_latest_cloned(name));
469                                    if let Some(bytes) = to_store {
470                                        tracing::debug!(document_name = %name, "last client left; force storing state");
471                                        let _ = store_bytes(&*state.db, name, &bytes).await;
472                                    }
473                                    state.doc_registry.remove(name);
474                                }
475                            }
476                        }
477                        if pending {
478                            tracing::debug!(pending = pending, "storing state on close");
479                            if let (Some(bytes), Some(name)) = (latest_state_bytes.as_ref(), selected_doc_name.as_ref()) {
480                                let _ = store_bytes(&*state.db, name, bytes).await;
481                            }
482                        }
483                        let _ = cmd_tx.send(WorkerCmd::Stop);
484                        #[cfg(feature = "redis")]
485                        if let Some(h) = redis_sub_handle.take() { h.abort(); }
486                        break;
487                    }
488                    None => {
489                        tracing::debug!(pending = pending, "socket closed by peer");
490                        // decrement connection count and if last, force save latest known state
491                        if let Some(name) = selected_doc_name.as_ref() {
492                            if let Some(remaining) = state.doc_registry.decrement(name) {
493                                if remaining == 0 {
494                                    let to_store = latest_state_bytes
495                                        .as_ref()
496                                        .cloned()
497                                        .or_else(|| state.doc_registry.get_latest_cloned(name));
498                                    if let Some(bytes) = to_store {
499                                        tracing::debug!(document_name = %name, "last client left; force storing state");
500                                        let _ = store_bytes(&*state.db, name, &bytes).await;
501                                    }
502                                    state.doc_registry.remove(name);
503                                }
504                            }
505                        }
506                        if pending {
507                            tracing::debug!(pending = pending, "storing state on close");
508                            if let (Some(bytes), Some(name)) = (latest_state_bytes.as_ref(), selected_doc_name.as_ref()) {
509                                let _ = store_bytes(&*state.db, name, bytes).await;
510                            }
511                        }
512                        let _ = cmd_tx.send(WorkerCmd::Stop);
513                        #[cfg(feature = "redis")]
514                        if let Some(h) = redis_sub_handle.take() { h.abort(); }
515                        break;
516                    }
517                    _ => {}
518                }
519            }
520            // Worker events
521            Some(ev) = ev_rx.recv() => {
522                match ev {
523                    WorkerEvent::OutgoingWs(bytes) => {
524                        tracing::debug!(len = bytes.len(), "sending binary to client");
525                        let _ = socket
526                            .send(Message::Binary(axum::body::Bytes::from(bytes.clone())))
527                            .await;
528                        #[cfg(feature = "redis")]
529                        if let (Some(name), Some(bc)) = (selected_doc_name.as_ref(), state.redis.as_ref()) {
530                            // determine message type after varstring
531                            let mut cur = YCursor::new(&bytes);
532                            let _ = cur.read_string();
533                            let mtype: u32 = cur.read_var().unwrap_or(u32::MAX);
534                            let payload = &bytes[..];
535                            // publish full framed message so other instances can forward as-is
536                            if mtype == MSG_SYNC {
537                                let _ = bc.publish_sync(name, payload).await;
538                            } else if mtype == MSG_AWARENESS {
539                                let _ = bc.publish_awareness(name, payload).await;
540                            }
541                        }
542                    }
543                    WorkerEvent::StoreState(bytes) => {
544                        tracing::debug!(state_len = bytes.len(), "doc update observed; scheduling store");
545                        if let Some(name) = selected_doc_name.as_ref() {
546                            state.doc_registry.set_latest(name, bytes.clone());
547                        }
548                        latest_state_bytes = Some(bytes);
549                        let now = Instant::now();
550                        pending = true;
551                        if first_pending_at.is_none() { first_pending_at = Some(now); }
552                        let debounced = now + Duration::from_millis(state.debounce_ms);
553                        let cap = first_pending_at.unwrap() + Duration::from_millis(state.max_debounce_ms);
554                        let target = if debounced > cap { cap } else { debounced };
555                        next_deadline = Some(target);
556                    }
557                }
558            }
559            // Debounce timer
560            _ = async { if let Some(fut) = sleep_fut { fut.await } } , if next_deadline.is_some() => {
561                if pending {
562                    if let (Some(bytes), Some(name)) = (latest_state_bytes.as_ref(), selected_doc_name.as_ref()) {
563                        tracing::debug!(document_name = %name, "debounce elapsed; storing state");
564                        if let Err(e) = store_bytes(&*state.db, name, bytes).await {
565                            tracing::warn!(error = %e, document_name = %name, "failed to store state");
566                        }
567                    }
568                    pending = false;
569                }
570                first_pending_at = None;
571                next_deadline = None;
572            }
573        }
574    }
575}
576
577async fn load_and_send_state<E: DatabaseExtension>(
578    db: &E,
579    name: &str,
580    cmd_tx: &std::sync::mpsc::Sender<WorkerCmd>,
581) -> Result<()> {
582    tracing::debug!(document_name = %name, "fetching state");
583    if let Some(bytes) = db
584        .fetch(FetchContext {
585            document_name: name.to_string(),
586        })
587        .await?
588    {
589        tracing::debug!(document_name = %name, bytes = bytes.len(), "applying fetched state (worker)");
590        let _ = cmd_tx.send(WorkerCmd::ApplyState(bytes));
591    } else {
592        tracing::debug!(document_name = %name, "no stored state; starting empty doc");
593    }
594    Ok(())
595}
596
597async fn store_bytes<E: DatabaseExtension>(db: &E, name: &str, bytes: &[u8]) -> Result<()> {
598    let now = SystemTime::now()
599        .duration_since(UNIX_EPOCH)
600        .unwrap()
601        .as_millis() as i64;
602    db.store(StoreContext {
603        document_name: name.to_string(),
604        state: bytes,
605        updated_at_millis: now,
606    })
607    .await?;
608    Ok(())
609}
610
611// ===== Auth support (Hocuspocus-compatible) =====
612
613#[derive(Clone, Copy, Debug)]
614pub enum AuthScope {
615    ReadOnly,
616    ReadWrite,
617}
618
619pub trait AuthProvider {
620    fn on_authenticate(&self, document_name: &str, token: &str) -> anyhow::Result<AuthScope>;
621}
622
623pub struct StaticTokenAuth {
624    pub token: String,
625    pub scope: AuthScope,
626}
627
628impl AuthProvider for StaticTokenAuth {
629    fn on_authenticate(&self, _document_name: &str, token: &str) -> anyhow::Result<AuthScope> {
630        if token == self.token {
631            Ok(self.scope)
632        } else {
633            anyhow::bail!("permission-denied");
634        }
635    }
636}
637
638async fn send_auth_token_request(socket: &mut WebSocket, name: &str) {
639    let mut out = Vec::new();
640    out.write_string(name);
641    out.write_var(MSG_AUTH);
642    // AuthMessageType.Token = 0
643    out.write_var(0u32);
644    let _ = socket
645        .send(Message::Binary(axum::body::Bytes::from(out)))
646        .await;
647}
648
649async fn send_auth_authenticated(socket: &mut WebSocket, name: &str, readonly: bool) {
650    let mut out = Vec::new();
651    out.write_string(name);
652    out.write_var(MSG_AUTH);
653    // AuthMessageType.Authenticated = 2
654    out.write_var(2u32);
655    out.write_string(if readonly { "readonly" } else { "read-write" });
656    let _ = socket
657        .send(Message::Binary(axum::body::Bytes::from(out)))
658        .await;
659}
660
661async fn send_auth_permission_denied(socket: &mut WebSocket, name: &str, reason: Option<&str>) {
662    let mut out = Vec::new();
663    out.write_string(name);
664    out.write_var(MSG_AUTH);
665    // AuthMessageType.PermissionDenied = 1
666    out.write_var(1u32);
667    out.write_string(reason.unwrap_or("permission-denied"));
668    let _ = socket
669        .send(Message::Binary(axum::body::Bytes::from(out)))
670        .await;
671}