Skip to main content

gritty/
lib.rs

1pub mod client;
2pub mod connect;
3pub mod daemon;
4pub mod protocol;
5pub mod security;
6pub mod server;
7
8/// Collect TERM/LANG/COLORTERM from the environment for forwarding to remote sessions.
9pub fn collect_env_vars() -> Vec<(String, String)> {
10    ["TERM", "LANG", "COLORTERM"]
11        .iter()
12        .filter_map(|k| std::env::var(k).ok().map(|v| (k.to_string(), v)))
13        .collect()
14}
15
16/// Spawn bidirectional relay tasks for a Unix stream channel.
17///
18/// Reader task reads from the stream and calls `on_data`/`on_close`.
19/// Writer task drains the returned sender and writes to the stream.
20pub fn spawn_channel_relay<F, G>(
21    channel_id: u32,
22    read_half: tokio::net::unix::OwnedReadHalf,
23    write_half: tokio::net::unix::OwnedWriteHalf,
24    on_data: F,
25    on_close: G,
26) -> tokio::sync::mpsc::UnboundedSender<bytes::Bytes>
27where
28    F: Fn(u32, bytes::Bytes) -> bool + Send + 'static,
29    G: Fn(u32) + Send + 'static,
30{
31    use tokio::io::{AsyncReadExt, AsyncWriteExt};
32
33    let (writer_tx, mut writer_rx) = tokio::sync::mpsc::unbounded_channel::<bytes::Bytes>();
34
35    tokio::spawn(async move {
36        let mut read_half = read_half;
37        let mut buf = vec![0u8; 8192];
38        loop {
39            match read_half.read(&mut buf).await {
40                Ok(0) | Err(_) => {
41                    on_close(channel_id);
42                    break;
43                }
44                Ok(n) => {
45                    if !on_data(channel_id, bytes::Bytes::copy_from_slice(&buf[..n])) {
46                        break;
47                    }
48                }
49            }
50        }
51    });
52
53    tokio::spawn(async move {
54        let mut write_half = write_half;
55        while let Some(data) = writer_rx.recv().await {
56            if write_half.write_all(&data).await.is_err() {
57                break;
58            }
59        }
60    });
61
62    writer_tx
63}
64
65#[cfg(test)]
66mod tests {
67    use super::*;
68
69    #[test]
70    fn collect_env_vars_only_known_keys() {
71        let vars = collect_env_vars();
72        for (k, _) in &vars {
73            assert!(["TERM", "LANG", "COLORTERM"].contains(&k.as_str()), "unexpected key: {k}");
74        }
75    }
76
77    #[test]
78    fn collect_env_vars_no_duplicates() {
79        let vars = collect_env_vars();
80        let keys: Vec<&str> = vars.iter().map(|(k, _)| k.as_str()).collect();
81        let mut deduped = keys.clone();
82        deduped.sort();
83        deduped.dedup();
84        assert_eq!(keys.len(), deduped.len());
85    }
86
87    #[test]
88    fn collect_env_vars_includes_term_if_set() {
89        if std::env::var("TERM").is_ok() {
90            let vars = collect_env_vars();
91            assert!(vars.iter().any(|(k, _)| k == "TERM"));
92        }
93    }
94
95    #[tokio::test]
96    async fn spawn_channel_relay_on_data_and_close() {
97        use std::sync::{Arc, Mutex};
98        use tokio::io::AsyncWriteExt;
99        use tokio::net::UnixStream;
100
101        // Two pairs: one for read side, one for write side
102        let (read_stream, mut feed_stream) = UnixStream::pair().unwrap();
103        let (write_stream, _drain_stream) = UnixStream::pair().unwrap();
104        let (read_half, _) = read_stream.into_split();
105        let (_, write_half) = write_stream.into_split();
106
107        let received = Arc::new(Mutex::new(Vec::<(u32, bytes::Bytes)>::new()));
108        let received_clone = received.clone();
109        let closed = Arc::new(Mutex::new(false));
110        let closed_clone = closed.clone();
111
112        let _writer_tx = spawn_channel_relay(
113            42,
114            read_half,
115            write_half,
116            move |ch, data| {
117                received_clone.lock().unwrap().push((ch, data));
118                true
119            },
120            move |ch| {
121                assert_eq!(ch, 42);
122                *closed_clone.lock().unwrap() = true;
123            },
124        );
125
126        // Write data to the relay's read side
127        feed_stream.write_all(b"hello").await.unwrap();
128        tokio::time::sleep(std::time::Duration::from_millis(100)).await;
129
130        let data = received.lock().unwrap();
131        assert_eq!(data.len(), 1);
132        assert_eq!(data[0].0, 42);
133        assert_eq!(&data[0].1[..], b"hello");
134        drop(data);
135
136        // Close the feed stream -> triggers on_close
137        drop(feed_stream);
138        tokio::time::sleep(std::time::Duration::from_millis(100)).await;
139        assert!(*closed.lock().unwrap());
140    }
141
142    #[tokio::test]
143    async fn spawn_channel_relay_writer_sends_data() {
144        use tokio::io::AsyncReadExt;
145        use tokio::net::UnixStream;
146
147        // Two pairs: one for read side, one for write side
148        let (read_stream, _feed_stream) = UnixStream::pair().unwrap();
149        let (write_stream, mut drain_stream) = UnixStream::pair().unwrap();
150        let (read_half, _) = read_stream.into_split();
151        let (_, write_half) = write_stream.into_split();
152
153        let writer_tx = spawn_channel_relay(7, read_half, write_half, |_, _| true, |_| {});
154
155        writer_tx.send(bytes::Bytes::from_static(b"hello")).unwrap();
156
157        let mut buf = vec![0u8; 32];
158        let n =
159            tokio::time::timeout(std::time::Duration::from_secs(2), drain_stream.read(&mut buf))
160                .await
161                .unwrap()
162                .unwrap();
163
164        assert_eq!(&buf[..n], b"hello");
165    }
166}