mf_collab_client/
provider.rs1use std::time::Duration;
2
3use tokio::time::timeout;
4use tokio_tungstenite::connect_async;
5use yrs::sync::{Message, SyncMessage};
6use yrs::updates::encoder::Encode;
7use yrs::{Subscription};
8use url::Url;
9use crate::AwarenessRef;
10use crate::conn::Connection;
11use crate::types::*;
12use crate::client::{ClientSink, ClientStream};
13use futures_util::{SinkExt, StreamExt};
14
15pub struct WebsocketProvider {
16 pub server_url: String,
17 pub room_name: String,
18 pub awareness: AwarenessRef,
19 client_conn: Option<Connection<ClientSink, ClientStream>>,
20 pub status: ConnectionStatus,
21 sync_event_sender: Option<SyncEventSender>,
23 sync_event_receiver: Option<SyncEventReceiver>,
24
25 pub ws_reconnect_attempts: u32,
26 pub max_backoff_time: u64,
27 pub ws_url: Option<Url>,
28 pub client_id: u64,
29 subscriptions: Vec<Subscription>,
30}
31
32impl WebsocketProvider {
33 pub async fn new(
34 server_url: String,
35 room_name: String,
36 awareness: AwarenessRef,
37 ) -> Self {
38 let (event_sender, event_receiver) =
39 tokio::sync::broadcast::channel(100);
40
41 let ws_url = Url::parse(&format!(
42 "{}/{}",
43 server_url.trim_end_matches('/'),
44 room_name
45 ))
46 .ok();
47
48 let client_id = awareness.read().await.doc().client_id();
49
50 Self {
51 client_id,
52 server_url,
53 room_name,
54 awareness,
55 client_conn: None,
56 status: ConnectionStatus::Disconnected,
57 sync_event_sender: Some(event_sender),
58 sync_event_receiver: Some(event_receiver),
59 ws_reconnect_attempts: 0,
60 max_backoff_time: 2500,
61 ws_url,
62 subscriptions: Vec::new(),
63 }
64 }
65
66 pub fn subscription(
67 &mut self,
68 subscription: Subscription,
69 ) {
70 self.subscriptions.push(subscription);
71 }
72 pub async fn connect(&mut self) {
73 if let Err(e) = self.smart_connect().await {
74 tracing::error!("{}", e);
75 }
76 }
77 pub async fn connect_with_retry(
78 &mut self,
79 config: Option<ConnectionRetryConfig>,
80 ) -> anyhow::Result<()> {
81 let config = config.unwrap_or_default();
82 let mut attempt = 0;
83 let mut delay = config.initial_delay_ms;
84
85 while attempt < config.max_attempts {
86 attempt += 1;
87 self.update_status(ConnectionStatus::Retrying {
88 attempt,
89 max_attempts: config.max_attempts,
90 });
91
92 tracing::info!("🔄 连接尝试 {}/{}", attempt, config.max_attempts);
93
94 match self.try_connect().await {
95 Ok(()) => {
96 self.update_status(ConnectionStatus::Connected);
97 return Ok(());
98 },
99 Err(e) => {
100 let error = self.classify_connection_error(&e);
101
102 if attempt >= config.max_attempts {
103 if let Some(sender) = &self.sync_event_sender {
105 let _ = sender.send(SyncEvent::ConnectionFailed(
106 error.clone(),
107 ));
108 }
109 self.update_status(ConnectionStatus::Failed(
110 error.clone(),
111 ));
112 tracing::error!(
113 "❌ 连接失败,已达到最大重试次数: {}",
114 error
115 );
116 return Err(anyhow::anyhow!("连接失败: {}", error));
117 }
118
119 tracing::warn!(
120 "⚠️ 连接失败 (尝试 {}/{}): {}",
121 attempt,
122 config.max_attempts,
123 error
124 );
125
126 tokio::time::sleep(Duration::from_millis(delay)).await;
128 delay = (delay as f64 * config.backoff_multiplier) as u64;
129 delay = delay.min(config.max_delay_ms);
130 },
131 }
132 }
133
134 Err(anyhow::anyhow!("连接失败,已达到最大重试次数"))
135 }
136 async fn try_connect(&mut self) -> anyhow::Result<()> {
137 if self.status == ConnectionStatus::Connected
138 || self.status == ConnectionStatus::Connecting
139 {
140 return Ok(());
141 }
142
143 self.status = ConnectionStatus::Connecting;
144
145 let ws_url = match &self.ws_url {
146 Some(url) => url.as_str(),
147 None => {
148 return Err(anyhow::anyhow!("无效的 WebSocket URL"));
149 },
150 };
151
152 let connect_timeout = Duration::from_secs(10);
154
155 match timeout(connect_timeout, connect_async(ws_url)).await {
156 Ok(connect_result) => {
157 match connect_result {
158 Ok((ws_stream, _)) => {
159 let (sink, stream) = ws_stream.split();
160
161 let client_conn = Connection::new_with_sync_detection(
163 self.awareness.clone(),
164 ClientSink(sink),
165 ClientStream(stream),
166 self.sync_event_sender.clone(),
167 );
168
169 self.client_conn = Some(client_conn);
170 self.ws_reconnect_attempts = 0;
171
172 Ok(())
173 },
174 Err(e) => {
175 self.status = ConnectionStatus::Disconnected;
176 Err(anyhow::anyhow!("WebSocket 连接失败: {}", e))
177 },
178 }
179 },
180 Err(_) => {
181 self.status = ConnectionStatus::Disconnected;
182 Err(anyhow::anyhow!("连接超时"))
183 },
184 }
185 }
186
187 fn classify_connection_error(
189 &self,
190 error: &anyhow::Error,
191 ) -> ConnectionError {
192 let error_str = error.to_string().to_lowercase();
193
194 if error_str.contains("timeout") || error_str.contains("timed out") {
195 ConnectionError::Timeout(10000)
196 } else if error_str.contains("connection refused")
197 || error_str.contains("failed to connect")
198 {
199 ConnectionError::ServerUnavailable(
200 "服务端未启动或端口未开放".to_string(),
201 )
202 } else if error_str.contains("websocket") {
203 ConnectionError::WebSocketError(error.to_string())
204 } else {
205 ConnectionError::NetworkError(error.to_string())
206 }
207 }
208 fn update_status(
209 &mut self,
210 new_status: ConnectionStatus,
211 ) {
212 self.status = new_status.clone();
213
214 if let Some(sender) = &self.sync_event_sender {
216 let _ = sender.send(SyncEvent::ConnectionChanged(new_status));
217 }
218 }
219 pub async fn check_server_availability(&self) -> bool {
221 if let Some(ws_url) = &self.ws_url {
222 let http_url = ws_url
223 .as_str()
224 .replace("ws://", "http://")
225 .replace("wss://", "https://");
226
227 match tokio::time::timeout(
229 Duration::from_secs(3),
230 reqwest::get(&http_url),
231 )
232 .await
233 {
234 Ok(Ok(_)) => true,
235 _ => false,
236 }
237 } else {
238 false
239 }
240 }
241 pub async fn smart_connect(&mut self) -> anyhow::Result<()> {
243 if !self.check_server_availability().await {
245 self.status = ConnectionStatus::Failed(
246 ConnectionError::ServerUnavailable("服务端未启动".to_string()),
247 );
248 return Err(anyhow::anyhow!("服务端未启动或不可访问"));
249 }
250
251 self.connect_with_retry(None).await?;
253 self.setup_update_listeners().await;
254 Ok(())
255 }
256
257 async fn setup_update_listeners(&mut self) {
260 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
262
263 let doc_subscription = {
265 let sink = self.client_conn.as_ref().unwrap().sink();
266 let client_id = self.client_id.clone();
267 let awareness_lock = self.awareness.read().await;
268 let doc = awareness_lock.doc();
269 doc.observe_update_v1(move |txn, event| {
270 let origin = txn.origin();
271
272 if let Some(origin_ref) = origin {
273 let origin_bytes = origin_ref.as_ref();
274 if let Ok(origin_str) = std::str::from_utf8(origin_bytes) {
275 let update = event.update.to_owned();
276 if origin_str == client_id.to_string() {
277 let sink_weak = sink.clone();
278 tokio::spawn(async move {
279 let msg =
280 Message::Sync(SyncMessage::Update(update))
281 .encode_v1();
282 let binding = sink_weak.upgrade().unwrap();
283 let mut sink = binding.lock().await;
284 sink.send(msg).await.unwrap();
285 });
286 }
287 }
288 }
289 })
290 };
291
292 if let Ok(subscription) = doc_subscription {
294 self.subscriptions.push(subscription);
295 }
296
297 {
300 let awareness_lock = self.awareness.write().await;
301 let sink = self.client_conn.as_ref().unwrap().sink();
302
303 let awareness_subscription =
305 awareness_lock.on_update(move |event| {
306 let awareness_update = event.awareness_update().unwrap();
307 let sink_weak = sink.clone();
308 tokio::spawn(async move {
309 let msg: Vec<u8> =
310 Message::Awareness(awareness_update).encode_v1();
311 let binding = sink_weak.upgrade().unwrap();
312 let mut sink = binding.lock().await;
313 sink.send(msg).await.unwrap();
314 });
315 });
316 self.subscriptions.push(awareness_subscription);
317 tracing::info!("✅ 本地 Awareness 变更监听器已设置");
318 }
319 }
320
321 pub async fn wait_for_protocol_sync(
323 &self,
324 timeout_ms: u64,
325 ) -> anyhow::Result<bool> {
326 match &self.client_conn {
327 Some(conn) => Ok(conn.wait_for_initial_sync(timeout_ms).await),
328 None => Err(anyhow::anyhow!("连接未建立")),
329 }
330 }
331
332 pub async fn get_protocol_sync_state(&self) -> Option<ProtocolSyncState> {
334 match &self.client_conn {
335 Some(conn) => Some(conn.get_protocol_sync_state().await),
336 None => None,
337 }
338 }
339
340 pub fn subscribe_sync_events(&mut self) -> Option<SyncEventReceiver> {
342 self.sync_event_receiver.take()
343 }
344
345 pub async fn disconnect(&mut self) {
347 tracing::info!("🔌 断开 WebSocket 连接...");
348
349 self.client_conn = None;
351 self.status = ConnectionStatus::Disconnected;
352 tracing::info!("✅ WebSocket 连接已断开");
353 }
354
355 pub fn is_connected(&self) -> bool {
357 self.status == ConnectionStatus::Connected && self.client_conn.is_some()
358 }
359
360 pub fn get_status(&self) -> &ConnectionStatus {
362 &self.status
363 }
364}
365
366impl Drop for WebsocketProvider {
367 fn drop(&mut self) {
368 tracing::debug!("🧹 WebsocketProvider 已清理");
370 }
371}