1pub mod client;
2pub mod config;
3pub mod connect;
4pub mod daemon;
5pub mod protocol;
6pub mod security;
7pub mod server;
8
9pub fn collect_env_vars() -> Vec<(String, String)> {
11 ["TERM", "LANG", "COLORTERM"]
12 .iter()
13 .filter_map(|k| std::env::var(k).ok().map(|v| (k.to_string(), v)))
14 .collect()
15}
16
17pub fn spawn_channel_relay<F, G>(
22 channel_id: u32,
23 read_half: tokio::net::unix::OwnedReadHalf,
24 write_half: tokio::net::unix::OwnedWriteHalf,
25 on_data: F,
26 on_close: G,
27) -> tokio::sync::mpsc::UnboundedSender<bytes::Bytes>
28where
29 F: Fn(u32, bytes::Bytes) -> bool + Send + 'static,
30 G: Fn(u32) + Send + 'static,
31{
32 use tokio::io::{AsyncReadExt, AsyncWriteExt};
33
34 let (writer_tx, mut writer_rx) = tokio::sync::mpsc::unbounded_channel::<bytes::Bytes>();
35
36 tokio::spawn(async move {
37 let mut read_half = read_half;
38 let mut buf = vec![0u8; 8192];
39 loop {
40 match read_half.read(&mut buf).await {
41 Ok(0) | Err(_) => {
42 on_close(channel_id);
43 break;
44 }
45 Ok(n) => {
46 if !on_data(channel_id, bytes::Bytes::copy_from_slice(&buf[..n])) {
47 break;
48 }
49 }
50 }
51 }
52 });
53
54 tokio::spawn(async move {
55 let mut write_half = write_half;
56 while let Some(data) = writer_rx.recv().await {
57 if write_half.write_all(&data).await.is_err() {
58 break;
59 }
60 }
61 });
62
63 writer_tx
64}
65
66#[cfg(test)]
67mod tests {
68 use super::*;
69
70 #[test]
71 fn collect_env_vars_only_known_keys() {
72 let vars = collect_env_vars();
73 for (k, _) in &vars {
74 assert!(["TERM", "LANG", "COLORTERM"].contains(&k.as_str()), "unexpected key: {k}");
75 }
76 }
77
78 #[test]
79 fn collect_env_vars_no_duplicates() {
80 let vars = collect_env_vars();
81 let keys: Vec<&str> = vars.iter().map(|(k, _)| k.as_str()).collect();
82 let mut deduped = keys.clone();
83 deduped.sort();
84 deduped.dedup();
85 assert_eq!(keys.len(), deduped.len());
86 }
87
88 #[test]
89 fn collect_env_vars_includes_term_if_set() {
90 if std::env::var("TERM").is_ok() {
91 let vars = collect_env_vars();
92 assert!(vars.iter().any(|(k, _)| k == "TERM"));
93 }
94 }
95
96 #[tokio::test]
97 async fn spawn_channel_relay_on_data_and_close() {
98 use std::sync::{Arc, Mutex};
99 use tokio::io::AsyncWriteExt;
100 use tokio::net::UnixStream;
101
102 let (read_stream, mut feed_stream) = UnixStream::pair().unwrap();
104 let (write_stream, _drain_stream) = UnixStream::pair().unwrap();
105 let (read_half, _) = read_stream.into_split();
106 let (_, write_half) = write_stream.into_split();
107
108 let received = Arc::new(Mutex::new(Vec::<(u32, bytes::Bytes)>::new()));
109 let received_clone = received.clone();
110 let closed = Arc::new(Mutex::new(false));
111 let closed_clone = closed.clone();
112
113 let _writer_tx = spawn_channel_relay(
114 42,
115 read_half,
116 write_half,
117 move |ch, data| {
118 received_clone.lock().unwrap().push((ch, data));
119 true
120 },
121 move |ch| {
122 assert_eq!(ch, 42);
123 *closed_clone.lock().unwrap() = true;
124 },
125 );
126
127 feed_stream.write_all(b"hello").await.unwrap();
129 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
130
131 let data = received.lock().unwrap();
132 assert_eq!(data.len(), 1);
133 assert_eq!(data[0].0, 42);
134 assert_eq!(&data[0].1[..], b"hello");
135 drop(data);
136
137 drop(feed_stream);
139 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
140 assert!(*closed.lock().unwrap());
141 }
142
143 #[tokio::test]
144 async fn spawn_channel_relay_writer_sends_data() {
145 use tokio::io::AsyncReadExt;
146 use tokio::net::UnixStream;
147
148 let (read_stream, _feed_stream) = UnixStream::pair().unwrap();
150 let (write_stream, mut drain_stream) = UnixStream::pair().unwrap();
151 let (read_half, _) = read_stream.into_split();
152 let (_, write_half) = write_stream.into_split();
153
154 let writer_tx = spawn_channel_relay(7, read_half, write_half, |_, _| true, |_| {});
155
156 writer_tx.send(bytes::Bytes::from_static(b"hello")).unwrap();
157
158 let mut buf = vec![0u8; 32];
159 let n =
160 tokio::time::timeout(std::time::Duration::from_secs(2), drain_stream.read(&mut buf))
161 .await
162 .unwrap()
163 .unwrap();
164
165 assert_eq!(&buf[..n], b"hello");
166 }
167}