1#![allow(dead_code)]
2use futures_util::sink::SinkExt;
3use futures_util::StreamExt;
4use std::future::Future;
5use std::marker::PhantomData;
6use std::pin::Pin;
7use std::sync::{Arc, Weak};
8use std::task::{Context, Poll};
9use tokio::spawn;
10use tokio::sync::{Mutex, RwLock};
11use tokio::task::JoinHandle;
12use yrs::encoding::read::Cursor;
13use yrs::sync::Awareness;
14use yrs::sync::{
15 DefaultProtocol, Error, Message, MessageReader, Protocol, SyncMessage,
16};
17use yrs::updates::decoder::{Decode, DecoderV1};
18use yrs::updates::encoder::{Encode, Encoder, EncoderV1};
19use yrs::Update;
20use std::time::Instant;
21
22#[derive(Debug)]
26pub struct Connection<Sink, Stream> {
27 processing_loop: JoinHandle<Result<(), Error>>,
28 awareness: Arc<RwLock<Awareness>>,
29 inbox: Arc<Mutex<Sink>>,
30 sync_tracker: Arc<RwLock<SyncTracker>>, _stream: PhantomData<Stream>,
32}
33
34impl<Sink, Stream, E> Connection<Sink, Stream>
35where
36 Sink: SinkExt<Vec<u8>, Error = E> + Send + Sync + Unpin + 'static,
37 E: Into<Error> + Send + Sync,
38{
39 pub async fn send(
40 &self,
41 msg: Vec<u8>,
42 ) -> Result<(), Error> {
43 let mut inbox = self.inbox.lock().await;
44 match inbox.send(msg).await {
45 Ok(_) => Ok(()),
46 Err(err) => Err(err.into()),
47 }
48 }
49
50 pub async fn close(self) -> Result<(), E> {
51 let mut inbox = self.inbox.lock().await;
52 inbox.close().await
53 }
54
55 pub fn sink(&self) -> Weak<Mutex<Sink>> {
56 Arc::downgrade(&self.inbox)
57 }
58}
59
60impl<Sink, Stream, E> Connection<Sink, Stream>
61where
62 Stream:
63 StreamExt<Item = Result<Vec<u8>, E>> + Send + Sync + Unpin + 'static,
64 Sink: SinkExt<Vec<u8>, Error = E> + Send + Sync + Unpin + 'static,
65 E: Into<Error> + Send + Sync,
66{
67 pub fn new_with_sync_detection(
69 awareness: Arc<RwLock<Awareness>>,
70 sink: Sink,
71 stream: Stream,
72 event_sender: Option<SyncEventSender>,
73 ) -> Self {
74 let sync_tracker =
75 Arc::new(RwLock::new(SyncTracker::new(event_sender)));
76 Self::with_protocol_and_sync(
77 awareness,
78 sink,
79 stream,
80 DefaultProtocol,
81 sync_tracker,
82 )
83 }
84 pub fn with_protocol_and_sync<P>(
86 awareness: Arc<RwLock<Awareness>>,
87 sink: Sink,
88 mut stream: Stream,
89 protocol: P,
90 sync_tracker: Arc<RwLock<SyncTracker>>,
91 ) -> Self
92 where
93 P: Protocol + Send + Sync + 'static,
94 {
95 let sink = Arc::new(Mutex::new(sink));
96 let inbox = sink.clone();
97 let loop_sink = Arc::downgrade(&sink);
98 let loop_awareness = Arc::downgrade(&awareness);
99 let loop_sync_tracker = Arc::downgrade(&sync_tracker);
100
101 let processing_loop: JoinHandle<Result<(), Error>> =
102 spawn(async move {
103 let payload = {
105 let awareness = loop_awareness.upgrade().unwrap();
106 let mut encoder = EncoderV1::new();
107 let awareness = awareness.read().await;
108 protocol.start(&awareness, &mut encoder)?;
109 encoder.to_vec()
110 };
111
112 if !payload.is_empty() {
113 if let Some(tracker) = loop_sync_tracker.upgrade() {
115 tracker.read().await.on_step1_sent();
116 }
117
118 if let Some(sink) = loop_sink.upgrade() {
119 let mut s = sink.lock().await;
120 if let Err(e) = s.send(payload).await {
121 return Err(e.into());
122 }
123 } else {
124 return Ok(());
125 }
126 }
127
128 while let Some(input) = stream.next().await {
130 match input {
131 Ok(data) => {
132 if let Some(mut sink) = loop_sink.upgrade() {
133 if let Some(awareness) =
134 loop_awareness.upgrade()
135 {
136 if let Some(sync_tracker) =
137 loop_sync_tracker.upgrade()
138 {
139 match Self::process_with_sync_detection(
140 &protocol,
141 &awareness,
142 &mut sink,
143 &sync_tracker,
144 data,
145 )
146 .await
147 {
148 Ok(()) => { },
149 Err(e) => return Err(e),
150 }
151 }
152 } else {
153 return Ok(());
154 }
155 } else {
156 return Ok(());
157 }
158 },
159 Err(e) => return Err(e.into()),
160 }
161 }
162
163 Ok(())
164 });
165
166 Connection {
167 processing_loop,
168 awareness,
169 inbox,
170 sync_tracker,
171 _stream: PhantomData,
172 }
173 }
174 async fn process_with_sync_detection<P: Protocol>(
176 protocol: &P,
177 awareness: &Arc<RwLock<Awareness>>,
178 sink: &mut Arc<Mutex<Sink>>,
179 sync_tracker: &Arc<RwLock<SyncTracker>>,
180 input: Vec<u8>,
181 ) -> Result<(), Error> {
182 let mut decoder = DecoderV1::new(Cursor::new(&input));
183 let reader = MessageReader::new(&mut decoder);
184
185 for r in reader {
186 let msg = r?;
187
188 Self::track_sync_message(&msg, sync_tracker).await;
190
191 if let Some(reply) = handle_msg(protocol, awareness, msg).await? {
192 let mut sender = sink.lock().await;
193 if let Err(e) = sender.send(reply.encode_v1()).await {
194 tracing::error!("连接发送回复失败");
195 return Err(e.into());
196 }
197 }
198 }
199
200 Ok(())
201 }
202 async fn track_sync_message(
204 msg: &Message,
205 sync_tracker: &Arc<RwLock<SyncTracker>>,
206 ) {
207 if let Message::Sync(sync_msg) = msg {
208 match sync_msg {
209 SyncMessage::SyncStep2(_) => {
210 let mut tracker = sync_tracker.write().await;
212 tracker.on_step2_received();
213 },
214 SyncMessage::Update(_) => {
215 let tracker = sync_tracker.read().await;
217 tracker.on_update_received();
218 },
219 _ => {},
220 }
221 }
222 }
223
224 pub fn sync_tracker(&self) -> &Arc<RwLock<SyncTracker>> {
226 &self.sync_tracker
227 }
228
229 pub async fn wait_for_initial_sync(
231 &self,
232 timeout_ms: u64,
233 ) -> bool {
234 let start_time = Instant::now();
235 let timeout_duration = tokio::time::Duration::from_millis(timeout_ms);
236
237 loop {
238 {
239 let tracker = self.sync_tracker.read().await;
240 if tracker.is_initial_sync_completed() {
241 return true;
242 }
243 }
244
245 if start_time.elapsed() >= timeout_duration {
246 break;
247 }
248
249 tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
250 }
251
252 false
253 }
254
255 pub async fn get_protocol_sync_state(&self) -> ProtocolSyncState {
257 self.sync_tracker.read().await.get_protocol_state()
258 }
259 pub fn awareness(&self) -> &Arc<RwLock<Awareness>> {
261 &self.awareness
262 }
263}
264
265impl<Sink, Stream> Unpin for Connection<Sink, Stream> {}
266
267impl<Sink, Stream> Future for Connection<Sink, Stream> {
268 type Output = Result<(), Error>;
269
270 fn poll(
271 mut self: Pin<&mut Self>,
272 cx: &mut Context<'_>,
273 ) -> Poll<Self::Output> {
274 match Pin::new(&mut self.processing_loop).poll(cx) {
275 Poll::Pending => Poll::Pending,
276 Poll::Ready(Err(e)) => Poll::Ready(Err(Error::Other(e.into()))),
277 Poll::Ready(Ok(r)) => Poll::Ready(r),
278 }
279 }
280}
281
282pub async fn handle_msg<P: Protocol>(
283 protocol: &P,
284 a: &Arc<RwLock<Awareness>>,
285 msg: Message,
286) -> Result<Option<Message>, Error> {
287 match msg {
288 Message::Sync(msg) => match msg {
289 SyncMessage::SyncStep1(sv) => {
290 let awareness = a.read().await;
291 protocol.handle_sync_step1(&awareness, sv)
292 },
293 SyncMessage::SyncStep2(update) => {
294 let mut awareness = a.write().await;
295 protocol.handle_sync_step2(
296 &mut awareness,
297 Update::decode_v1(&update)?,
298 )
299 },
300 SyncMessage::Update(update) => {
301 let mut awareness = a.write().await;
302 protocol
303 .handle_update(&mut awareness, Update::decode_v1(&update)?)
304 },
305 },
306 Message::Auth(reason) => {
307 let awareness = a.read().await;
308 protocol.handle_auth(&awareness, reason)
309 },
310 Message::AwarenessQuery => {
311 let awareness = a.read().await;
312 protocol.handle_awareness_query(&awareness)
313 },
314 Message::Awareness(update) => {
315 let mut awareness = a.write().await;
316 protocol.handle_awareness_update(&mut awareness, update)
317 },
318 Message::Custom(tag, data) => {
319 let mut awareness = a.write().await;
320 protocol.missing_handle(&mut awareness, tag, data)
321 },
322 }
323}
324
325use crate::types::{ConnectionError, ProtocolSyncState, SyncEvent, SyncEventSender};
326use std::sync::atomic::{AtomicBool, AtomicU8, Ordering};
327
328#[derive(Debug)]
330pub struct SyncTracker {
331 protocol_state: AtomicU8, has_data: AtomicBool,
333 start_time: Option<Instant>,
334 step2_time: Option<Instant>,
335 event_sender: Option<SyncEventSender>,
336}
337
338impl SyncTracker {
339 pub fn new(event_sender: Option<SyncEventSender>) -> Self {
340 Self {
341 protocol_state: AtomicU8::new(0),
342 has_data: AtomicBool::new(false),
343 start_time: Some(Instant::now()),
344 step2_time: None,
345 event_sender,
346 }
347 }
348 pub fn on_step1_sent(&self) {
349 let prev = self.protocol_state.swap(1, Ordering::Relaxed);
350 if prev == 0 {
351 tracing::debug!("📡 协议: SyncStep1 已发送");
352 self.emit_event(SyncEvent::ProtocolStateChanged(
353 ProtocolSyncState::Step1Sent,
354 ));
355 }
356 }
357 pub fn on_step2_received(&mut self) -> bool {
358 let prev = self.protocol_state.swap(2, Ordering::Relaxed);
359 if prev == 1 {
360 self.step2_time = Some(Instant::now());
362
363 let elapsed_ms = if let (Some(start), Some(step2)) =
364 (self.start_time, self.step2_time)
365 {
366 step2.duration_since(start).as_millis() as u64
367 } else {
368 0
369 };
370
371 let has_data = self.has_data.load(Ordering::Relaxed);
372
373 tracing::info!(
374 "🎉 协议同步完成: Step1->Step2, 耗时 {}ms, 有数据: {}",
375 elapsed_ms,
376 has_data
377 );
378
379 self.emit_event(SyncEvent::ProtocolStateChanged(
380 ProtocolSyncState::Step2Received,
381 ));
382 self.emit_event(SyncEvent::InitialSyncCompleted {
383 has_data,
384 elapsed_ms,
385 });
386
387 return true; }
389 false
390 }
391 pub fn on_update_received(&self) {
392 let prev_state = self.protocol_state.load(Ordering::Relaxed);
393
394 self.has_data.store(true, Ordering::Relaxed);
396
397 if prev_state == 2 {
399 self.protocol_state.store(3, Ordering::Relaxed);
400 self.emit_event(SyncEvent::ProtocolStateChanged(
401 ProtocolSyncState::Updating,
402 ));
403 }
404
405 self.emit_event(SyncEvent::DataReceived);
406 }
407
408 pub fn is_initial_sync_completed(&self) -> bool {
409 self.protocol_state.load(Ordering::Relaxed) >= 2
410 }
411
412 pub fn get_protocol_state(&self) -> ProtocolSyncState {
413 match self.protocol_state.load(Ordering::Relaxed) {
414 0 => ProtocolSyncState::NotStarted,
415 1 => ProtocolSyncState::Step1Sent,
416 2 => ProtocolSyncState::Step2Received,
417 3 => ProtocolSyncState::Updating,
418 _ => ProtocolSyncState::NotStarted,
419 }
420 }
421 fn emit_event(
422 &self,
423 event: SyncEvent,
424 ) {
425 if let Some(sender) = &self.event_sender {
426 let _ = sender.send(event);
427 }
428 }
429 pub fn reset(&mut self) {
431 self.protocol_state.store(0, Ordering::Relaxed);
432 self.has_data.store(false, Ordering::Relaxed);
433 self.start_time = Some(Instant::now());
434 self.step2_time = None;
435 }
436
437 pub fn on_connection_failed(
439 &self,
440 error: &ConnectionError,
441 ) {
442 tracing::error!("🔌 连接失败: {}", error);
443 self.emit_event(SyncEvent::ConnectionFailed(error.clone()));
444 }
445}