1#![allow(dead_code)]
2use futures_util::sink::SinkExt;
3use futures_util::StreamExt;
4use std::future::Future;
5use std::marker::PhantomData;
6use std::pin::Pin;
7use std::sync::{Arc, Weak};
8use std::task::{Context, Poll};
9use tokio::spawn;
10use tokio::sync::{Mutex, RwLock};
11use tokio::task::JoinHandle;
12use yrs::encoding::read::Cursor;
13use yrs::sync::Awareness;
14use yrs::sync::{
15 DefaultProtocol, Error, Message, MessageReader, Protocol, SyncMessage,
16};
17use yrs::updates::decoder::{Decode, DecoderV1};
18use yrs::updates::encoder::{Encode, Encoder, EncoderV1};
19use yrs::Update;
20use std::time::Instant;
21
22#[derive(Debug)]
26pub struct Connection<Sink, Stream> {
27 processing_loop: JoinHandle<Result<(), Error>>,
28 awareness: Arc<RwLock<Awareness>>,
29 inbox: Arc<Mutex<Sink>>,
30 sync_tracker: Arc<RwLock<SyncTracker>>, _stream: PhantomData<Stream>,
32}
33
34impl<Sink, Stream, E> Connection<Sink, Stream>
35where
36 Sink: SinkExt<Vec<u8>, Error = E> + Send + Sync + Unpin + 'static,
37 E: Into<Error> + Send + Sync,
38{
39 pub async fn send(
40 &self,
41 msg: Vec<u8>,
42 ) -> Result<(), Error> {
43 let mut inbox = self.inbox.lock().await;
44 match inbox.send(msg).await {
45 Ok(_) => Ok(()),
46 Err(err) => Err(err.into()),
47 }
48 }
49
50 pub async fn close(self) -> Result<(), E> {
51 let mut inbox = self.inbox.lock().await;
52 inbox.close().await
53 }
54
55 pub fn sink(&self) -> Weak<Mutex<Sink>> {
56 Arc::downgrade(&self.inbox)
57 }
58}
59
60impl<Sink, Stream, E> Connection<Sink, Stream>
61where
62 Stream:
63 StreamExt<Item = Result<Vec<u8>, E>> + Send + Sync + Unpin + 'static,
64 Sink: SinkExt<Vec<u8>, Error = E> + Send + Sync + Unpin + 'static,
65 E: Into<Error> + Send + Sync,
66{
67 pub fn new_with_sync_detection(
69 awareness: Arc<RwLock<Awareness>>,
70 sink: Sink,
71 stream: Stream,
72 event_sender: Option<SyncEventSender>,
73 ) -> Self {
74 let sync_tracker =
75 Arc::new(RwLock::new(SyncTracker::new(event_sender)));
76 Self::with_protocol_and_sync(
77 awareness,
78 sink,
79 stream,
80 DefaultProtocol,
81 sync_tracker,
82 )
83 }
84 pub fn with_protocol_and_sync<P>(
86 awareness: Arc<RwLock<Awareness>>,
87 sink: Sink,
88 mut stream: Stream,
89 protocol: P,
90 sync_tracker: Arc<RwLock<SyncTracker>>,
91 ) -> Self
92 where
93 P: Protocol + Send + Sync + 'static,
94 {
95 let sink = Arc::new(Mutex::new(sink));
96 let inbox = sink.clone();
97 let loop_sink = Arc::downgrade(&sink);
98 let loop_awareness = Arc::downgrade(&awareness);
99 let loop_sync_tracker = Arc::downgrade(&sync_tracker);
100
101 let processing_loop: JoinHandle<Result<(), Error>> =
102 spawn(async move {
103 let payload = {
105 let awareness = loop_awareness.upgrade().unwrap();
106 let mut encoder = EncoderV1::new();
107 let awareness = awareness.read().await;
108 protocol.start(&awareness, &mut encoder)?;
109 encoder.to_vec()
110 };
111
112 if !payload.is_empty() {
113 if let Some(tracker) = loop_sync_tracker.upgrade() {
115 tracker.read().await.on_step1_sent();
116 }
117
118 if let Some(sink) = loop_sink.upgrade() {
119 let mut s = sink.lock().await;
120 if let Err(e) = s.send(payload).await {
121 return Err(e.into());
122 }
123 } else {
124 return Ok(());
125 }
126 }
127
128 while let Some(input) = stream.next().await {
130 match input {
131 Ok(data) => {
132 if let Some(mut sink) = loop_sink.upgrade() {
133 if let Some(awareness) =
134 loop_awareness.upgrade()
135 {
136 if let Some(sync_tracker) =
137 loop_sync_tracker.upgrade()
138 {
139 match Self::process_with_sync_detection(
140 &protocol,
141 &awareness,
142 &mut sink,
143 &sync_tracker,
144 data,
145 )
146 .await
147 {
148 Ok(()) => { },
149 Err(e) => return Err(e),
150 }
151 }
152 } else {
153 return Ok(());
154 }
155 } else {
156 return Ok(());
157 }
158 },
159 Err(e) => return Err(e.into()),
160 }
161 }
162
163 Ok(())
164 });
165
166 Connection {
167 processing_loop,
168 awareness,
169 inbox,
170 sync_tracker,
171 _stream: PhantomData::default(),
172 }
173 }
174 async fn process_with_sync_detection<P: Protocol>(
176 protocol: &P,
177 awareness: &Arc<RwLock<Awareness>>,
178 sink: &mut Arc<Mutex<Sink>>,
179 sync_tracker: &Arc<RwLock<SyncTracker>>,
180 input: Vec<u8>,
181 ) -> Result<(), Error> {
182 let mut decoder = DecoderV1::new(Cursor::new(&input));
183 let reader = MessageReader::new(&mut decoder);
184
185 for r in reader {
186 let msg = r?;
187
188 Self::track_sync_message(&msg, sync_tracker).await;
190
191 if let Some(reply) = handle_msg(protocol, &awareness, msg).await? {
192 let mut sender = sink.lock().await;
193 if let Err(e) = sender.send(reply.encode_v1()).await {
194 tracing::error!("连接发送回复失败");
195 return Err(e.into());
196 }
197 }
198 }
199
200 Ok(())
201 }
202 async fn track_sync_message(
204 msg: &Message,
205 sync_tracker: &Arc<RwLock<SyncTracker>>,
206 ) {
207 match msg {
208 Message::Sync(sync_msg) => {
209 match sync_msg {
210 SyncMessage::SyncStep2(_) => {
211 let mut tracker = sync_tracker.write().await;
213 tracker.on_step2_received();
214 },
215 SyncMessage::Update(_) => {
216 let tracker = sync_tracker.read().await;
218 tracker.on_update_received();
219 },
220 _ => {},
221 }
222 },
223 _ => {},
224 }
225 }
226
227 pub fn sync_tracker(&self) -> &Arc<RwLock<SyncTracker>> {
229 &self.sync_tracker
230 }
231
232 pub async fn wait_for_initial_sync(
234 &self,
235 timeout_ms: u64,
236 ) -> bool {
237 let start_time = Instant::now();
238 let timeout_duration = tokio::time::Duration::from_millis(timeout_ms);
239
240 loop {
241 {
242 let tracker = self.sync_tracker.read().await;
243 if tracker.is_initial_sync_completed() {
244 return true;
245 }
246 }
247
248 if start_time.elapsed() >= timeout_duration {
249 break;
250 }
251
252 tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
253 }
254
255 false
256 }
257
258 pub async fn get_protocol_sync_state(&self) -> ProtocolSyncState {
260 self.sync_tracker.read().await.get_protocol_state()
261 }
262 pub fn awareness(&self) -> &Arc<RwLock<Awareness>> {
264 &self.awareness
265 }
266}
267
268impl<Sink, Stream> Unpin for Connection<Sink, Stream> {}
269
270impl<Sink, Stream> Future for Connection<Sink, Stream> {
271 type Output = Result<(), Error>;
272
273 fn poll(
274 mut self: Pin<&mut Self>,
275 cx: &mut Context<'_>,
276 ) -> Poll<Self::Output> {
277 match Pin::new(&mut self.processing_loop).poll(cx) {
278 Poll::Pending => Poll::Pending,
279 Poll::Ready(Err(e)) => Poll::Ready(Err(Error::Other(e.into()))),
280 Poll::Ready(Ok(r)) => Poll::Ready(r),
281 }
282 }
283}
284
285pub async fn handle_msg<P: Protocol>(
286 protocol: &P,
287 a: &Arc<RwLock<Awareness>>,
288 msg: Message,
289) -> Result<Option<Message>, Error> {
290 match msg {
291 Message::Sync(msg) => match msg {
292 SyncMessage::SyncStep1(sv) => {
293 let awareness = a.read().await;
294 protocol.handle_sync_step1(&awareness, sv)
295 },
296 SyncMessage::SyncStep2(update) => {
297 let mut awareness = a.write().await;
298 protocol.handle_sync_step2(
299 &mut awareness,
300 Update::decode_v1(&update)?,
301 )
302 },
303 SyncMessage::Update(update) => {
304 let mut awareness = a.write().await;
305 protocol
306 .handle_update(&mut awareness, Update::decode_v1(&update)?)
307 },
308 },
309 Message::Auth(reason) => {
310 let awareness = a.read().await;
311 protocol.handle_auth(&awareness, reason)
312 },
313 Message::AwarenessQuery => {
314 let awareness = a.read().await;
315 protocol.handle_awareness_query(&awareness)
316 },
317 Message::Awareness(update) => {
318 let mut awareness = a.write().await;
319 protocol.handle_awareness_update(&mut awareness, update)
320 },
321 Message::Custom(tag, data) => {
322 let mut awareness = a.write().await;
323 protocol.missing_handle(&mut awareness, tag, data)
324 },
325 }
326}
327
328use crate::types::{ConnectionError, ProtocolSyncState, SyncEvent, SyncEventSender};
329use std::sync::atomic::{AtomicBool, AtomicU8, Ordering};
330
331#[derive(Debug)]
333pub struct SyncTracker {
334 protocol_state: AtomicU8, has_data: AtomicBool,
336 start_time: Option<Instant>,
337 step2_time: Option<Instant>,
338 event_sender: Option<SyncEventSender>,
339}
340
341impl SyncTracker {
342 pub fn new(event_sender: Option<SyncEventSender>) -> Self {
343 Self {
344 protocol_state: AtomicU8::new(0),
345 has_data: AtomicBool::new(false),
346 start_time: Some(Instant::now()),
347 step2_time: None,
348 event_sender,
349 }
350 }
351 pub fn on_step1_sent(&self) {
352 let prev = self.protocol_state.swap(1, Ordering::Relaxed);
353 if prev == 0 {
354 tracing::debug!("📡 协议: SyncStep1 已发送");
355 self.emit_event(SyncEvent::ProtocolStateChanged(
356 ProtocolSyncState::Step1Sent,
357 ));
358 }
359 }
360 pub fn on_step2_received(&mut self) -> bool {
361 let prev = self.protocol_state.swap(2, Ordering::Relaxed);
362 if prev == 1 {
363 self.step2_time = Some(Instant::now());
365
366 let elapsed_ms = if let (Some(start), Some(step2)) =
367 (self.start_time, self.step2_time)
368 {
369 step2.duration_since(start).as_millis() as u64
370 } else {
371 0
372 };
373
374 let has_data = self.has_data.load(Ordering::Relaxed);
375
376 tracing::info!(
377 "🎉 协议同步完成: Step1->Step2, 耗时 {}ms, 有数据: {}",
378 elapsed_ms,
379 has_data
380 );
381
382 self.emit_event(SyncEvent::ProtocolStateChanged(
383 ProtocolSyncState::Step2Received,
384 ));
385 self.emit_event(SyncEvent::InitialSyncCompleted {
386 has_data,
387 elapsed_ms,
388 });
389
390 return true; }
392 false
393 }
394 pub fn on_update_received(&self) {
395 let prev_state = self.protocol_state.load(Ordering::Relaxed);
396
397 self.has_data.store(true, Ordering::Relaxed);
399
400 if prev_state == 2 {
402 self.protocol_state.store(3, Ordering::Relaxed);
403 self.emit_event(SyncEvent::ProtocolStateChanged(
404 ProtocolSyncState::Updating,
405 ));
406 }
407
408 self.emit_event(SyncEvent::DataReceived);
409 }
410
411 pub fn is_initial_sync_completed(&self) -> bool {
412 self.protocol_state.load(Ordering::Relaxed) >= 2
413 }
414
415 pub fn get_protocol_state(&self) -> ProtocolSyncState {
416 match self.protocol_state.load(Ordering::Relaxed) {
417 0 => ProtocolSyncState::NotStarted,
418 1 => ProtocolSyncState::Step1Sent,
419 2 => ProtocolSyncState::Step2Received,
420 3 => ProtocolSyncState::Updating,
421 _ => ProtocolSyncState::NotStarted,
422 }
423 }
424 fn emit_event(
425 &self,
426 event: SyncEvent,
427 ) {
428 if let Some(sender) = &self.event_sender {
429 let _ = sender.send(event);
430 }
431 }
432 pub fn reset(&mut self) {
434 self.protocol_state.store(0, Ordering::Relaxed);
435 self.has_data.store(false, Ordering::Relaxed);
436 self.start_time = Some(Instant::now());
437 self.step2_time = None;
438 }
439
440 pub fn on_connection_failed(
442 &self,
443 error: &ConnectionError,
444 ) {
445 tracing::error!("🔌 连接失败: {}", error);
446 self.emit_event(SyncEvent::ConnectionFailed(error.clone()));
447 }
448}