stdio_server/
stdio_server.rs

1use mcprotocol_rs::message;
2use mcprotocol_rs::{
3    protocol::{Message, Response},
4    transport::{ServerTransportFactory, TransportConfig, TransportType},
5    Result,
6};
7use serde_json::json;
8use std::collections::HashSet;
9
10#[tokio::main]
11async fn main() -> Result<()> {
12    // 跟踪会话中使用的请求 ID
13    // Track request IDs used in the session
14    let mut session_ids = HashSet::new();
15
16    // 配置 Stdio 服务器
17    // Configure Stdio server
18    let config = TransportConfig {
19        transport_type: TransportType::Stdio {
20            server_path: None,
21            server_args: None,
22        },
23        parameters: None,
24    };
25
26    // 创建服务器实例
27    // Create server instance
28    let factory = ServerTransportFactory;
29    let mut server = factory.create(config)?;
30
31    // 初始化服务器
32    // Initialize server
33    server.initialize().await?;
34    eprintln!("Server initialized and ready to receive messages...");
35
36    // 持续接收和处理消息
37    // Continuously receive and process messages
38    loop {
39        match server.receive().await {
40            Ok(message) => {
41                eprintln!("Received message: {:?}", message);
42
43                // 根据消息类型处理
44                // Process messages based on type
45                match message {
46                    Message::Request(request) => {
47                        // 验证请求 ID 的唯一性
48                        // Validate request ID uniqueness
49                        if !request.validate_id_uniqueness(&mut session_ids) {
50                            let error = Message::Response(Response::error(
51                                message::ResponseError {
52                                    code: message::error_codes::INVALID_REQUEST,
53                                    message: "Request ID has already been used".to_string(),
54                                    data: None,
55                                },
56                                request.id,
57                            ));
58                            if let Err(e) = server.send(error).await {
59                                eprintln!("Error sending error response: {}", e);
60                                break;
61                            }
62                            continue;
63                        }
64
65                        match request.method.as_str() {
66                            "prompts/execute" => {
67                                // 创建响应消息
68                                // Create response message
69                                let response = Message::Response(Response::success(
70                                    json!({
71                                        "content": "Hello from server!",
72                                        "role": "assistant"
73                                    }),
74                                    request.id,
75                                ));
76
77                                // 发送响应
78                                // Send response
79                                if let Err(e) = server.send(response).await {
80                                    eprintln!("Error sending response: {}", e);
81                                    break;
82                                }
83                            }
84                            _ => {
85                                eprintln!("Unknown method: {}", request.method);
86                                let error = Message::Response(Response::error(
87                                    message::ResponseError {
88                                        code: message::error_codes::METHOD_NOT_FOUND,
89                                        message: "Method not found".to_string(),
90                                        data: None,
91                                    },
92                                    request.id,
93                                ));
94                                if let Err(e) = server.send(error).await {
95                                    eprintln!("Error sending error response: {}", e);
96                                    break;
97                                }
98                            }
99                        }
100                    }
101                    _ => {
102                        eprintln!("Unexpected message type");
103                    }
104                }
105            }
106            Err(e) => {
107                eprintln!("Error receiving message: {}", e);
108                break;
109            }
110        }
111    }
112
113    // 关闭服务器
114    // Close server
115    server.close().await?;
116    Ok(())
117}