claude_codes/
protocol.rs

1//! Protocol implementation for JSON lines communication
2
3use crate::error::{Error, Result};
4use crate::messages::{Event, Request, Response};
5use serde::{Deserialize, Serialize};
6use std::io::{BufRead, BufReader, Write};
7use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader as AsyncBufReader};
8
9/// Protocol handler for Claude Code JSON lines communication
10pub struct Protocol;
11
12impl Protocol {
13    /// Serialize a message to JSON lines format
14    pub fn serialize<T: Serialize>(message: &T) -> Result<String> {
15        let json = serde_json::to_string(message)?;
16        Ok(format!("{}\n", json))
17    }
18
19    /// Deserialize a JSON line into a message
20    pub fn deserialize<T: for<'de> Deserialize<'de>>(line: &str) -> Result<T> {
21        let trimmed = line.trim();
22        if trimmed.is_empty() {
23            return Err(Error::Protocol("Empty line".to_string()));
24        }
25        Ok(serde_json::from_str(trimmed)?)
26    }
27
28    /// Write a message to a synchronous writer
29    pub fn write_sync<W: Write, T: Serialize>(writer: &mut W, message: &T) -> Result<()> {
30        let line = Self::serialize(message)?;
31        writer.write_all(line.as_bytes())?;
32        writer.flush()?;
33        Ok(())
34    }
35
36    /// Read a message from a synchronous reader
37    pub fn read_sync<R: BufRead, T: for<'de> Deserialize<'de>>(reader: &mut R) -> Result<T> {
38        let mut line = String::new();
39        let bytes_read = reader.read_line(&mut line)?;
40        if bytes_read == 0 {
41            return Err(Error::ConnectionClosed);
42        }
43        Self::deserialize(&line)
44    }
45
46    /// Write a message to an async writer
47    pub async fn write_async<W: AsyncWriteExt + Unpin, T: Serialize>(
48        writer: &mut W,
49        message: &T,
50    ) -> Result<()> {
51        let line = Self::serialize(message)?;
52        writer.write_all(line.as_bytes()).await?;
53        writer.flush().await?;
54        Ok(())
55    }
56
57    /// Read a message from an async reader
58    pub async fn read_async<R: AsyncBufReadExt + Unpin, T: for<'de> Deserialize<'de>>(
59        reader: &mut R,
60    ) -> Result<T> {
61        let mut line = String::new();
62        let bytes_read = reader.read_line(&mut line).await?;
63        if bytes_read == 0 {
64            return Err(Error::ConnectionClosed);
65        }
66        Self::deserialize(&line)
67    }
68}
69
70/// Message envelope for routing different message types
71#[derive(Debug, Clone, Serialize, Deserialize)]
72#[serde(tag = "message_class", rename_all = "snake_case")]
73pub enum MessageEnvelope {
74    Request(Request),
75    Response(Response),
76    Event(Event),
77}
78
79/// Stream processor for handling continuous message streams
80pub struct StreamProcessor<R> {
81    reader: BufReader<R>,
82}
83
84impl<R: std::io::Read> StreamProcessor<R> {
85    /// Create a new stream processor
86    pub fn new(reader: R) -> Self {
87        Self {
88            reader: BufReader::new(reader),
89        }
90    }
91
92    /// Process the next message from the stream
93    pub fn next_message<T: for<'de> Deserialize<'de>>(&mut self) -> Result<T> {
94        Protocol::read_sync(&mut self.reader)
95    }
96
97    /// Process all messages in the stream
98    pub fn process_all<T, F>(&mut self, mut handler: F) -> Result<()>
99    where
100        T: for<'de> Deserialize<'de>,
101        F: FnMut(T) -> Result<()>,
102    {
103        loop {
104            match self.next_message() {
105                Ok(message) => handler(message)?,
106                Err(Error::ConnectionClosed) => break,
107                Err(e) => return Err(e),
108            }
109        }
110        Ok(())
111    }
112}
113
114/// Async stream processor for handling continuous message streams
115pub struct AsyncStreamProcessor<R> {
116    reader: AsyncBufReader<R>,
117}
118
119impl<R: tokio::io::AsyncRead + Unpin> AsyncStreamProcessor<R> {
120    /// Create a new async stream processor
121    pub fn new(reader: R) -> Self {
122        Self {
123            reader: AsyncBufReader::new(reader),
124        }
125    }
126
127    /// Process the next message from the stream
128    pub async fn next_message<T: for<'de> Deserialize<'de>>(&mut self) -> Result<T> {
129        Protocol::read_async(&mut self.reader).await
130    }
131
132    /// Process all messages in the stream
133    pub async fn process_all<T, F, Fut>(&mut self, mut handler: F) -> Result<()>
134    where
135        T: for<'de> Deserialize<'de>,
136        F: FnMut(T) -> Fut,
137        Fut: std::future::Future<Output = Result<()>>,
138    {
139        loop {
140            match self.next_message().await {
141                Ok(message) => handler(message).await?,
142                Err(Error::ConnectionClosed) => break,
143                Err(e) => return Err(e),
144            }
145        }
146        Ok(())
147    }
148}
149
150#[cfg(test)]
151mod tests {
152    use super::*;
153    use crate::messages::*;
154
155    #[test]
156    fn test_serialize_deserialize() {
157        let request = Request {
158            message_type: "request".to_string(),
159            id: "test-123".to_string(),
160            session_id: Some("session-456".to_string()),
161            payload: RequestPayload::Initialize(InitializeRequest {
162                working_directory: Some("/home/user".to_string()),
163                environment: None,
164                capabilities: None,
165            }),
166            metadata: None,
167        };
168
169        let serialized = Protocol::serialize(&request).unwrap();
170        assert!(serialized.ends_with('\n'));
171
172        let deserialized: Request = Protocol::deserialize(&serialized).unwrap();
173        assert_eq!(deserialized.id, request.id);
174    }
175
176    #[test]
177    fn test_empty_line_error() {
178        let result: Result<Request> = Protocol::deserialize("");
179        assert!(result.is_err());
180    }
181
182    #[test]
183    fn test_invalid_json_error() {
184        let result: Result<Request> = Protocol::deserialize("not valid json");
185        assert!(result.is_err());
186    }
187}