Skip to main content

mf_collab_client/
provider.rs

1use std::time::Duration;
2
3use tokio::time::timeout;
4use tokio_tungstenite::connect_async;
5use yrs::sync::{Message, SyncMessage};
6use yrs::updates::encoder::Encode;
7use yrs::{Subscription};
8use url::Url;
9use crate::AwarenessRef;
10use crate::conn::Connection;
11use crate::types::*;
12use crate::client::{ClientSink, ClientStream};
13use futures_util::{SinkExt, StreamExt};
14
15pub struct WebsocketProvider {
16    pub server_url: String,
17    pub room_name: String,
18    pub awareness: AwarenessRef,
19    client_conn: Option<Connection<ClientSink, ClientStream>>,
20    pub status: ConnectionStatus,
21    // 同步检测相关
22    sync_event_sender: Option<SyncEventSender>,
23    sync_event_receiver: Option<SyncEventReceiver>,
24
25    pub ws_reconnect_attempts: u32,
26    pub max_backoff_time: u64,
27    pub ws_url: Option<Url>,
28    pub client_id: u64,
29    subscriptions: Vec<Subscription>,
30}
31
32impl WebsocketProvider {
33    pub async fn new(
34        server_url: String,
35        room_name: String,
36        awareness: AwarenessRef,
37    ) -> Self {
38        let (event_sender, event_receiver) =
39            tokio::sync::broadcast::channel(100);
40
41        let ws_url = Url::parse(&format!(
42            "{}/{}",
43            server_url.trim_end_matches('/'),
44            room_name
45        ))
46        .ok();
47
48        let client_id = awareness.read().await.doc().client_id();
49
50        Self {
51            client_id,
52            server_url,
53            room_name,
54            awareness,
55            client_conn: None,
56            status: ConnectionStatus::Disconnected,
57            sync_event_sender: Some(event_sender),
58            sync_event_receiver: Some(event_receiver),
59            ws_reconnect_attempts: 0,
60            max_backoff_time: 2500,
61            ws_url,
62            subscriptions: Vec::new(),
63        }
64    }
65
66    pub fn subscription(
67        &mut self,
68        subscription: Subscription,
69    ) {
70        self.subscriptions.push(subscription);
71    }
72    pub async fn connect(&mut self) {
73        if let Err(e) = self.smart_connect().await {
74            tracing::error!("{}", e);
75        }
76    }
77    pub async fn connect_with_retry(
78        &mut self,
79        config: Option<ConnectionRetryConfig>,
80    ) -> anyhow::Result<()> {
81        let config = config.unwrap_or_default();
82        let mut attempt = 0;
83        let mut delay = config.initial_delay_ms;
84
85        while attempt < config.max_attempts {
86            attempt += 1;
87            self.update_status(ConnectionStatus::Retrying {
88                attempt,
89                max_attempts: config.max_attempts,
90            });
91
92            tracing::info!("🔄 连接尝试 {}/{}", attempt, config.max_attempts);
93
94            match self.try_connect().await {
95                Ok(()) => {
96                    self.update_status(ConnectionStatus::Connected);
97                    return Ok(());
98                },
99                Err(e) => {
100                    let error = self.classify_connection_error(&e);
101
102                    if attempt >= config.max_attempts {
103                        // 🔥 发送连接失败事件
104                        if let Some(sender) = &self.sync_event_sender {
105                            let _ = sender.send(SyncEvent::ConnectionFailed(
106                                error.clone(),
107                            ));
108                        }
109                        self.update_status(ConnectionStatus::Failed(
110                            error.clone(),
111                        ));
112                        tracing::error!(
113                            "❌ 连接失败,已达到最大重试次数: {}",
114                            error
115                        );
116                        return Err(anyhow::anyhow!("连接失败: {}", error));
117                    }
118
119                    tracing::warn!(
120                        "⚠️ 连接失败 (尝试 {}/{}): {}",
121                        attempt,
122                        config.max_attempts,
123                        error
124                    );
125
126                    // 指数退避延迟
127                    tokio::time::sleep(Duration::from_millis(delay)).await;
128                    delay = (delay as f64 * config.backoff_multiplier) as u64;
129                    delay = delay.min(config.max_delay_ms);
130                },
131            }
132        }
133
134        Err(anyhow::anyhow!("连接失败,已达到最大重试次数"))
135    }
136    async fn try_connect(&mut self) -> anyhow::Result<()> {
137        if self.status == ConnectionStatus::Connected
138            || self.status == ConnectionStatus::Connecting
139        {
140            return Ok(());
141        }
142
143        self.status = ConnectionStatus::Connecting;
144
145        let ws_url = match &self.ws_url {
146            Some(url) => url.as_str(),
147            None => {
148                return Err(anyhow::anyhow!("无效的 WebSocket URL"));
149            },
150        };
151
152        // 设置连接超时
153        let connect_timeout = Duration::from_secs(10);
154
155        match timeout(connect_timeout, connect_async(ws_url)).await {
156            Ok(connect_result) => {
157                match connect_result {
158                    Ok((ws_stream, _)) => {
159                        let (sink, stream) = ws_stream.split();
160
161                        // 使用带同步检测的连接
162                        let client_conn = Connection::new_with_sync_detection(
163                            self.awareness.clone(),
164                            ClientSink(sink),
165                            ClientStream(stream),
166                            self.sync_event_sender.clone(),
167                        );
168
169                        self.client_conn = Some(client_conn);
170                        self.ws_reconnect_attempts = 0;
171
172                        Ok(())
173                    },
174                    Err(e) => {
175                        self.status = ConnectionStatus::Disconnected;
176                        Err(anyhow::anyhow!("WebSocket 连接失败: {}", e))
177                    },
178                }
179            },
180            Err(_) => {
181                self.status = ConnectionStatus::Disconnected;
182                Err(anyhow::anyhow!("连接超时"))
183            },
184        }
185    }
186
187    /// 清理文档与 awareness 的订阅监听器
188    fn clear_subscriptions(&mut self) {
189        if !self.subscriptions.is_empty() {
190            tracing::debug!(
191                count = self.subscriptions.len(),
192                "🧹 正在清理订阅监听器: {} 个",
193                self.subscriptions.len()
194            );
195        }
196        // 逐一 drop 以解除注册
197        self.subscriptions.drain(..);
198    }
199
200    /// 分类连接错误
201    fn classify_connection_error(
202        &self,
203        error: &anyhow::Error,
204    ) -> ConnectionError {
205        let error_str = error.to_string().to_lowercase();
206
207        if error_str.contains("timeout") || error_str.contains("timed out") {
208            ConnectionError::Timeout(10000)
209        } else if error_str.contains("connection refused")
210            || error_str.contains("failed to connect")
211        {
212            ConnectionError::ServerUnavailable(
213                "服务端未启动或端口未开放".to_string(),
214            )
215        } else if error_str.contains("websocket") {
216            ConnectionError::WebSocketError(error.to_string())
217        } else {
218            ConnectionError::NetworkError(error.to_string())
219        }
220    }
221    fn update_status(
222        &mut self,
223        new_status: ConnectionStatus,
224    ) {
225        self.status = new_status.clone();
226
227        // 发送状态变化事件
228        if let Some(sender) = &self.sync_event_sender {
229            let _ = sender.send(SyncEvent::ConnectionChanged(new_status));
230        }
231    }
232    /// 检查服务端是否可用
233    pub async fn check_server_availability(&self) -> bool {
234        if let Some(ws_url) = &self.ws_url {
235            let http_url = ws_url
236                .as_str()
237                .replace("ws://", "http://")
238                .replace("wss://", "https://");
239
240            // 尝试 HTTP 连接检查
241            matches!(
242                tokio::time::timeout(
243                    Duration::from_secs(3),
244                    reqwest::get(&http_url),
245                )
246                .await,
247                Ok(Ok(_))
248            )
249        } else {
250            false
251        }
252    }
253    // 智能连接(先检查服务端可用性)
254    pub async fn smart_connect(&mut self) -> anyhow::Result<()> {
255        // 先检查服务端是否可用
256        if !self.check_server_availability().await {
257            self.status = ConnectionStatus::Failed(
258                ConnectionError::ServerUnavailable("服务端未启动".to_string()),
259            );
260            return Err(anyhow::anyhow!("服务端未启动或不可访问"));
261        }
262
263        // 使用重试机制连接
264        self.connect_with_retry(None).await?;
265        self.setup_update_listeners().await;
266        Ok(())
267    }
268
269    /// 设置统一的文档变更监听器
270    /// 监听所有文档变更并发送事件通知
271    async fn setup_update_listeners(&mut self) {
272        // 延迟 100 毫秒,以避免与 yrs-warp 的初始事务发生竞争
273        tokio::time::sleep(std::time::Duration::from_millis(100)).await;
274
275        // 确保客户端连接存在
276        let conn = match self.client_conn.as_ref() {
277            Some(conn) => conn,
278            None => {
279                tracing::error!("尝试设置监听器时客户端连接不存在");
280                return;
281            },
282        };
283
284        // 1. 监听文档变更
285        let doc_subscription = {
286            let sink = conn.sink();
287            let client_id = self.client_id;
288            let awareness_lock = self.awareness.read().await;
289            let doc = awareness_lock.doc();
290            doc.observe_update_v1(move |txn, event| {
291                let origin = txn.origin();
292
293                if let Some(origin_ref) = origin {
294                    let origin_bytes = origin_ref.as_ref();
295                    if let Ok(origin_str) = std::str::from_utf8(origin_bytes) {
296                        let update = event.update.to_owned();
297                        if origin_str == client_id.to_string() {
298                            let sink_weak = sink.clone();
299                            tokio::spawn(async move {
300                                let msg =
301                                    Message::Sync(SyncMessage::Update(update))
302                                        .encode_v1();
303                                if let Some(binding) = sink_weak.upgrade() {
304                                    let mut sink = binding.lock().await;
305                                    if let Err(e) = sink.send(msg).await {
306                                        tracing::debug!(
307                                            "忽略发送错误(可能已断开): {}",
308                                            e
309                                        );
310                                    }
311                                } else {
312                                    tracing::debug!(
313                                        "发送通道已释放(可能已断开),跳过文档更新发送"
314                                    );
315                                }
316                            });
317                        }
318                    }
319                }
320            })
321        };
322
323        // 保存订阅
324        if let Ok(subscription) = doc_subscription {
325            self.subscriptions.push(subscription);
326        }
327
328        // 2. 监听本地 awareness 变更
329
330        {
331            let awareness_lock = self.awareness.write().await;
332            // 再次安全获取连接(虽然上面已经检查过,但为了代码清晰性)
333            let conn = match self.client_conn.as_ref() {
334                Some(conn) => conn,
335                None => {
336                    tracing::error!(
337                        "尝试设置 awareness 监听器时客户端连接不存在"
338                    );
339                    return;
340                },
341            };
342            let sink = conn.sink();
343
344            // 修复 on_update 签名以匹配 yrs v0.18.8
345            let awareness_subscription = awareness_lock.on_update(move |event| {
346                if let Some(awareness_update) = event.awareness_update() {
347                    let sink_weak = sink.clone();
348                    tokio::spawn(async move {
349                        let msg: Vec<u8> =
350                            Message::Awareness(awareness_update).encode_v1();
351                        if let Some(binding) = sink_weak.upgrade() {
352                            let mut sink = binding.lock().await;
353                            if let Err(e) = sink.send(msg).await {
354                                tracing::debug!(
355                                    "忽略发送错误(可能已断开): {}",
356                                    e
357                                );
358                            }
359                        } else {
360                            tracing::debug!(
361                                "发送通道已释放(可能已断开),跳过 Awareness 发送"
362                            );
363                        }
364                    });
365                }
366            });
367            self.subscriptions.push(awareness_subscription);
368            tracing::info!("✅ 本地 Awareness 变更监听器已设置");
369        }
370    }
371
372    /// 等待协议级同步完成(包括空房间)
373    pub async fn wait_for_protocol_sync(
374        &self,
375        timeout_ms: u64,
376    ) -> anyhow::Result<bool> {
377        match &self.client_conn {
378            Some(conn) => Ok(conn.wait_for_initial_sync(timeout_ms).await),
379            None => Err(anyhow::anyhow!("连接未建立")),
380        }
381    }
382
383    /// 获取协议同步状态
384    pub async fn get_protocol_sync_state(&self) -> Option<ProtocolSyncState> {
385        match &self.client_conn {
386            Some(conn) => Some(conn.get_protocol_sync_state().await),
387            None => None,
388        }
389    }
390
391    /// 订阅同步事件
392    pub fn subscribe_sync_events(&mut self) -> Option<SyncEventReceiver> {
393        self.sync_event_receiver.take()
394    }
395
396    /// 断开连接并清理资源
397    pub async fn disconnect(&mut self) {
398        tracing::info!("🔌 断开 WebSocket 连接...");
399
400        // 1) 先清理订阅监听器,防止回调在断连后继续触发
401        self.clear_subscriptions();
402
403        // 2) 优雅关闭连接(关闭 sink 以促使处理循环退出)
404        if let Some(conn) = self.client_conn.take() {
405            if let Err(e) = conn.close().await {
406                tracing::debug!("关闭连接时出现错误(忽略): {:?}", e);
407            }
408        }
409
410        // 3) 更新状态并通知
411        self.update_status(ConnectionStatus::Disconnected);
412        tracing::info!("✅ WebSocket 连接已断开且监听器已清理");
413    }
414
415    /// 检查连接状态
416    pub fn is_connected(&self) -> bool {
417        self.status == ConnectionStatus::Connected && self.client_conn.is_some()
418    }
419
420    /// 获取连接状态
421    pub fn get_status(&self) -> &ConnectionStatus {
422        &self.status
423    }
424}
425
426impl Drop for WebsocketProvider {
427    fn drop(&mut self) {
428        // 在析构时清理监听器
429        self.clear_subscriptions();
430        tracing::debug!("🧹 WebsocketProvider 已清理(订阅监听器已释放)");
431    }
432}