Skip to main content

gritty/
lib.rs

1pub mod client;
2pub mod config;
3pub mod connect;
4pub mod daemon;
5pub mod protocol;
6pub mod security;
7pub mod server;
8
9/// Perform a protocol version handshake with the server.
10///
11/// Sends Hello with our PROTOCOL_VERSION, expects HelloAck with the
12/// negotiated version (min of client and server).
13pub async fn handshake(
14    framed: &mut tokio_util::codec::Framed<tokio::net::UnixStream, protocol::FrameCodec>,
15) -> anyhow::Result<u16> {
16    use futures_util::{SinkExt, StreamExt};
17    framed.send(protocol::Frame::Hello { version: protocol::PROTOCOL_VERSION }).await?;
18    match protocol::Frame::expect_from(framed.next().await)? {
19        protocol::Frame::HelloAck { version } => Ok(version),
20        protocol::Frame::Error { message } => anyhow::bail!("handshake rejected: {message}"),
21        other => anyhow::bail!("expected HelloAck, got {other:?}"),
22    }
23}
24
25/// Collect TERM/LANG/COLORTERM from the environment for forwarding to remote sessions.
26pub fn collect_env_vars() -> Vec<(String, String)> {
27    ["TERM", "LANG", "COLORTERM"]
28        .iter()
29        .filter_map(|k| std::env::var(k).ok().map(|v| (k.to_string(), v)))
30        .collect()
31}
32
33/// Spawn bidirectional relay tasks for a Unix stream channel.
34///
35/// Reader task reads from the stream and calls `on_data`/`on_close`.
36/// Writer task drains the returned sender and writes to the stream.
37pub fn spawn_channel_relay<F, G>(
38    channel_id: u32,
39    read_half: tokio::net::unix::OwnedReadHalf,
40    write_half: tokio::net::unix::OwnedWriteHalf,
41    on_data: F,
42    on_close: G,
43) -> tokio::sync::mpsc::UnboundedSender<bytes::Bytes>
44where
45    F: Fn(u32, bytes::Bytes) -> bool + Send + 'static,
46    G: Fn(u32) + Send + 'static,
47{
48    use tokio::io::{AsyncReadExt, AsyncWriteExt};
49
50    let (writer_tx, mut writer_rx) = tokio::sync::mpsc::unbounded_channel::<bytes::Bytes>();
51
52    tokio::spawn(async move {
53        let mut read_half = read_half;
54        let mut buf = vec![0u8; 8192];
55        loop {
56            match read_half.read(&mut buf).await {
57                Ok(0) | Err(_) => {
58                    on_close(channel_id);
59                    break;
60                }
61                Ok(n) => {
62                    if !on_data(channel_id, bytes::Bytes::copy_from_slice(&buf[..n])) {
63                        break;
64                    }
65                }
66            }
67        }
68    });
69
70    tokio::spawn(async move {
71        let mut write_half = write_half;
72        while let Some(data) = writer_rx.recv().await {
73            if write_half.write_all(&data).await.is_err() {
74                break;
75            }
76        }
77    });
78
79    writer_tx
80}
81
82#[cfg(test)]
83mod tests {
84    use super::*;
85
86    #[test]
87    fn collect_env_vars_only_known_keys() {
88        let vars = collect_env_vars();
89        for (k, _) in &vars {
90            assert!(["TERM", "LANG", "COLORTERM"].contains(&k.as_str()), "unexpected key: {k}");
91        }
92    }
93
94    #[test]
95    fn collect_env_vars_no_duplicates() {
96        let vars = collect_env_vars();
97        let keys: Vec<&str> = vars.iter().map(|(k, _)| k.as_str()).collect();
98        let mut deduped = keys.clone();
99        deduped.sort();
100        deduped.dedup();
101        assert_eq!(keys.len(), deduped.len());
102    }
103
104    #[test]
105    fn collect_env_vars_includes_term_if_set() {
106        if std::env::var("TERM").is_ok() {
107            let vars = collect_env_vars();
108            assert!(vars.iter().any(|(k, _)| k == "TERM"));
109        }
110    }
111
112    #[tokio::test]
113    async fn spawn_channel_relay_on_data_and_close() {
114        use std::sync::{Arc, Mutex};
115        use tokio::io::AsyncWriteExt;
116        use tokio::net::UnixStream;
117
118        // Two pairs: one for read side, one for write side
119        let (read_stream, mut feed_stream) = UnixStream::pair().unwrap();
120        let (write_stream, _drain_stream) = UnixStream::pair().unwrap();
121        let (read_half, _) = read_stream.into_split();
122        let (_, write_half) = write_stream.into_split();
123
124        let received = Arc::new(Mutex::new(Vec::<(u32, bytes::Bytes)>::new()));
125        let received_clone = received.clone();
126        let closed = Arc::new(Mutex::new(false));
127        let closed_clone = closed.clone();
128
129        let _writer_tx = spawn_channel_relay(
130            42,
131            read_half,
132            write_half,
133            move |ch, data| {
134                received_clone.lock().unwrap().push((ch, data));
135                true
136            },
137            move |ch| {
138                assert_eq!(ch, 42);
139                *closed_clone.lock().unwrap() = true;
140            },
141        );
142
143        // Write data to the relay's read side
144        feed_stream.write_all(b"hello").await.unwrap();
145        tokio::time::sleep(std::time::Duration::from_millis(100)).await;
146
147        let data = received.lock().unwrap();
148        assert_eq!(data.len(), 1);
149        assert_eq!(data[0].0, 42);
150        assert_eq!(&data[0].1[..], b"hello");
151        drop(data);
152
153        // Close the feed stream -> triggers on_close
154        drop(feed_stream);
155        tokio::time::sleep(std::time::Duration::from_millis(100)).await;
156        assert!(*closed.lock().unwrap());
157    }
158
159    #[tokio::test]
160    async fn spawn_channel_relay_writer_sends_data() {
161        use tokio::io::AsyncReadExt;
162        use tokio::net::UnixStream;
163
164        // Two pairs: one for read side, one for write side
165        let (read_stream, _feed_stream) = UnixStream::pair().unwrap();
166        let (write_stream, mut drain_stream) = UnixStream::pair().unwrap();
167        let (read_half, _) = read_stream.into_split();
168        let (_, write_half) = write_stream.into_split();
169
170        let writer_tx = spawn_channel_relay(7, read_half, write_half, |_, _| true, |_| {});
171
172        writer_tx.send(bytes::Bytes::from_static(b"hello")).unwrap();
173
174        let mut buf = vec![0u8; 32];
175        let n =
176            tokio::time::timeout(std::time::Duration::from_secs(2), drain_stream.read(&mut buf))
177                .await
178                .unwrap()
179                .unwrap();
180
181        assert_eq!(&buf[..n], b"hello");
182    }
183}