Skip to main content

aster/teleport/
connection.rs

1//! WebSocket 连接管理
2//!
3//! 提供 WebSocket 连接、心跳、断线重连等功能
4
5use super::types::*;
6use std::sync::atomic::{AtomicBool, Ordering};
7use std::sync::Arc;
8use std::time::Duration;
9use tokio::sync::{broadcast, mpsc};
10use tokio::time::interval;
11
12/// 连接配置
13#[derive(Debug, Clone)]
14pub struct ConnectionConfig {
15    /// WebSocket URL
16    pub url: String,
17    /// 认证令牌
18    pub auth_token: Option<String>,
19    /// 会话 ID
20    pub session_id: String,
21    /// 心跳间隔(秒)
22    pub heartbeat_interval: u64,
23    /// 重连延迟(秒)
24    pub reconnect_delay: u64,
25    /// 最大重连次数
26    pub max_reconnect_attempts: u32,
27    /// 连接超时(秒)
28    pub connect_timeout: u64,
29}
30
31impl Default for ConnectionConfig {
32    fn default() -> Self {
33        Self {
34            url: String::new(),
35            auth_token: None,
36            session_id: String::new(),
37            heartbeat_interval: 30,
38            reconnect_delay: 5,
39            max_reconnect_attempts: 10,
40            connect_timeout: 30,
41        }
42    }
43}
44
45/// 连接事件
46#[derive(Debug, Clone)]
47pub enum ConnectionEvent {
48    /// 已连接
49    Connected,
50    /// 已断开
51    Disconnected,
52    /// 重连中
53    Reconnecting { attempt: u32 },
54    /// 收到消息
55    Message(RemoteMessage),
56    /// 错误
57    Error(String),
58}
59
60/// WebSocket 连接管理器
61pub struct WebSocketManager {
62    /// 配置
63    config: ConnectionConfig,
64    /// 是否已连接
65    connected: Arc<AtomicBool>,
66    /// 事件发送器
67    event_tx: broadcast::Sender<ConnectionEvent>,
68    /// 消息发送通道
69    outgoing_tx: Option<mpsc::Sender<RemoteMessage>>,
70    /// 停止信号
71    stop_tx: Option<mpsc::Sender<()>>,
72}
73
74impl WebSocketManager {
75    /// 创建新的连接管理器
76    pub fn new(config: ConnectionConfig) -> Self {
77        let (event_tx, _) = broadcast::channel(100);
78
79        Self {
80            config,
81            connected: Arc::new(AtomicBool::new(false)),
82            event_tx,
83            outgoing_tx: None,
84            stop_tx: None,
85        }
86    }
87
88    /// 订阅事件
89    pub fn subscribe(&self) -> broadcast::Receiver<ConnectionEvent> {
90        self.event_tx.subscribe()
91    }
92
93    /// 是否已连接
94    pub fn is_connected(&self) -> bool {
95        self.connected.load(Ordering::SeqCst)
96    }
97
98    /// 发送消息
99    pub async fn send(&self, message: RemoteMessage) -> anyhow::Result<()> {
100        let tx = self
101            .outgoing_tx
102            .as_ref()
103            .ok_or_else(|| anyhow::anyhow!("未连接"))?;
104        tx.send(message).await?;
105        Ok(())
106    }
107
108    /// 连接(带重连逻辑)
109    pub async fn connect(&mut self) -> anyhow::Result<()> {
110        let mut attempts = 0;
111
112        loop {
113            match self.try_connect().await {
114                Ok(_) => {
115                    self.connected.store(true, Ordering::SeqCst);
116                    let _ = self.event_tx.send(ConnectionEvent::Connected);
117                    return Ok(());
118                }
119                Err(e) => {
120                    attempts += 1;
121                    if attempts >= self.config.max_reconnect_attempts {
122                        let _ = self.event_tx.send(ConnectionEvent::Error(e.to_string()));
123                        return Err(e);
124                    }
125
126                    let _ = self
127                        .event_tx
128                        .send(ConnectionEvent::Reconnecting { attempt: attempts });
129                    tokio::time::sleep(Duration::from_secs(self.config.reconnect_delay)).await;
130                }
131            }
132        }
133    }
134
135    /// 尝试连接一次
136    async fn try_connect(&mut self) -> anyhow::Result<()> {
137        // 构建 WebSocket URL
138        let ws_url = self.build_websocket_url()?;
139
140        // 创建通道
141        let (outgoing_tx, outgoing_rx) = mpsc::channel::<RemoteMessage>(100);
142        let (stop_tx, mut stop_rx) = mpsc::channel::<()>(1);
143
144        self.outgoing_tx = Some(outgoing_tx);
145        self.stop_tx = Some(stop_tx);
146
147        // 启动心跳任务
148        let heartbeat_interval = self.config.heartbeat_interval;
149        let session_id = self.config.session_id.clone();
150        let event_tx = self.event_tx.clone();
151
152        // 标记 outgoing_rx 为使用(实际连接逻辑待实现)
153        let _ = outgoing_rx;
154        let connected = Arc::clone(&self.connected);
155
156        tokio::spawn(async move {
157            let mut ticker = interval(Duration::from_secs(heartbeat_interval));
158
159            loop {
160                tokio::select! {
161                    _ = ticker.tick() => {
162                        if connected.load(Ordering::SeqCst) {
163                            let heartbeat = RemoteMessage {
164                                message_type: RemoteMessageType::Heartbeat,
165                                id: None,
166                                session_id: session_id.clone(),
167                                timestamp: chrono::Utc::now().to_rfc3339(),
168                                payload: serde_json::json!({}),
169                            };
170                            let _ = event_tx.send(ConnectionEvent::Message(heartbeat));
171                        }
172                    }
173                    _ = stop_rx.recv() => {
174                        break;
175                    }
176                }
177            }
178        });
179
180        // TODO: 实际的 WebSocket 连接
181        // 这里是框架代码,实际需要使用 tokio-tungstenite
182        tracing::info!("连接到 WebSocket: {}", ws_url);
183
184        Ok(())
185    }
186
187    /// 断开连接
188    pub async fn disconnect(&mut self) {
189        // 发送停止信号
190        if let Some(tx) = self.stop_tx.take() {
191            let _ = tx.send(()).await;
192        }
193
194        self.connected.store(false, Ordering::SeqCst);
195        self.outgoing_tx = None;
196
197        let _ = self.event_tx.send(ConnectionEvent::Disconnected);
198    }
199
200    /// 构建 WebSocket URL
201    fn build_websocket_url(&self) -> anyhow::Result<String> {
202        let mut url = self.config.url.clone();
203
204        if url.is_empty() {
205            anyhow::bail!("WebSocket URL 为空");
206        }
207
208        // 转换协议
209        if url.starts_with("http://") {
210            url = url.replace("http://", "ws://");
211        } else if url.starts_with("https://") {
212            url = url.replace("https://", "wss://");
213        } else if !url.starts_with("ws://") && !url.starts_with("wss://") {
214            url = format!("wss://{}", url);
215        }
216
217        // 添加会话路径
218        if !url.contains("/teleport/") {
219            url = format!(
220                "{}/teleport/{}",
221                url.trim_end_matches('/'),
222                self.config.session_id
223            );
224        }
225
226        Ok(url)
227    }
228}
229
230/// 便捷函数:连接到远程会话
231pub async fn connect_to_remote_session(
232    session_id: &str,
233    ingress_url: Option<&str>,
234    auth_token: Option<&str>,
235) -> anyhow::Result<WebSocketManager> {
236    // 从环境变量获取 URL
237    let url = ingress_url
238        .map(|s| s.to_string())
239        .or_else(|| std::env::var("ASTER_TELEPORT_URL").ok())
240        .ok_or_else(|| anyhow::anyhow!("未提供远程服务器 URL"))?;
241
242    let config = ConnectionConfig {
243        url,
244        auth_token: auth_token.map(|s| s.to_string()),
245        session_id: session_id.to_string(),
246        ..Default::default()
247    };
248
249    let mut manager = WebSocketManager::new(config);
250    manager.connect().await?;
251
252    Ok(manager)
253}
254
255/// 检查会话是否可以进行 teleport
256pub async fn can_teleport_to_session(_session_id: &str) -> bool {
257    // 检查是否在 git 仓库中
258    super::validation::get_current_repo_url().await.is_some()
259}
260
261#[cfg(test)]
262mod tests {
263    use super::*;
264
265    #[test]
266    fn test_connection_config_default() {
267        let config = ConnectionConfig::default();
268        assert!(config.url.is_empty());
269        assert!(config.auth_token.is_none());
270        assert!(config.session_id.is_empty());
271        assert_eq!(config.heartbeat_interval, 30);
272        assert_eq!(config.reconnect_delay, 5);
273        assert_eq!(config.max_reconnect_attempts, 10);
274        assert_eq!(config.connect_timeout, 30);
275    }
276
277    #[test]
278    fn test_connection_config_custom() {
279        let config = ConnectionConfig {
280            url: "wss://example.com".to_string(),
281            auth_token: Some("token".to_string()),
282            session_id: "session-1".to_string(),
283            heartbeat_interval: 60,
284            reconnect_delay: 10,
285            max_reconnect_attempts: 5,
286            connect_timeout: 60,
287        };
288        assert_eq!(config.url, "wss://example.com");
289        assert_eq!(config.heartbeat_interval, 60);
290    }
291
292    #[test]
293    fn test_websocket_manager_new() {
294        let config = ConnectionConfig {
295            url: "wss://example.com".to_string(),
296            session_id: "test".to_string(),
297            ..Default::default()
298        };
299        let manager = WebSocketManager::new(config);
300        assert!(!manager.is_connected());
301    }
302
303    #[test]
304    fn test_websocket_manager_subscribe() {
305        let config = ConnectionConfig::default();
306        let manager = WebSocketManager::new(config);
307        let _rx = manager.subscribe();
308        // 应该能订阅
309    }
310
311    #[test]
312    fn test_websocket_manager_is_connected() {
313        let config = ConnectionConfig::default();
314        let manager = WebSocketManager::new(config);
315        assert!(!manager.is_connected());
316    }
317
318    #[test]
319    fn test_websocket_manager_build_url_http() {
320        let config = ConnectionConfig {
321            url: "http://example.com".to_string(),
322            session_id: "test".to_string(),
323            ..Default::default()
324        };
325        let manager = WebSocketManager::new(config);
326        let url = manager.build_websocket_url().unwrap();
327        assert!(url.starts_with("ws://"));
328        assert!(url.contains("/teleport/test"));
329    }
330
331    #[test]
332    fn test_websocket_manager_build_url_https() {
333        let config = ConnectionConfig {
334            url: "https://example.com".to_string(),
335            session_id: "test".to_string(),
336            ..Default::default()
337        };
338        let manager = WebSocketManager::new(config);
339        let url = manager.build_websocket_url().unwrap();
340        assert!(url.starts_with("wss://"));
341    }
342
343    #[test]
344    fn test_websocket_manager_build_url_ws() {
345        let config = ConnectionConfig {
346            url: "ws://example.com".to_string(),
347            session_id: "test".to_string(),
348            ..Default::default()
349        };
350        let manager = WebSocketManager::new(config);
351        let url = manager.build_websocket_url().unwrap();
352        assert!(url.starts_with("ws://"));
353    }
354
355    #[test]
356    fn test_websocket_manager_build_url_no_protocol() {
357        let config = ConnectionConfig {
358            url: "example.com".to_string(),
359            session_id: "test".to_string(),
360            ..Default::default()
361        };
362        let manager = WebSocketManager::new(config);
363        let url = manager.build_websocket_url().unwrap();
364        assert!(url.starts_with("wss://"));
365    }
366
367    #[test]
368    fn test_websocket_manager_build_url_empty() {
369        let config = ConnectionConfig {
370            url: "".to_string(),
371            session_id: "test".to_string(),
372            ..Default::default()
373        };
374        let manager = WebSocketManager::new(config);
375        let result = manager.build_websocket_url();
376        assert!(result.is_err());
377    }
378
379    #[test]
380    fn test_websocket_manager_build_url_with_teleport_path() {
381        let config = ConnectionConfig {
382            url: "wss://example.com/teleport/existing".to_string(),
383            session_id: "test".to_string(),
384            ..Default::default()
385        };
386        let manager = WebSocketManager::new(config);
387        let url = manager.build_websocket_url().unwrap();
388        // 不应该重复添加 /teleport/
389        assert!(!url.contains("/teleport/test"));
390    }
391
392    #[test]
393    fn test_connection_event_variants() {
394        let events = [
395            ConnectionEvent::Connected,
396            ConnectionEvent::Disconnected,
397            ConnectionEvent::Reconnecting { attempt: 1 },
398            ConnectionEvent::Message(RemoteMessage {
399                message_type: RemoteMessageType::Heartbeat,
400                id: None,
401                session_id: "test".to_string(),
402                payload: serde_json::json!({}),
403                timestamp: "2026-01-14".to_string(),
404            }),
405            ConnectionEvent::Error("error".to_string()),
406        ];
407        assert_eq!(events.len(), 5);
408    }
409
410    #[tokio::test]
411    async fn test_can_teleport_to_session() {
412        let can = can_teleport_to_session("test-session").await;
413        // 在 git 仓库中应该返回 true
414        println!("Can teleport: {}", can);
415    }
416
417    #[tokio::test]
418    async fn test_websocket_manager_send_not_connected() {
419        let config = ConnectionConfig::default();
420        let manager = WebSocketManager::new(config);
421        let msg = RemoteMessage {
422            message_type: RemoteMessageType::Message,
423            id: None,
424            session_id: "test".to_string(),
425            payload: serde_json::json!({}),
426            timestamp: "2026-01-14".to_string(),
427        };
428        let result = manager.send(msg).await;
429        assert!(result.is_err());
430    }
431}