1pub mod client;
2pub mod config;
3pub mod connect;
4pub mod daemon;
5pub mod protocol;
6pub mod security;
7pub mod server;
8pub mod table;
9
10pub async fn handshake(
15 framed: &mut tokio_util::codec::Framed<tokio::net::UnixStream, protocol::FrameCodec>,
16) -> anyhow::Result<u16> {
17 use futures_util::{SinkExt, StreamExt};
18 framed.send(protocol::Frame::Hello { version: protocol::PROTOCOL_VERSION }).await?;
19 match protocol::Frame::expect_from(framed.next().await)? {
20 protocol::Frame::HelloAck { version } => Ok(version),
21 protocol::Frame::Error { message } => anyhow::bail!("handshake rejected: {message}"),
22 other => anyhow::bail!("expected HelloAck, got {other:?}"),
23 }
24}
25
26pub fn collect_env_vars() -> Vec<(String, String)> {
28 ["TERM", "LANG", "COLORTERM"]
29 .iter()
30 .filter_map(|k| std::env::var(k).ok().map(|v| (k.to_string(), v)))
31 .collect()
32}
33
34const CHANNEL_RELAY_BUFFER: usize = 256;
41
42pub fn spawn_channel_relay<R, W, F, G>(
43 channel_id: u32,
44 read_half: R,
45 write_half: W,
46 on_data: F,
47 on_close: G,
48) -> tokio::sync::mpsc::Sender<bytes::Bytes>
49where
50 R: tokio::io::AsyncRead + Unpin + Send + 'static,
51 W: tokio::io::AsyncWrite + Unpin + Send + 'static,
52 F: Fn(u32, bytes::Bytes) -> bool + Send + 'static,
53 G: Fn(u32) + Send + 'static,
54{
55 use tokio::io::{AsyncReadExt, AsyncWriteExt};
56
57 let (writer_tx, mut writer_rx) =
58 tokio::sync::mpsc::channel::<bytes::Bytes>(CHANNEL_RELAY_BUFFER);
59
60 tokio::spawn(async move {
61 let mut read_half = read_half;
62 let mut buf = vec![0u8; 8192];
63 loop {
64 match read_half.read(&mut buf).await {
65 Ok(0) | Err(_) => {
66 on_close(channel_id);
67 break;
68 }
69 Ok(n) => {
70 if !on_data(channel_id, bytes::Bytes::copy_from_slice(&buf[..n])) {
71 break;
72 }
73 }
74 }
75 }
76 });
77
78 tokio::spawn(async move {
79 let mut write_half = write_half;
80 while let Some(data) = writer_rx.recv().await {
81 if write_half.write_all(&data).await.is_err() {
82 break;
83 }
84 }
85 let _ = write_half.shutdown().await;
87 });
88
89 writer_tx
90}
91
92#[cfg(test)]
93mod tests {
94 use super::*;
95
96 #[test]
97 fn collect_env_vars_only_known_keys() {
98 let vars = collect_env_vars();
99 for (k, _) in &vars {
100 assert!(["TERM", "LANG", "COLORTERM"].contains(&k.as_str()), "unexpected key: {k}");
101 }
102 }
103
104 #[test]
105 fn collect_env_vars_no_duplicates() {
106 let vars = collect_env_vars();
107 let keys: Vec<&str> = vars.iter().map(|(k, _)| k.as_str()).collect();
108 let mut deduped = keys.clone();
109 deduped.sort();
110 deduped.dedup();
111 assert_eq!(keys.len(), deduped.len());
112 }
113
114 #[test]
115 fn collect_env_vars_includes_term_if_set() {
116 if std::env::var("TERM").is_ok() {
117 let vars = collect_env_vars();
118 assert!(vars.iter().any(|(k, _)| k == "TERM"));
119 }
120 }
121
122 #[tokio::test]
123 async fn spawn_channel_relay_on_data_and_close() {
124 use std::sync::{Arc, Mutex};
125 use tokio::io::AsyncWriteExt;
126 use tokio::net::UnixStream;
127
128 let (read_stream, mut feed_stream) = UnixStream::pair().unwrap();
130 let (write_stream, _drain_stream) = UnixStream::pair().unwrap();
131 let (read_half, _) = read_stream.into_split();
132 let (_, write_half) = write_stream.into_split();
133
134 let received = Arc::new(Mutex::new(Vec::<(u32, bytes::Bytes)>::new()));
135 let received_clone = received.clone();
136 let closed = Arc::new(Mutex::new(false));
137 let closed_clone = closed.clone();
138
139 let _writer_tx = spawn_channel_relay(
140 42,
141 read_half,
142 write_half,
143 move |ch, data| {
144 received_clone.lock().unwrap().push((ch, data));
145 true
146 },
147 move |ch| {
148 assert_eq!(ch, 42);
149 *closed_clone.lock().unwrap() = true;
150 },
151 );
152
153 feed_stream.write_all(b"hello").await.unwrap();
155 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
156
157 let data = received.lock().unwrap();
158 assert_eq!(data.len(), 1);
159 assert_eq!(data[0].0, 42);
160 assert_eq!(&data[0].1[..], b"hello");
161 drop(data);
162
163 drop(feed_stream);
165 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
166 assert!(*closed.lock().unwrap());
167 }
168
169 #[tokio::test]
170 async fn spawn_channel_relay_writer_sends_data() {
171 use tokio::io::AsyncReadExt;
172 use tokio::net::UnixStream;
173
174 let (read_stream, _feed_stream) = UnixStream::pair().unwrap();
176 let (write_stream, mut drain_stream) = UnixStream::pair().unwrap();
177 let (read_half, _) = read_stream.into_split();
178 let (_, write_half) = write_stream.into_split();
179
180 let writer_tx = spawn_channel_relay(7, read_half, write_half, |_, _| true, |_| {});
181
182 writer_tx.try_send(bytes::Bytes::from_static(b"hello")).unwrap();
183
184 let mut buf = vec![0u8; 32];
185 let n =
186 tokio::time::timeout(std::time::Duration::from_secs(2), drain_stream.read(&mut buf))
187 .await
188 .unwrap()
189 .unwrap();
190
191 assert_eq!(&buf[..n], b"hello");
192 }
193
194 #[tokio::test]
195 async fn spawn_channel_relay_writer_half_close() {
196 use tokio::io::AsyncReadExt;
197 use tokio::net::UnixStream;
198
199 let (read_stream, _feed_stream) = UnixStream::pair().unwrap();
200 let (write_stream, mut drain_stream) = UnixStream::pair().unwrap();
201 let (read_half, _) = read_stream.into_split();
202 let (_, write_half) = write_stream.into_split();
203
204 let writer_tx = spawn_channel_relay(7, read_half, write_half, |_, _| true, |_| {});
205
206 writer_tx.try_send(bytes::Bytes::from_static(b"request")).unwrap();
208 drop(writer_tx);
209
210 let mut buf = vec![0u8; 32];
212 let n =
213 tokio::time::timeout(std::time::Duration::from_secs(2), drain_stream.read(&mut buf))
214 .await
215 .unwrap()
216 .unwrap();
217 assert_eq!(&buf[..n], b"request");
218
219 let n =
221 tokio::time::timeout(std::time::Duration::from_secs(2), drain_stream.read(&mut buf))
222 .await
223 .unwrap()
224 .unwrap();
225 assert_eq!(n, 0, "expected EOF from graceful half-close");
226 }
227}