hocuspocus_rs_ws/types.rs
1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3use std::sync::Arc;
4use yrs::
5 encoding::read::Error as DecodeError
6;
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
9pub enum MessageType {
10 Sync = 0,
11 Awareness = 1,
12 Auth = 2,
13 QueryAwareness = 3,
14 SyncReply = 4, // ?
15 Stateless = 5, // ?
16 BroadcastStateless = 6, // ?
17 SyncStatus = 7, // ?
18 Close = 8, // ?
19}
20
21impl From<u8> for MessageType {
22 fn from(value: u8) -> Self {
23 match value {
24 0 => MessageType::Sync,
25 1 => MessageType::Awareness,
26 2 => MessageType::Auth,
27 3 => MessageType::QueryAwareness,
28 5 => MessageType::Stateless,
29 6 => MessageType::BroadcastStateless,
30 7 => MessageType::SyncStatus,
31 8 => MessageType::Close,
32 _ => panic!("Invalid message type: {}", value),
33 }
34 }
35}
36
37#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
38#[repr(u8)]
39pub enum SyncType {
40 Step1 = 0,
41 Step2 = 1,
42 Update = 2,
43}
44
45impl SyncType {
46 pub fn from_u8(value: u8) -> Result<Self, DecodeError> {
47 match value {
48 0 => Ok(SyncType::Step1),
49 1 => Ok(SyncType::Step2),
50 2 => Ok(SyncType::Update),
51 _ => panic!("Invalid sync type: {}", value),
52 }
53 }
54
55 pub fn from_u64(value: u64) -> Result<Self, DecodeError> {
56 Self::from_u8(value as u8)
57 }
58}
59
60#[derive(Debug, Clone, PartialEq)]
61pub struct ClientAwareness {
62 pub client_id: u64,
63 pub clock: u32,
64 pub state: Option<serde_json::Value>,
65}
66
67// // Awareness 프로토콜 메시지
68// #[derive(Debug)]
69// pub struct AwarenessUpdate {
70// pub clients: Vec<AwarenessClient>,
71// }
72
73// #[derive(Debug, Clone)]
74// pub struct AwarenessClient {
75// pub client_id: u64,
76// pub clock: u32,
77// pub state: Option<Vec<u8>>, // JSON as bytes
78// }
79
80// // Awareness 상태 관리
81// #[derive(Debug, Clone)]
82// pub struct Awareness {
83// pub states: HashMap<u64, AwarenessState>,
84// pub local_client_id: u64,
85// }
86
87// #[derive(Debug, Clone)]
88// pub struct AwarenessState {
89// pub clock: u32,
90// pub state: serde_json::Value,
91// pub last_updated: std::time::Instant,
92// }
93
94// impl Awareness {
95// pub fn new(client_id: u64) -> Self {
96// Self {
97// states: HashMap::new(),
98// local_client_id: client_id,
99// }
100// }
101
102// pub fn set_local_state(&mut self, state: serde_json::Value) {
103// let entry = self
104// .states
105// .entry(self.local_client_id)
106// .or_insert(AwarenessState {
107// clock: 0,
108// state: serde_json::Value::Null,
109// last_updated: std::time::Instant::now(),
110// });
111
112// entry.clock += 1;
113// entry.state = state;
114// entry.last_updated = std::time::Instant::now();
115// }
116
117// pub fn get_states(&self) -> &HashMap<u64, AwarenessState> {
118// &self.states
119// }
120
121// pub fn apply_update(&mut self, update: AwarenessUpdate) -> Vec<u64> {
122// let mut changed = Vec::new();
123
124// for client in update.clients {
125// let should_update = self
126// .states
127// .get(&client.client_id)
128// .map(|s| s.clock < client.clock)
129// .unwrap_or(true);
130
131// if should_update {
132// if let Some(state_bytes) = client.state {
133// if let Ok(state) = serde_json::from_slice(&state_bytes) {
134// self.states.insert(
135// client.client_id,
136// AwarenessState {
137// clock: client.clock,
138// state,
139// last_updated: std::time::Instant::now(),
140// },
141// );
142// changed.push(client.client_id);
143// }
144// } else {
145// // Client disconnected
146// self.states.remove(&client.client_id);
147// changed.push(client.client_id);
148// }
149// }
150// }
151
152// changed
153// }
154
155// pub fn encode_update(&self, clients: &[u64]) -> Vec<u8> {
156// let mut encoder = Encoder::new();
157
158// // Number of clients
159// encoder.write_var_uint(clients.len() as u64);
160
161// for &client_id in clients {
162// encoder.write_var_uint(client_id);
163
164// if let Some(state) = self.states.get(&client_id) {
165// encoder.write_var_uint(state.clock as u64);
166// let json_bytes = serde_json::to_vec(&state.state).unwrap_or_default();
167// encoder.write_var_string(&String::from_utf8_lossy(&json_bytes));
168// } else {
169// encoder.write_var_uint(0); // clock = 0 means removed
170// encoder.write_var_uint(0); // empty state
171// }
172// }
173
174// encoder.to_vec()
175// }
176
177// pub fn decode_update(data: &[u8]) -> Result<AwarenessUpdate, DecodeError> {
178// let mut decoder = Decoder::new(data);
179// let num_clients = decoder.read_var_uint()? as usize;
180// let mut clients = Vec::with_capacity(num_clients);
181
182// for _ in 0..num_clients {
183// let client_id = decoder.read_var_uint()?;
184// let clock = decoder.read_var_uint()? as u32;
185
186// let state = if clock > 0 {
187// let state_str = decoder.read_var_string()?;
188// Some(state_str.as_bytes().to_vec())
189// } else {
190// None
191// };
192
193// clients.push(AwarenessClient {
194// client_id,
195// clock,
196// state,
197// });
198// }
199
200// Ok(AwarenessUpdate { clients })
201// }
202// }
203
204// 더 자세한 Sync 메시지 구조
205#[derive(Debug)]
206pub enum SyncMessage<'a> {
207 Step1 { state_vector: &'a [u8] },
208 Step2 { update: &'a [u8] },
209 Update { update: &'a [u8] },
210}
211
212impl<'a> SyncMessage<'a> {
213 pub fn decode(sync_type: SyncType, data: &'a [u8]) -> Result<Self, DecodeError> {
214 match sync_type {
215 SyncType::Step1 => Ok(SyncMessage::Step1 { state_vector: data }),
216 SyncType::Step2 => Ok(SyncMessage::Step2 { update: data }),
217 SyncType::Update => Ok(SyncMessage::Update { update: data }),
218 }
219 }
220}
221
222// // Separate payload structs for each message type
223// #[derive(Debug, Clone)]
224// pub struct SyncPayload {
225// pub sync_type: SyncType,
226// pub data: Vec<u8>,
227// }
228
229// #[derive(Debug, Clone)]
230// pub struct SyncReplyPayload {
231// pub sync_type: SyncType,
232// pub data: Vec<u8>,
233// }
234
235// #[derive(Debug, Clone)]
236// pub struct AwarenessPayload {
237// pub clients: Vec<ClientAwareness>,
238// }
239
240// #[derive(Debug, Clone)]
241// pub struct AuthPayload {
242// pub token: String,
243// }
244
245// #[derive(Debug, Clone)]
246// pub struct QueryAwarenessPayload {
247// // No additional data needed
248// }
249
250// #[derive(Debug, Clone)]
251// pub struct StatelessPayload {
252// pub payload: Vec<u8>,
253// }
254
255// #[derive(Debug, Clone)]
256// pub struct BroadcastStatelessPayload {
257// pub payload: Vec<u8>,
258// }
259
260// #[derive(Debug, Clone)]
261// pub struct ClosePayload {
262// pub code: u16,
263// pub reason: String,
264// }
265
266// #[derive(Debug)]
267// pub enum MessageTypeV2 {
268// Sync(SyncPayload),
269// SyncReply(SyncReplyPayload),
270// Awareness(AwarenessPayload),
271// Auth(AuthPayload),
272// QueryAwareness(QueryAwarenessPayload),
273// Stateless(StatelessPayload),
274// BroadcastStateless(BroadcastStatelessPayload),
275// Close(ClosePayload),
276// }
277
278// impl MessageTypeV2 {
279// pub fn decode(buffer: &[u8], skip_document: bool) -> Result<Self, DecodeError> {
280// let mut decoder = DecoderV2::new(Cursor::new(buffer)).unwrap();
281
282// // Skip document name if present in the buffer
283// if !skip_document {
284// let _document = decoder.read_string()?;
285// }
286
287// let msg_type = decoder.read_u8()?;
288
289// match msg_type {
290// 0 => {
291// let sync_type = SyncType::from_u8(decoder.read_u8()?)?;
292// Ok(MessageTypeV2::Sync(SyncPayload {
293// sync_type,
294// data: decoder.read_buf()?.to_vec(),
295// }))
296// }
297// 1 => {
298// let clients = decode_awareness(&mut decoder)?;
299// Ok(MessageTypeV2::Awareness(AwarenessPayload { clients }))
300// }
301// 2 => {
302// let token = decoder.read_str()?.to_string();
303// Ok(MessageTypeV2::Auth(AuthPayload { token }))
304// }
305// 3 => Ok(MessageTypeV2::QueryAwareness(QueryAwarenessPayload {})),
306// 4 => {
307// let sync_type = SyncType::from_u8(decoder.read_u8()?)?;
308// Ok(MessageTypeV2::SyncReply(SyncReplyPayload {
309// sync_type,
310// data: decoder.read_buf()?.to_vec(),
311// }))
312// }
313// 5 => Ok(MessageTypeV2::Stateless(StatelessPayload {
314// payload: decoder.read_buf()?.to_vec(),
315// })),
316// 6 => Ok(MessageTypeV2::BroadcastStateless(
317// BroadcastStatelessPayload {
318// payload: decoder.read_buf()?.to_vec(),
319// },
320// )),
321// 8 => {
322// let code = decoder.read_u16()?;
323// let reason = decoder.read_str()?.to_string();
324// Ok(MessageTypeV2::Close(ClosePayload { code, reason }))
325// }
326// _ => Err(DecodeError::UnknownMessageType(msg_type)),
327// }
328// }
329
330// pub fn message_type(&self) -> MessageType {
331// match self {
332// MessageTypeV2::Sync(_) => MessageType::Sync,
333// MessageTypeV2::SyncReply(_) => MessageType::SyncReply,
334// MessageTypeV2::Awareness(_) => MessageType::Awareness,
335// MessageTypeV2::Auth(_) => MessageType::Auth,
336// MessageTypeV2::QueryAwareness(_) => MessageType::QueryAwareness,
337// MessageTypeV2::Stateless(_) => MessageType::Stateless,
338// MessageTypeV2::BroadcastStateless(_) => MessageType::BroadcastStateless,
339// MessageTypeV2::Close(_) => MessageType::Close,
340// }
341// }
342// }
343
344#[derive(Debug, Clone, Copy, PartialEq, Eq)]
345pub enum SyncMessageType {
346 Step1 = 0,
347 Step2 = 1,
348 Update = 2,
349}
350
351impl From<u8> for SyncMessageType {
352 fn from(value: u8) -> Self {
353 match value {
354 0 => SyncMessageType::Step1,
355 1 => SyncMessageType::Step2,
356 2 => SyncMessageType::Update,
357 _ => panic!("Invalid sync message type: {}", value),
358 }
359 }
360}
361
362#[derive(Debug, Clone, Serialize, Deserialize)]
363pub struct ConnectionConfiguration {
364 pub id: String,
365 pub is_authenticated: bool,
366 pub readonly: bool,
367 pub user_id: Option<String>,
368 pub socket_id: String,
369 pub context: HashMap<String, serde_json::Value>,
370}
371
372#[derive(Debug, Clone, Serialize, Deserialize)]
373pub struct DocumentConfiguration {
374 pub name: String,
375 pub gc: bool,
376}
377
378impl Default for DocumentConfiguration {
379 fn default() -> Self {
380 Self {
381 name: String::new(),
382 gc: true,
383 }
384 }
385}
386
387#[derive(Clone)]
388pub struct HocuspocusConfiguration {
389 pub name: Option<String>,
390 pub timeout: u64,
391 pub debounce: u64,
392 pub max_debounce: u64,
393 pub quiet: bool,
394 pub extensions: Vec<Arc<dyn Extension>>,
395}
396
397impl Default for HocuspocusConfiguration {
398 fn default() -> Self {
399 Self {
400 name: None,
401 timeout: 30000,
402 debounce: 2000,
403 max_debounce: 10000,
404 quiet: false,
405 extensions: Vec::new(),
406 }
407 }
408}
409
410#[async_trait::async_trait]
411pub trait Extension: Send + Sync {
412 async fn on_configure(
413 &self,
414 _configuration: &mut HocuspocusConfiguration,
415 ) -> anyhow::Result<()> {
416 Ok(())
417 }
418
419 async fn on_listen(&self, _port: u16) -> anyhow::Result<()> {
420 Ok(())
421 }
422
423 async fn on_upgrade(
424 &self,
425 _request: &axum::http::Request<axum::body::Body>,
426 ) -> anyhow::Result<()> {
427 Ok(())
428 }
429
430 async fn on_connect(&self, _data: ConnectEventData) -> anyhow::Result<()> {
431 Ok(())
432 }
433
434 async fn on_authenticate(
435 &self,
436 _data: AuthenticateEventData,
437 ) -> anyhow::Result<serde_json::Value> {
438 Ok(serde_json::Value::Null)
439 }
440
441 async fn connected(&self, _data: ConnectedEventData) -> anyhow::Result<()> {
442 Ok(())
443 }
444
445 async fn on_create_document(&self, _data: CreateDocumentEventData) -> anyhow::Result<()> {
446 Ok(())
447 }
448
449 async fn on_load_document(
450 &self,
451 _data: LoadDocumentEventData,
452 ) -> anyhow::Result<Option<Vec<u8>>> {
453 Ok(None)
454 }
455
456 async fn before_handle_message(
457 &self,
458 _data: BeforeHandleMessageEventData,
459 ) -> anyhow::Result<()> {
460 Ok(())
461 }
462
463 async fn before_broadcast(&self, _data: BeforeBroadcastEventData) -> anyhow::Result<()> {
464 Ok(())
465 }
466
467 async fn on_change(&self, _data: ChangeEventData) -> anyhow::Result<()> {
468 Ok(())
469 }
470
471 async fn on_store_document(&self, _data: StoreDocumentEventData) -> anyhow::Result<()> {
472 Ok(())
473 }
474
475 async fn after_store_document(&self, _data: AfterStoreDocumentEventData) -> anyhow::Result<()> {
476 Ok(())
477 }
478
479 async fn on_awareness_update(&self, _data: AwarenessUpdateEventData) -> anyhow::Result<()> {
480 Ok(())
481 }
482
483 async fn before_send_stateless(
484 &self,
485 _data: BeforeSendStatelessEventData,
486 ) -> anyhow::Result<()> {
487 Ok(())
488 }
489
490 async fn before_send_awareness(
491 &self,
492 _data: BeforeSendAwarenessEventData,
493 ) -> anyhow::Result<()> {
494 Ok(())
495 }
496
497 async fn on_disconnect(&self, _data: DisconnectEventData) -> anyhow::Result<()> {
498 Ok(())
499 }
500
501 async fn on_destroy(&self) -> anyhow::Result<()> {
502 Ok(())
503 }
504
505 async fn after_unload_document(
506 &self,
507 _data: AfterUnloadDocumentEventData,
508 ) -> anyhow::Result<()> {
509 Ok(())
510 }
511
512 async fn on_stateless(&self, _data: StatelessEventData) -> anyhow::Result<()> {
513 Ok(())
514 }
515}
516
517#[derive(Clone)]
518pub struct ConnectEventData {
519 pub socket_id: String,
520 pub connection: ConnectionConfiguration,
521 pub request_headers: HashMap<String, String>,
522 pub request_path: String,
523}
524
525#[derive(Clone, Debug)]
526pub struct AuthenticateEventData {
527 pub doc_id: String,
528 pub token: String,
529 // pub socket_id: String,
530}
531
532#[derive(Clone)]
533pub struct ConnectedEventData {
534 pub socket_id: String,
535 // pub context: HashMap<String, serde_json::Value>,
536}
537
538#[derive(Clone)]
539pub struct CreateDocumentEventData {
540 pub document_name: String,
541 pub socket_id: String,
542 pub connection: ConnectionConfiguration,
543}
544
545#[derive(Clone)]
546pub struct LoadDocumentEventData {
547 pub document_name: String,
548 pub socket_id: String,
549}
550
551#[derive(Clone)]
552pub struct BeforeHandleMessageEventData {
553 pub message: Vec<u8>,
554 pub socket_id: String,
555 pub connection: ConnectionConfiguration,
556}
557
558#[derive(Clone)]
559pub struct BeforeBroadcastEventData {
560 pub document_name: String,
561 pub exclude: Vec<String>,
562 pub message: Vec<u8>,
563}
564
565#[derive(Clone)]
566pub struct ChangeEventData {
567 pub document_name: String,
568 pub socket_id: String,
569 pub update: Vec<u8>,
570}
571
572#[derive(Clone)]
573pub struct StoreDocumentEventData {
574 pub document_name: String,
575 pub state: Vec<u8>,
576}
577
578#[derive(Clone)]
579pub struct AfterStoreDocumentEventData {
580 pub document_name: String,
581}
582
583#[derive(Clone)]
584pub struct AwarenessUpdateEventData {
585 pub document_name: String,
586 pub awareness: Vec<u8>,
587 pub added: Vec<u64>,
588 pub updated: Vec<u64>,
589 pub removed: Vec<u64>,
590}
591
592#[derive(Clone)]
593pub struct BeforeSendStatelessEventData {
594 pub document_name: String,
595 pub socket_id: String,
596 pub payload: String,
597}
598
599#[derive(Clone)]
600pub struct BeforeSendAwarenessEventData {
601 pub document_name: String,
602 pub socket_id: String,
603 pub awareness: Vec<u8>,
604}
605
606#[derive(Clone)]
607pub struct DisconnectEventData {
608 pub socket_id: String,
609}
610
611#[derive(Clone)]
612pub struct AfterUnloadDocumentEventData {
613 pub document_name: String,
614}
615
616#[derive(Clone)]
617pub struct StatelessEventData {
618 pub document_name: String,
619 pub socket_id: String,
620 pub payload: String,
621}
622
623pub type ChannelId = String;
624pub type DocumentName = String;
625pub type SocketId = String;