Skip to main content

gritty/
lib.rs

1pub mod client;
2pub mod config;
3pub mod connect;
4pub mod daemon;
5pub mod protocol;
6pub mod security;
7pub mod server;
8pub mod table;
9
10/// Perform a protocol version handshake with the server.
11///
12/// Sends Hello with our PROTOCOL_VERSION, expects HelloAck with the
13/// negotiated version (min of client and server).
14pub async fn handshake(
15    framed: &mut tokio_util::codec::Framed<tokio::net::UnixStream, protocol::FrameCodec>,
16) -> anyhow::Result<u16> {
17    use futures_util::{SinkExt, StreamExt};
18    framed.send(protocol::Frame::Hello { version: protocol::PROTOCOL_VERSION }).await?;
19    match protocol::Frame::expect_from(framed.next().await)? {
20        protocol::Frame::HelloAck { version } => Ok(version),
21        protocol::Frame::Error { message } => anyhow::bail!("handshake rejected: {message}"),
22        other => anyhow::bail!("expected HelloAck, got {other:?}"),
23    }
24}
25
26/// Collect TERM/LANG/COLORTERM from the environment for forwarding to remote sessions.
27pub fn collect_env_vars() -> Vec<(String, String)> {
28    ["TERM", "LANG", "COLORTERM"]
29        .iter()
30        .filter_map(|k| std::env::var(k).ok().map(|v| (k.to_string(), v)))
31        .collect()
32}
33
34/// Spawn bidirectional relay tasks for a stream channel.
35///
36/// Reader task reads from the stream and calls `on_data`/`on_close`.
37/// Writer task drains the returned sender and writes to the stream.
38/// Channel buffer size for `spawn_channel_relay` writer channels.
39/// At 8KB per read, 256 entries ≈ 2MB per channel.
40const CHANNEL_RELAY_BUFFER: usize = 256;
41
42pub fn spawn_channel_relay<R, W, F, G>(
43    channel_id: u32,
44    read_half: R,
45    write_half: W,
46    on_data: F,
47    on_close: G,
48) -> tokio::sync::mpsc::Sender<bytes::Bytes>
49where
50    R: tokio::io::AsyncRead + Unpin + Send + 'static,
51    W: tokio::io::AsyncWrite + Unpin + Send + 'static,
52    F: Fn(u32, bytes::Bytes) -> bool + Send + 'static,
53    G: Fn(u32) + Send + 'static,
54{
55    use tokio::io::{AsyncReadExt, AsyncWriteExt};
56
57    let (writer_tx, mut writer_rx) =
58        tokio::sync::mpsc::channel::<bytes::Bytes>(CHANNEL_RELAY_BUFFER);
59
60    tokio::spawn(async move {
61        let mut read_half = read_half;
62        let mut buf = vec![0u8; 8192];
63        loop {
64            match read_half.read(&mut buf).await {
65                Ok(0) | Err(_) => {
66                    on_close(channel_id);
67                    break;
68                }
69                Ok(n) => {
70                    if !on_data(channel_id, bytes::Bytes::copy_from_slice(&buf[..n])) {
71                        break;
72                    }
73                }
74            }
75        }
76    });
77
78    tokio::spawn(async move {
79        let mut write_half = write_half;
80        while let Some(data) = writer_rx.recv().await {
81            if write_half.write_all(&data).await.is_err() {
82                break;
83            }
84        }
85        // Graceful half-close: send FIN instead of RST
86        let _ = write_half.shutdown().await;
87    });
88
89    writer_tx
90}
91
92#[cfg(test)]
93mod tests {
94    use super::*;
95
96    #[test]
97    fn collect_env_vars_only_known_keys() {
98        let vars = collect_env_vars();
99        for (k, _) in &vars {
100            assert!(["TERM", "LANG", "COLORTERM"].contains(&k.as_str()), "unexpected key: {k}");
101        }
102    }
103
104    #[test]
105    fn collect_env_vars_no_duplicates() {
106        let vars = collect_env_vars();
107        let keys: Vec<&str> = vars.iter().map(|(k, _)| k.as_str()).collect();
108        let mut deduped = keys.clone();
109        deduped.sort();
110        deduped.dedup();
111        assert_eq!(keys.len(), deduped.len());
112    }
113
114    #[test]
115    fn collect_env_vars_includes_term_if_set() {
116        if std::env::var("TERM").is_ok() {
117            let vars = collect_env_vars();
118            assert!(vars.iter().any(|(k, _)| k == "TERM"));
119        }
120    }
121
122    #[tokio::test]
123    async fn spawn_channel_relay_on_data_and_close() {
124        use std::sync::{Arc, Mutex};
125        use tokio::io::AsyncWriteExt;
126        use tokio::net::UnixStream;
127
128        // Two pairs: one for read side, one for write side
129        let (read_stream, mut feed_stream) = UnixStream::pair().unwrap();
130        let (write_stream, _drain_stream) = UnixStream::pair().unwrap();
131        let (read_half, _) = read_stream.into_split();
132        let (_, write_half) = write_stream.into_split();
133
134        let received = Arc::new(Mutex::new(Vec::<(u32, bytes::Bytes)>::new()));
135        let received_clone = received.clone();
136        let closed = Arc::new(Mutex::new(false));
137        let closed_clone = closed.clone();
138
139        let _writer_tx = spawn_channel_relay(
140            42,
141            read_half,
142            write_half,
143            move |ch, data| {
144                received_clone.lock().unwrap().push((ch, data));
145                true
146            },
147            move |ch| {
148                assert_eq!(ch, 42);
149                *closed_clone.lock().unwrap() = true;
150            },
151        );
152
153        // Write data to the relay's read side
154        feed_stream.write_all(b"hello").await.unwrap();
155        tokio::time::sleep(std::time::Duration::from_millis(100)).await;
156
157        let data = received.lock().unwrap();
158        assert_eq!(data.len(), 1);
159        assert_eq!(data[0].0, 42);
160        assert_eq!(&data[0].1[..], b"hello");
161        drop(data);
162
163        // Close the feed stream -> triggers on_close
164        drop(feed_stream);
165        tokio::time::sleep(std::time::Duration::from_millis(100)).await;
166        assert!(*closed.lock().unwrap());
167    }
168
169    #[tokio::test]
170    async fn spawn_channel_relay_writer_sends_data() {
171        use tokio::io::AsyncReadExt;
172        use tokio::net::UnixStream;
173
174        // Two pairs: one for read side, one for write side
175        let (read_stream, _feed_stream) = UnixStream::pair().unwrap();
176        let (write_stream, mut drain_stream) = UnixStream::pair().unwrap();
177        let (read_half, _) = read_stream.into_split();
178        let (_, write_half) = write_stream.into_split();
179
180        let writer_tx = spawn_channel_relay(7, read_half, write_half, |_, _| true, |_| {});
181
182        writer_tx.try_send(bytes::Bytes::from_static(b"hello")).unwrap();
183
184        let mut buf = vec![0u8; 32];
185        let n =
186            tokio::time::timeout(std::time::Duration::from_secs(2), drain_stream.read(&mut buf))
187                .await
188                .unwrap()
189                .unwrap();
190
191        assert_eq!(&buf[..n], b"hello");
192    }
193
194    #[tokio::test]
195    async fn spawn_channel_relay_writer_half_close() {
196        use tokio::io::AsyncReadExt;
197        use tokio::net::UnixStream;
198
199        let (read_stream, _feed_stream) = UnixStream::pair().unwrap();
200        let (write_stream, mut drain_stream) = UnixStream::pair().unwrap();
201        let (read_half, _) = read_stream.into_split();
202        let (_, write_half) = write_stream.into_split();
203
204        let writer_tx = spawn_channel_relay(7, read_half, write_half, |_, _| true, |_| {});
205
206        // Send data then drop the sender (triggers half-close)
207        writer_tx.try_send(bytes::Bytes::from_static(b"request")).unwrap();
208        drop(writer_tx);
209
210        // Read the data
211        let mut buf = vec![0u8; 32];
212        let n =
213            tokio::time::timeout(std::time::Duration::from_secs(2), drain_stream.read(&mut buf))
214                .await
215                .unwrap()
216                .unwrap();
217        assert_eq!(&buf[..n], b"request");
218
219        // Read again -- should get EOF (graceful shutdown), not error
220        let n =
221            tokio::time::timeout(std::time::Duration::from_secs(2), drain_stream.read(&mut buf))
222                .await
223                .unwrap()
224                .unwrap();
225        assert_eq!(n, 0, "expected EOF from graceful half-close");
226    }
227}