1use std::time::Duration;
2
3use tokio::time::timeout;
4use tokio_tungstenite::connect_async;
5use yrs::sync::{Message, SyncMessage};
6use yrs::updates::encoder::Encode;
7use yrs::{Subscription};
8use url::Url;
9use crate::AwarenessRef;
10use crate::conn::Connection;
11use crate::types::*;
12use crate::client::{ClientSink, ClientStream};
13use futures_util::{SinkExt, StreamExt};
14
15pub struct WebsocketProvider {
16 pub server_url: String,
17 pub room_name: String,
18 pub awareness: AwarenessRef,
19 client_conn: Option<Connection<ClientSink, ClientStream>>,
20 pub status: ConnectionStatus,
21 sync_event_sender: Option<SyncEventSender>,
23 sync_event_receiver: Option<SyncEventReceiver>,
24
25 pub ws_reconnect_attempts: u32,
26 pub max_backoff_time: u64,
27 pub ws_url: Option<Url>,
28 pub client_id: u64,
29 subscriptions: Vec<Subscription>,
30}
31
32impl WebsocketProvider {
33 pub async fn new(
34 server_url: String,
35 room_name: String,
36 awareness: AwarenessRef,
37 ) -> Self {
38 let (event_sender, event_receiver) =
39 tokio::sync::broadcast::channel(100);
40
41 let ws_url = Url::parse(&format!(
42 "{}/{}",
43 server_url.trim_end_matches('/'),
44 room_name
45 ))
46 .ok();
47
48 let client_id = awareness.read().await.doc().client_id();
49
50 Self {
51 client_id,
52 server_url,
53 room_name,
54 awareness,
55 client_conn: None,
56 status: ConnectionStatus::Disconnected,
57 sync_event_sender: Some(event_sender),
58 sync_event_receiver: Some(event_receiver),
59 ws_reconnect_attempts: 0,
60 max_backoff_time: 2500,
61 ws_url,
62 subscriptions: Vec::new(),
63 }
64 }
65
66 pub fn subscription(
67 &mut self,
68 subscription: Subscription,
69 ) {
70 self.subscriptions.push(subscription);
71 }
72 pub async fn connect(&mut self) {
73 if let Err(e) = self.smart_connect().await {
74 tracing::error!("{}", e);
75 }
76 }
77 pub async fn connect_with_retry(
78 &mut self,
79 config: Option<ConnectionRetryConfig>,
80 ) -> anyhow::Result<()> {
81 let config = config.unwrap_or_default();
82 let mut attempt = 0;
83 let mut delay = config.initial_delay_ms;
84
85 while attempt < config.max_attempts {
86 attempt += 1;
87 self.update_status(ConnectionStatus::Retrying {
88 attempt,
89 max_attempts: config.max_attempts,
90 });
91
92 tracing::info!("🔄 连接尝试 {}/{}", attempt, config.max_attempts);
93
94 match self.try_connect().await {
95 Ok(()) => {
96 self.update_status(ConnectionStatus::Connected);
97 return Ok(());
98 },
99 Err(e) => {
100 let error = self.classify_connection_error(&e);
101
102 if attempt >= config.max_attempts {
103 if let Some(sender) = &self.sync_event_sender {
105 let _ = sender.send(SyncEvent::ConnectionFailed(
106 error.clone(),
107 ));
108 }
109 self.update_status(ConnectionStatus::Failed(
110 error.clone(),
111 ));
112 tracing::error!(
113 "❌ 连接失败,已达到最大重试次数: {}",
114 error
115 );
116 return Err(anyhow::anyhow!("连接失败: {}", error));
117 }
118
119 tracing::warn!(
120 "⚠️ 连接失败 (尝试 {}/{}): {}",
121 attempt,
122 config.max_attempts,
123 error
124 );
125
126 tokio::time::sleep(Duration::from_millis(delay)).await;
128 delay = (delay as f64 * config.backoff_multiplier) as u64;
129 delay = delay.min(config.max_delay_ms);
130 },
131 }
132 }
133
134 Err(anyhow::anyhow!("连接失败,已达到最大重试次数"))
135 }
136 async fn try_connect(&mut self) -> anyhow::Result<()> {
137 if self.status == ConnectionStatus::Connected
138 || self.status == ConnectionStatus::Connecting
139 {
140 return Ok(());
141 }
142
143 self.status = ConnectionStatus::Connecting;
144
145 let ws_url = match &self.ws_url {
146 Some(url) => url.as_str(),
147 None => {
148 return Err(anyhow::anyhow!("无效的 WebSocket URL"));
149 },
150 };
151
152 let connect_timeout = Duration::from_secs(10);
154
155 match timeout(connect_timeout, connect_async(ws_url)).await {
156 Ok(connect_result) => {
157 match connect_result {
158 Ok((ws_stream, _)) => {
159 let (sink, stream) = ws_stream.split();
160
161 let client_conn = Connection::new_with_sync_detection(
163 self.awareness.clone(),
164 ClientSink(sink),
165 ClientStream(stream),
166 self.sync_event_sender.clone(),
167 );
168
169 self.client_conn = Some(client_conn);
170 self.ws_reconnect_attempts = 0;
171
172 Ok(())
173 },
174 Err(e) => {
175 self.status = ConnectionStatus::Disconnected;
176 Err(anyhow::anyhow!("WebSocket 连接失败: {}", e))
177 },
178 }
179 },
180 Err(_) => {
181 self.status = ConnectionStatus::Disconnected;
182 Err(anyhow::anyhow!("连接超时"))
183 },
184 }
185 }
186
187 fn clear_subscriptions(&mut self) {
189 if !self.subscriptions.is_empty() {
190 tracing::debug!(
191 count = self.subscriptions.len(),
192 "🧹 正在清理订阅监听器: {} 个",
193 self.subscriptions.len()
194 );
195 }
196 self.subscriptions.drain(..);
198 }
199
200 fn classify_connection_error(
202 &self,
203 error: &anyhow::Error,
204 ) -> ConnectionError {
205 let error_str = error.to_string().to_lowercase();
206
207 if error_str.contains("timeout") || error_str.contains("timed out") {
208 ConnectionError::Timeout(10000)
209 } else if error_str.contains("connection refused")
210 || error_str.contains("failed to connect")
211 {
212 ConnectionError::ServerUnavailable(
213 "服务端未启动或端口未开放".to_string(),
214 )
215 } else if error_str.contains("websocket") {
216 ConnectionError::WebSocketError(error.to_string())
217 } else {
218 ConnectionError::NetworkError(error.to_string())
219 }
220 }
221 fn update_status(
222 &mut self,
223 new_status: ConnectionStatus,
224 ) {
225 self.status = new_status.clone();
226
227 if let Some(sender) = &self.sync_event_sender {
229 let _ = sender.send(SyncEvent::ConnectionChanged(new_status));
230 }
231 }
232 pub async fn check_server_availability(&self) -> bool {
234 if let Some(ws_url) = &self.ws_url {
235 let http_url = ws_url
236 .as_str()
237 .replace("ws://", "http://")
238 .replace("wss://", "https://");
239
240 matches!(
242 tokio::time::timeout(
243 Duration::from_secs(3),
244 reqwest::get(&http_url),
245 )
246 .await,
247 Ok(Ok(_))
248 )
249 } else {
250 false
251 }
252 }
253 pub async fn smart_connect(&mut self) -> anyhow::Result<()> {
255 if !self.check_server_availability().await {
257 self.status = ConnectionStatus::Failed(
258 ConnectionError::ServerUnavailable("服务端未启动".to_string()),
259 );
260 return Err(anyhow::anyhow!("服务端未启动或不可访问"));
261 }
262
263 self.connect_with_retry(None).await?;
265 self.setup_update_listeners().await;
266 Ok(())
267 }
268
269 async fn setup_update_listeners(&mut self) {
272 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
274
275 let conn = match self.client_conn.as_ref() {
277 Some(conn) => conn,
278 None => {
279 tracing::error!("尝试设置监听器时客户端连接不存在");
280 return;
281 },
282 };
283
284 let doc_subscription = {
286 let sink = conn.sink();
287 let client_id = self.client_id;
288 let awareness_lock = self.awareness.read().await;
289 let doc = awareness_lock.doc();
290 doc.observe_update_v1(move |txn, event| {
291 let origin = txn.origin();
292
293 if let Some(origin_ref) = origin {
294 let origin_bytes = origin_ref.as_ref();
295 if let Ok(origin_str) = std::str::from_utf8(origin_bytes) {
296 let update = event.update.to_owned();
297 if origin_str == client_id.to_string() {
298 let sink_weak = sink.clone();
299 tokio::spawn(async move {
300 let msg =
301 Message::Sync(SyncMessage::Update(update))
302 .encode_v1();
303 if let Some(binding) = sink_weak.upgrade() {
304 let mut sink = binding.lock().await;
305 if let Err(e) = sink.send(msg).await {
306 tracing::debug!(
307 "忽略发送错误(可能已断开): {}",
308 e
309 );
310 }
311 } else {
312 tracing::debug!(
313 "发送通道已释放(可能已断开),跳过文档更新发送"
314 );
315 }
316 });
317 }
318 }
319 }
320 })
321 };
322
323 if let Ok(subscription) = doc_subscription {
325 self.subscriptions.push(subscription);
326 }
327
328 {
331 let awareness_lock = self.awareness.write().await;
332 let conn = match self.client_conn.as_ref() {
334 Some(conn) => conn,
335 None => {
336 tracing::error!(
337 "尝试设置 awareness 监听器时客户端连接不存在"
338 );
339 return;
340 },
341 };
342 let sink = conn.sink();
343
344 let awareness_subscription = awareness_lock.on_update(move |event| {
346 if let Some(awareness_update) = event.awareness_update() {
347 let sink_weak = sink.clone();
348 tokio::spawn(async move {
349 let msg: Vec<u8> =
350 Message::Awareness(awareness_update).encode_v1();
351 if let Some(binding) = sink_weak.upgrade() {
352 let mut sink = binding.lock().await;
353 if let Err(e) = sink.send(msg).await {
354 tracing::debug!(
355 "忽略发送错误(可能已断开): {}",
356 e
357 );
358 }
359 } else {
360 tracing::debug!(
361 "发送通道已释放(可能已断开),跳过 Awareness 发送"
362 );
363 }
364 });
365 }
366 });
367 self.subscriptions.push(awareness_subscription);
368 tracing::info!("✅ 本地 Awareness 变更监听器已设置");
369 }
370 }
371
372 pub async fn wait_for_protocol_sync(
374 &self,
375 timeout_ms: u64,
376 ) -> anyhow::Result<bool> {
377 match &self.client_conn {
378 Some(conn) => Ok(conn.wait_for_initial_sync(timeout_ms).await),
379 None => Err(anyhow::anyhow!("连接未建立")),
380 }
381 }
382
383 pub async fn get_protocol_sync_state(&self) -> Option<ProtocolSyncState> {
385 match &self.client_conn {
386 Some(conn) => Some(conn.get_protocol_sync_state().await),
387 None => None,
388 }
389 }
390
391 pub fn subscribe_sync_events(&mut self) -> Option<SyncEventReceiver> {
393 self.sync_event_receiver.take()
394 }
395
396 pub async fn disconnect(&mut self) {
398 tracing::info!("🔌 断开 WebSocket 连接...");
399
400 self.clear_subscriptions();
402
403 if let Some(conn) = self.client_conn.take() {
405 if let Err(e) = conn.close().await {
406 tracing::debug!("关闭连接时出现错误(忽略): {:?}", e);
407 }
408 }
409
410 self.update_status(ConnectionStatus::Disconnected);
412 tracing::info!("✅ WebSocket 连接已断开且监听器已清理");
413 }
414
415 pub fn is_connected(&self) -> bool {
417 self.status == ConnectionStatus::Connected && self.client_conn.is_some()
418 }
419
420 pub fn get_status(&self) -> &ConnectionStatus {
422 &self.status
423 }
424}
425
426impl Drop for WebsocketProvider {
427 fn drop(&mut self) {
428 self.clear_subscriptions();
430 tracing::debug!("🧹 WebsocketProvider 已清理(订阅监听器已释放)");
431 }
432}