1use crate::error::{Error, Result};
4use crate::messages::{Event, Request, Response};
5use serde::{Deserialize, Serialize};
6use std::io::{BufRead, BufReader, Write};
7use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader as AsyncBufReader};
8use tracing::debug;
9
10pub struct Protocol;
12
13impl Protocol {
14 pub fn serialize<T: Serialize>(message: &T) -> Result<String> {
16 let json = serde_json::to_string(message)?;
17 Ok(format!("{}\n", json))
18 }
19
20 pub fn deserialize<T: for<'de> Deserialize<'de>>(line: &str) -> Result<T> {
22 let trimmed = line.trim();
23 if trimmed.is_empty() {
24 return Err(Error::Protocol("Empty line".to_string()));
25 }
26 Ok(serde_json::from_str(trimmed)?)
27 }
28
29 pub fn write_sync<W: Write, T: Serialize>(writer: &mut W, message: &T) -> Result<()> {
31 let line = Self::serialize(message)?;
32 debug!("[PROTOCOL] Sending: {}", line.trim());
33 writer.write_all(line.as_bytes())?;
34 writer.flush()?;
35 Ok(())
36 }
37
38 pub fn read_sync<R: BufRead, T: for<'de> Deserialize<'de>>(reader: &mut R) -> Result<T> {
40 let mut line = String::new();
41 let bytes_read = reader.read_line(&mut line)?;
42 if bytes_read == 0 {
43 return Err(Error::ConnectionClosed);
44 }
45 debug!("[PROTOCOL] Received: {}", line.trim());
46 Self::deserialize(&line)
47 }
48
49 pub async fn write_async<W: AsyncWriteExt + Unpin, T: Serialize>(
51 writer: &mut W,
52 message: &T,
53 ) -> Result<()> {
54 let line = Self::serialize(message)?;
55 debug!("[PROTOCOL] Sending async: {}", line.trim());
56 writer.write_all(line.as_bytes()).await?;
57 writer.flush().await?;
58 Ok(())
59 }
60
61 pub async fn read_async<R: AsyncBufReadExt + Unpin, T: for<'de> Deserialize<'de>>(
63 reader: &mut R,
64 ) -> Result<T> {
65 let mut line = String::new();
66 let bytes_read = reader.read_line(&mut line).await?;
67 if bytes_read == 0 {
68 return Err(Error::ConnectionClosed);
69 }
70 debug!("[PROTOCOL] Received async: {}", line.trim());
71 Self::deserialize(&line)
72 }
73}
74
75#[derive(Debug, Clone, Serialize, Deserialize)]
77#[serde(tag = "message_class", rename_all = "snake_case")]
78pub enum MessageEnvelope {
79 Request(Request),
80 Response(Response),
81 Event(Event),
82}
83
84pub struct StreamProcessor<R> {
86 reader: BufReader<R>,
87}
88
89impl<R: std::io::Read> StreamProcessor<R> {
90 pub fn new(reader: R) -> Self {
92 Self {
93 reader: BufReader::new(reader),
94 }
95 }
96
97 pub fn next_message<T: for<'de> Deserialize<'de>>(&mut self) -> Result<T> {
99 Protocol::read_sync(&mut self.reader)
100 }
101
102 pub fn process_all<T, F>(&mut self, mut handler: F) -> Result<()>
104 where
105 T: for<'de> Deserialize<'de>,
106 F: FnMut(T) -> Result<()>,
107 {
108 loop {
109 match self.next_message() {
110 Ok(message) => handler(message)?,
111 Err(Error::ConnectionClosed) => break,
112 Err(e) => return Err(e),
113 }
114 }
115 Ok(())
116 }
117}
118
119pub struct AsyncStreamProcessor<R> {
121 reader: AsyncBufReader<R>,
122}
123
124impl<R: tokio::io::AsyncRead + Unpin> AsyncStreamProcessor<R> {
125 pub fn new(reader: R) -> Self {
127 Self {
128 reader: AsyncBufReader::new(reader),
129 }
130 }
131
132 pub async fn next_message<T: for<'de> Deserialize<'de>>(&mut self) -> Result<T> {
134 Protocol::read_async(&mut self.reader).await
135 }
136
137 pub async fn process_all<T, F, Fut>(&mut self, mut handler: F) -> Result<()>
139 where
140 T: for<'de> Deserialize<'de>,
141 F: FnMut(T) -> Fut,
142 Fut: std::future::Future<Output = Result<()>>,
143 {
144 loop {
145 match self.next_message().await {
146 Ok(message) => handler(message).await?,
147 Err(Error::ConnectionClosed) => break,
148 Err(e) => return Err(e),
149 }
150 }
151 Ok(())
152 }
153}
154
155#[cfg(test)]
156mod tests {
157 use super::*;
158 use crate::messages::*;
159
160 #[test]
161 fn test_serialize_deserialize() {
162 let request = Request {
163 message_type: "request".to_string(),
164 id: "test-123".to_string(),
165 session_id: Some("session-456".to_string()),
166 payload: RequestPayload::Initialize(InitializeRequest {
167 working_directory: Some("/home/user".to_string()),
168 environment: None,
169 capabilities: None,
170 }),
171 metadata: None,
172 };
173
174 let serialized = Protocol::serialize(&request).unwrap();
175 assert!(serialized.ends_with('\n'));
176
177 let deserialized: Request = Protocol::deserialize(&serialized).unwrap();
178 assert_eq!(deserialized.id, request.id);
179 }
180
181 #[test]
182 fn test_empty_line_error() {
183 let result: Result<Request> = Protocol::deserialize("");
184 assert!(result.is_err());
185 }
186
187 #[test]
188 fn test_invalid_json_error() {
189 let result: Result<Request> = Protocol::deserialize("not valid json");
190 assert!(result.is_err());
191 }
192}