ping_example/
ping_example.rs

1use mcprotocol_rs::{
2    transport::{ClientTransportFactory, ServerTransportFactory, TransportConfig, TransportType},
3    Message, Method, Notification, Request, RequestId, Result,
4};
5use std::{collections::HashSet, time::Duration};
6use tokio::{self, time::sleep, time::timeout};
7
8// 调整超时设置以匹配服务器端配置
9// Adjust timeout settings to match server configuration
10const PING_INTERVAL: Duration = Duration::from_secs(60); // 每分钟发送一次 ping 以保持活跃
11const PING_TIMEOUT: Duration = Duration::from_secs(2);
12const CONNECTION_TIMEOUT: Duration = Duration::from_secs(5);
13const SERVER_TIMEOUT: Duration = Duration::from_secs(300); // 5 分钟服务器超时
14const SERVER_PORT: u16 = 3000;
15const SERVER_URL: &str = "127.0.0.1:3000";
16const AUTH_TOKEN: &str = "test-auth-token";
17
18#[tokio::main]
19async fn main() -> Result<()> {
20    // 启动服务器
21    // Start server
22    let server_handle = tokio::spawn(run_server());
23
24    // 等待服务器启动
25    // Wait for server to start
26    sleep(Duration::from_millis(100)).await;
27
28    // 启动客户端
29    // Start client
30    let client_handle = tokio::spawn(run_client());
31
32    // 等待客户端和服务器完成
33    // Wait for client and server to complete
34    match tokio::try_join!(server_handle, client_handle) {
35        Ok((server_result, client_result)) => {
36            server_result?;
37            client_result?;
38            Ok(())
39        }
40        Err(e) => {
41            eprintln!("Error in task execution: {}", e);
42            Err(mcprotocol_rs::Error::Transport(e.to_string()))
43        }
44    }
45}
46
47async fn run_server() -> Result<()> {
48    // 配置服务器
49    // Configure server
50    let config = TransportConfig {
51        transport_type: TransportType::Http {
52            base_url: SERVER_URL.to_string(),
53            auth_token: Some(AUTH_TOKEN.to_string()),
54        },
55        parameters: None,
56    };
57
58    // 创建服务器实例
59    // Create server instance
60    let factory = ServerTransportFactory;
61    let mut server = factory.create(config)?;
62
63    // 初始化服务器
64    // Initialize server
65    server.initialize().await?;
66    eprintln!(
67        "Server started and waiting for ping requests on port {}",
68        SERVER_PORT
69    );
70
71    // 等待退出信号或超时
72    // Wait for exit signal or timeout
73    let (tx, mut rx) = tokio::sync::oneshot::channel::<()>();
74
75    let exit_signal = async move {
76        rx.await.ok();
77    };
78
79    tokio::select! {
80        _ = exit_signal => {
81            eprintln!("Server received exit signal");
82        }
83        _ = tokio::time::sleep(SERVER_TIMEOUT) => {
84            eprintln!("Server timeout after {} seconds", SERVER_TIMEOUT.as_secs());
85        }
86    }
87
88    server.close().await?;
89    eprintln!("Server stopped");
90    Ok(())
91}
92
93async fn run_client() -> Result<()> {
94    // 跟踪会话中使用的请求 ID
95    // Track request IDs used in the session
96    let mut session_ids = HashSet::new();
97    let mut ping_count = 0;
98    let total_pings = 3;
99
100    // 配置客户端
101    // Configure client
102    let config = TransportConfig {
103        transport_type: TransportType::Http {
104            base_url: format!("http://{}", SERVER_URL),
105            auth_token: Some(AUTH_TOKEN.to_string()),
106        },
107        parameters: None,
108    };
109
110    // 创建客户端实例
111    // Create client instance
112    let factory = ClientTransportFactory;
113    let mut client = factory.create(config)?;
114
115    // 初始化客户端
116    // Initialize client
117    match timeout(CONNECTION_TIMEOUT, client.initialize()).await {
118        Ok(result) => result?,
119        Err(_) => {
120            return Err(mcprotocol_rs::Error::Transport(
121                "Client initialization timeout".into(),
122            ))
123        }
124    }
125    eprintln!("Client started");
126
127    // 发送 ping 请求并保持连接活跃
128    // Send ping requests and keep connection alive
129    let start_time = std::time::Instant::now();
130
131    while ping_count < total_pings {
132        // 检查是否接近服务器超时时间
133        // Check if approaching server timeout
134        if start_time.elapsed() > SERVER_TIMEOUT - Duration::from_secs(30) {
135            eprintln!("Approaching server timeout, ending session");
136            break;
137        }
138
139        // 发送 ping 请求
140        // Send ping request
141        let request_id = RequestId::String(format!("ping-{}", ping_count + 1));
142        let ping_request = Request::new(Method::Ping, None, request_id.clone());
143
144        // 验证请求 ID 的唯一性
145        // Validate request ID uniqueness
146        if !ping_request.validate_id_uniqueness(&mut session_ids) {
147            eprintln!("Request ID has already been used in this session");
148            break;
149        }
150
151        eprintln!("Sending ping request #{}", ping_count + 1);
152        client.send(Message::Request(ping_request.clone())).await?;
153
154        // 等待 pong 响应,带超时
155        // Wait for pong response with timeout
156        match timeout(PING_TIMEOUT, client.receive()).await {
157            Ok(Ok(Message::Response(response))) => {
158                if !request_id_matches(&request_id, &response.id) {
159                    eprintln!(
160                        "Received response with mismatched ID: expected {}, got {}",
161                        request_id_to_string(&request_id),
162                        request_id_to_string(&response.id)
163                    );
164                    continue;
165                }
166
167                if response.error.is_some() {
168                    eprintln!("Received error response: {:?}", response.error);
169                    break;
170                }
171                eprintln!("Received pong response #{}", ping_count + 1);
172            }
173            Ok(Ok(message)) => {
174                eprintln!("Unexpected message type: {:?}", message);
175                continue;
176            }
177            Ok(Err(e)) => {
178                eprintln!("Error receiving response: {}", e);
179                break;
180            }
181            Err(_) => {
182                eprintln!("Ping timeout for request #{}", ping_count + 1);
183                break;
184            }
185        }
186
187        ping_count += 1;
188        if ping_count < total_pings {
189            // 使用较短的间隔以避免服务器超时
190            // Use shorter interval to avoid server timeout
191            sleep(PING_INTERVAL.min(Duration::from_secs(30))).await;
192        }
193    }
194
195    // 发送关闭请求
196    // Send shutdown request
197    if ping_count == total_pings {
198        let shutdown_request = Request::new(
199            Method::Shutdown,
200            None,
201            RequestId::String("shutdown".to_string()),
202        );
203
204        if shutdown_request.validate_id_uniqueness(&mut session_ids) {
205            client.send(Message::Request(shutdown_request)).await?;
206
207            // 等待关闭响应
208            // Wait for shutdown response
209            match timeout(PING_TIMEOUT, client.receive()).await {
210                Ok(Ok(Message::Response(response))) => {
211                    if response.error.is_some() {
212                        eprintln!("Shutdown failed: {:?}", response.error);
213                    } else {
214                        // 发送退出通知
215                        // Send exit notification
216                        let exit_notification = Notification::new(Method::Exit, None);
217                        client
218                            .send(Message::Notification(exit_notification))
219                            .await?;
220                    }
221                }
222                Ok(Ok(_)) => eprintln!("Unexpected response type"),
223                Ok(Err(e)) => eprintln!("Error receiving shutdown response: {}", e),
224                Err(_) => eprintln!("Shutdown response timeout"),
225            }
226        } else {
227            eprintln!("Shutdown request ID has already been used in this session");
228        }
229    }
230
231    client.close().await?;
232    eprintln!("Client stopped");
233    Ok(())
234}
235
236// 辅助函数:检查请求 ID 是否匹配
237// Helper function: Check if request ID matches
238fn request_id_matches(request_id: &RequestId, response_id: &RequestId) -> bool {
239    match (request_id, response_id) {
240        (RequestId::String(req), RequestId::String(res)) => req == res,
241        (RequestId::Number(req), RequestId::Number(res)) => req == res,
242        _ => false,
243    }
244}
245
246// 辅助函数:将请求 ID 转换为字符串
247// Helper function: Convert request ID to string
248fn request_id_to_string(id: &RequestId) -> String {
249    match id {
250        RequestId::String(s) => s.clone(),
251        RequestId::Number(n) => n.to_string(),
252    }
253}