Skip to main content

copilot_sdk/
transport.rs

1// Copyright (c) 2026 Elias Bachaalany
2// SPDX-License-Identifier: MIT
3
4//! Transport layer for the Copilot SDK.
5//!
6//! Provides async byte I/O and Content-Length message framing (LSP-style).
7
8use crate::error::{CopilotError, Result};
9use std::future::Future;
10use std::pin::Pin;
11use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, BufReader};
12
13// =============================================================================
14// Transport Trait
15// =============================================================================
16
17/// Async transport for raw byte I/O.
18///
19/// Implementations provide the underlying byte stream (stdio pipes, TCP sockets, etc.)
20/// The transport is responsible for reading/writing raw bytes; framing is handled
21/// separately by `MessageFramer`.
22pub trait Transport: Send + Sync {
23    /// Read up to `buf.len()` bytes into buffer.
24    /// Returns the number of bytes read (0 indicates EOF).
25    fn read<'a>(
26        &'a mut self,
27        buf: &'a mut [u8],
28    ) -> Pin<Box<dyn Future<Output = Result<usize>> + Send + 'a>>;
29
30    /// Write all bytes to the transport.
31    fn write<'a>(
32        &'a mut self,
33        data: &'a [u8],
34    ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>;
35
36    /// Close the transport.
37    fn close(&mut self) -> Pin<Box<dyn Future<Output = Result<()>> + Send + '_>>;
38
39    /// Check if transport is open.
40    fn is_open(&self) -> bool;
41}
42
43// =============================================================================
44// Stdio Transport
45// =============================================================================
46
47/// Transport using stdin/stdout of a child process.
48pub struct StdioTransport {
49    stdin: tokio::process::ChildStdin,
50    stdout: BufReader<tokio::process::ChildStdout>,
51    open: bool,
52}
53
54impl StdioTransport {
55    /// Create a new stdio transport from child process handles.
56    pub fn new(stdin: tokio::process::ChildStdin, stdout: tokio::process::ChildStdout) -> Self {
57        Self {
58            stdin,
59            stdout: BufReader::new(stdout),
60            open: true,
61        }
62    }
63
64    /// Split into separate read and write handles.
65    pub fn split(self) -> (tokio::process::ChildStdin, tokio::process::ChildStdout) {
66        (self.stdin, self.stdout.into_inner())
67    }
68}
69
70impl Transport for StdioTransport {
71    fn read<'a>(
72        &'a mut self,
73        buf: &'a mut [u8],
74    ) -> Pin<Box<dyn Future<Output = Result<usize>> + Send + 'a>> {
75        Box::pin(async move {
76            if !self.open {
77                return Err(CopilotError::ConnectionClosed);
78            }
79            self.stdout.read(buf).await.map_err(CopilotError::Transport)
80        })
81    }
82
83    fn write<'a>(
84        &'a mut self,
85        data: &'a [u8],
86    ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>> {
87        Box::pin(async move {
88            if !self.open {
89                return Err(CopilotError::ConnectionClosed);
90            }
91            self.stdin
92                .write_all(data)
93                .await
94                .map_err(CopilotError::Transport)?;
95            // Flush may fail on Windows with pipes, but data is still written
96            let _ = self.stdin.flush().await;
97            Ok(())
98        })
99    }
100
101    fn close(&mut self) -> Pin<Box<dyn Future<Output = Result<()>> + Send + '_>> {
102        Box::pin(async move {
103            self.open = false;
104            Ok(())
105        })
106    }
107
108    fn is_open(&self) -> bool {
109        self.open
110    }
111}
112
113// =============================================================================
114// Content-Length Message Framer (LSP-style)
115// =============================================================================
116
117/// Handles Content-Length header framing for JSON-RPC messages.
118///
119/// Message format:
120/// ```text
121/// Content-Length: <length>\r\n
122/// \r\n
123/// <json-rpc-message>
124/// ```
125///
126/// This is the standard LSP (Language Server Protocol) framing used by
127/// StreamJsonRpc's HeaderDelimitedMessageHandler.
128pub struct MessageFramer<T: Transport> {
129    transport: T,
130    buffer: Vec<u8>,
131    buffer_pos: usize,
132    buffer_len: usize,
133}
134
135/// Message writer for framing outgoing messages.
136pub struct MessageWriter<W> {
137    writer: W,
138}
139
140impl<W> MessageWriter<W>
141where
142    W: AsyncWrite + Unpin + Send,
143{
144    /// Create a new message writer.
145    pub fn new(writer: W) -> Self {
146        Self { writer }
147    }
148
149    /// Write a message with Content-Length framing.
150    pub async fn write_message(&mut self, message: &str) -> Result<()> {
151        let frame = format!("Content-Length: {}\r\n\r\n{}", message.len(), message);
152        self.writer
153            .write_all(frame.as_bytes())
154            .await
155            .map_err(CopilotError::Transport)?;
156        // Flush may fail on Windows with pipes, but data is still written.
157        let _ = self.writer.flush().await;
158        Ok(())
159    }
160}
161
162/// Message reader for parsing incoming framed messages.
163pub struct MessageReader<R> {
164    reader: BufReader<R>,
165    buffer: Vec<u8>,
166    buffer_pos: usize,
167    buffer_len: usize,
168}
169
170impl<R> MessageReader<R>
171where
172    R: AsyncRead + Unpin + Send,
173{
174    /// Create a new message reader.
175    pub fn new(reader: R) -> Self {
176        Self {
177            reader: BufReader::new(reader),
178            buffer: vec![0u8; 4096],
179            buffer_pos: 0,
180            buffer_len: 0,
181        }
182    }
183
184    /// Read a complete framed message.
185    pub async fn read_message(&mut self) -> Result<String> {
186        // Read headers until empty line
187        let mut content_length: Option<usize> = None;
188
189        loop {
190            let line = self.read_line().await?;
191
192            // Empty line signals end of headers
193            if line.is_empty() {
194                break;
195            }
196
197            // Parse Content-Length header (case-insensitive)
198            let lower_line = line.to_lowercase();
199            if let Some(value) = lower_line.strip_prefix("content-length:") {
200                let value_str = value.trim();
201                content_length = Some(value_str.parse().map_err(|_| {
202                    CopilotError::Protocol(format!("Invalid Content-Length value: {}", value_str))
203                })?);
204            }
205        }
206
207        let content_length = content_length
208            .ok_or_else(|| CopilotError::Protocol("Missing Content-Length header".into()))?;
209
210        // Read the message body
211        let mut message = vec![0u8; content_length];
212        self.read_exact(&mut message).await?;
213
214        String::from_utf8(message)
215            .map_err(|e| CopilotError::Protocol(format!("Invalid UTF-8 in message: {}", e)))
216    }
217
218    /// Read exactly n bytes.
219    async fn read_exact(&mut self, buf: &mut [u8]) -> Result<()> {
220        let mut total_read = 0;
221
222        // First, use any buffered data
223        while total_read < buf.len() && self.buffer_pos < self.buffer_len {
224            buf[total_read] = self.buffer[self.buffer_pos];
225            total_read += 1;
226            self.buffer_pos += 1;
227        }
228
229        // Read remaining directly from reader
230        while total_read < buf.len() {
231            let bytes_read = self
232                .reader
233                .read(&mut buf[total_read..])
234                .await
235                .map_err(CopilotError::Transport)?;
236            if bytes_read == 0 {
237                return Err(CopilotError::ConnectionClosed);
238            }
239            total_read += bytes_read;
240        }
241
242        Ok(())
243    }
244
245    /// Read a single line (up to \r\n or \n).
246    async fn read_line(&mut self) -> Result<String> {
247        let mut line = String::new();
248
249        loop {
250            // Refill buffer if empty
251            if self.buffer_pos >= self.buffer_len {
252                self.fill_buffer(1).await?;
253                if self.buffer_len == 0 {
254                    return Err(CopilotError::ConnectionClosed);
255                }
256            }
257
258            let c = self.buffer[self.buffer_pos] as char;
259            self.buffer_pos += 1;
260
261            if c == '\n' {
262                // Remove trailing \r if present
263                if line.ends_with('\r') {
264                    line.pop();
265                }
266                return Ok(line);
267            }
268
269            line.push(c);
270        }
271    }
272
273    /// Fill buffer with at least min_bytes.
274    async fn fill_buffer(&mut self, min_bytes: usize) -> Result<()> {
275        // Compact buffer if needed
276        if self.buffer_pos > 0 {
277            if self.buffer_pos < self.buffer_len {
278                self.buffer.copy_within(self.buffer_pos..self.buffer_len, 0);
279                self.buffer_len -= self.buffer_pos;
280            } else {
281                self.buffer_len = 0;
282            }
283            self.buffer_pos = 0;
284        }
285
286        // Read more data
287        while self.buffer_len < min_bytes {
288            let bytes_read = self
289                .reader
290                .read(&mut self.buffer[self.buffer_len..])
291                .await?;
292
293            if bytes_read == 0 {
294                // EOF - return what we have
295                return Ok(());
296            }
297
298            self.buffer_len += bytes_read;
299        }
300
301        Ok(())
302    }
303}
304
305impl<T: Transport> MessageFramer<T> {
306    /// Create a new message framer wrapping a transport.
307    pub fn new(transport: T) -> Self {
308        Self {
309            transport,
310            buffer: vec![0u8; 4096],
311            buffer_pos: 0,
312            buffer_len: 0,
313        }
314    }
315
316    /// Read a complete framed message.
317    ///
318    /// Returns the message content (without headers).
319    pub async fn read_message(&mut self) -> Result<String> {
320        // Read headers until empty line
321        let mut content_length: Option<usize> = None;
322
323        loop {
324            let line = self.read_line().await?;
325
326            // Empty line signals end of headers
327            if line.is_empty() {
328                break;
329            }
330
331            // Parse Content-Length header (case-insensitive)
332            let lower_line = line.to_lowercase();
333            if let Some(value) = lower_line.strip_prefix("content-length:") {
334                let value_str = value.trim();
335                content_length = Some(value_str.parse().map_err(|_| {
336                    CopilotError::Protocol(format!("Invalid Content-Length value: {}", value_str))
337                })?);
338            }
339            // Ignore other headers (e.g., Content-Type)
340        }
341
342        let content_length = content_length
343            .ok_or_else(|| CopilotError::Protocol("Missing Content-Length header".into()))?;
344
345        // Read the message body
346        let mut message = vec![0u8; content_length];
347        self.read_exact(&mut message).await?;
348
349        String::from_utf8(message)
350            .map_err(|e| CopilotError::Protocol(format!("Invalid UTF-8 in message: {}", e)))
351    }
352
353    /// Write a message with Content-Length framing.
354    pub async fn write_message(&mut self, message: &str) -> Result<()> {
355        let frame = format!("Content-Length: {}\r\n\r\n{}", message.len(), message);
356        self.transport.write(frame.as_bytes()).await
357    }
358
359    /// Get a reference to the underlying transport.
360    pub fn transport(&self) -> &T {
361        &self.transport
362    }
363
364    /// Get a mutable reference to the underlying transport.
365    pub fn transport_mut(&mut self) -> &mut T {
366        &mut self.transport
367    }
368
369    /// Consume the framer and return the transport.
370    pub fn into_transport(self) -> T {
371        self.transport
372    }
373
374    // =========================================================================
375    // Private helpers
376    // =========================================================================
377
378    /// Read exactly n bytes.
379    async fn read_exact(&mut self, buf: &mut [u8]) -> Result<()> {
380        let mut total_read = 0;
381
382        // First, use any buffered data
383        while total_read < buf.len() && self.buffer_pos < self.buffer_len {
384            buf[total_read] = self.buffer[self.buffer_pos];
385            total_read += 1;
386            self.buffer_pos += 1;
387        }
388
389        // Read remaining directly from transport
390        while total_read < buf.len() {
391            let bytes_read = self.transport.read(&mut buf[total_read..]).await?;
392            if bytes_read == 0 {
393                return Err(CopilotError::ConnectionClosed);
394            }
395            total_read += bytes_read;
396        }
397
398        Ok(())
399    }
400
401    /// Read a single line (up to \r\n or \n).
402    async fn read_line(&mut self) -> Result<String> {
403        let mut line = String::new();
404
405        loop {
406            // Refill buffer if empty
407            if self.buffer_pos >= self.buffer_len {
408                self.fill_buffer(1).await?;
409                if self.buffer_len == 0 {
410                    return Err(CopilotError::ConnectionClosed);
411                }
412            }
413
414            let c = self.buffer[self.buffer_pos] as char;
415            self.buffer_pos += 1;
416
417            if c == '\n' {
418                // Remove trailing \r if present
419                if line.ends_with('\r') {
420                    line.pop();
421                }
422                return Ok(line);
423            }
424
425            line.push(c);
426        }
427    }
428
429    /// Fill buffer with at least min_bytes.
430    async fn fill_buffer(&mut self, min_bytes: usize) -> Result<()> {
431        // Compact buffer if needed
432        if self.buffer_pos > 0 {
433            if self.buffer_pos < self.buffer_len {
434                self.buffer.copy_within(self.buffer_pos..self.buffer_len, 0);
435                self.buffer_len -= self.buffer_pos;
436            } else {
437                self.buffer_len = 0;
438            }
439            self.buffer_pos = 0;
440        }
441
442        // Read more data
443        while self.buffer_len < min_bytes {
444            let bytes_read = self
445                .transport
446                .read(&mut self.buffer[self.buffer_len..])
447                .await?;
448
449            if bytes_read == 0 {
450                // EOF - return what we have
451                return Ok(());
452            }
453
454            self.buffer_len += bytes_read;
455        }
456
457        Ok(())
458    }
459}
460
461// =============================================================================
462// In-Memory Transport (for testing)
463// =============================================================================
464
465/// In-memory transport for testing.
466#[cfg(test)]
467pub struct MemoryTransport {
468    read_data: Vec<u8>,
469    read_pos: usize,
470    write_data: Vec<u8>,
471    open: bool,
472}
473
474#[cfg(test)]
475impl MemoryTransport {
476    /// Create a new memory transport with initial read data.
477    pub fn new(read_data: Vec<u8>) -> Self {
478        Self {
479            read_data,
480            read_pos: 0,
481            write_data: Vec::new(),
482            open: true,
483        }
484    }
485
486    /// Get the data that was written.
487    pub fn written_data(&self) -> &[u8] {
488        &self.write_data
489    }
490}
491
492#[cfg(test)]
493impl Transport for MemoryTransport {
494    fn read<'a>(
495        &'a mut self,
496        buf: &'a mut [u8],
497    ) -> Pin<Box<dyn Future<Output = Result<usize>> + Send + 'a>> {
498        Box::pin(async move {
499            if !self.open {
500                return Err(CopilotError::ConnectionClosed);
501            }
502            let remaining = self.read_data.len() - self.read_pos;
503            let to_read = remaining.min(buf.len());
504            buf[..to_read].copy_from_slice(&self.read_data[self.read_pos..self.read_pos + to_read]);
505            self.read_pos += to_read;
506            Ok(to_read)
507        })
508    }
509
510    fn write<'a>(
511        &'a mut self,
512        data: &'a [u8],
513    ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>> {
514        Box::pin(async move {
515            if !self.open {
516                return Err(CopilotError::ConnectionClosed);
517            }
518            self.write_data.extend_from_slice(data);
519            Ok(())
520        })
521    }
522
523    fn close(&mut self) -> Pin<Box<dyn Future<Output = Result<()>> + Send + '_>> {
524        Box::pin(async move {
525            self.open = false;
526            Ok(())
527        })
528    }
529
530    fn is_open(&self) -> bool {
531        self.open
532    }
533}
534
535#[cfg(test)]
536mod tests {
537    use super::*;
538
539    #[tokio::test]
540    async fn test_read_message() {
541        let data = b"Content-Length: 13\r\n\r\n{\"test\":true}";
542        let transport = MemoryTransport::new(data.to_vec());
543        let mut framer = MessageFramer::new(transport);
544
545        let message = framer.read_message().await.unwrap();
546        assert_eq!(message, "{\"test\":true}");
547    }
548
549    #[tokio::test]
550    async fn test_read_message_lf_only() {
551        // Some implementations use LF only
552        let data = b"Content-Length: 13\n\n{\"test\":true}";
553        let transport = MemoryTransport::new(data.to_vec());
554        let mut framer = MessageFramer::new(transport);
555
556        let message = framer.read_message().await.unwrap();
557        assert_eq!(message, "{\"test\":true}");
558    }
559
560    #[tokio::test]
561    async fn test_read_message_with_extra_headers() {
562        let data = b"Content-Type: application/json\r\nContent-Length: 13\r\n\r\n{\"test\":true}";
563        let transport = MemoryTransport::new(data.to_vec());
564        let mut framer = MessageFramer::new(transport);
565
566        let message = framer.read_message().await.unwrap();
567        assert_eq!(message, "{\"test\":true}");
568    }
569
570    #[tokio::test]
571    async fn test_write_message() {
572        let transport = MemoryTransport::new(Vec::new());
573        let mut framer = MessageFramer::new(transport);
574
575        framer.write_message("{\"test\":true}").await.unwrap();
576
577        let written = framer.transport().written_data();
578        assert_eq!(written, b"Content-Length: 13\r\n\r\n{\"test\":true}");
579    }
580
581    #[tokio::test]
582    async fn test_read_multiple_messages() {
583        let data =
584            b"Content-Length: 13\r\n\r\n{\"test\":true}Content-Length: 14\r\n\r\n{\"test\":false}";
585        let transport = MemoryTransport::new(data.to_vec());
586        let mut framer = MessageFramer::new(transport);
587
588        let msg1 = framer.read_message().await.unwrap();
589        assert_eq!(msg1, "{\"test\":true}");
590
591        let msg2 = framer.read_message().await.unwrap();
592        assert_eq!(msg2, "{\"test\":false}");
593    }
594
595    #[tokio::test]
596    async fn test_missing_content_length() {
597        let data = b"Content-Type: application/json\r\n\r\n{\"test\":true}";
598        let transport = MemoryTransport::new(data.to_vec());
599        let mut framer = MessageFramer::new(transport);
600
601        let result = framer.read_message().await;
602        assert!(result.is_err());
603        if let Err(CopilotError::Protocol(msg)) = result {
604            assert!(msg.contains("Missing Content-Length"));
605        } else {
606            panic!("Expected Protocol error");
607        }
608    }
609
610    #[tokio::test]
611    async fn test_case_insensitive_header() {
612        let data = b"content-length: 13\r\n\r\n{\"test\":true}";
613        let transport = MemoryTransport::new(data.to_vec());
614        let mut framer = MessageFramer::new(transport);
615
616        let message = framer.read_message().await.unwrap();
617        assert_eq!(message, "{\"test\":true}");
618    }
619
620    #[tokio::test]
621    async fn test_transport_closed() {
622        let mut transport = MemoryTransport::new(Vec::new());
623        transport.close().await.unwrap();
624
625        let mut buf = [0u8; 10];
626        let result = transport.read(&mut buf).await;
627        assert!(matches!(result, Err(CopilotError::ConnectionClosed)));
628    }
629}