hocuspocus_rs_ws/
client_connection.rs1use crate::doc_connection::DocConnection;
7use crate::sync::awareness::Awareness;
8use crate::sync::Message;
9use anyhow::{Result, anyhow};
10use async_trait::async_trait;
11use std::collections::HashMap;
12use std::sync::{Arc, Mutex, RwLock};
13use std::time::Duration;
14use tokio::sync::mpsc;
15use tracing::error;
16use uuid::Uuid;
17use yrs::encoding::read::{Cursor, Read};
18use yrs::updates::decoder::{Decode as _, Decoder, DecoderV1};
19
20#[derive(Debug, Clone)]
21pub struct DocConnectionConfig {
22 pub read_only: bool,
23 pub is_authenticated: bool,
24}
25
26impl Default for DocConnectionConfig {
27 fn default() -> Self {
28 Self {
29 read_only: false,
30 is_authenticated: true,
31 }
32 }
33}
34
35#[derive(Debug)]
36pub struct MessageQueueEntry {
37 pub data: Vec<u8>,
38 pub document_name: String,
39}
40
41#[async_trait]
42pub trait DocServer: Send + Sync {
43 async fn fetch(&self, doc_id: &str) -> Result<Arc<RwLock<Awareness>>>;
44 async fn authenticate(&self, doc_id: &str, token: &str) -> Result<DocConnectionConfig>;
45}
46
47pub struct ClientConnection {
48 socket_id: String,
50
51 doc_server: Arc<dyn DocServer>,
52
53 document_connections: Arc<Mutex<HashMap<String, Arc<DocConnection>>>>,
55
56 send_callback: mpsc::Sender<Vec<u8>>,
58 timeout: Duration,
60
61 context: Arc<Mutex<HashMap<String, String>>>,
63
64 closed: Arc<Mutex<bool>>,
66}
67
68impl ClientConnection {
69 pub fn new(
70 doc_server: Arc<dyn DocServer>,
71 send_callback: mpsc::Sender<Vec<u8>>,
72 timeout: Duration,
73 default_context: HashMap<String, String>,
74 ) -> Self {
75 Self {
76 doc_server,
77 socket_id: Uuid::new_v4().to_string(),
78 document_connections: Arc::new(Mutex::new(HashMap::new())),
79 send_callback,
80 timeout,
81 context: Arc::new(Mutex::new(default_context)),
82 closed: Arc::new(Mutex::new(false)),
83 }
84 }
85
86 pub fn socket_id(&self) -> &str {
87 &self.socket_id
88 }
89
90 pub fn is_closed(&self) -> bool {
91 *self.closed.lock().unwrap()
92 }
93
94 pub fn close(&self) -> Result<()> {
95 let mut closed = self.closed.lock().unwrap();
96 if *closed {
97 return Ok(());
98 }
99 *closed = true;
100
101 let connections = self.document_connections.lock().unwrap();
103 for (_name, _connection) in connections.iter() {
104 }
106
107 Ok(())
108 }
109
110 pub async fn handle_message(&self, data: &[u8]) -> Result<()> {
114 if self.is_closed() {
115 return Err(anyhow!("Connection is closed"));
116 }
117
118 let mut decoder = DecoderV1::new(Cursor::new(data));
119 let document_name = decoder.read_string()?.to_owned();
120 let msg = Message::decode_v1(decoder.read_to_end()?)?;
121
122 let doc_connection = self.fetch_connection(&document_name).await;
123 match doc_connection {
124 Err(err) => {
125 error!(
126 "Failed to fetch connection for document '{}': {}",
127 document_name, err
128 );
129 return Ok(());
130 }
131 Ok(connection) => {
132 match connection.handle_msg(msg).await {
133 Ok(result_msg) => {
134 if let Some(response) = result_msg {
136 connection.send_message(response).await?;
137 }
138 }
139 Err(err) => return Err(anyhow!("Failed to handle message: {}", err)),
140 }
141 }
142 }
143
144 Ok(())
174 }
175
176 async fn fetch_connection(&self, document_name: &str) -> Result<Arc<DocConnection>> {
177 {
180 let connections = self.document_connections.lock().unwrap();
181 let doc_connection = connections.get(document_name).cloned();
182 if let Some(conn) = doc_connection {
183 return Ok(conn.clone());
184 }
185 }
186
187 let awareness = Arc::new(RwLock::new(self.doc_server.fetch(document_name).await?))
188 .read()
189 .unwrap()
190 .clone();
191
192 let send_callback = self.send_callback.clone();
193
194 let connection = Arc::new(DocConnection::new(
195 document_name.to_string(),
196 self.doc_server.clone(),
197 awareness.clone(),
198 send_callback.clone(),
199 ));
200
201 {
202 let mut connections = self.document_connections.lock().unwrap();
203 connections.insert(document_name.to_string(), connection.clone());
204 }
205
206 Ok(connection)
207 }
208
209 pub fn get_context(&self, key: &str) -> Option<String> {
210 let context = self.context.lock().unwrap();
211 context.get(key).cloned()
212 }
213
214 pub fn set_context(&self, key: String, value: String) {
215 let mut context = self.context.lock().unwrap();
216 context.insert(key, value);
217 }
218
219 pub fn get_document_count(&self) -> usize {
220 let connections = self.document_connections.lock().unwrap();
221 connections.len()
222 }
223
224 pub fn has_document(&self, document_name: &str) -> bool {
225 let connections = self.document_connections.lock().unwrap();
226 connections.contains_key(document_name)
227 }
228}
229
230impl Drop for ClientConnection {
231 fn drop(&mut self) {
232 let _ = self.close();
233 }
234}
235
236#[cfg(test)]
237mod tests {
238
239 use yrs::updates::decoder::Decoder;
240
241 use super::*;
242
243 #[test]
262 fn test_read_to_end() {
263 let data = vec![1, 2, 3, 4, 5];
264 let cursor = Cursor::new(&data);
265 let mut decoder = DecoderV1::new(cursor);
266
267 decoder.read_u8().unwrap(); assert_eq!(decoder.read_to_end().unwrap(), &data[1..]);
270 }
271}