1use crate::error::{CopilotError, Result};
9use std::future::Future;
10use std::pin::Pin;
11use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, BufReader};
12
13pub trait Transport: Send + Sync {
23 fn read<'a>(
26 &'a mut self,
27 buf: &'a mut [u8],
28 ) -> Pin<Box<dyn Future<Output = Result<usize>> + Send + 'a>>;
29
30 fn write<'a>(
32 &'a mut self,
33 data: &'a [u8],
34 ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>;
35
36 fn close(&mut self) -> Pin<Box<dyn Future<Output = Result<()>> + Send + '_>>;
38
39 fn is_open(&self) -> bool;
41}
42
43pub struct StdioTransport {
49 stdin: tokio::process::ChildStdin,
50 stdout: BufReader<tokio::process::ChildStdout>,
51 open: bool,
52}
53
54impl StdioTransport {
55 pub fn new(stdin: tokio::process::ChildStdin, stdout: tokio::process::ChildStdout) -> Self {
57 Self {
58 stdin,
59 stdout: BufReader::new(stdout),
60 open: true,
61 }
62 }
63
64 pub fn split(self) -> (tokio::process::ChildStdin, tokio::process::ChildStdout) {
66 (self.stdin, self.stdout.into_inner())
67 }
68}
69
70impl Transport for StdioTransport {
71 fn read<'a>(
72 &'a mut self,
73 buf: &'a mut [u8],
74 ) -> Pin<Box<dyn Future<Output = Result<usize>> + Send + 'a>> {
75 Box::pin(async move {
76 if !self.open {
77 return Err(CopilotError::ConnectionClosed);
78 }
79 self.stdout.read(buf).await.map_err(CopilotError::Transport)
80 })
81 }
82
83 fn write<'a>(
84 &'a mut self,
85 data: &'a [u8],
86 ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>> {
87 Box::pin(async move {
88 if !self.open {
89 return Err(CopilotError::ConnectionClosed);
90 }
91 self.stdin
92 .write_all(data)
93 .await
94 .map_err(CopilotError::Transport)?;
95 let _ = self.stdin.flush().await;
97 Ok(())
98 })
99 }
100
101 fn close(&mut self) -> Pin<Box<dyn Future<Output = Result<()>> + Send + '_>> {
102 Box::pin(async move {
103 self.open = false;
104 Ok(())
105 })
106 }
107
108 fn is_open(&self) -> bool {
109 self.open
110 }
111}
112
113pub struct MessageFramer<T: Transport> {
129 transport: T,
130 buffer: Vec<u8>,
131 buffer_pos: usize,
132 buffer_len: usize,
133}
134
135pub struct MessageWriter<W> {
137 writer: W,
138}
139
140impl<W> MessageWriter<W>
141where
142 W: AsyncWrite + Unpin + Send,
143{
144 pub fn new(writer: W) -> Self {
146 Self { writer }
147 }
148
149 pub async fn write_message(&mut self, message: &str) -> Result<()> {
151 let frame = format!("Content-Length: {}\r\n\r\n{}", message.len(), message);
152 self.writer
153 .write_all(frame.as_bytes())
154 .await
155 .map_err(CopilotError::Transport)?;
156 let _ = self.writer.flush().await;
158 Ok(())
159 }
160}
161
162pub struct MessageReader<R> {
164 reader: BufReader<R>,
165 buffer: Vec<u8>,
166 buffer_pos: usize,
167 buffer_len: usize,
168}
169
170impl<R> MessageReader<R>
171where
172 R: AsyncRead + Unpin + Send,
173{
174 pub fn new(reader: R) -> Self {
176 Self {
177 reader: BufReader::new(reader),
178 buffer: vec![0u8; 4096],
179 buffer_pos: 0,
180 buffer_len: 0,
181 }
182 }
183
184 pub async fn read_message(&mut self) -> Result<String> {
186 let mut content_length: Option<usize> = None;
188
189 loop {
190 let line = self.read_line().await?;
191
192 if line.is_empty() {
194 break;
195 }
196
197 let lower_line = line.to_lowercase();
199 if let Some(value) = lower_line.strip_prefix("content-length:") {
200 let value_str = value.trim();
201 content_length = Some(value_str.parse().map_err(|_| {
202 CopilotError::Protocol(format!("Invalid Content-Length value: {}", value_str))
203 })?);
204 }
205 }
206
207 let content_length = content_length
208 .ok_or_else(|| CopilotError::Protocol("Missing Content-Length header".into()))?;
209
210 let mut message = vec![0u8; content_length];
212 self.read_exact(&mut message).await?;
213
214 String::from_utf8(message)
215 .map_err(|e| CopilotError::Protocol(format!("Invalid UTF-8 in message: {}", e)))
216 }
217
218 async fn read_exact(&mut self, buf: &mut [u8]) -> Result<()> {
220 let mut total_read = 0;
221
222 while total_read < buf.len() && self.buffer_pos < self.buffer_len {
224 buf[total_read] = self.buffer[self.buffer_pos];
225 total_read += 1;
226 self.buffer_pos += 1;
227 }
228
229 while total_read < buf.len() {
231 let bytes_read = self
232 .reader
233 .read(&mut buf[total_read..])
234 .await
235 .map_err(CopilotError::Transport)?;
236 if bytes_read == 0 {
237 return Err(CopilotError::ConnectionClosed);
238 }
239 total_read += bytes_read;
240 }
241
242 Ok(())
243 }
244
245 async fn read_line(&mut self) -> Result<String> {
247 let mut line = String::new();
248
249 loop {
250 if self.buffer_pos >= self.buffer_len {
252 self.fill_buffer(1).await?;
253 if self.buffer_len == 0 {
254 return Err(CopilotError::ConnectionClosed);
255 }
256 }
257
258 let c = self.buffer[self.buffer_pos] as char;
259 self.buffer_pos += 1;
260
261 if c == '\n' {
262 if line.ends_with('\r') {
264 line.pop();
265 }
266 return Ok(line);
267 }
268
269 line.push(c);
270 }
271 }
272
273 async fn fill_buffer(&mut self, min_bytes: usize) -> Result<()> {
275 if self.buffer_pos > 0 {
277 if self.buffer_pos < self.buffer_len {
278 self.buffer.copy_within(self.buffer_pos..self.buffer_len, 0);
279 self.buffer_len -= self.buffer_pos;
280 } else {
281 self.buffer_len = 0;
282 }
283 self.buffer_pos = 0;
284 }
285
286 while self.buffer_len < min_bytes {
288 let bytes_read = self
289 .reader
290 .read(&mut self.buffer[self.buffer_len..])
291 .await?;
292
293 if bytes_read == 0 {
294 return Ok(());
296 }
297
298 self.buffer_len += bytes_read;
299 }
300
301 Ok(())
302 }
303}
304
305impl<T: Transport> MessageFramer<T> {
306 pub fn new(transport: T) -> Self {
308 Self {
309 transport,
310 buffer: vec![0u8; 4096],
311 buffer_pos: 0,
312 buffer_len: 0,
313 }
314 }
315
316 pub async fn read_message(&mut self) -> Result<String> {
320 let mut content_length: Option<usize> = None;
322
323 loop {
324 let line = self.read_line().await?;
325
326 if line.is_empty() {
328 break;
329 }
330
331 let lower_line = line.to_lowercase();
333 if let Some(value) = lower_line.strip_prefix("content-length:") {
334 let value_str = value.trim();
335 content_length = Some(value_str.parse().map_err(|_| {
336 CopilotError::Protocol(format!("Invalid Content-Length value: {}", value_str))
337 })?);
338 }
339 }
341
342 let content_length = content_length
343 .ok_or_else(|| CopilotError::Protocol("Missing Content-Length header".into()))?;
344
345 let mut message = vec![0u8; content_length];
347 self.read_exact(&mut message).await?;
348
349 String::from_utf8(message)
350 .map_err(|e| CopilotError::Protocol(format!("Invalid UTF-8 in message: {}", e)))
351 }
352
353 pub async fn write_message(&mut self, message: &str) -> Result<()> {
355 let frame = format!("Content-Length: {}\r\n\r\n{}", message.len(), message);
356 self.transport.write(frame.as_bytes()).await
357 }
358
359 pub fn transport(&self) -> &T {
361 &self.transport
362 }
363
364 pub fn transport_mut(&mut self) -> &mut T {
366 &mut self.transport
367 }
368
369 pub fn into_transport(self) -> T {
371 self.transport
372 }
373
374 async fn read_exact(&mut self, buf: &mut [u8]) -> Result<()> {
380 let mut total_read = 0;
381
382 while total_read < buf.len() && self.buffer_pos < self.buffer_len {
384 buf[total_read] = self.buffer[self.buffer_pos];
385 total_read += 1;
386 self.buffer_pos += 1;
387 }
388
389 while total_read < buf.len() {
391 let bytes_read = self.transport.read(&mut buf[total_read..]).await?;
392 if bytes_read == 0 {
393 return Err(CopilotError::ConnectionClosed);
394 }
395 total_read += bytes_read;
396 }
397
398 Ok(())
399 }
400
401 async fn read_line(&mut self) -> Result<String> {
403 let mut line = String::new();
404
405 loop {
406 if self.buffer_pos >= self.buffer_len {
408 self.fill_buffer(1).await?;
409 if self.buffer_len == 0 {
410 return Err(CopilotError::ConnectionClosed);
411 }
412 }
413
414 let c = self.buffer[self.buffer_pos] as char;
415 self.buffer_pos += 1;
416
417 if c == '\n' {
418 if line.ends_with('\r') {
420 line.pop();
421 }
422 return Ok(line);
423 }
424
425 line.push(c);
426 }
427 }
428
429 async fn fill_buffer(&mut self, min_bytes: usize) -> Result<()> {
431 if self.buffer_pos > 0 {
433 if self.buffer_pos < self.buffer_len {
434 self.buffer.copy_within(self.buffer_pos..self.buffer_len, 0);
435 self.buffer_len -= self.buffer_pos;
436 } else {
437 self.buffer_len = 0;
438 }
439 self.buffer_pos = 0;
440 }
441
442 while self.buffer_len < min_bytes {
444 let bytes_read = self
445 .transport
446 .read(&mut self.buffer[self.buffer_len..])
447 .await?;
448
449 if bytes_read == 0 {
450 return Ok(());
452 }
453
454 self.buffer_len += bytes_read;
455 }
456
457 Ok(())
458 }
459}
460
461#[cfg(test)]
467pub struct MemoryTransport {
468 read_data: Vec<u8>,
469 read_pos: usize,
470 write_data: Vec<u8>,
471 open: bool,
472}
473
474#[cfg(test)]
475impl MemoryTransport {
476 pub fn new(read_data: Vec<u8>) -> Self {
478 Self {
479 read_data,
480 read_pos: 0,
481 write_data: Vec::new(),
482 open: true,
483 }
484 }
485
486 pub fn written_data(&self) -> &[u8] {
488 &self.write_data
489 }
490}
491
492#[cfg(test)]
493impl Transport for MemoryTransport {
494 fn read<'a>(
495 &'a mut self,
496 buf: &'a mut [u8],
497 ) -> Pin<Box<dyn Future<Output = Result<usize>> + Send + 'a>> {
498 Box::pin(async move {
499 if !self.open {
500 return Err(CopilotError::ConnectionClosed);
501 }
502 let remaining = self.read_data.len() - self.read_pos;
503 let to_read = remaining.min(buf.len());
504 buf[..to_read].copy_from_slice(&self.read_data[self.read_pos..self.read_pos + to_read]);
505 self.read_pos += to_read;
506 Ok(to_read)
507 })
508 }
509
510 fn write<'a>(
511 &'a mut self,
512 data: &'a [u8],
513 ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>> {
514 Box::pin(async move {
515 if !self.open {
516 return Err(CopilotError::ConnectionClosed);
517 }
518 self.write_data.extend_from_slice(data);
519 Ok(())
520 })
521 }
522
523 fn close(&mut self) -> Pin<Box<dyn Future<Output = Result<()>> + Send + '_>> {
524 Box::pin(async move {
525 self.open = false;
526 Ok(())
527 })
528 }
529
530 fn is_open(&self) -> bool {
531 self.open
532 }
533}
534
535#[cfg(test)]
536mod tests {
537 use super::*;
538
539 #[tokio::test]
540 async fn test_read_message() {
541 let data = b"Content-Length: 13\r\n\r\n{\"test\":true}";
542 let transport = MemoryTransport::new(data.to_vec());
543 let mut framer = MessageFramer::new(transport);
544
545 let message = framer.read_message().await.unwrap();
546 assert_eq!(message, "{\"test\":true}");
547 }
548
549 #[tokio::test]
550 async fn test_read_message_lf_only() {
551 let data = b"Content-Length: 13\n\n{\"test\":true}";
553 let transport = MemoryTransport::new(data.to_vec());
554 let mut framer = MessageFramer::new(transport);
555
556 let message = framer.read_message().await.unwrap();
557 assert_eq!(message, "{\"test\":true}");
558 }
559
560 #[tokio::test]
561 async fn test_read_message_with_extra_headers() {
562 let data = b"Content-Type: application/json\r\nContent-Length: 13\r\n\r\n{\"test\":true}";
563 let transport = MemoryTransport::new(data.to_vec());
564 let mut framer = MessageFramer::new(transport);
565
566 let message = framer.read_message().await.unwrap();
567 assert_eq!(message, "{\"test\":true}");
568 }
569
570 #[tokio::test]
571 async fn test_write_message() {
572 let transport = MemoryTransport::new(Vec::new());
573 let mut framer = MessageFramer::new(transport);
574
575 framer.write_message("{\"test\":true}").await.unwrap();
576
577 let written = framer.transport().written_data();
578 assert_eq!(written, b"Content-Length: 13\r\n\r\n{\"test\":true}");
579 }
580
581 #[tokio::test]
582 async fn test_read_multiple_messages() {
583 let data =
584 b"Content-Length: 13\r\n\r\n{\"test\":true}Content-Length: 14\r\n\r\n{\"test\":false}";
585 let transport = MemoryTransport::new(data.to_vec());
586 let mut framer = MessageFramer::new(transport);
587
588 let msg1 = framer.read_message().await.unwrap();
589 assert_eq!(msg1, "{\"test\":true}");
590
591 let msg2 = framer.read_message().await.unwrap();
592 assert_eq!(msg2, "{\"test\":false}");
593 }
594
595 #[tokio::test]
596 async fn test_missing_content_length() {
597 let data = b"Content-Type: application/json\r\n\r\n{\"test\":true}";
598 let transport = MemoryTransport::new(data.to_vec());
599 let mut framer = MessageFramer::new(transport);
600
601 let result = framer.read_message().await;
602 assert!(result.is_err());
603 if let Err(CopilotError::Protocol(msg)) = result {
604 assert!(msg.contains("Missing Content-Length"));
605 } else {
606 panic!("Expected Protocol error");
607 }
608 }
609
610 #[tokio::test]
611 async fn test_case_insensitive_header() {
612 let data = b"content-length: 13\r\n\r\n{\"test\":true}";
613 let transport = MemoryTransport::new(data.to_vec());
614 let mut framer = MessageFramer::new(transport);
615
616 let message = framer.read_message().await.unwrap();
617 assert_eq!(message, "{\"test\":true}");
618 }
619
620 #[tokio::test]
621 async fn test_transport_closed() {
622 let mut transport = MemoryTransport::new(Vec::new());
623 transport.close().await.unwrap();
624
625 let mut buf = [0u8; 10];
626 let result = transport.read(&mut buf).await;
627 assert!(matches!(result, Err(CopilotError::ConnectionClosed)));
628 }
629}