agent_client_protocol_tokio/
lib.rs1mod acp_agent;
8
9pub use acp_agent::{AcpAgent, LineDirection};
10use agent_client_protocol::{ByteStreams, ConnectTo, Role};
11use std::sync::Arc;
12use tokio_util::compat::{TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt};
13
14pub struct Stdio {
15 debug_callback: Option<Arc<dyn Fn(&str, LineDirection) + Send + Sync + 'static>>,
16}
17
18impl std::fmt::Debug for Stdio {
19 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
20 f.debug_struct("Stdio").finish_non_exhaustive()
21 }
22}
23
24impl Stdio {
25 #[must_use]
26 pub fn new() -> Self {
27 Self {
28 debug_callback: None,
29 }
30 }
31
32 #[must_use]
33 pub fn with_debug<F>(mut self, callback: F) -> Self
34 where
35 F: Fn(&str, LineDirection) + Send + Sync + 'static,
36 {
37 self.debug_callback = Some(Arc::new(callback));
38 self
39 }
40}
41
42impl Default for Stdio {
43 fn default() -> Self {
44 Self::new()
45 }
46}
47
48impl<Counterpart: Role> ConnectTo<Counterpart> for Stdio {
49 async fn connect_to(
50 self,
51 client: impl ConnectTo<Counterpart::Counterpart>,
52 ) -> Result<(), agent_client_protocol::Error> {
53 if let Some(callback) = self.debug_callback {
54 use futures::AsyncBufReadExt;
55 use futures::AsyncWriteExt;
56 use futures::StreamExt;
57 use futures::io::BufReader;
58
59 let stdin = tokio::io::stdin();
61 let stdout = tokio::io::stdout();
62
63 let incoming_callback = callback.clone();
65 let incoming_lines = Box::pin(BufReader::new(stdin.compat()).lines().inspect(
66 move |result| {
67 if let Ok(line) = result {
68 incoming_callback(line, LineDirection::Stdin);
69 }
70 },
71 ))
72 as std::pin::Pin<Box<dyn futures::Stream<Item = std::io::Result<String>> + Send>>;
73
74 let outgoing_sink = Box::pin(futures::sink::unfold(
76 (stdout.compat_write(), callback),
77 async move |(mut writer, callback), line: String| {
78 callback(&line, LineDirection::Stdout);
79 let mut bytes = line.into_bytes();
80 bytes.push(b'\n');
81 writer.write_all(&bytes).await?;
82 Ok::<_, std::io::Error>((writer, callback))
83 },
84 ))
85 as std::pin::Pin<Box<dyn futures::Sink<String, Error = std::io::Error> + Send>>;
86
87 ConnectTo::<Counterpart>::connect_to(
88 agent_client_protocol::Lines::new(outgoing_sink, incoming_lines),
89 client,
90 )
91 .await
92 } else {
93 ConnectTo::<Counterpart>::connect_to(
95 ByteStreams::new(
96 tokio::io::stdout().compat_write(),
97 tokio::io::stdin().compat(),
98 ),
99 client,
100 )
101 .await
102 }
103 }
104}