mf_collab_client/
conn.rs

1#![allow(dead_code)]
2use futures_util::sink::SinkExt;
3use futures_util::StreamExt;
4use std::future::Future;
5use std::marker::PhantomData;
6use std::pin::Pin;
7use std::sync::{Arc, Weak};
8use std::task::{Context, Poll};
9use tokio::spawn;
10use tokio::sync::{Mutex, RwLock};
11use tokio::task::JoinHandle;
12use yrs::encoding::read::Cursor;
13use yrs::sync::Awareness;
14use yrs::sync::{
15    DefaultProtocol, Error, Message, MessageReader, Protocol, SyncMessage,
16};
17use yrs::updates::decoder::{Decode, DecoderV1};
18use yrs::updates::encoder::{Encode, Encoder, EncoderV1};
19use yrs::Update;
20use std::time::Instant;
21
22/// 链接处理,通过消息流实现 Yjs/Yrs 意识和更新交换协议。
23///
24/// 这个连接实现了 Future 模式,可以被等待,以便调用者识别底层 websocket 连接是否已优雅地完成或突然结束。
25#[derive(Debug)]
26pub struct Connection<Sink, Stream> {
27    processing_loop: JoinHandle<Result<(), Error>>,
28    awareness: Arc<RwLock<Awareness>>,
29    inbox: Arc<Mutex<Sink>>,
30    sync_tracker: Arc<RwLock<SyncTracker>>, // 新增同步跟踪器
31    _stream: PhantomData<Stream>,
32}
33
34impl<Sink, Stream, E> Connection<Sink, Stream>
35where
36    Sink: SinkExt<Vec<u8>, Error = E> + Send + Sync + Unpin + 'static,
37    E: Into<Error> + Send + Sync,
38{
39    pub async fn send(
40        &self,
41        msg: Vec<u8>,
42    ) -> Result<(), Error> {
43        let mut inbox = self.inbox.lock().await;
44        match inbox.send(msg).await {
45            Ok(_) => Ok(()),
46            Err(err) => Err(err.into()),
47        }
48    }
49
50    pub async fn close(self) -> Result<(), E> {
51        let mut inbox = self.inbox.lock().await;
52        inbox.close().await
53    }
54
55    pub fn sink(&self) -> Weak<Mutex<Sink>> {
56        Arc::downgrade(&self.inbox)
57    }
58}
59
60impl<Sink, Stream, E> Connection<Sink, Stream>
61where
62    Stream:
63        StreamExt<Item = Result<Vec<u8>, E>> + Send + Sync + Unpin + 'static,
64    Sink: SinkExt<Vec<u8>, Error = E> + Send + Sync + Unpin + 'static,
65    E: Into<Error> + Send + Sync,
66{
67    /// 创建带同步检测的连接
68    pub fn new_with_sync_detection(
69        awareness: Arc<RwLock<Awareness>>,
70        sink: Sink,
71        stream: Stream,
72        event_sender: Option<SyncEventSender>,
73    ) -> Self {
74        let sync_tracker =
75            Arc::new(RwLock::new(SyncTracker::new(event_sender)));
76        Self::with_protocol_and_sync(
77            awareness,
78            sink,
79            stream,
80            DefaultProtocol,
81            sync_tracker,
82        )
83    }
84    /// 创建带协议和同步检测的连接
85    pub fn with_protocol_and_sync<P>(
86        awareness: Arc<RwLock<Awareness>>,
87        sink: Sink,
88        mut stream: Stream,
89        protocol: P,
90        sync_tracker: Arc<RwLock<SyncTracker>>,
91    ) -> Self
92    where
93        P: Protocol + Send + Sync + 'static,
94    {
95        let sink = Arc::new(Mutex::new(sink));
96        let inbox = sink.clone();
97        let loop_sink = Arc::downgrade(&sink);
98        let loop_awareness = Arc::downgrade(&awareness);
99        let loop_sync_tracker = Arc::downgrade(&sync_tracker);
100
101        let processing_loop: JoinHandle<Result<(), Error>> =
102            spawn(async move {
103                // 发送 SyncStep1
104                let payload = {
105                    let awareness = loop_awareness.upgrade().unwrap();
106                    let mut encoder = EncoderV1::new();
107                    let awareness = awareness.read().await;
108                    protocol.start(&awareness, &mut encoder)?;
109                    encoder.to_vec()
110                };
111
112                if !payload.is_empty() {
113                    // 🔥 标记 Step1 已发送
114                    if let Some(tracker) = loop_sync_tracker.upgrade() {
115                        tracker.read().await.on_step1_sent();
116                    }
117
118                    if let Some(sink) = loop_sink.upgrade() {
119                        let mut s = sink.lock().await;
120                        if let Err(e) = s.send(payload).await {
121                            return Err(e.into());
122                        }
123                    } else {
124                        return Ok(());
125                    }
126                }
127
128                // 消息处理循环
129                while let Some(input) = stream.next().await {
130                    match input {
131                        Ok(data) => {
132                            if let Some(mut sink) = loop_sink.upgrade() {
133                                if let Some(awareness) =
134                                    loop_awareness.upgrade()
135                                {
136                                    if let Some(sync_tracker) =
137                                        loop_sync_tracker.upgrade()
138                                    {
139                                        match Self::process_with_sync_detection(
140                                            &protocol,
141                                            &awareness,
142                                            &mut sink,
143                                            &sync_tracker,
144                                            data,
145                                        )
146                                        .await
147                                        {
148                                            Ok(()) => { /* continue */ },
149                                            Err(e) => return Err(e),
150                                        }
151                                    }
152                                } else {
153                                    return Ok(());
154                                }
155                            } else {
156                                return Ok(());
157                            }
158                        },
159                        Err(e) => return Err(e.into()),
160                    }
161                }
162
163                Ok(())
164            });
165
166        Connection {
167            processing_loop,
168            awareness,
169            inbox,
170            sync_tracker,
171            _stream: PhantomData,
172        }
173    }
174    /// 带同步检测的消息处理
175    async fn process_with_sync_detection<P: Protocol>(
176        protocol: &P,
177        awareness: &Arc<RwLock<Awareness>>,
178        sink: &mut Arc<Mutex<Sink>>,
179        sync_tracker: &Arc<RwLock<SyncTracker>>,
180        input: Vec<u8>,
181    ) -> Result<(), Error> {
182        let mut decoder = DecoderV1::new(Cursor::new(&input));
183        let reader = MessageReader::new(&mut decoder);
184
185        for r in reader {
186            let msg = r?;
187
188            // 🔥 在处理消息前检测同步状态
189            Self::track_sync_message(&msg, sync_tracker).await;
190
191            if let Some(reply) = handle_msg(protocol, awareness, msg).await? {
192                let mut sender = sink.lock().await;
193                if let Err(e) = sender.send(reply.encode_v1()).await {
194                    tracing::error!("连接发送回复失败");
195                    return Err(e.into());
196                }
197            }
198        }
199
200        Ok(())
201    }
202    /// 跟踪同步消息
203    async fn track_sync_message(
204        msg: &Message,
205        sync_tracker: &Arc<RwLock<SyncTracker>>,
206    ) {
207        if let Message::Sync(sync_msg) = msg {
208            match sync_msg {
209                SyncMessage::SyncStep2(_) => {
210                    // 🎉 收到 Step2,首次同步完成!
211                    let mut tracker = sync_tracker.write().await;
212                    tracker.on_step2_received();
213                },
214                SyncMessage::Update(_) => {
215                    // 收到数据更新
216                    let tracker = sync_tracker.read().await;
217                    tracker.on_update_received();
218                },
219                _ => {},
220            }
221        }
222    }
223
224    /// 获取同步跟踪器
225    pub fn sync_tracker(&self) -> &Arc<RwLock<SyncTracker>> {
226        &self.sync_tracker
227    }
228
229    /// 等待初始同步完成
230    pub async fn wait_for_initial_sync(
231        &self,
232        timeout_ms: u64,
233    ) -> bool {
234        let start_time = Instant::now();
235        let timeout_duration = tokio::time::Duration::from_millis(timeout_ms);
236
237        loop {
238            {
239                let tracker = self.sync_tracker.read().await;
240                if tracker.is_initial_sync_completed() {
241                    return true;
242                }
243            }
244
245            if start_time.elapsed() >= timeout_duration {
246                break;
247            }
248
249            tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
250        }
251
252        false
253    }
254
255    /// 获取当前协议同步状态
256    pub async fn get_protocol_sync_state(&self) -> ProtocolSyncState {
257        self.sync_tracker.read().await.get_protocol_state()
258    }
259    /// Returns an underlying [Awareness] structure, that contains client state of that connection.
260    pub fn awareness(&self) -> &Arc<RwLock<Awareness>> {
261        &self.awareness
262    }
263}
264
265impl<Sink, Stream> Unpin for Connection<Sink, Stream> {}
266
267impl<Sink, Stream> Future for Connection<Sink, Stream> {
268    type Output = Result<(), Error>;
269
270    fn poll(
271        mut self: Pin<&mut Self>,
272        cx: &mut Context<'_>,
273    ) -> Poll<Self::Output> {
274        match Pin::new(&mut self.processing_loop).poll(cx) {
275            Poll::Pending => Poll::Pending,
276            Poll::Ready(Err(e)) => Poll::Ready(Err(Error::Other(e.into()))),
277            Poll::Ready(Ok(r)) => Poll::Ready(r),
278        }
279    }
280}
281
282pub async fn handle_msg<P: Protocol>(
283    protocol: &P,
284    a: &Arc<RwLock<Awareness>>,
285    msg: Message,
286) -> Result<Option<Message>, Error> {
287    match msg {
288        Message::Sync(msg) => match msg {
289            SyncMessage::SyncStep1(sv) => {
290                let awareness = a.read().await;
291                protocol.handle_sync_step1(&awareness, sv)
292            },
293            SyncMessage::SyncStep2(update) => {
294                let mut awareness = a.write().await;
295                protocol.handle_sync_step2(
296                    &mut awareness,
297                    Update::decode_v1(&update)?,
298                )
299            },
300            SyncMessage::Update(update) => {
301                let mut awareness = a.write().await;
302                protocol
303                    .handle_update(&mut awareness, Update::decode_v1(&update)?)
304            },
305        },
306        Message::Auth(reason) => {
307            let awareness = a.read().await;
308            protocol.handle_auth(&awareness, reason)
309        },
310        Message::AwarenessQuery => {
311            let awareness = a.read().await;
312            protocol.handle_awareness_query(&awareness)
313        },
314        Message::Awareness(update) => {
315            let mut awareness = a.write().await;
316            protocol.handle_awareness_update(&mut awareness, update)
317        },
318        Message::Custom(tag, data) => {
319            let mut awareness = a.write().await;
320            protocol.missing_handle(&mut awareness, tag, data)
321        },
322    }
323}
324
325use crate::types::{ConnectionError, ProtocolSyncState, SyncEvent, SyncEventSender};
326use std::sync::atomic::{AtomicBool, AtomicU8, Ordering};
327
328/// 同步状态跟踪器
329#[derive(Debug)]
330pub struct SyncTracker {
331    protocol_state: AtomicU8, // 0=NotStarted, 1=Step1Sent, 2=Step2Received, 3=Updating
332    has_data: AtomicBool,
333    start_time: Option<Instant>,
334    step2_time: Option<Instant>,
335    event_sender: Option<SyncEventSender>,
336}
337
338impl SyncTracker {
339    pub fn new(event_sender: Option<SyncEventSender>) -> Self {
340        Self {
341            protocol_state: AtomicU8::new(0),
342            has_data: AtomicBool::new(false),
343            start_time: Some(Instant::now()),
344            step2_time: None,
345            event_sender,
346        }
347    }
348    pub fn on_step1_sent(&self) {
349        let prev = self.protocol_state.swap(1, Ordering::Relaxed);
350        if prev == 0 {
351            tracing::debug!("📡 协议: SyncStep1 已发送");
352            self.emit_event(SyncEvent::ProtocolStateChanged(
353                ProtocolSyncState::Step1Sent,
354            ));
355        }
356    }
357    pub fn on_step2_received(&mut self) -> bool {
358        let prev = self.protocol_state.swap(2, Ordering::Relaxed);
359        if prev == 1 {
360            // Step1 -> Step2,首次同步完成!
361            self.step2_time = Some(Instant::now());
362
363            let elapsed_ms = if let (Some(start), Some(step2)) =
364                (self.start_time, self.step2_time)
365            {
366                step2.duration_since(start).as_millis() as u64
367            } else {
368                0
369            };
370
371            let has_data = self.has_data.load(Ordering::Relaxed);
372
373            tracing::info!(
374                "🎉 协议同步完成: Step1->Step2, 耗时 {}ms, 有数据: {}",
375                elapsed_ms,
376                has_data
377            );
378
379            self.emit_event(SyncEvent::ProtocolStateChanged(
380                ProtocolSyncState::Step2Received,
381            ));
382            self.emit_event(SyncEvent::InitialSyncCompleted {
383                has_data,
384                elapsed_ms,
385            });
386
387            return true; // 首次同步完成
388        }
389        false
390    }
391    pub fn on_update_received(&self) {
392        let prev_state = self.protocol_state.load(Ordering::Relaxed);
393
394        // 标记有数据
395        self.has_data.store(true, Ordering::Relaxed);
396
397        // 如果还在Step2状态,切换到Updating
398        if prev_state == 2 {
399            self.protocol_state.store(3, Ordering::Relaxed);
400            self.emit_event(SyncEvent::ProtocolStateChanged(
401                ProtocolSyncState::Updating,
402            ));
403        }
404
405        self.emit_event(SyncEvent::DataReceived);
406    }
407
408    pub fn is_initial_sync_completed(&self) -> bool {
409        self.protocol_state.load(Ordering::Relaxed) >= 2
410    }
411
412    pub fn get_protocol_state(&self) -> ProtocolSyncState {
413        match self.protocol_state.load(Ordering::Relaxed) {
414            0 => ProtocolSyncState::NotStarted,
415            1 => ProtocolSyncState::Step1Sent,
416            2 => ProtocolSyncState::Step2Received,
417            3 => ProtocolSyncState::Updating,
418            _ => ProtocolSyncState::NotStarted,
419        }
420    }
421    fn emit_event(
422        &self,
423        event: SyncEvent,
424    ) {
425        if let Some(sender) = &self.event_sender {
426            let _ = sender.send(event);
427        }
428    }
429    /// 重置同步状态(用于重连)
430    pub fn reset(&mut self) {
431        self.protocol_state.store(0, Ordering::Relaxed);
432        self.has_data.store(false, Ordering::Relaxed);
433        self.start_time = Some(Instant::now());
434        self.step2_time = None;
435    }
436
437    /// 标记连接失败
438    pub fn on_connection_failed(
439        &self,
440        error: &ConnectionError,
441    ) {
442        tracing::error!("🔌 连接失败: {}", error);
443        self.emit_event(SyncEvent::ConnectionFailed(error.clone()));
444    }
445}