1use crate::error::{Error, Result};
28use crate::messages::{Event, Request, Response};
29use serde::{Deserialize, Serialize};
30use std::io::{BufRead, BufReader, Write};
31use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader as AsyncBufReader};
32use tracing::debug;
33
34pub struct Protocol;
36
37impl Protocol {
38 pub fn serialize<T: Serialize>(message: &T) -> Result<String> {
40 let json = serde_json::to_string(message)?;
41 Ok(format!("{}\n", json))
42 }
43
44 pub fn deserialize<T: for<'de> Deserialize<'de>>(line: &str) -> Result<T> {
46 let trimmed = line.trim();
47 if trimmed.is_empty() {
48 return Err(Error::Protocol("Empty line".to_string()));
49 }
50 Ok(serde_json::from_str(trimmed)?)
51 }
52
53 pub fn write_sync<W: Write, T: Serialize>(writer: &mut W, message: &T) -> Result<()> {
55 let line = Self::serialize(message)?;
56 debug!("[PROTOCOL] Sending: {}", line.trim());
57 writer.write_all(line.as_bytes())?;
58 writer.flush()?;
59 Ok(())
60 }
61
62 pub fn read_sync<R: BufRead, T: for<'de> Deserialize<'de>>(reader: &mut R) -> Result<T> {
64 let mut line = String::new();
65 let bytes_read = reader.read_line(&mut line)?;
66 if bytes_read == 0 {
67 return Err(Error::ConnectionClosed);
68 }
69 debug!("[PROTOCOL] Received: {}", line.trim());
70 Self::deserialize(&line)
71 }
72
73 pub async fn write_async<W: AsyncWriteExt + Unpin, T: Serialize>(
75 writer: &mut W,
76 message: &T,
77 ) -> Result<()> {
78 let line = Self::serialize(message)?;
79 debug!("[PROTOCOL] Sending async: {}", line.trim());
80 writer.write_all(line.as_bytes()).await?;
81 writer.flush().await?;
82 Ok(())
83 }
84
85 pub async fn read_async<R: AsyncBufReadExt + Unpin, T: for<'de> Deserialize<'de>>(
87 reader: &mut R,
88 ) -> Result<T> {
89 let mut line = String::new();
90 let bytes_read = reader.read_line(&mut line).await?;
91 if bytes_read == 0 {
92 return Err(Error::ConnectionClosed);
93 }
94 debug!("[PROTOCOL] Received async: {}", line.trim());
95 Self::deserialize(&line)
96 }
97}
98
99#[derive(Debug, Clone, Serialize, Deserialize)]
101#[serde(tag = "message_class", rename_all = "snake_case")]
102pub enum MessageEnvelope {
103 Request(Request),
104 Response(Response),
105 Event(Event),
106}
107
108pub struct StreamProcessor<R> {
110 reader: BufReader<R>,
111}
112
113impl<R: std::io::Read> StreamProcessor<R> {
114 pub fn new(reader: R) -> Self {
116 Self {
117 reader: BufReader::new(reader),
118 }
119 }
120
121 pub fn next_message<T: for<'de> Deserialize<'de>>(&mut self) -> Result<T> {
123 Protocol::read_sync(&mut self.reader)
124 }
125
126 pub fn process_all<T, F>(&mut self, mut handler: F) -> Result<()>
128 where
129 T: for<'de> Deserialize<'de>,
130 F: FnMut(T) -> Result<()>,
131 {
132 loop {
133 match self.next_message() {
134 Ok(message) => handler(message)?,
135 Err(Error::ConnectionClosed) => break,
136 Err(e) => return Err(e),
137 }
138 }
139 Ok(())
140 }
141}
142
143pub struct AsyncStreamProcessor<R> {
145 reader: AsyncBufReader<R>,
146}
147
148impl<R: tokio::io::AsyncRead + Unpin> AsyncStreamProcessor<R> {
149 pub fn new(reader: R) -> Self {
151 Self {
152 reader: AsyncBufReader::new(reader),
153 }
154 }
155
156 pub async fn next_message<T: for<'de> Deserialize<'de>>(&mut self) -> Result<T> {
158 Protocol::read_async(&mut self.reader).await
159 }
160
161 pub async fn process_all<T, F, Fut>(&mut self, mut handler: F) -> Result<()>
163 where
164 T: for<'de> Deserialize<'de>,
165 F: FnMut(T) -> Fut,
166 Fut: std::future::Future<Output = Result<()>>,
167 {
168 loop {
169 match self.next_message().await {
170 Ok(message) => handler(message).await?,
171 Err(Error::ConnectionClosed) => break,
172 Err(e) => return Err(e),
173 }
174 }
175 Ok(())
176 }
177}
178
179#[cfg(test)]
180mod tests {
181 use super::*;
182 use crate::messages::*;
183
184 #[test]
185 fn test_serialize_deserialize() {
186 let request = Request {
187 message_type: "request".to_string(),
188 id: "test-123".to_string(),
189 session_id: Some("session-456".to_string()),
190 payload: RequestPayload::Initialize(InitializeRequest {
191 working_directory: Some("/home/user".to_string()),
192 environment: None,
193 capabilities: None,
194 }),
195 metadata: None,
196 };
197
198 let serialized = Protocol::serialize(&request).unwrap();
199 assert!(serialized.ends_with('\n'));
200
201 let deserialized: Request = Protocol::deserialize(&serialized).unwrap();
202 assert_eq!(deserialized.id, request.id);
203 }
204
205 #[test]
206 fn test_empty_line_error() {
207 let result: Result<Request> = Protocol::deserialize("");
208 assert!(result.is_err());
209 }
210
211 #[test]
212 fn test_invalid_json_error() {
213 let result: Result<Request> = Protocol::deserialize("not valid json");
214 assert!(result.is_err());
215 }
216}