1use anyhow::Result;
2use serde::{Deserialize, Serialize};
3use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
4use tokio::process::{ChildStdin, ChildStdout};
5use tracing::{debug, info};
6
7#[derive(Debug, Clone, Serialize, Deserialize)]
9pub struct ProtocolMessage {
10 pub id: Option<u64>,
11 pub jsonrpc: String,
12 pub method: Option<String>,
13 pub params: Option<serde_json::Value>,
14 pub result: Option<serde_json::Value>,
15 pub error: Option<ProtocolError>,
16}
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct ProtocolError {
21 pub code: i64,
22 pub message: String,
23 pub data: Option<serde_json::Value>,
24}
25
26pub struct StdinTransport {
28 stdin: ChildStdin,
29 stdout: ChildStdout,
30}
31
32impl StdinTransport {
33 pub fn new(stdin: ChildStdin, stdout: ChildStdout) -> Self {
34 StdinTransport { stdin, stdout }
35 }
36
37 pub async fn send(&mut self, message: &ProtocolMessage) -> Result<()> {
39 let json = serde_json::to_string(message)?;
40 debug!("Sending to CLI: {}", json);
41
42 self.stdin
43 .write_all(format!("{}\n", json).as_bytes())
44 .await?;
45 self.stdin.flush().await?;
46
47 Ok(())
48 }
49
50 pub async fn receive(&mut self) -> Result<Option<ProtocolMessage>> {
52 let mut reader = BufReader::new(&mut self.stdout);
53 let mut line = String::new();
54
55 let bytes_read = reader.read_line(&mut line).await?;
56
57 if bytes_read == 0 {
58 return Ok(None); }
60
61 debug!("Received from CLI: {}", line.trim());
62
63 let message: ProtocolMessage = serde_json::from_str(&line)?;
64 Ok(Some(message))
65 }
66
67 pub async fn close(&mut self) -> Result<()> {
69 info!("Closing stdin transport");
70 self.stdin.shutdown().await?;
71 Ok(())
72 }
73}
74
75pub fn create_request(id: u64, method: &str, params: Option<serde_json::Value>) -> ProtocolMessage {
77 ProtocolMessage {
78 id: Some(id),
79 jsonrpc: "2.0".to_string(),
80 method: Some(method.to_string()),
81 params,
82 result: None,
83 error: None,
84 }
85}
86
87pub fn create_response(id: u64, result: serde_json::Value) -> ProtocolMessage {
89 ProtocolMessage {
90 id: Some(id),
91 jsonrpc: "2.0".to_string(),
92 method: None,
93 params: None,
94 result: Some(result),
95 error: None,
96 }
97}
98
99pub fn create_error(id: u64, code: i64, message: &str) -> ProtocolMessage {
101 ProtocolMessage {
102 id: Some(id),
103 jsonrpc: "2.0".to_string(),
104 method: None,
105 params: None,
106 result: None,
107 error: Some(ProtocolError {
108 code,
109 message: message.to_string(),
110 data: None,
111 }),
112 }
113}
114
115pub fn create_notification(method: &str, params: Option<serde_json::Value>) -> ProtocolMessage {
117 ProtocolMessage {
118 id: None,
119 jsonrpc: "2.0".to_string(),
120 method: Some(method.to_string()),
121 params,
122 result: None,
123 error: None,
124 }
125}
126
127#[cfg(test)]
128mod tests {
129 use super::*;
130
131 #[test]
132 fn test_create_request() {
133 let msg = create_request(1, "initialize", Some(serde_json::json!({"version": "1.0"})));
134
135 assert_eq!(msg.id, Some(1));
136 assert_eq!(msg.jsonrpc, "2.0");
137 assert_eq!(msg.method, Some("initialize".to_string()));
138 assert!(msg.params.is_some());
139 assert!(msg.result.is_none());
140 assert!(msg.error.is_none());
141 }
142
143 #[test]
144 fn test_create_response() {
145 let msg = create_response(1, serde_json::json!({"status": "ok"}));
146
147 assert_eq!(msg.id, Some(1));
148 assert_eq!(msg.jsonrpc, "2.0");
149 assert!(msg.method.is_none());
150 assert!(msg.result.is_some());
151 assert!(msg.error.is_none());
152 }
153
154 #[test]
155 fn test_create_error() {
156 let msg = create_error(1, -32600, "Invalid Request");
157
158 assert_eq!(msg.id, Some(1));
159 assert_eq!(msg.jsonrpc, "2.0");
160 assert!(msg.method.is_none());
161 assert!(msg.result.is_none());
162 assert!(msg.error.is_some());
163
164 let error = msg.error.unwrap();
165 assert_eq!(error.code, -32600);
166 assert_eq!(error.message, "Invalid Request");
167 }
168
169 #[test]
170 fn test_create_notification() {
171 let msg = create_notification("update", Some(serde_json::json!({"progress": 50})));
172
173 assert_eq!(msg.id, None);
174 assert_eq!(msg.jsonrpc, "2.0");
175 assert_eq!(msg.method, Some("update".to_string()));
176 assert!(msg.params.is_some());
177 assert!(msg.result.is_none());
178 assert!(msg.error.is_none());
179 }
180
181 #[test]
182 fn test_protocol_message_serialization() {
183 let msg = create_request(42, "test_method", None);
184 let json = serde_json::to_string(&msg).unwrap();
185
186 assert!(json.contains("\"id\":42"));
187 assert!(json.contains("\"jsonrpc\":\"2.0\""));
188 assert!(json.contains("\"method\":\"test_method\""));
189 }
190
191 #[test]
192 fn test_protocol_message_deserialization() {
193 let json = r#"{
194 "id": 1,
195 "jsonrpc": "2.0",
196 "method": "test",
197 "params": {"key": "value"}
198 }"#;
199
200 let msg: ProtocolMessage = serde_json::from_str(json).unwrap();
201
202 assert_eq!(msg.id, Some(1));
203 assert_eq!(msg.jsonrpc, "2.0");
204 assert_eq!(msg.method, Some("test".to_string()));
205 assert!(msg.params.is_some());
206 }
207
208 #[test]
209 fn test_protocol_error_structure() {
210 let error = ProtocolError {
211 code: -32601,
212 message: "Method not found".to_string(),
213 data: Some(serde_json::json!({"details": "unknown method"})),
214 };
215
216 assert_eq!(error.code, -32601);
217 assert_eq!(error.message, "Method not found");
218 assert!(error.data.is_some());
219 }
220
221 #[test]
222 fn test_protocol_message_debug_format() {
223 let msg = create_request(1, "init", None);
224 let debug_str = format!("{:?}", msg);
225
226 assert!(debug_str.contains("ProtocolMessage"));
227 assert!(debug_str.contains("jsonrpc"));
228 }
229}