Skip to main content

gritty/
lib.rs

1pub mod client;
2pub mod config;
3pub mod connect;
4pub mod daemon;
5pub mod protocol;
6pub mod security;
7pub mod server;
8
9/// Collect TERM/LANG/COLORTERM from the environment for forwarding to remote sessions.
10pub fn collect_env_vars() -> Vec<(String, String)> {
11    ["TERM", "LANG", "COLORTERM"]
12        .iter()
13        .filter_map(|k| std::env::var(k).ok().map(|v| (k.to_string(), v)))
14        .collect()
15}
16
17/// Spawn bidirectional relay tasks for a Unix stream channel.
18///
19/// Reader task reads from the stream and calls `on_data`/`on_close`.
20/// Writer task drains the returned sender and writes to the stream.
21pub fn spawn_channel_relay<F, G>(
22    channel_id: u32,
23    read_half: tokio::net::unix::OwnedReadHalf,
24    write_half: tokio::net::unix::OwnedWriteHalf,
25    on_data: F,
26    on_close: G,
27) -> tokio::sync::mpsc::UnboundedSender<bytes::Bytes>
28where
29    F: Fn(u32, bytes::Bytes) -> bool + Send + 'static,
30    G: Fn(u32) + Send + 'static,
31{
32    use tokio::io::{AsyncReadExt, AsyncWriteExt};
33
34    let (writer_tx, mut writer_rx) = tokio::sync::mpsc::unbounded_channel::<bytes::Bytes>();
35
36    tokio::spawn(async move {
37        let mut read_half = read_half;
38        let mut buf = vec![0u8; 8192];
39        loop {
40            match read_half.read(&mut buf).await {
41                Ok(0) | Err(_) => {
42                    on_close(channel_id);
43                    break;
44                }
45                Ok(n) => {
46                    if !on_data(channel_id, bytes::Bytes::copy_from_slice(&buf[..n])) {
47                        break;
48                    }
49                }
50            }
51        }
52    });
53
54    tokio::spawn(async move {
55        let mut write_half = write_half;
56        while let Some(data) = writer_rx.recv().await {
57            if write_half.write_all(&data).await.is_err() {
58                break;
59            }
60        }
61    });
62
63    writer_tx
64}
65
66#[cfg(test)]
67mod tests {
68    use super::*;
69
70    #[test]
71    fn collect_env_vars_only_known_keys() {
72        let vars = collect_env_vars();
73        for (k, _) in &vars {
74            assert!(["TERM", "LANG", "COLORTERM"].contains(&k.as_str()), "unexpected key: {k}");
75        }
76    }
77
78    #[test]
79    fn collect_env_vars_no_duplicates() {
80        let vars = collect_env_vars();
81        let keys: Vec<&str> = vars.iter().map(|(k, _)| k.as_str()).collect();
82        let mut deduped = keys.clone();
83        deduped.sort();
84        deduped.dedup();
85        assert_eq!(keys.len(), deduped.len());
86    }
87
88    #[test]
89    fn collect_env_vars_includes_term_if_set() {
90        if std::env::var("TERM").is_ok() {
91            let vars = collect_env_vars();
92            assert!(vars.iter().any(|(k, _)| k == "TERM"));
93        }
94    }
95
96    #[tokio::test]
97    async fn spawn_channel_relay_on_data_and_close() {
98        use std::sync::{Arc, Mutex};
99        use tokio::io::AsyncWriteExt;
100        use tokio::net::UnixStream;
101
102        // Two pairs: one for read side, one for write side
103        let (read_stream, mut feed_stream) = UnixStream::pair().unwrap();
104        let (write_stream, _drain_stream) = UnixStream::pair().unwrap();
105        let (read_half, _) = read_stream.into_split();
106        let (_, write_half) = write_stream.into_split();
107
108        let received = Arc::new(Mutex::new(Vec::<(u32, bytes::Bytes)>::new()));
109        let received_clone = received.clone();
110        let closed = Arc::new(Mutex::new(false));
111        let closed_clone = closed.clone();
112
113        let _writer_tx = spawn_channel_relay(
114            42,
115            read_half,
116            write_half,
117            move |ch, data| {
118                received_clone.lock().unwrap().push((ch, data));
119                true
120            },
121            move |ch| {
122                assert_eq!(ch, 42);
123                *closed_clone.lock().unwrap() = true;
124            },
125        );
126
127        // Write data to the relay's read side
128        feed_stream.write_all(b"hello").await.unwrap();
129        tokio::time::sleep(std::time::Duration::from_millis(100)).await;
130
131        let data = received.lock().unwrap();
132        assert_eq!(data.len(), 1);
133        assert_eq!(data[0].0, 42);
134        assert_eq!(&data[0].1[..], b"hello");
135        drop(data);
136
137        // Close the feed stream -> triggers on_close
138        drop(feed_stream);
139        tokio::time::sleep(std::time::Duration::from_millis(100)).await;
140        assert!(*closed.lock().unwrap());
141    }
142
143    #[tokio::test]
144    async fn spawn_channel_relay_writer_sends_data() {
145        use tokio::io::AsyncReadExt;
146        use tokio::net::UnixStream;
147
148        // Two pairs: one for read side, one for write side
149        let (read_stream, _feed_stream) = UnixStream::pair().unwrap();
150        let (write_stream, mut drain_stream) = UnixStream::pair().unwrap();
151        let (read_half, _) = read_stream.into_split();
152        let (_, write_half) = write_stream.into_split();
153
154        let writer_tx = spawn_channel_relay(7, read_half, write_half, |_, _| true, |_| {});
155
156        writer_tx.send(bytes::Bytes::from_static(b"hello")).unwrap();
157
158        let mut buf = vec![0u8; 32];
159        let n =
160            tokio::time::timeout(std::time::Duration::from_secs(2), drain_stream.read(&mut buf))
161                .await
162                .unwrap()
163                .unwrap();
164
165        assert_eq!(&buf[..n], b"hello");
166    }
167}