stdio_server/stdio_server.rs
1use mcprotocol_rs::message;
2use mcprotocol_rs::{
3 protocol::{Message, Response},
4 transport::{ServerTransportFactory, TransportConfig, TransportType},
5 Result,
6};
7use serde_json::json;
8use std::collections::HashSet;
9
10#[tokio::main]
11async fn main() -> Result<()> {
12 // 跟踪会话中使用的请求 ID
13 // Track request IDs used in the session
14 let mut session_ids = HashSet::new();
15
16 // 配置 Stdio 服务器
17 // Configure Stdio server
18 let config = TransportConfig {
19 transport_type: TransportType::Stdio {
20 server_path: None,
21 server_args: None,
22 },
23 parameters: None,
24 };
25
26 // 创建服务器实例
27 // Create server instance
28 let factory = ServerTransportFactory;
29 let mut server = factory.create(config)?;
30
31 // 初始化服务器
32 // Initialize server
33 server.initialize().await?;
34 eprintln!("Server initialized and ready to receive messages...");
35
36 // 持续接收和处理消息
37 // Continuously receive and process messages
38 loop {
39 match server.receive().await {
40 Ok(message) => {
41 eprintln!("Received message: {:?}", message);
42
43 // 根据消息类型处理
44 // Process messages based on type
45 match message {
46 Message::Request(request) => {
47 // 验证请求 ID 的唯一性
48 // Validate request ID uniqueness
49 if !request.validate_id_uniqueness(&mut session_ids) {
50 let error = Message::Response(Response::error(
51 message::ResponseError {
52 code: message::error_codes::INVALID_REQUEST,
53 message: "Request ID has already been used".to_string(),
54 data: None,
55 },
56 request.id,
57 ));
58 if let Err(e) = server.send(error).await {
59 eprintln!("Error sending error response: {}", e);
60 break;
61 }
62 continue;
63 }
64
65 match request.method.as_str() {
66 "prompts/execute" => {
67 // 创建响应消息
68 // Create response message
69 let response = Message::Response(Response::success(
70 json!({
71 "content": "Hello from server!",
72 "role": "assistant"
73 }),
74 request.id,
75 ));
76
77 // 发送响应
78 // Send response
79 if let Err(e) = server.send(response).await {
80 eprintln!("Error sending response: {}", e);
81 break;
82 }
83 }
84 _ => {
85 eprintln!("Unknown method: {}", request.method);
86 let error = Message::Response(Response::error(
87 message::ResponseError {
88 code: message::error_codes::METHOD_NOT_FOUND,
89 message: "Method not found".to_string(),
90 data: None,
91 },
92 request.id,
93 ));
94 if let Err(e) = server.send(error).await {
95 eprintln!("Error sending error response: {}", e);
96 break;
97 }
98 }
99 }
100 }
101 _ => {
102 eprintln!("Unexpected message type");
103 }
104 }
105 }
106 Err(e) => {
107 eprintln!("Error receiving message: {}", e);
108 break;
109 }
110 }
111 }
112
113 // 关闭服务器
114 // Close server
115 server.close().await?;
116 Ok(())
117}