1pub mod client;
2pub mod connect;
3pub mod daemon;
4pub mod protocol;
5pub mod security;
6pub mod server;
7
8pub fn collect_env_vars() -> Vec<(String, String)> {
10 ["TERM", "LANG", "COLORTERM"]
11 .iter()
12 .filter_map(|k| std::env::var(k).ok().map(|v| (k.to_string(), v)))
13 .collect()
14}
15
16pub fn spawn_channel_relay<F, G>(
21 channel_id: u32,
22 read_half: tokio::net::unix::OwnedReadHalf,
23 write_half: tokio::net::unix::OwnedWriteHalf,
24 on_data: F,
25 on_close: G,
26) -> tokio::sync::mpsc::UnboundedSender<bytes::Bytes>
27where
28 F: Fn(u32, bytes::Bytes) -> bool + Send + 'static,
29 G: Fn(u32) + Send + 'static,
30{
31 use tokio::io::{AsyncReadExt, AsyncWriteExt};
32
33 let (writer_tx, mut writer_rx) = tokio::sync::mpsc::unbounded_channel::<bytes::Bytes>();
34
35 tokio::spawn(async move {
36 let mut read_half = read_half;
37 let mut buf = vec![0u8; 8192];
38 loop {
39 match read_half.read(&mut buf).await {
40 Ok(0) | Err(_) => {
41 on_close(channel_id);
42 break;
43 }
44 Ok(n) => {
45 if !on_data(channel_id, bytes::Bytes::copy_from_slice(&buf[..n])) {
46 break;
47 }
48 }
49 }
50 }
51 });
52
53 tokio::spawn(async move {
54 let mut write_half = write_half;
55 while let Some(data) = writer_rx.recv().await {
56 if write_half.write_all(&data).await.is_err() {
57 break;
58 }
59 }
60 });
61
62 writer_tx
63}
64
65#[cfg(test)]
66mod tests {
67 use super::*;
68
69 #[test]
70 fn collect_env_vars_only_known_keys() {
71 let vars = collect_env_vars();
72 for (k, _) in &vars {
73 assert!(["TERM", "LANG", "COLORTERM"].contains(&k.as_str()), "unexpected key: {k}");
74 }
75 }
76
77 #[test]
78 fn collect_env_vars_no_duplicates() {
79 let vars = collect_env_vars();
80 let keys: Vec<&str> = vars.iter().map(|(k, _)| k.as_str()).collect();
81 let mut deduped = keys.clone();
82 deduped.sort();
83 deduped.dedup();
84 assert_eq!(keys.len(), deduped.len());
85 }
86
87 #[test]
88 fn collect_env_vars_includes_term_if_set() {
89 if std::env::var("TERM").is_ok() {
90 let vars = collect_env_vars();
91 assert!(vars.iter().any(|(k, _)| k == "TERM"));
92 }
93 }
94
95 #[tokio::test]
96 async fn spawn_channel_relay_on_data_and_close() {
97 use std::sync::{Arc, Mutex};
98 use tokio::io::AsyncWriteExt;
99 use tokio::net::UnixStream;
100
101 let (read_stream, mut feed_stream) = UnixStream::pair().unwrap();
103 let (write_stream, _drain_stream) = UnixStream::pair().unwrap();
104 let (read_half, _) = read_stream.into_split();
105 let (_, write_half) = write_stream.into_split();
106
107 let received = Arc::new(Mutex::new(Vec::<(u32, bytes::Bytes)>::new()));
108 let received_clone = received.clone();
109 let closed = Arc::new(Mutex::new(false));
110 let closed_clone = closed.clone();
111
112 let _writer_tx = spawn_channel_relay(
113 42,
114 read_half,
115 write_half,
116 move |ch, data| {
117 received_clone.lock().unwrap().push((ch, data));
118 true
119 },
120 move |ch| {
121 assert_eq!(ch, 42);
122 *closed_clone.lock().unwrap() = true;
123 },
124 );
125
126 feed_stream.write_all(b"hello").await.unwrap();
128 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
129
130 let data = received.lock().unwrap();
131 assert_eq!(data.len(), 1);
132 assert_eq!(data[0].0, 42);
133 assert_eq!(&data[0].1[..], b"hello");
134 drop(data);
135
136 drop(feed_stream);
138 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
139 assert!(*closed.lock().unwrap());
140 }
141
142 #[tokio::test]
143 async fn spawn_channel_relay_writer_sends_data() {
144 use tokio::io::AsyncReadExt;
145 use tokio::net::UnixStream;
146
147 let (read_stream, _feed_stream) = UnixStream::pair().unwrap();
149 let (write_stream, mut drain_stream) = UnixStream::pair().unwrap();
150 let (read_half, _) = read_stream.into_split();
151 let (_, write_half) = write_stream.into_split();
152
153 let writer_tx = spawn_channel_relay(7, read_half, write_half, |_, _| true, |_| {});
154
155 writer_tx.send(bytes::Bytes::from_static(b"hello")).unwrap();
156
157 let mut buf = vec![0u8; 32];
158 let n =
159 tokio::time::timeout(std::time::Duration::from_secs(2), drain_stream.read(&mut buf))
160 .await
161 .unwrap()
162 .unwrap();
163
164 assert_eq!(&buf[..n], b"hello");
165 }
166}