claude_codes/
protocol.rs

1//! Protocol implementation for JSON lines communication
2
3use crate::error::{Error, Result};
4use crate::messages::{Event, Request, Response};
5use serde::{Deserialize, Serialize};
6use std::io::{BufRead, BufReader, Write};
7use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader as AsyncBufReader};
8use tracing::debug;
9
10/// Protocol handler for Claude Code JSON lines communication
11pub struct Protocol;
12
13impl Protocol {
14    /// Serialize a message to JSON lines format
15    pub fn serialize<T: Serialize>(message: &T) -> Result<String> {
16        let json = serde_json::to_string(message)?;
17        Ok(format!("{}\n", json))
18    }
19
20    /// Deserialize a JSON line into a message
21    pub fn deserialize<T: for<'de> Deserialize<'de>>(line: &str) -> Result<T> {
22        let trimmed = line.trim();
23        if trimmed.is_empty() {
24            return Err(Error::Protocol("Empty line".to_string()));
25        }
26        Ok(serde_json::from_str(trimmed)?)
27    }
28
29    /// Write a message to a synchronous writer
30    pub fn write_sync<W: Write, T: Serialize>(writer: &mut W, message: &T) -> Result<()> {
31        let line = Self::serialize(message)?;
32        debug!("[PROTOCOL] Sending: {}", line.trim());
33        writer.write_all(line.as_bytes())?;
34        writer.flush()?;
35        Ok(())
36    }
37
38    /// Read a message from a synchronous reader
39    pub fn read_sync<R: BufRead, T: for<'de> Deserialize<'de>>(reader: &mut R) -> Result<T> {
40        let mut line = String::new();
41        let bytes_read = reader.read_line(&mut line)?;
42        if bytes_read == 0 {
43            return Err(Error::ConnectionClosed);
44        }
45        debug!("[PROTOCOL] Received: {}", line.trim());
46        Self::deserialize(&line)
47    }
48
49    /// Write a message to an async writer
50    pub async fn write_async<W: AsyncWriteExt + Unpin, T: Serialize>(
51        writer: &mut W,
52        message: &T,
53    ) -> Result<()> {
54        let line = Self::serialize(message)?;
55        debug!("[PROTOCOL] Sending async: {}", line.trim());
56        writer.write_all(line.as_bytes()).await?;
57        writer.flush().await?;
58        Ok(())
59    }
60
61    /// Read a message from an async reader
62    pub async fn read_async<R: AsyncBufReadExt + Unpin, T: for<'de> Deserialize<'de>>(
63        reader: &mut R,
64    ) -> Result<T> {
65        let mut line = String::new();
66        let bytes_read = reader.read_line(&mut line).await?;
67        if bytes_read == 0 {
68            return Err(Error::ConnectionClosed);
69        }
70        debug!("[PROTOCOL] Received async: {}", line.trim());
71        Self::deserialize(&line)
72    }
73}
74
75/// Message envelope for routing different message types
76#[derive(Debug, Clone, Serialize, Deserialize)]
77#[serde(tag = "message_class", rename_all = "snake_case")]
78pub enum MessageEnvelope {
79    Request(Request),
80    Response(Response),
81    Event(Event),
82}
83
84/// Stream processor for handling continuous message streams
85pub struct StreamProcessor<R> {
86    reader: BufReader<R>,
87}
88
89impl<R: std::io::Read> StreamProcessor<R> {
90    /// Create a new stream processor
91    pub fn new(reader: R) -> Self {
92        Self {
93            reader: BufReader::new(reader),
94        }
95    }
96
97    /// Process the next message from the stream
98    pub fn next_message<T: for<'de> Deserialize<'de>>(&mut self) -> Result<T> {
99        Protocol::read_sync(&mut self.reader)
100    }
101
102    /// Process all messages in the stream
103    pub fn process_all<T, F>(&mut self, mut handler: F) -> Result<()>
104    where
105        T: for<'de> Deserialize<'de>,
106        F: FnMut(T) -> Result<()>,
107    {
108        loop {
109            match self.next_message() {
110                Ok(message) => handler(message)?,
111                Err(Error::ConnectionClosed) => break,
112                Err(e) => return Err(e),
113            }
114        }
115        Ok(())
116    }
117}
118
119/// Async stream processor for handling continuous message streams
120pub struct AsyncStreamProcessor<R> {
121    reader: AsyncBufReader<R>,
122}
123
124impl<R: tokio::io::AsyncRead + Unpin> AsyncStreamProcessor<R> {
125    /// Create a new async stream processor
126    pub fn new(reader: R) -> Self {
127        Self {
128            reader: AsyncBufReader::new(reader),
129        }
130    }
131
132    /// Process the next message from the stream
133    pub async fn next_message<T: for<'de> Deserialize<'de>>(&mut self) -> Result<T> {
134        Protocol::read_async(&mut self.reader).await
135    }
136
137    /// Process all messages in the stream
138    pub async fn process_all<T, F, Fut>(&mut self, mut handler: F) -> Result<()>
139    where
140        T: for<'de> Deserialize<'de>,
141        F: FnMut(T) -> Fut,
142        Fut: std::future::Future<Output = Result<()>>,
143    {
144        loop {
145            match self.next_message().await {
146                Ok(message) => handler(message).await?,
147                Err(Error::ConnectionClosed) => break,
148                Err(e) => return Err(e),
149            }
150        }
151        Ok(())
152    }
153}
154
155#[cfg(test)]
156mod tests {
157    use super::*;
158    use crate::messages::*;
159
160    #[test]
161    fn test_serialize_deserialize() {
162        let request = Request {
163            message_type: "request".to_string(),
164            id: "test-123".to_string(),
165            session_id: Some("session-456".to_string()),
166            payload: RequestPayload::Initialize(InitializeRequest {
167                working_directory: Some("/home/user".to_string()),
168                environment: None,
169                capabilities: None,
170            }),
171            metadata: None,
172        };
173
174        let serialized = Protocol::serialize(&request).unwrap();
175        assert!(serialized.ends_with('\n'));
176
177        let deserialized: Request = Protocol::deserialize(&serialized).unwrap();
178        assert_eq!(deserialized.id, request.id);
179    }
180
181    #[test]
182    fn test_empty_line_error() {
183        let result: Result<Request> = Protocol::deserialize("");
184        assert!(result.is_err());
185    }
186
187    #[test]
188    fn test_invalid_json_error() {
189        let result: Result<Request> = Protocol::deserialize("not valid json");
190        assert!(result.is_err());
191    }
192}