1use crate::error::{Error, Result};
4use crate::messages::{Event, Request, Response};
5use serde::{Deserialize, Serialize};
6use std::io::{BufRead, BufReader, Write};
7use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader as AsyncBufReader};
8
9pub struct Protocol;
11
12impl Protocol {
13 pub fn serialize<T: Serialize>(message: &T) -> Result<String> {
15 let json = serde_json::to_string(message)?;
16 Ok(format!("{}\n", json))
17 }
18
19 pub fn deserialize<T: for<'de> Deserialize<'de>>(line: &str) -> Result<T> {
21 let trimmed = line.trim();
22 if trimmed.is_empty() {
23 return Err(Error::Protocol("Empty line".to_string()));
24 }
25 Ok(serde_json::from_str(trimmed)?)
26 }
27
28 pub fn write_sync<W: Write, T: Serialize>(writer: &mut W, message: &T) -> Result<()> {
30 let line = Self::serialize(message)?;
31 writer.write_all(line.as_bytes())?;
32 writer.flush()?;
33 Ok(())
34 }
35
36 pub fn read_sync<R: BufRead, T: for<'de> Deserialize<'de>>(reader: &mut R) -> Result<T> {
38 let mut line = String::new();
39 let bytes_read = reader.read_line(&mut line)?;
40 if bytes_read == 0 {
41 return Err(Error::ConnectionClosed);
42 }
43 Self::deserialize(&line)
44 }
45
46 pub async fn write_async<W: AsyncWriteExt + Unpin, T: Serialize>(
48 writer: &mut W,
49 message: &T,
50 ) -> Result<()> {
51 let line = Self::serialize(message)?;
52 writer.write_all(line.as_bytes()).await?;
53 writer.flush().await?;
54 Ok(())
55 }
56
57 pub async fn read_async<R: AsyncBufReadExt + Unpin, T: for<'de> Deserialize<'de>>(
59 reader: &mut R,
60 ) -> Result<T> {
61 let mut line = String::new();
62 let bytes_read = reader.read_line(&mut line).await?;
63 if bytes_read == 0 {
64 return Err(Error::ConnectionClosed);
65 }
66 Self::deserialize(&line)
67 }
68}
69
70#[derive(Debug, Clone, Serialize, Deserialize)]
72#[serde(tag = "message_class", rename_all = "snake_case")]
73pub enum MessageEnvelope {
74 Request(Request),
75 Response(Response),
76 Event(Event),
77}
78
79pub struct StreamProcessor<R> {
81 reader: BufReader<R>,
82}
83
84impl<R: std::io::Read> StreamProcessor<R> {
85 pub fn new(reader: R) -> Self {
87 Self {
88 reader: BufReader::new(reader),
89 }
90 }
91
92 pub fn next_message<T: for<'de> Deserialize<'de>>(&mut self) -> Result<T> {
94 Protocol::read_sync(&mut self.reader)
95 }
96
97 pub fn process_all<T, F>(&mut self, mut handler: F) -> Result<()>
99 where
100 T: for<'de> Deserialize<'de>,
101 F: FnMut(T) -> Result<()>,
102 {
103 loop {
104 match self.next_message() {
105 Ok(message) => handler(message)?,
106 Err(Error::ConnectionClosed) => break,
107 Err(e) => return Err(e),
108 }
109 }
110 Ok(())
111 }
112}
113
114pub struct AsyncStreamProcessor<R> {
116 reader: AsyncBufReader<R>,
117}
118
119impl<R: tokio::io::AsyncRead + Unpin> AsyncStreamProcessor<R> {
120 pub fn new(reader: R) -> Self {
122 Self {
123 reader: AsyncBufReader::new(reader),
124 }
125 }
126
127 pub async fn next_message<T: for<'de> Deserialize<'de>>(&mut self) -> Result<T> {
129 Protocol::read_async(&mut self.reader).await
130 }
131
132 pub async fn process_all<T, F, Fut>(&mut self, mut handler: F) -> Result<()>
134 where
135 T: for<'de> Deserialize<'de>,
136 F: FnMut(T) -> Fut,
137 Fut: std::future::Future<Output = Result<()>>,
138 {
139 loop {
140 match self.next_message().await {
141 Ok(message) => handler(message).await?,
142 Err(Error::ConnectionClosed) => break,
143 Err(e) => return Err(e),
144 }
145 }
146 Ok(())
147 }
148}
149
150#[cfg(test)]
151mod tests {
152 use super::*;
153 use crate::messages::*;
154
155 #[test]
156 fn test_serialize_deserialize() {
157 let request = Request {
158 message_type: "request".to_string(),
159 id: "test-123".to_string(),
160 session_id: Some("session-456".to_string()),
161 payload: RequestPayload::Initialize(InitializeRequest {
162 working_directory: Some("/home/user".to_string()),
163 environment: None,
164 capabilities: None,
165 }),
166 metadata: None,
167 };
168
169 let serialized = Protocol::serialize(&request).unwrap();
170 assert!(serialized.ends_with('\n'));
171
172 let deserialized: Request = Protocol::deserialize(&serialized).unwrap();
173 assert_eq!(deserialized.id, request.id);
174 }
175
176 #[test]
177 fn test_empty_line_error() {
178 let result: Result<Request> = Protocol::deserialize("");
179 assert!(result.is_err());
180 }
181
182 #[test]
183 fn test_invalid_json_error() {
184 let result: Result<Request> = Protocol::deserialize("not valid json");
185 assert!(result.is_err());
186 }
187}