1use super::types::*;
6use std::sync::atomic::{AtomicBool, Ordering};
7use std::sync::Arc;
8use std::time::Duration;
9use tokio::sync::{broadcast, mpsc};
10use tokio::time::interval;
11
12#[derive(Debug, Clone)]
14pub struct ConnectionConfig {
15 pub url: String,
17 pub auth_token: Option<String>,
19 pub session_id: String,
21 pub heartbeat_interval: u64,
23 pub reconnect_delay: u64,
25 pub max_reconnect_attempts: u32,
27 pub connect_timeout: u64,
29}
30
31impl Default for ConnectionConfig {
32 fn default() -> Self {
33 Self {
34 url: String::new(),
35 auth_token: None,
36 session_id: String::new(),
37 heartbeat_interval: 30,
38 reconnect_delay: 5,
39 max_reconnect_attempts: 10,
40 connect_timeout: 30,
41 }
42 }
43}
44
45#[derive(Debug, Clone)]
47pub enum ConnectionEvent {
48 Connected,
50 Disconnected,
52 Reconnecting { attempt: u32 },
54 Message(RemoteMessage),
56 Error(String),
58}
59
60pub struct WebSocketManager {
62 config: ConnectionConfig,
64 connected: Arc<AtomicBool>,
66 event_tx: broadcast::Sender<ConnectionEvent>,
68 outgoing_tx: Option<mpsc::Sender<RemoteMessage>>,
70 stop_tx: Option<mpsc::Sender<()>>,
72}
73
74impl WebSocketManager {
75 pub fn new(config: ConnectionConfig) -> Self {
77 let (event_tx, _) = broadcast::channel(100);
78
79 Self {
80 config,
81 connected: Arc::new(AtomicBool::new(false)),
82 event_tx,
83 outgoing_tx: None,
84 stop_tx: None,
85 }
86 }
87
88 pub fn subscribe(&self) -> broadcast::Receiver<ConnectionEvent> {
90 self.event_tx.subscribe()
91 }
92
93 pub fn is_connected(&self) -> bool {
95 self.connected.load(Ordering::SeqCst)
96 }
97
98 pub async fn send(&self, message: RemoteMessage) -> anyhow::Result<()> {
100 let tx = self
101 .outgoing_tx
102 .as_ref()
103 .ok_or_else(|| anyhow::anyhow!("未连接"))?;
104 tx.send(message).await?;
105 Ok(())
106 }
107
108 pub async fn connect(&mut self) -> anyhow::Result<()> {
110 let mut attempts = 0;
111
112 loop {
113 match self.try_connect().await {
114 Ok(_) => {
115 self.connected.store(true, Ordering::SeqCst);
116 let _ = self.event_tx.send(ConnectionEvent::Connected);
117 return Ok(());
118 }
119 Err(e) => {
120 attempts += 1;
121 if attempts >= self.config.max_reconnect_attempts {
122 let _ = self.event_tx.send(ConnectionEvent::Error(e.to_string()));
123 return Err(e);
124 }
125
126 let _ = self
127 .event_tx
128 .send(ConnectionEvent::Reconnecting { attempt: attempts });
129 tokio::time::sleep(Duration::from_secs(self.config.reconnect_delay)).await;
130 }
131 }
132 }
133 }
134
135 async fn try_connect(&mut self) -> anyhow::Result<()> {
137 let ws_url = self.build_websocket_url()?;
139
140 let (outgoing_tx, outgoing_rx) = mpsc::channel::<RemoteMessage>(100);
142 let (stop_tx, mut stop_rx) = mpsc::channel::<()>(1);
143
144 self.outgoing_tx = Some(outgoing_tx);
145 self.stop_tx = Some(stop_tx);
146
147 let heartbeat_interval = self.config.heartbeat_interval;
149 let session_id = self.config.session_id.clone();
150 let event_tx = self.event_tx.clone();
151
152 let _ = outgoing_rx;
154 let connected = Arc::clone(&self.connected);
155
156 tokio::spawn(async move {
157 let mut ticker = interval(Duration::from_secs(heartbeat_interval));
158
159 loop {
160 tokio::select! {
161 _ = ticker.tick() => {
162 if connected.load(Ordering::SeqCst) {
163 let heartbeat = RemoteMessage {
164 message_type: RemoteMessageType::Heartbeat,
165 id: None,
166 session_id: session_id.clone(),
167 timestamp: chrono::Utc::now().to_rfc3339(),
168 payload: serde_json::json!({}),
169 };
170 let _ = event_tx.send(ConnectionEvent::Message(heartbeat));
171 }
172 }
173 _ = stop_rx.recv() => {
174 break;
175 }
176 }
177 }
178 });
179
180 tracing::info!("连接到 WebSocket: {}", ws_url);
183
184 Ok(())
185 }
186
187 pub async fn disconnect(&mut self) {
189 if let Some(tx) = self.stop_tx.take() {
191 let _ = tx.send(()).await;
192 }
193
194 self.connected.store(false, Ordering::SeqCst);
195 self.outgoing_tx = None;
196
197 let _ = self.event_tx.send(ConnectionEvent::Disconnected);
198 }
199
200 fn build_websocket_url(&self) -> anyhow::Result<String> {
202 let mut url = self.config.url.clone();
203
204 if url.is_empty() {
205 anyhow::bail!("WebSocket URL 为空");
206 }
207
208 if url.starts_with("http://") {
210 url = url.replace("http://", "ws://");
211 } else if url.starts_with("https://") {
212 url = url.replace("https://", "wss://");
213 } else if !url.starts_with("ws://") && !url.starts_with("wss://") {
214 url = format!("wss://{}", url);
215 }
216
217 if !url.contains("/teleport/") {
219 url = format!(
220 "{}/teleport/{}",
221 url.trim_end_matches('/'),
222 self.config.session_id
223 );
224 }
225
226 Ok(url)
227 }
228}
229
230pub async fn connect_to_remote_session(
232 session_id: &str,
233 ingress_url: Option<&str>,
234 auth_token: Option<&str>,
235) -> anyhow::Result<WebSocketManager> {
236 let url = ingress_url
238 .map(|s| s.to_string())
239 .or_else(|| std::env::var("ASTER_TELEPORT_URL").ok())
240 .ok_or_else(|| anyhow::anyhow!("未提供远程服务器 URL"))?;
241
242 let config = ConnectionConfig {
243 url,
244 auth_token: auth_token.map(|s| s.to_string()),
245 session_id: session_id.to_string(),
246 ..Default::default()
247 };
248
249 let mut manager = WebSocketManager::new(config);
250 manager.connect().await?;
251
252 Ok(manager)
253}
254
255pub async fn can_teleport_to_session(_session_id: &str) -> bool {
257 super::validation::get_current_repo_url().await.is_some()
259}
260
261#[cfg(test)]
262mod tests {
263 use super::*;
264
265 #[test]
266 fn test_connection_config_default() {
267 let config = ConnectionConfig::default();
268 assert!(config.url.is_empty());
269 assert!(config.auth_token.is_none());
270 assert!(config.session_id.is_empty());
271 assert_eq!(config.heartbeat_interval, 30);
272 assert_eq!(config.reconnect_delay, 5);
273 assert_eq!(config.max_reconnect_attempts, 10);
274 assert_eq!(config.connect_timeout, 30);
275 }
276
277 #[test]
278 fn test_connection_config_custom() {
279 let config = ConnectionConfig {
280 url: "wss://example.com".to_string(),
281 auth_token: Some("token".to_string()),
282 session_id: "session-1".to_string(),
283 heartbeat_interval: 60,
284 reconnect_delay: 10,
285 max_reconnect_attempts: 5,
286 connect_timeout: 60,
287 };
288 assert_eq!(config.url, "wss://example.com");
289 assert_eq!(config.heartbeat_interval, 60);
290 }
291
292 #[test]
293 fn test_websocket_manager_new() {
294 let config = ConnectionConfig {
295 url: "wss://example.com".to_string(),
296 session_id: "test".to_string(),
297 ..Default::default()
298 };
299 let manager = WebSocketManager::new(config);
300 assert!(!manager.is_connected());
301 }
302
303 #[test]
304 fn test_websocket_manager_subscribe() {
305 let config = ConnectionConfig::default();
306 let manager = WebSocketManager::new(config);
307 let _rx = manager.subscribe();
308 }
310
311 #[test]
312 fn test_websocket_manager_is_connected() {
313 let config = ConnectionConfig::default();
314 let manager = WebSocketManager::new(config);
315 assert!(!manager.is_connected());
316 }
317
318 #[test]
319 fn test_websocket_manager_build_url_http() {
320 let config = ConnectionConfig {
321 url: "http://example.com".to_string(),
322 session_id: "test".to_string(),
323 ..Default::default()
324 };
325 let manager = WebSocketManager::new(config);
326 let url = manager.build_websocket_url().unwrap();
327 assert!(url.starts_with("ws://"));
328 assert!(url.contains("/teleport/test"));
329 }
330
331 #[test]
332 fn test_websocket_manager_build_url_https() {
333 let config = ConnectionConfig {
334 url: "https://example.com".to_string(),
335 session_id: "test".to_string(),
336 ..Default::default()
337 };
338 let manager = WebSocketManager::new(config);
339 let url = manager.build_websocket_url().unwrap();
340 assert!(url.starts_with("wss://"));
341 }
342
343 #[test]
344 fn test_websocket_manager_build_url_ws() {
345 let config = ConnectionConfig {
346 url: "ws://example.com".to_string(),
347 session_id: "test".to_string(),
348 ..Default::default()
349 };
350 let manager = WebSocketManager::new(config);
351 let url = manager.build_websocket_url().unwrap();
352 assert!(url.starts_with("ws://"));
353 }
354
355 #[test]
356 fn test_websocket_manager_build_url_no_protocol() {
357 let config = ConnectionConfig {
358 url: "example.com".to_string(),
359 session_id: "test".to_string(),
360 ..Default::default()
361 };
362 let manager = WebSocketManager::new(config);
363 let url = manager.build_websocket_url().unwrap();
364 assert!(url.starts_with("wss://"));
365 }
366
367 #[test]
368 fn test_websocket_manager_build_url_empty() {
369 let config = ConnectionConfig {
370 url: "".to_string(),
371 session_id: "test".to_string(),
372 ..Default::default()
373 };
374 let manager = WebSocketManager::new(config);
375 let result = manager.build_websocket_url();
376 assert!(result.is_err());
377 }
378
379 #[test]
380 fn test_websocket_manager_build_url_with_teleport_path() {
381 let config = ConnectionConfig {
382 url: "wss://example.com/teleport/existing".to_string(),
383 session_id: "test".to_string(),
384 ..Default::default()
385 };
386 let manager = WebSocketManager::new(config);
387 let url = manager.build_websocket_url().unwrap();
388 assert!(!url.contains("/teleport/test"));
390 }
391
392 #[test]
393 fn test_connection_event_variants() {
394 let events = [
395 ConnectionEvent::Connected,
396 ConnectionEvent::Disconnected,
397 ConnectionEvent::Reconnecting { attempt: 1 },
398 ConnectionEvent::Message(RemoteMessage {
399 message_type: RemoteMessageType::Heartbeat,
400 id: None,
401 session_id: "test".to_string(),
402 payload: serde_json::json!({}),
403 timestamp: "2026-01-14".to_string(),
404 }),
405 ConnectionEvent::Error("error".to_string()),
406 ];
407 assert_eq!(events.len(), 5);
408 }
409
410 #[tokio::test]
411 async fn test_can_teleport_to_session() {
412 let can = can_teleport_to_session("test-session").await;
413 println!("Can teleport: {}", can);
415 }
416
417 #[tokio::test]
418 async fn test_websocket_manager_send_not_connected() {
419 let config = ConnectionConfig::default();
420 let manager = WebSocketManager::new(config);
421 let msg = RemoteMessage {
422 message_type: RemoteMessageType::Message,
423 id: None,
424 session_id: "test".to_string(),
425 payload: serde_json::json!({}),
426 timestamp: "2026-01-14".to_string(),
427 };
428 let result = manager.send(msg).await;
429 assert!(result.is_err());
430 }
431}