hocuspocus_rs_ws/
client_connection.rs

1// Portions of this module are adapted from the Hocuspocus JavaScript server
2// (https://github.com/ueberdosis/hocuspocus) and y-sweet
3// (https://github.com/y-sweet/y-sweet), both distributed under the MIT license.
4// Adapted code retains the original license terms.
5
6use crate::doc_connection::DocConnection;
7use crate::sync::awareness::Awareness;
8use crate::sync::Message;
9use anyhow::{Result, anyhow};
10use async_trait::async_trait;
11use std::collections::HashMap;
12use std::sync::{Arc, Mutex, RwLock};
13use std::time::Duration;
14use tokio::sync::mpsc;
15use tracing::error;
16use uuid::Uuid;
17use yrs::encoding::read::{Cursor, Read};
18use yrs::updates::decoder::{Decode as _, Decoder, DecoderV1};
19
20#[derive(Debug, Clone)]
21pub struct DocConnectionConfig {
22    pub read_only: bool,
23    pub is_authenticated: bool,
24}
25
26impl Default for DocConnectionConfig {
27    fn default() -> Self {
28        Self {
29            read_only: false,
30            is_authenticated: true,
31        }
32    }
33}
34
35#[derive(Debug)]
36pub struct MessageQueueEntry {
37    pub data: Vec<u8>,
38    pub document_name: String,
39}
40
41#[async_trait]
42pub trait DocServer: Send + Sync {
43    async fn fetch(&self, doc_id: &str) -> Result<Arc<RwLock<Awareness>>>;
44    async fn authenticate(&self, doc_id: &str, token: &str) -> Result<DocConnectionConfig>;
45}
46
47pub struct ClientConnection {
48    // Unique identifier for this connection
49    socket_id: String,
50
51    doc_server: Arc<dyn DocServer>,
52
53    // Document connections by document name
54    document_connections: Arc<Mutex<HashMap<String, Arc<DocConnection>>>>,
55
56    // Callback for sending messages back to the client
57    send_callback: mpsc::Sender<Vec<u8>>,
58    // Timeout settings
59    timeout: Duration,
60
61    // Context for the connection
62    context: Arc<Mutex<HashMap<String, String>>>,
63
64    // Whether the connection is closed
65    closed: Arc<Mutex<bool>>,
66}
67
68impl ClientConnection {
69    pub fn new(
70        doc_server: Arc<dyn DocServer>,
71        send_callback: mpsc::Sender<Vec<u8>>,
72        timeout: Duration,
73        default_context: HashMap<String, String>,
74    ) -> Self {
75        Self {
76            doc_server,
77            socket_id: Uuid::new_v4().to_string(),
78            document_connections: Arc::new(Mutex::new(HashMap::new())),
79            send_callback,
80            timeout,
81            context: Arc::new(Mutex::new(default_context)),
82            closed: Arc::new(Mutex::new(false)),
83        }
84    }
85
86    pub fn socket_id(&self) -> &str {
87        &self.socket_id
88    }
89
90    pub fn is_closed(&self) -> bool {
91        *self.closed.lock().unwrap()
92    }
93
94    pub fn close(&self) -> Result<()> {
95        let mut closed = self.closed.lock().unwrap();
96        if *closed {
97            return Ok(());
98        }
99        *closed = true;
100
101        // Close all document connections
102        let connections = self.document_connections.lock().unwrap();
103        for (_name, _connection) in connections.iter() {
104            // DocConnection will handle cleanup in its Drop implementation
105        }
106
107        Ok(())
108    }
109
110    /// Handle incoming message for a specific document
111    /// The document name must be provided separately as the ysweet protocol
112    /// doesn't include document names in the message payload
113    pub async fn handle_message(&self, data: &[u8]) -> Result<()> {
114        if self.is_closed() {
115            return Err(anyhow!("Connection is closed"));
116        }
117
118        let mut decoder = DecoderV1::new(Cursor::new(data));
119        let document_name = decoder.read_string()?.to_owned();
120        let msg = Message::decode_v1(decoder.read_to_end()?)?;
121
122        let doc_connection = self.fetch_connection(&document_name).await;
123        match doc_connection {
124            Err(err) => {
125                error!(
126                    "Failed to fetch connection for document '{}': {}",
127                    document_name, err
128                );
129                return Ok(());
130            }
131            Ok(connection) => {
132                match connection.handle_msg(msg).await {
133                    Ok(result_msg) => {
134                        // Handle the result message if needed
135                        if let Some(response) = result_msg {
136                            connection.send_message(response).await?;
137                        }
138                    }
139                    Err(err) => return Err(anyhow!("Failed to handle message: {}", err)),
140                }
141            }
142        }
143
144        // Queue the message if connection is not established yet
145        // {
146        //     let mut queue = self.incoming_message_queue.lock().unwrap();
147        //     let entry = queue.entry(document_name.to_string()).or_default();
148        //     entry.push(data.to_vec());
149        // }
150
151        // Check if this is the first message for this document
152        // let mut established = self.document_connections_established.lock().unwrap();
153        // if !established.contains_key(document_name) {
154        //     established.insert(document_name.to_string(), false);
155
156        //     // Initialize connection config for this document
157        //     let mut configs = self.connection_configs.lock().unwrap();
158        //     configs.insert(document_name.to_string(), ConnectionConfig::default());
159        //     drop(configs);
160
161        //     // Try to establish the connection
162
163        // }
164
165        // let establish_connection = self.try_establish_connection(document_name, msg).await;
166        // if let Err(err) = establish_connection {
167        //     error!(
168        //         "Failed to establish connection for document '{}': {}",
169        //         document_name, err
170        //     );
171        // }
172
173        Ok(())
174    }
175
176    async fn fetch_connection(&self, document_name: &str) -> Result<Arc<DocConnection>> {
177        // For now, we'll create a basic connection without authentication
178        // In a real implementation, you'd handle authentication here
179        {
180            let connections = self.document_connections.lock().unwrap();
181            let doc_connection = connections.get(document_name).cloned();
182            if let Some(conn) = doc_connection {
183                return Ok(conn.clone());
184            }
185        }
186
187        let awareness = Arc::new(RwLock::new(self.doc_server.fetch(document_name).await?))
188            .read()
189            .unwrap()
190            .clone();
191
192        let send_callback = self.send_callback.clone();
193
194        let connection = Arc::new(DocConnection::new(
195            document_name.to_string(),
196            self.doc_server.clone(),
197            awareness.clone(),
198            send_callback.clone(),
199        ));
200
201        {
202            let mut connections = self.document_connections.lock().unwrap();
203            connections.insert(document_name.to_string(), connection.clone());
204        }
205
206        Ok(connection)
207    }
208
209    pub fn get_context(&self, key: &str) -> Option<String> {
210        let context = self.context.lock().unwrap();
211        context.get(key).cloned()
212    }
213
214    pub fn set_context(&self, key: String, value: String) {
215        let mut context = self.context.lock().unwrap();
216        context.insert(key, value);
217    }
218
219    pub fn get_document_count(&self) -> usize {
220        let connections = self.document_connections.lock().unwrap();
221        connections.len()
222    }
223
224    pub fn has_document(&self, document_name: &str) -> bool {
225        let connections = self.document_connections.lock().unwrap();
226        connections.contains_key(document_name)
227    }
228}
229
230impl Drop for ClientConnection {
231    fn drop(&mut self) {
232        let _ = self.close();
233    }
234}
235
236#[cfg(test)]
237mod tests {
238
239    use yrs::updates::decoder::Decoder;
240
241    use super::*;
242
243    // #[tokio::test]
244    // async fn test_client_connection_creation() {
245    //     let (tx, _rx) = mpsc::channel(16);
246    //     let connection = ClientConnection::new(tx, Duration::from_secs(30), HashMap::new());
247
248    //     assert!(!connection.is_closed());
249    //     assert_eq!(connection.get_document_count(), 0);
250    // }
251
252    // #[tokio::test]
253    // async fn test_client_connection_close() {
254    //     let (tx, _rx) = mpsc::channel(16);
255    //     let connection = ClientConnection::new(tx, Duration::from_secs(30), HashMap::new());
256
257    //     connection.close().unwrap();
258    //     assert!(connection.is_closed());
259    // }
260
261    #[test]
262    fn test_read_to_end() {
263        let data = vec![1, 2, 3, 4, 5];
264        let cursor = Cursor::new(&data);
265        let mut decoder = DecoderV1::new(cursor);
266
267        decoder.read_u8().unwrap(); // read one byte
268
269        assert_eq!(decoder.read_to_end().unwrap(), &data[1..]);
270    }
271}