hocuspocus_rs_ws/
types.rs

1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3use std::sync::Arc;
4use yrs::
5    encoding::read::Error as DecodeError
6;
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
9pub enum MessageType {
10    Sync = 0,
11    Awareness = 1,
12    Auth = 2,
13    QueryAwareness = 3,
14    SyncReply = 4,          // ?
15    Stateless = 5,          // ?
16    BroadcastStateless = 6, // ?
17    SyncStatus = 7,         // ?
18    Close = 8,              // ?
19}
20
21impl From<u8> for MessageType {
22    fn from(value: u8) -> Self {
23        match value {
24            0 => MessageType::Sync,
25            1 => MessageType::Awareness,
26            2 => MessageType::Auth,
27            3 => MessageType::QueryAwareness,
28            5 => MessageType::Stateless,
29            6 => MessageType::BroadcastStateless,
30            7 => MessageType::SyncStatus,
31            8 => MessageType::Close,
32            _ => panic!("Invalid message type: {}", value),
33        }
34    }
35}
36
37#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
38#[repr(u8)]
39pub enum SyncType {
40    Step1 = 0,
41    Step2 = 1,
42    Update = 2,
43}
44
45impl SyncType {
46    pub fn from_u8(value: u8) -> Result<Self, DecodeError> {
47        match value {
48            0 => Ok(SyncType::Step1),
49            1 => Ok(SyncType::Step2),
50            2 => Ok(SyncType::Update),
51            _ => panic!("Invalid sync type: {}", value),
52        }
53    }
54
55    pub fn from_u64(value: u64) -> Result<Self, DecodeError> {
56        Self::from_u8(value as u8)
57    }
58}
59
60#[derive(Debug, Clone, PartialEq)]
61pub struct ClientAwareness {
62    pub client_id: u64,
63    pub clock: u32,
64    pub state: Option<serde_json::Value>,
65}
66
67// // Awareness 프로토콜 메시지
68// #[derive(Debug)]
69// pub struct AwarenessUpdate {
70//     pub clients: Vec<AwarenessClient>,
71// }
72
73// #[derive(Debug, Clone)]
74// pub struct AwarenessClient {
75//     pub client_id: u64,
76//     pub clock: u32,
77//     pub state: Option<Vec<u8>>, // JSON as bytes
78// }
79
80// // Awareness 상태 관리
81// #[derive(Debug, Clone)]
82// pub struct Awareness {
83//     pub states: HashMap<u64, AwarenessState>,
84//     pub local_client_id: u64,
85// }
86
87// #[derive(Debug, Clone)]
88// pub struct AwarenessState {
89//     pub clock: u32,
90//     pub state: serde_json::Value,
91//     pub last_updated: std::time::Instant,
92// }
93
94// impl Awareness {
95//     pub fn new(client_id: u64) -> Self {
96//         Self {
97//             states: HashMap::new(),
98//             local_client_id: client_id,
99//         }
100//     }
101
102//     pub fn set_local_state(&mut self, state: serde_json::Value) {
103//         let entry = self
104//             .states
105//             .entry(self.local_client_id)
106//             .or_insert(AwarenessState {
107//                 clock: 0,
108//                 state: serde_json::Value::Null,
109//                 last_updated: std::time::Instant::now(),
110//             });
111
112//         entry.clock += 1;
113//         entry.state = state;
114//         entry.last_updated = std::time::Instant::now();
115//     }
116
117//     pub fn get_states(&self) -> &HashMap<u64, AwarenessState> {
118//         &self.states
119//     }
120
121//     pub fn apply_update(&mut self, update: AwarenessUpdate) -> Vec<u64> {
122//         let mut changed = Vec::new();
123
124//         for client in update.clients {
125//             let should_update = self
126//                 .states
127//                 .get(&client.client_id)
128//                 .map(|s| s.clock < client.clock)
129//                 .unwrap_or(true);
130
131//             if should_update {
132//                 if let Some(state_bytes) = client.state {
133//                     if let Ok(state) = serde_json::from_slice(&state_bytes) {
134//                         self.states.insert(
135//                             client.client_id,
136//                             AwarenessState {
137//                                 clock: client.clock,
138//                                 state,
139//                                 last_updated: std::time::Instant::now(),
140//                             },
141//                         );
142//                         changed.push(client.client_id);
143//                     }
144//                 } else {
145//                     // Client disconnected
146//                     self.states.remove(&client.client_id);
147//                     changed.push(client.client_id);
148//                 }
149//             }
150//         }
151
152//         changed
153//     }
154
155//     pub fn encode_update(&self, clients: &[u64]) -> Vec<u8> {
156//         let mut encoder = Encoder::new();
157
158//         // Number of clients
159//         encoder.write_var_uint(clients.len() as u64);
160
161//         for &client_id in clients {
162//             encoder.write_var_uint(client_id);
163
164//             if let Some(state) = self.states.get(&client_id) {
165//                 encoder.write_var_uint(state.clock as u64);
166//                 let json_bytes = serde_json::to_vec(&state.state).unwrap_or_default();
167//                 encoder.write_var_string(&String::from_utf8_lossy(&json_bytes));
168//             } else {
169//                 encoder.write_var_uint(0); // clock = 0 means removed
170//                 encoder.write_var_uint(0); // empty state
171//             }
172//         }
173
174//         encoder.to_vec()
175//     }
176
177//     pub fn decode_update(data: &[u8]) -> Result<AwarenessUpdate, DecodeError> {
178//         let mut decoder = Decoder::new(data);
179//         let num_clients = decoder.read_var_uint()? as usize;
180//         let mut clients = Vec::with_capacity(num_clients);
181
182//         for _ in 0..num_clients {
183//             let client_id = decoder.read_var_uint()?;
184//             let clock = decoder.read_var_uint()? as u32;
185
186//             let state = if clock > 0 {
187//                 let state_str = decoder.read_var_string()?;
188//                 Some(state_str.as_bytes().to_vec())
189//             } else {
190//                 None
191//             };
192
193//             clients.push(AwarenessClient {
194//                 client_id,
195//                 clock,
196//                 state,
197//             });
198//         }
199
200//         Ok(AwarenessUpdate { clients })
201//     }
202// }
203
204// 더 자세한 Sync 메시지 구조
205#[derive(Debug)]
206pub enum SyncMessage<'a> {
207    Step1 { state_vector: &'a [u8] },
208    Step2 { update: &'a [u8] },
209    Update { update: &'a [u8] },
210}
211
212impl<'a> SyncMessage<'a> {
213    pub fn decode(sync_type: SyncType, data: &'a [u8]) -> Result<Self, DecodeError> {
214        match sync_type {
215            SyncType::Step1 => Ok(SyncMessage::Step1 { state_vector: data }),
216            SyncType::Step2 => Ok(SyncMessage::Step2 { update: data }),
217            SyncType::Update => Ok(SyncMessage::Update { update: data }),
218        }
219    }
220}
221
222// // Separate payload structs for each message type
223// #[derive(Debug, Clone)]
224// pub struct SyncPayload {
225//     pub sync_type: SyncType,
226//     pub data: Vec<u8>,
227// }
228
229// #[derive(Debug, Clone)]
230// pub struct SyncReplyPayload {
231//     pub sync_type: SyncType,
232//     pub data: Vec<u8>,
233// }
234
235// #[derive(Debug, Clone)]
236// pub struct AwarenessPayload {
237//     pub clients: Vec<ClientAwareness>,
238// }
239
240// #[derive(Debug, Clone)]
241// pub struct AuthPayload {
242//     pub token: String,
243// }
244
245// #[derive(Debug, Clone)]
246// pub struct QueryAwarenessPayload {
247//     // No additional data needed
248// }
249
250// #[derive(Debug, Clone)]
251// pub struct StatelessPayload {
252//     pub payload: Vec<u8>,
253// }
254
255// #[derive(Debug, Clone)]
256// pub struct BroadcastStatelessPayload {
257//     pub payload: Vec<u8>,
258// }
259
260// #[derive(Debug, Clone)]
261// pub struct ClosePayload {
262//     pub code: u16,
263//     pub reason: String,
264// }
265
266// #[derive(Debug)]
267// pub enum MessageTypeV2 {
268//     Sync(SyncPayload),
269//     SyncReply(SyncReplyPayload),
270//     Awareness(AwarenessPayload),
271//     Auth(AuthPayload),
272//     QueryAwareness(QueryAwarenessPayload),
273//     Stateless(StatelessPayload),
274//     BroadcastStateless(BroadcastStatelessPayload),
275//     Close(ClosePayload),
276// }
277
278// impl MessageTypeV2 {
279//     pub fn decode(buffer: &[u8], skip_document: bool) -> Result<Self, DecodeError> {
280//         let mut decoder = DecoderV2::new(Cursor::new(buffer)).unwrap();
281
282//         // Skip document name if present in the buffer
283//         if !skip_document {
284//             let _document = decoder.read_string()?;
285//         }
286
287//         let msg_type = decoder.read_u8()?;
288
289//         match msg_type {
290//             0 => {
291//                 let sync_type = SyncType::from_u8(decoder.read_u8()?)?;
292//                 Ok(MessageTypeV2::Sync(SyncPayload {
293//                     sync_type,
294//                     data: decoder.read_buf()?.to_vec(),
295//                 }))
296//             }
297//             1 => {
298//                 let clients = decode_awareness(&mut decoder)?;
299//                 Ok(MessageTypeV2::Awareness(AwarenessPayload { clients }))
300//             }
301//             2 => {
302//                 let token = decoder.read_str()?.to_string();
303//                 Ok(MessageTypeV2::Auth(AuthPayload { token }))
304//             }
305//             3 => Ok(MessageTypeV2::QueryAwareness(QueryAwarenessPayload {})),
306//             4 => {
307//                 let sync_type = SyncType::from_u8(decoder.read_u8()?)?;
308//                 Ok(MessageTypeV2::SyncReply(SyncReplyPayload {
309//                     sync_type,
310//                     data: decoder.read_buf()?.to_vec(),
311//                 }))
312//             }
313//             5 => Ok(MessageTypeV2::Stateless(StatelessPayload {
314//                 payload: decoder.read_buf()?.to_vec(),
315//             })),
316//             6 => Ok(MessageTypeV2::BroadcastStateless(
317//                 BroadcastStatelessPayload {
318//                     payload: decoder.read_buf()?.to_vec(),
319//                 },
320//             )),
321//             8 => {
322//                 let code = decoder.read_u16()?;
323//                 let reason = decoder.read_str()?.to_string();
324//                 Ok(MessageTypeV2::Close(ClosePayload { code, reason }))
325//             }
326//             _ => Err(DecodeError::UnknownMessageType(msg_type)),
327//         }
328//     }
329
330//     pub fn message_type(&self) -> MessageType {
331//         match self {
332//             MessageTypeV2::Sync(_) => MessageType::Sync,
333//             MessageTypeV2::SyncReply(_) => MessageType::SyncReply,
334//             MessageTypeV2::Awareness(_) => MessageType::Awareness,
335//             MessageTypeV2::Auth(_) => MessageType::Auth,
336//             MessageTypeV2::QueryAwareness(_) => MessageType::QueryAwareness,
337//             MessageTypeV2::Stateless(_) => MessageType::Stateless,
338//             MessageTypeV2::BroadcastStateless(_) => MessageType::BroadcastStateless,
339//             MessageTypeV2::Close(_) => MessageType::Close,
340//         }
341//     }
342// }
343
344#[derive(Debug, Clone, Copy, PartialEq, Eq)]
345pub enum SyncMessageType {
346    Step1 = 0,
347    Step2 = 1,
348    Update = 2,
349}
350
351impl From<u8> for SyncMessageType {
352    fn from(value: u8) -> Self {
353        match value {
354            0 => SyncMessageType::Step1,
355            1 => SyncMessageType::Step2,
356            2 => SyncMessageType::Update,
357            _ => panic!("Invalid sync message type: {}", value),
358        }
359    }
360}
361
362#[derive(Debug, Clone, Serialize, Deserialize)]
363pub struct ConnectionConfiguration {
364    pub id: String,
365    pub is_authenticated: bool,
366    pub readonly: bool,
367    pub user_id: Option<String>,
368    pub socket_id: String,
369    pub context: HashMap<String, serde_json::Value>,
370}
371
372#[derive(Debug, Clone, Serialize, Deserialize)]
373pub struct DocumentConfiguration {
374    pub name: String,
375    pub gc: bool,
376}
377
378impl Default for DocumentConfiguration {
379    fn default() -> Self {
380        Self {
381            name: String::new(),
382            gc: true,
383        }
384    }
385}
386
387#[derive(Clone)]
388pub struct HocuspocusConfiguration {
389    pub name: Option<String>,
390    pub timeout: u64,
391    pub debounce: u64,
392    pub max_debounce: u64,
393    pub quiet: bool,
394    pub extensions: Vec<Arc<dyn Extension>>,
395}
396
397impl Default for HocuspocusConfiguration {
398    fn default() -> Self {
399        Self {
400            name: None,
401            timeout: 30000,
402            debounce: 2000,
403            max_debounce: 10000,
404            quiet: false,
405            extensions: Vec::new(),
406        }
407    }
408}
409
410#[async_trait::async_trait]
411pub trait Extension: Send + Sync {
412    async fn on_configure(
413        &self,
414        _configuration: &mut HocuspocusConfiguration,
415    ) -> anyhow::Result<()> {
416        Ok(())
417    }
418
419    async fn on_listen(&self, _port: u16) -> anyhow::Result<()> {
420        Ok(())
421    }
422
423    async fn on_upgrade(
424        &self,
425        _request: &axum::http::Request<axum::body::Body>,
426    ) -> anyhow::Result<()> {
427        Ok(())
428    }
429
430    async fn on_connect(&self, _data: ConnectEventData) -> anyhow::Result<()> {
431        Ok(())
432    }
433
434    async fn on_authenticate(
435        &self,
436        _data: AuthenticateEventData,
437    ) -> anyhow::Result<serde_json::Value> {
438        Ok(serde_json::Value::Null)
439    }
440
441    async fn connected(&self, _data: ConnectedEventData) -> anyhow::Result<()> {
442        Ok(())
443    }
444
445    async fn on_create_document(&self, _data: CreateDocumentEventData) -> anyhow::Result<()> {
446        Ok(())
447    }
448
449    async fn on_load_document(
450        &self,
451        _data: LoadDocumentEventData,
452    ) -> anyhow::Result<Option<Vec<u8>>> {
453        Ok(None)
454    }
455
456    async fn before_handle_message(
457        &self,
458        _data: BeforeHandleMessageEventData,
459    ) -> anyhow::Result<()> {
460        Ok(())
461    }
462
463    async fn before_broadcast(&self, _data: BeforeBroadcastEventData) -> anyhow::Result<()> {
464        Ok(())
465    }
466
467    async fn on_change(&self, _data: ChangeEventData) -> anyhow::Result<()> {
468        Ok(())
469    }
470
471    async fn on_store_document(&self, _data: StoreDocumentEventData) -> anyhow::Result<()> {
472        Ok(())
473    }
474
475    async fn after_store_document(&self, _data: AfterStoreDocumentEventData) -> anyhow::Result<()> {
476        Ok(())
477    }
478
479    async fn on_awareness_update(&self, _data: AwarenessUpdateEventData) -> anyhow::Result<()> {
480        Ok(())
481    }
482
483    async fn before_send_stateless(
484        &self,
485        _data: BeforeSendStatelessEventData,
486    ) -> anyhow::Result<()> {
487        Ok(())
488    }
489
490    async fn before_send_awareness(
491        &self,
492        _data: BeforeSendAwarenessEventData,
493    ) -> anyhow::Result<()> {
494        Ok(())
495    }
496
497    async fn on_disconnect(&self, _data: DisconnectEventData) -> anyhow::Result<()> {
498        Ok(())
499    }
500
501    async fn on_destroy(&self) -> anyhow::Result<()> {
502        Ok(())
503    }
504
505    async fn after_unload_document(
506        &self,
507        _data: AfterUnloadDocumentEventData,
508    ) -> anyhow::Result<()> {
509        Ok(())
510    }
511
512    async fn on_stateless(&self, _data: StatelessEventData) -> anyhow::Result<()> {
513        Ok(())
514    }
515}
516
517#[derive(Clone)]
518pub struct ConnectEventData {
519    pub socket_id: String,
520    pub connection: ConnectionConfiguration,
521    pub request_headers: HashMap<String, String>,
522    pub request_path: String,
523}
524
525#[derive(Clone, Debug)]
526pub struct AuthenticateEventData {
527    pub doc_id: String,
528    pub token: String,
529    // pub socket_id: String,
530}
531
532#[derive(Clone)]
533pub struct ConnectedEventData {
534    pub socket_id: String,
535    // pub context: HashMap<String, serde_json::Value>,
536}
537
538#[derive(Clone)]
539pub struct CreateDocumentEventData {
540    pub document_name: String,
541    pub socket_id: String,
542    pub connection: ConnectionConfiguration,
543}
544
545#[derive(Clone)]
546pub struct LoadDocumentEventData {
547    pub document_name: String,
548    pub socket_id: String,
549}
550
551#[derive(Clone)]
552pub struct BeforeHandleMessageEventData {
553    pub message: Vec<u8>,
554    pub socket_id: String,
555    pub connection: ConnectionConfiguration,
556}
557
558#[derive(Clone)]
559pub struct BeforeBroadcastEventData {
560    pub document_name: String,
561    pub exclude: Vec<String>,
562    pub message: Vec<u8>,
563}
564
565#[derive(Clone)]
566pub struct ChangeEventData {
567    pub document_name: String,
568    pub socket_id: String,
569    pub update: Vec<u8>,
570}
571
572#[derive(Clone)]
573pub struct StoreDocumentEventData {
574    pub document_name: String,
575    pub state: Vec<u8>,
576}
577
578#[derive(Clone)]
579pub struct AfterStoreDocumentEventData {
580    pub document_name: String,
581}
582
583#[derive(Clone)]
584pub struct AwarenessUpdateEventData {
585    pub document_name: String,
586    pub awareness: Vec<u8>,
587    pub added: Vec<u64>,
588    pub updated: Vec<u64>,
589    pub removed: Vec<u64>,
590}
591
592#[derive(Clone)]
593pub struct BeforeSendStatelessEventData {
594    pub document_name: String,
595    pub socket_id: String,
596    pub payload: String,
597}
598
599#[derive(Clone)]
600pub struct BeforeSendAwarenessEventData {
601    pub document_name: String,
602    pub socket_id: String,
603    pub awareness: Vec<u8>,
604}
605
606#[derive(Clone)]
607pub struct DisconnectEventData {
608    pub socket_id: String,
609}
610
611#[derive(Clone)]
612pub struct AfterUnloadDocumentEventData {
613    pub document_name: String,
614}
615
616#[derive(Clone)]
617pub struct StatelessEventData {
618    pub document_name: String,
619    pub socket_id: String,
620    pub payload: String,
621}
622
623pub type ChannelId = String;
624pub type DocumentName = String;
625pub type SocketId = String;