mf_collab_client/
provider.rs

1use std::time::Duration;
2
3use tokio::time::timeout;
4use tokio_tungstenite::connect_async;
5use yrs::sync::{Message, SyncMessage};
6use yrs::updates::encoder::Encode;
7use yrs::{Subscription};
8use url::Url;
9use crate::AwarenessRef;
10use crate::conn::Connection;
11use crate::types::*;
12use crate::client::{ClientSink, ClientStream};
13use futures_util::{SinkExt, StreamExt};
14
15pub struct WebsocketProvider {
16    pub server_url: String,
17    pub room_name: String,
18    pub awareness: AwarenessRef,
19    client_conn: Option<Connection<ClientSink, ClientStream>>,
20    pub status: ConnectionStatus,
21    // 同步检测相关
22    sync_event_sender: Option<SyncEventSender>,
23    sync_event_receiver: Option<SyncEventReceiver>,
24
25    pub ws_reconnect_attempts: u32,
26    pub max_backoff_time: u64,
27    pub ws_url: Option<Url>,
28    pub client_id: u64,
29    subscriptions: Vec<Subscription>,
30}
31
32impl WebsocketProvider {
33    pub async fn new(
34        server_url: String,
35        room_name: String,
36        awareness: AwarenessRef,
37    ) -> Self {
38        let (event_sender, event_receiver) =
39            tokio::sync::broadcast::channel(100);
40
41        let ws_url = Url::parse(&format!(
42            "{}/{}",
43            server_url.trim_end_matches('/'),
44            room_name
45        ))
46        .ok();
47
48        let client_id = awareness.read().await.doc().client_id();
49
50        Self {
51            client_id,
52            server_url,
53            room_name,
54            awareness,
55            client_conn: None,
56            status: ConnectionStatus::Disconnected,
57            sync_event_sender: Some(event_sender),
58            sync_event_receiver: Some(event_receiver),
59            ws_reconnect_attempts: 0,
60            max_backoff_time: 2500,
61            ws_url,
62            subscriptions: Vec::new(),
63        }
64    }
65
66    pub fn subscription(
67        &mut self,
68        subscription: Subscription,
69    ) {
70        self.subscriptions.push(subscription);
71    }
72    pub async fn connect(&mut self) {
73        if let Err(e) = self.smart_connect().await {
74            tracing::error!("{}", e);
75        }
76    }
77    pub async fn connect_with_retry(
78        &mut self,
79        config: Option<ConnectionRetryConfig>,
80    ) -> anyhow::Result<()> {
81        let config = config.unwrap_or_default();
82        let mut attempt = 0;
83        let mut delay = config.initial_delay_ms;
84
85        while attempt < config.max_attempts {
86            attempt += 1;
87            self.update_status(ConnectionStatus::Retrying {
88                attempt,
89                max_attempts: config.max_attempts,
90            });
91
92            tracing::info!("🔄 连接尝试 {}/{}", attempt, config.max_attempts);
93
94            match self.try_connect().await {
95                Ok(()) => {
96                    self.update_status(ConnectionStatus::Connected);
97                    return Ok(());
98                },
99                Err(e) => {
100                    let error = self.classify_connection_error(&e);
101
102                    if attempt >= config.max_attempts {
103                        // 🔥 发送连接失败事件
104                        if let Some(sender) = &self.sync_event_sender {
105                            let _ = sender.send(SyncEvent::ConnectionFailed(
106                                error.clone(),
107                            ));
108                        }
109                        self.update_status(ConnectionStatus::Failed(
110                            error.clone(),
111                        ));
112                        tracing::error!(
113                            "❌ 连接失败,已达到最大重试次数: {}",
114                            error
115                        );
116                        return Err(anyhow::anyhow!("连接失败: {}", error));
117                    }
118
119                    tracing::warn!(
120                        "⚠️ 连接失败 (尝试 {}/{}): {}",
121                        attempt,
122                        config.max_attempts,
123                        error
124                    );
125
126                    // 指数退避延迟
127                    tokio::time::sleep(Duration::from_millis(delay)).await;
128                    delay = (delay as f64 * config.backoff_multiplier) as u64;
129                    delay = delay.min(config.max_delay_ms);
130                },
131            }
132        }
133
134        Err(anyhow::anyhow!("连接失败,已达到最大重试次数"))
135    }
136    async fn try_connect(&mut self) -> anyhow::Result<()> {
137        if self.status == ConnectionStatus::Connected
138            || self.status == ConnectionStatus::Connecting
139        {
140            return Ok(());
141        }
142
143        self.status = ConnectionStatus::Connecting;
144
145        let ws_url = match &self.ws_url {
146            Some(url) => url.as_str(),
147            None => {
148                return Err(anyhow::anyhow!("无效的 WebSocket URL"));
149            },
150        };
151
152        // 设置连接超时
153        let connect_timeout = Duration::from_secs(10);
154
155        match timeout(connect_timeout, connect_async(ws_url)).await {
156            Ok(connect_result) => {
157                match connect_result {
158                    Ok((ws_stream, _)) => {
159                        let (sink, stream) = ws_stream.split();
160
161                        // 使用带同步检测的连接
162                        let client_conn = Connection::new_with_sync_detection(
163                            self.awareness.clone(),
164                            ClientSink(sink),
165                            ClientStream(stream),
166                            self.sync_event_sender.clone(),
167                        );
168
169                        self.client_conn = Some(client_conn);
170                        self.ws_reconnect_attempts = 0;
171
172                        Ok(())
173                    },
174                    Err(e) => {
175                        self.status = ConnectionStatus::Disconnected;
176                        Err(anyhow::anyhow!("WebSocket 连接失败: {}", e))
177                    },
178                }
179            },
180            Err(_) => {
181                self.status = ConnectionStatus::Disconnected;
182                Err(anyhow::anyhow!("连接超时"))
183            },
184        }
185    }
186
187    /// 分类连接错误
188    fn classify_connection_error(
189        &self,
190        error: &anyhow::Error,
191    ) -> ConnectionError {
192        let error_str = error.to_string().to_lowercase();
193
194        if error_str.contains("timeout") || error_str.contains("timed out") {
195            ConnectionError::Timeout(10000)
196        } else if error_str.contains("connection refused")
197            || error_str.contains("failed to connect")
198        {
199            ConnectionError::ServerUnavailable(
200                "服务端未启动或端口未开放".to_string(),
201            )
202        } else if error_str.contains("websocket") {
203            ConnectionError::WebSocketError(error.to_string())
204        } else {
205            ConnectionError::NetworkError(error.to_string())
206        }
207    }
208    fn update_status(
209        &mut self,
210        new_status: ConnectionStatus,
211    ) {
212        self.status = new_status.clone();
213
214        // 发送状态变化事件
215        if let Some(sender) = &self.sync_event_sender {
216            let _ = sender.send(SyncEvent::ConnectionChanged(new_status));
217        }
218    }
219    /// 检查服务端是否可用
220    pub async fn check_server_availability(&self) -> bool {
221        if let Some(ws_url) = &self.ws_url {
222            let http_url = ws_url
223                .as_str()
224                .replace("ws://", "http://")
225                .replace("wss://", "https://");
226
227            // 尝试 HTTP 连接检查
228            match tokio::time::timeout(
229                Duration::from_secs(3),
230                reqwest::get(&http_url),
231            )
232            .await
233            {
234                Ok(Ok(_)) => true,
235                _ => false,
236            }
237        } else {
238            false
239        }
240    }
241    // 智能连接(先检查服务端可用性)
242    pub async fn smart_connect(&mut self) -> anyhow::Result<()> {
243        // 先检查服务端是否可用
244        if !self.check_server_availability().await {
245            self.status = ConnectionStatus::Failed(
246                ConnectionError::ServerUnavailable("服务端未启动".to_string()),
247            );
248            return Err(anyhow::anyhow!("服务端未启动或不可访问"));
249        }
250
251        // 使用重试机制连接
252        self.connect_with_retry(None).await?;
253        self.setup_update_listeners().await;
254        Ok(())
255    }
256
257    /// 设置统一的文档变更监听器
258    /// 监听所有文档变更并发送事件通知
259    async fn setup_update_listeners(&mut self) {
260        // 延迟 100 毫秒,以避免与 yrs-warp 的初始事务发生竞争
261        tokio::time::sleep(std::time::Duration::from_millis(100)).await;
262
263        // 1. 监听文档变更
264        let doc_subscription = {
265            let sink = self.client_conn.as_ref().unwrap().sink();
266            let client_id = self.client_id.clone();
267            let awareness_lock = self.awareness.read().await;
268            let doc = awareness_lock.doc();
269            doc.observe_update_v1(move |txn, event| {
270                let origin = txn.origin();
271
272                if let Some(origin_ref) = origin {
273                    let origin_bytes = origin_ref.as_ref();
274                    if let Ok(origin_str) = std::str::from_utf8(origin_bytes) {
275                        let update = event.update.to_owned();
276                        if origin_str == client_id.to_string() {
277                            let sink_weak = sink.clone();
278                            tokio::spawn(async move {
279                                let msg =
280                                    Message::Sync(SyncMessage::Update(update))
281                                        .encode_v1();
282                                let binding = sink_weak.upgrade().unwrap();
283                                let mut sink = binding.lock().await;
284                                sink.send(msg).await.unwrap();
285                            });
286                        }
287                    }
288                }
289            })
290        };
291
292        // 保存订阅
293        if let Ok(subscription) = doc_subscription {
294            self.subscriptions.push(subscription);
295        }
296
297        // 2. 监听本地 awareness 变更
298
299        {
300            let awareness_lock = self.awareness.write().await;
301            let sink = self.client_conn.as_ref().unwrap().sink();
302
303            // 修复 on_update 签名以匹配 yrs v0.18.8
304            let awareness_subscription =
305                awareness_lock.on_update(move |event| {
306                    let awareness_update = event.awareness_update().unwrap();
307                    let sink_weak = sink.clone();
308                    tokio::spawn(async move {
309                        let msg: Vec<u8> =
310                            Message::Awareness(awareness_update).encode_v1();
311                        let binding = sink_weak.upgrade().unwrap();
312                        let mut sink = binding.lock().await;
313                        sink.send(msg).await.unwrap();
314                    });
315                });
316            self.subscriptions.push(awareness_subscription);
317            tracing::info!("✅ 本地 Awareness 变更监听器已设置");
318        }
319    }
320
321    /// 等待协议级同步完成(包括空房间)
322    pub async fn wait_for_protocol_sync(
323        &self,
324        timeout_ms: u64,
325    ) -> anyhow::Result<bool> {
326        match &self.client_conn {
327            Some(conn) => Ok(conn.wait_for_initial_sync(timeout_ms).await),
328            None => Err(anyhow::anyhow!("连接未建立")),
329        }
330    }
331
332    /// 获取协议同步状态
333    pub async fn get_protocol_sync_state(&self) -> Option<ProtocolSyncState> {
334        match &self.client_conn {
335            Some(conn) => Some(conn.get_protocol_sync_state().await),
336            None => None,
337        }
338    }
339
340    /// 订阅同步事件
341    pub fn subscribe_sync_events(&mut self) -> Option<SyncEventReceiver> {
342        self.sync_event_receiver.take()
343    }
344
345    /// 断开连接并清理资源
346    pub async fn disconnect(&mut self) {
347        tracing::info!("🔌 断开 WebSocket 连接...");
348
349        // 清理连接
350        self.client_conn = None;
351        self.status = ConnectionStatus::Disconnected;
352        tracing::info!("✅ WebSocket 连接已断开");
353    }
354
355    /// 检查连接状态
356    pub fn is_connected(&self) -> bool {
357        self.status == ConnectionStatus::Connected && self.client_conn.is_some()
358    }
359
360    /// 获取连接状态
361    pub fn get_status(&self) -> &ConnectionStatus {
362        &self.status
363    }
364}
365
366impl Drop for WebsocketProvider {
367    fn drop(&mut self) {
368        // 在析构时清理监听器
369        tracing::debug!("🧹 WebsocketProvider 已清理");
370    }
371}