1pub mod client;
2pub mod config;
3pub mod connect;
4pub mod daemon;
5pub mod protocol;
6pub mod security;
7pub mod server;
8
9pub async fn handshake(
14 framed: &mut tokio_util::codec::Framed<tokio::net::UnixStream, protocol::FrameCodec>,
15) -> anyhow::Result<u16> {
16 use futures_util::{SinkExt, StreamExt};
17 framed.send(protocol::Frame::Hello { version: protocol::PROTOCOL_VERSION }).await?;
18 match protocol::Frame::expect_from(framed.next().await)? {
19 protocol::Frame::HelloAck { version } => Ok(version),
20 protocol::Frame::Error { message } => anyhow::bail!("handshake rejected: {message}"),
21 other => anyhow::bail!("expected HelloAck, got {other:?}"),
22 }
23}
24
25pub fn collect_env_vars() -> Vec<(String, String)> {
27 ["TERM", "LANG", "COLORTERM"]
28 .iter()
29 .filter_map(|k| std::env::var(k).ok().map(|v| (k.to_string(), v)))
30 .collect()
31}
32
33pub fn spawn_channel_relay<F, G>(
38 channel_id: u32,
39 read_half: tokio::net::unix::OwnedReadHalf,
40 write_half: tokio::net::unix::OwnedWriteHalf,
41 on_data: F,
42 on_close: G,
43) -> tokio::sync::mpsc::UnboundedSender<bytes::Bytes>
44where
45 F: Fn(u32, bytes::Bytes) -> bool + Send + 'static,
46 G: Fn(u32) + Send + 'static,
47{
48 use tokio::io::{AsyncReadExt, AsyncWriteExt};
49
50 let (writer_tx, mut writer_rx) = tokio::sync::mpsc::unbounded_channel::<bytes::Bytes>();
51
52 tokio::spawn(async move {
53 let mut read_half = read_half;
54 let mut buf = vec![0u8; 8192];
55 loop {
56 match read_half.read(&mut buf).await {
57 Ok(0) | Err(_) => {
58 on_close(channel_id);
59 break;
60 }
61 Ok(n) => {
62 if !on_data(channel_id, bytes::Bytes::copy_from_slice(&buf[..n])) {
63 break;
64 }
65 }
66 }
67 }
68 });
69
70 tokio::spawn(async move {
71 let mut write_half = write_half;
72 while let Some(data) = writer_rx.recv().await {
73 if write_half.write_all(&data).await.is_err() {
74 break;
75 }
76 }
77 });
78
79 writer_tx
80}
81
82#[cfg(test)]
83mod tests {
84 use super::*;
85
86 #[test]
87 fn collect_env_vars_only_known_keys() {
88 let vars = collect_env_vars();
89 for (k, _) in &vars {
90 assert!(["TERM", "LANG", "COLORTERM"].contains(&k.as_str()), "unexpected key: {k}");
91 }
92 }
93
94 #[test]
95 fn collect_env_vars_no_duplicates() {
96 let vars = collect_env_vars();
97 let keys: Vec<&str> = vars.iter().map(|(k, _)| k.as_str()).collect();
98 let mut deduped = keys.clone();
99 deduped.sort();
100 deduped.dedup();
101 assert_eq!(keys.len(), deduped.len());
102 }
103
104 #[test]
105 fn collect_env_vars_includes_term_if_set() {
106 if std::env::var("TERM").is_ok() {
107 let vars = collect_env_vars();
108 assert!(vars.iter().any(|(k, _)| k == "TERM"));
109 }
110 }
111
112 #[tokio::test]
113 async fn spawn_channel_relay_on_data_and_close() {
114 use std::sync::{Arc, Mutex};
115 use tokio::io::AsyncWriteExt;
116 use tokio::net::UnixStream;
117
118 let (read_stream, mut feed_stream) = UnixStream::pair().unwrap();
120 let (write_stream, _drain_stream) = UnixStream::pair().unwrap();
121 let (read_half, _) = read_stream.into_split();
122 let (_, write_half) = write_stream.into_split();
123
124 let received = Arc::new(Mutex::new(Vec::<(u32, bytes::Bytes)>::new()));
125 let received_clone = received.clone();
126 let closed = Arc::new(Mutex::new(false));
127 let closed_clone = closed.clone();
128
129 let _writer_tx = spawn_channel_relay(
130 42,
131 read_half,
132 write_half,
133 move |ch, data| {
134 received_clone.lock().unwrap().push((ch, data));
135 true
136 },
137 move |ch| {
138 assert_eq!(ch, 42);
139 *closed_clone.lock().unwrap() = true;
140 },
141 );
142
143 feed_stream.write_all(b"hello").await.unwrap();
145 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
146
147 let data = received.lock().unwrap();
148 assert_eq!(data.len(), 1);
149 assert_eq!(data[0].0, 42);
150 assert_eq!(&data[0].1[..], b"hello");
151 drop(data);
152
153 drop(feed_stream);
155 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
156 assert!(*closed.lock().unwrap());
157 }
158
159 #[tokio::test]
160 async fn spawn_channel_relay_writer_sends_data() {
161 use tokio::io::AsyncReadExt;
162 use tokio::net::UnixStream;
163
164 let (read_stream, _feed_stream) = UnixStream::pair().unwrap();
166 let (write_stream, mut drain_stream) = UnixStream::pair().unwrap();
167 let (read_half, _) = read_stream.into_split();
168 let (_, write_half) = write_stream.into_split();
169
170 let writer_tx = spawn_channel_relay(7, read_half, write_half, |_, _| true, |_| {});
171
172 writer_tx.send(bytes::Bytes::from_static(b"hello")).unwrap();
173
174 let mut buf = vec![0u8; 32];
175 let n =
176 tokio::time::timeout(std::time::Duration::from_secs(2), drain_stream.read(&mut buf))
177 .await
178 .unwrap()
179 .unwrap();
180
181 assert_eq!(&buf[..n], b"hello");
182 }
183}