1use crate::protocol::{Frame, FrameCodec};
2use bytes::Bytes;
3use futures_util::{SinkExt, StreamExt};
4use nix::sys::termios::{self, FlushArg, LocalFlags, SetArg, SpecialCharacterIndices, Termios};
5use std::collections::HashMap;
6use std::io::{self, Read, Write};
7use std::ops::ControlFlow;
8use std::os::fd::{AsFd, AsRawFd, BorrowedFd};
9use std::path::Path;
10use std::sync::Arc;
11use std::sync::atomic::{AtomicU32, Ordering};
12use std::time::Duration;
13use tokio::io::unix::AsyncFd;
14use tokio::net::UnixStream;
15use tokio::signal::unix::{SignalKind, signal};
16use tokio::sync::mpsc;
17use tokio::time::Instant;
18
19enum RelayExit {
21 Exit(i32),
23 Disconnected,
25}
26use tokio_util::codec::Framed;
27use tracing::{debug, info};
28
29const ESCAPE_HELP: &[u8] = b"\r\nSupported escape sequences:\r\n\
32 ~. - detach from session\r\n\
33 ~R - force reconnect\r\n\
34 ~^Z - suspend client\r\n\
35 ~# - session status and RTT\r\n\
36 ~? - this message\r\n\
37 ~~ - send the escape character by typing it twice\r\n\
38(Note that escapes are only recognized immediately after newline.)\r\n";
39
40#[derive(Debug, Clone, Copy, PartialEq, Eq)]
41enum EscapeState {
42 Normal,
43 AfterNewline,
44 AfterTilde,
45}
46
47#[derive(Debug, PartialEq, Eq)]
48enum EscapeAction {
49 Data(Vec<u8>),
50 Detach,
51 Reconnect,
52 Suspend,
53 Status,
54 Help,
55}
56
57struct EscapeProcessor {
58 state: EscapeState,
59}
60
61impl EscapeProcessor {
62 fn new() -> Self {
63 Self { state: EscapeState::AfterNewline }
64 }
65
66 fn process(&mut self, input: &[u8]) -> Vec<EscapeAction> {
67 let mut actions = Vec::new();
68 let mut data_buf = Vec::new();
69
70 for &b in input {
71 match self.state {
72 EscapeState::Normal => {
73 if b == b'\n' || b == b'\r' {
74 self.state = EscapeState::AfterNewline;
75 }
76 data_buf.push(b);
77 }
78 EscapeState::AfterNewline => {
79 if b == b'~' {
80 self.state = EscapeState::AfterTilde;
81 if !data_buf.is_empty() {
83 actions.push(EscapeAction::Data(std::mem::take(&mut data_buf)));
84 }
85 } else if b == b'\n' || b == b'\r' {
86 data_buf.push(b);
88 } else {
89 self.state = EscapeState::Normal;
90 data_buf.push(b);
91 }
92 }
93 EscapeState::AfterTilde => {
94 match b {
95 b'.' => {
96 if !data_buf.is_empty() {
97 actions.push(EscapeAction::Data(std::mem::take(&mut data_buf)));
98 }
99 actions.push(EscapeAction::Detach);
100 return actions; }
102 b'R' => {
103 if !data_buf.is_empty() {
104 actions.push(EscapeAction::Data(std::mem::take(&mut data_buf)));
105 }
106 actions.push(EscapeAction::Reconnect);
107 return actions; }
109 0x1a => {
110 if !data_buf.is_empty() {
112 actions.push(EscapeAction::Data(std::mem::take(&mut data_buf)));
113 }
114 actions.push(EscapeAction::Suspend);
115 self.state = EscapeState::Normal;
116 }
117 b'#' => {
118 if !data_buf.is_empty() {
119 actions.push(EscapeAction::Data(std::mem::take(&mut data_buf)));
120 }
121 actions.push(EscapeAction::Status);
122 self.state = EscapeState::Normal;
123 }
124 b'?' => {
125 if !data_buf.is_empty() {
126 actions.push(EscapeAction::Data(std::mem::take(&mut data_buf)));
127 }
128 actions.push(EscapeAction::Help);
129 self.state = EscapeState::Normal;
130 }
131 b'~' => {
132 data_buf.push(b'~');
134 self.state = EscapeState::Normal;
135 }
136 b'\n' | b'\r' => {
137 data_buf.push(b'~');
139 data_buf.push(b);
140 self.state = EscapeState::AfterNewline;
141 }
142 _ => {
143 data_buf.push(b'~');
145 data_buf.push(b);
146 self.state = EscapeState::Normal;
147 }
148 }
149 }
150 }
151 }
152
153 if !data_buf.is_empty() {
154 actions.push(EscapeAction::Data(data_buf));
155 }
156 actions
157 }
158}
159
160fn suspend(raw_guard: &RawModeGuard, nb_guard: &NonBlockGuard) -> anyhow::Result<()> {
161 termios::tcsetattr(raw_guard.fd, SetArg::TCSAFLUSH, &raw_guard.original)?;
163 let _ = nix::fcntl::fcntl(nb_guard.fd, nix::fcntl::FcntlArg::F_SETFL(nb_guard.original_flags));
164
165 nix::sys::signal::kill(nix::unistd::Pid::from_raw(0), nix::sys::signal::Signal::SIGTSTP)?;
166
167 let _ = nix::fcntl::fcntl(
169 nb_guard.fd,
170 nix::fcntl::FcntlArg::F_SETFL(nb_guard.original_flags | nix::fcntl::OFlag::O_NONBLOCK),
171 );
172 let mut raw = raw_guard.original.clone();
173 termios::cfmakeraw(&mut raw);
174 termios::tcsetattr(raw_guard.fd, SetArg::TCSAFLUSH, &raw)?;
175 Ok(())
176}
177
178const SEND_TIMEOUT: Duration = Duration::from_secs(5);
179
180struct NonBlockGuard {
181 fd: BorrowedFd<'static>,
182 original_flags: nix::fcntl::OFlag,
183}
184
185impl NonBlockGuard {
186 fn set(fd: BorrowedFd<'static>) -> nix::Result<Self> {
187 let flags = nix::fcntl::fcntl(fd, nix::fcntl::FcntlArg::F_GETFL)?;
188 let original_flags = nix::fcntl::OFlag::from_bits_truncate(flags);
189 nix::fcntl::fcntl(
190 fd,
191 nix::fcntl::FcntlArg::F_SETFL(original_flags | nix::fcntl::OFlag::O_NONBLOCK),
192 )?;
193 Ok(Self { fd, original_flags })
194 }
195}
196
197impl Drop for NonBlockGuard {
198 fn drop(&mut self) {
199 let _ = nix::fcntl::fcntl(self.fd, nix::fcntl::FcntlArg::F_SETFL(self.original_flags));
200 }
201}
202
203struct RawModeGuard {
204 fd: BorrowedFd<'static>,
205 original: Termios,
206}
207
208impl RawModeGuard {
209 fn enter(fd: BorrowedFd<'static>) -> nix::Result<Self> {
210 let original = termios::tcgetattr(fd)?;
211 let mut raw = original.clone();
212 termios::cfmakeraw(&mut raw);
213 termios::tcsetattr(fd, SetArg::TCSAFLUSH, &raw)?;
214 Ok(Self { fd, original })
215 }
216}
217
218impl Drop for RawModeGuard {
219 fn drop(&mut self) {
220 let _ = termios::tcsetattr(self.fd, SetArg::TCSAFLUSH, &self.original);
221 }
222}
223
224struct SuppressInputGuard {
227 fd: BorrowedFd<'static>,
228 original: Termios,
229}
230
231impl SuppressInputGuard {
232 fn enter(fd: BorrowedFd<'static>) -> nix::Result<Self> {
233 let original = termios::tcgetattr(fd)?;
234 let mut modified = original.clone();
235 modified.local_flags.remove(LocalFlags::ECHO | LocalFlags::ICANON);
236 modified.control_chars[SpecialCharacterIndices::VMIN as usize] = 1;
237 modified.control_chars[SpecialCharacterIndices::VTIME as usize] = 0;
238 termios::tcsetattr(fd, SetArg::TCSAFLUSH, &modified)?;
239 Ok(Self { fd, original })
240 }
241}
242
243impl Drop for SuppressInputGuard {
244 fn drop(&mut self) {
245 let _ = termios::tcflush(self.fd, FlushArg::TCIFLUSH);
246 let _ = termios::tcsetattr(self.fd, SetArg::TCSAFLUSH, &self.original);
247 }
248}
249
250async fn write_stdout_async(fd: &AsyncFd<std::os::fd::OwnedFd>, data: &[u8]) -> io::Result<()> {
253 let mut written = 0;
254 while written < data.len() {
255 let mut guard = fd.writable().await?;
256 match guard
257 .try_io(|inner| nix::unistd::write(inner, &data[written..]).map_err(io::Error::from))
258 {
259 Ok(Ok(0)) => {
260 return Err(io::Error::new(io::ErrorKind::WriteZero, "stdout closed"));
261 }
262 Ok(Ok(n)) => written += n,
263 Ok(Err(e)) => return Err(e),
264 Err(_would_block) => continue,
265 }
266 }
267 Ok(())
268}
269
270pub fn format_size(bytes: u64) -> String {
272 humansize::format_size(bytes, humansize::BINARY)
273}
274
275fn status_msg(text: &str) -> String {
276 format!("\r\n\x1b[2;33m[{text}]\x1b[0m\r\n")
277}
278
279fn success_msg(text: &str) -> String {
280 format!("\r\n\x1b[32m[{text}]\x1b[0m\r\n")
281}
282
283fn error_msg(text: &str) -> String {
284 format!("\r\n\x1b[31m[{text}]\x1b[0m\r\n")
285}
286
287fn get_terminal_size() -> (u16, u16) {
288 terminal_size::terminal_size().map(|(w, h)| (w.0, h.0)).unwrap_or((80, 24))
289}
290
291async fn timed_send(framed: &mut Framed<UnixStream, FrameCodec>, frame: Frame) -> bool {
293 match tokio::time::timeout(SEND_TIMEOUT, framed.send(frame)).await {
294 Ok(Ok(())) => true,
295 Ok(Err(e)) => {
296 debug!("send error: {e}");
297 false
298 }
299 Err(_) => {
300 debug!("send timed out");
301 false
302 }
303 }
304}
305
306const DEFAULT_HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5);
307const DEFAULT_HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(15);
308
309enum AgentEvent {
311 Data { channel_id: u32, data: Bytes },
312 Closed { channel_id: u32 },
313}
314
315enum ClientTunnelEvent {
317 Accepted { channel_id: u32, writer_tx: mpsc::Sender<Bytes> },
318 Data { channel_id: u32, data: Bytes },
319 Closed { channel_id: u32 },
320}
321
322enum ClientPortForwardEvent {
324 Accepted { forward_id: u32, channel_id: u32, writer_tx: mpsc::Sender<Bytes> },
325 Data { channel_id: u32, data: Bytes },
326 Closed { channel_id: u32 },
327}
328
329struct ClientPortForwardState {
331 listener_handle: Option<tokio::task::JoinHandle<()>>,
332 target_port: u16,
333}
334
335struct ClientAgentState {
337 channels: HashMap<u32, mpsc::Sender<Bytes>>,
338}
339
340impl ClientAgentState {
341 fn new() -> Self {
342 Self { channels: HashMap::new() }
343 }
344
345 fn teardown(&mut self) {
346 self.channels.clear();
347 }
348}
349
350struct ClientTunnelState {
352 listener: Option<tokio::task::JoinHandle<()>>,
353 channels: HashMap<u32, mpsc::Sender<Bytes>>,
354 next_channel_id: Arc<AtomicU32>,
355}
356
357impl ClientTunnelState {
358 fn new() -> Self {
359 Self {
360 listener: None,
361 channels: HashMap::new(),
362 next_channel_id: Arc::new(AtomicU32::new(0)),
363 }
364 }
365
366 fn teardown(&mut self) {
367 self.channels.clear();
368 if let Some(handle) = self.listener.take() {
369 handle.abort();
370 }
371 }
372}
373
374struct ClientPortForwardTable {
376 forwards: HashMap<u32, ClientPortForwardState>,
377 channels: HashMap<u32, (u32, mpsc::Sender<Bytes>)>,
378 next_channel_id: std::sync::Arc<std::sync::atomic::AtomicU32>,
379}
380
381impl ClientPortForwardTable {
382 fn new() -> Self {
383 Self {
384 forwards: HashMap::new(),
385 channels: HashMap::new(),
386 next_channel_id: std::sync::Arc::new(std::sync::atomic::AtomicU32::new(0)),
387 }
388 }
389
390 fn teardown(&mut self) {
391 for (_, fwd) in self.forwards.drain() {
392 if let Some(h) = fwd.listener_handle {
393 h.abort();
394 }
395 }
396 self.channels.clear();
397 }
398}
399
400async fn send_init_frames(
403 framed: &mut Framed<UnixStream, FrameCodec>,
404 env_vars: &[(String, String)],
405 forward_agent: bool,
406 agent_socket: Option<&str>,
407 forward_open: bool,
408 redraw: bool,
409) -> bool {
410 if !env_vars.is_empty() && !timed_send(framed, Frame::Env { vars: env_vars.to_vec() }).await {
411 return false;
412 }
413 if forward_agent && agent_socket.is_some() && !timed_send(framed, Frame::AgentForward).await {
414 return false;
415 }
416 if forward_open && !timed_send(framed, Frame::OpenForward).await {
417 return false;
418 }
419 let (cols, rows) = get_terminal_size();
420 if !timed_send(framed, Frame::Resize { cols, rows }).await {
421 return false;
422 }
423 if redraw && !timed_send(framed, Frame::Data(Bytes::from_static(b"\x0c"))).await {
424 return false;
425 }
426 true
427}
428
429struct ClientRelay<'a> {
432 async_stdout: &'a AsyncFd<std::os::fd::OwnedFd>,
433 agent: &'a mut ClientAgentState,
434 agent_event_tx: &'a mpsc::UnboundedSender<AgentEvent>,
435 agent_socket: Option<&'a str>,
436 tunnel: &'a mut ClientTunnelState,
437 tunnel_event_tx: &'a mpsc::UnboundedSender<ClientTunnelEvent>,
438 oauth_redirect: bool,
439 oauth_timeout: u64,
440 pf: &'a mut ClientPortForwardTable,
441 pf_event_tx: &'a mpsc::UnboundedSender<ClientPortForwardEvent>,
442 last_pong: &'a mut Instant,
443 last_ping_sent: &'a mut Instant,
444 last_rtt: &'a mut Option<Duration>,
445 connected_at: Instant,
446 bytes_relayed: &'a mut u64,
447}
448
449impl ClientRelay<'_> {
450 async fn handle_server_frame(
451 &mut self,
452 framed: &mut Framed<UnixStream, FrameCodec>,
453 frame: Option<Result<Frame, io::Error>>,
454 ) -> Result<ControlFlow<RelayExit>, anyhow::Error> {
455 match frame {
456 Some(Ok(Frame::Data(data))) => {
457 debug!(len = data.len(), "socket → stdout");
458 *self.bytes_relayed += data.len() as u64;
459 write_stdout_async(self.async_stdout, &data).await?;
460 }
461 Some(Ok(Frame::Pong)) => {
462 *self.last_rtt = Some(self.last_ping_sent.elapsed());
463 debug!(rtt_ms = self.last_rtt.unwrap().as_secs_f64() * 1000.0, "pong received");
464 *self.last_pong = Instant::now();
465 }
466 Some(Ok(Frame::Exit { code })) => {
467 debug!(code, "server sent exit");
468 return Ok(ControlFlow::Break(RelayExit::Exit(code)));
469 }
470 Some(Ok(Frame::Detached)) => {
471 info!("detached by another client");
472 self.agent.teardown();
473 self.tunnel.teardown();
474 self.pf.teardown();
475 write_stdout_async(self.async_stdout, status_msg("detached").as_bytes()).await?;
476 return Ok(ControlFlow::Break(RelayExit::Exit(0)));
477 }
478 Some(Ok(Frame::AgentOpen { channel_id })) => {
479 if let Some(sock_path) = self.agent_socket {
480 match tokio::net::UnixStream::connect(sock_path).await {
481 Ok(stream) => {
482 let (read_half, write_half) = stream.into_split();
483 let data_tx = self.agent_event_tx.clone();
484 let close_tx = self.agent_event_tx.clone();
485 let writer_tx = crate::spawn_channel_relay(
486 channel_id,
487 read_half,
488 write_half,
489 move |id, data| {
490 data_tx.send(AgentEvent::Data { channel_id: id, data }).is_ok()
491 },
492 move |id| {
493 let _ = close_tx.send(AgentEvent::Closed { channel_id: id });
494 },
495 );
496 self.agent.channels.insert(channel_id, writer_tx);
497 }
498 Err(e) => {
499 debug!("failed to connect to local agent: {e}");
500 let _ = timed_send(framed, Frame::AgentClose { channel_id }).await;
501 }
502 }
503 } else {
504 let _ = timed_send(framed, Frame::AgentClose { channel_id }).await;
505 }
506 }
507 Some(Ok(Frame::AgentData { channel_id, data })) => {
508 if let Some(tx) = self.agent.channels.get(&channel_id) {
509 let _ = tx.send(data).await;
510 }
511 }
512 Some(Ok(Frame::AgentClose { channel_id })) => {
513 self.agent.channels.remove(&channel_id);
514 }
515 Some(Ok(Frame::OpenUrl { url })) => {
516 if url.starts_with("http://") || url.starts_with("https://") {
517 debug!("opening URL locally: {url}");
518 tokio::task::spawn_blocking(move || {
519 let _ = opener::open(&url);
520 });
521 } else {
522 debug!("rejected non-http(s) URL: {url}");
523 }
524 }
525 Some(Ok(Frame::TunnelListen { port })) => {
526 if !self.oauth_redirect {
527 debug!(port, "tunnel: oauth-redirect disabled, declining");
528 let _ = timed_send(framed, Frame::TunnelClose { channel_id: 0 }).await;
529 } else {
530 match std::net::TcpListener::bind(("127.0.0.1", port)) {
532 Ok(std_listener) => {
533 debug!(port, "tunnel: bound local port");
534 std_listener.set_nonblocking(true).ok();
535 let listener = tokio::net::TcpListener::from_std(std_listener).unwrap();
536 let tx = self.tunnel_event_tx.clone();
537 let timeout = self.oauth_timeout;
538 let next_id = Arc::clone(&self.tunnel.next_channel_id);
539 self.tunnel.listener = Some(tokio::spawn(async move {
540 let deadline =
541 tokio::time::Instant::now() + Duration::from_secs(timeout);
542 loop {
543 let accept =
544 tokio::time::timeout_at(deadline, listener.accept()).await;
545 match accept {
546 Ok(Ok((stream, _))) => {
547 let channel_id =
548 next_id.fetch_add(1, Ordering::Relaxed);
549 let (read_half, write_half) = stream.into_split();
550 let (writer_tx, mut writer_rx) =
551 mpsc::channel::<Bytes>(crate::CHANNEL_RELAY_BUFFER);
552
553 tokio::spawn(async move {
555 use tokio::io::AsyncWriteExt;
556 let mut writer = write_half;
557 while let Some(data) = writer_rx.recv().await {
558 if writer.write_all(&data).await.is_err() {
559 break;
560 }
561 }
562 });
563
564 let _ = tx.send(ClientTunnelEvent::Accepted {
565 channel_id,
566 writer_tx,
567 });
568
569 let reader_tx = tx.clone();
572 tokio::spawn(async move {
573 use tokio::io::AsyncReadExt;
574 let mut read_half = read_half;
575 let mut buf = vec![0u8; 4096];
576 loop {
577 match read_half.read(&mut buf).await {
578 Ok(0) | Err(_) => {
579 let _ = reader_tx.send(
580 ClientTunnelEvent::Closed {
581 channel_id,
582 },
583 );
584 break;
585 }
586 Ok(n) => {
587 let data =
588 Bytes::copy_from_slice(&buf[..n]);
589 if reader_tx
590 .send(ClientTunnelEvent::Data {
591 channel_id,
592 data,
593 })
594 .is_err()
595 {
596 break;
597 }
598 }
599 }
600 }
601 });
602 }
603 _ => {
604 debug!(port, "tunnel: accept timed out or failed");
605 break;
606 }
607 }
608 }
609 }));
610 }
611 Err(e) => {
612 debug!(port, "tunnel: bind failed: {e}");
613 let _ = timed_send(framed, Frame::TunnelClose { channel_id: 0 }).await;
614 }
615 }
616 }
617 }
618 Some(Ok(Frame::SendOffer { file_count, total_bytes })) => {
619 let size_str = format_size(total_bytes);
620 let s = if file_count == 1 { "" } else { "s" };
621 write_stdout_async(
622 self.async_stdout,
623 status_msg(&format!("gritty: receiving {file_count} file{s} ({size_str})"))
624 .as_bytes(),
625 )
626 .await?;
627 }
628 Some(Ok(Frame::SendDone)) => {
629 write_stdout_async(
630 self.async_stdout,
631 success_msg("gritty: transfer complete").as_bytes(),
632 )
633 .await?;
634 }
635 Some(Ok(Frame::SendCancel { reason })) => {
636 write_stdout_async(
637 self.async_stdout,
638 error_msg(&format!("gritty: transfer cancelled: {reason}")).as_bytes(),
639 )
640 .await?;
641 }
642 Some(Ok(Frame::TunnelData { channel_id, data })) => {
643 if let Some(tx) = self.tunnel.channels.get(&channel_id) {
644 let _ = tx.send(data).await;
645 }
646 }
647 Some(Ok(Frame::TunnelClose { channel_id })) => {
648 self.tunnel.channels.remove(&channel_id);
649 }
650 Some(Ok(Frame::PortForwardListen { forward_id, listen_port, target_port })) => {
652 match std::net::TcpListener::bind(("127.0.0.1", listen_port)) {
653 Ok(std_listener) => {
654 debug!(forward_id, listen_port, "port forward: bound local port");
655 std_listener.set_nonblocking(true).ok();
656 let listener = tokio::net::TcpListener::from_std(std_listener).unwrap();
657 let tx = self.pf_event_tx.clone();
658 let nid = self.pf.next_channel_id.clone();
659 let handle = tokio::spawn(async move {
660 loop {
661 let (stream, _) = match listener.accept().await {
662 Ok(conn) => conn,
663 Err(_) => break,
664 };
665 let channel_id =
666 nid.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
667 let (read_half, write_half) = stream.into_split();
668 let data_tx = tx.clone();
669 let close_tx = tx.clone();
670 let writer_tx = crate::spawn_channel_relay(
671 channel_id,
672 read_half,
673 write_half,
674 move |id, data| {
675 data_tx
676 .send(ClientPortForwardEvent::Data {
677 channel_id: id,
678 data,
679 })
680 .is_ok()
681 },
682 move |id| {
683 let _ = close_tx.send(ClientPortForwardEvent::Closed {
684 channel_id: id,
685 });
686 },
687 );
688 if tx
689 .send(ClientPortForwardEvent::Accepted {
690 forward_id,
691 channel_id,
692 writer_tx,
693 })
694 .is_err()
695 {
696 break;
697 }
698 }
699 });
700 self.pf.forwards.insert(
701 forward_id,
702 ClientPortForwardState { listener_handle: Some(handle), target_port },
703 );
704 if !timed_send(framed, Frame::PortForwardReady { forward_id }).await {
705 return Ok(ControlFlow::Break(RelayExit::Disconnected));
706 }
707 }
708 Err(e) => {
709 debug!(forward_id, listen_port, "port forward: bind failed: {e}");
710 let _ = timed_send(framed, Frame::PortForwardStop { forward_id }).await;
711 }
712 }
713 }
714 Some(Ok(Frame::PortForwardOpen { forward_id, channel_id, target_port })) => {
716 if self.pf.forwards.contains_key(&forward_id) || forward_id == u32::MAX {
717 match tokio::net::TcpStream::connect(("127.0.0.1", target_port)).await {
719 Ok(stream) => {
720 let (read_half, write_half) = stream.into_split();
721 let data_tx = self.pf_event_tx.clone();
722 let close_tx = self.pf_event_tx.clone();
723 let writer_tx = crate::spawn_channel_relay(
724 channel_id,
725 read_half,
726 write_half,
727 move |id, data| {
728 data_tx
729 .send(ClientPortForwardEvent::Data { channel_id: id, data })
730 .is_ok()
731 },
732 move |id| {
733 let _ = close_tx
734 .send(ClientPortForwardEvent::Closed { channel_id: id });
735 },
736 );
737 self.pf.channels.insert(channel_id, (forward_id, writer_tx));
738 }
739 Err(e) => {
740 debug!(channel_id, target_port, "pf connect failed: {e}");
741 let _ =
742 timed_send(framed, Frame::PortForwardClose { channel_id }).await;
743 }
744 }
745 }
746 }
747 Some(Ok(Frame::PortForwardData { channel_id, data })) => {
749 if let Some((_, tx)) = self.pf.channels.get(&channel_id) {
750 let _ = tx.send(data).await;
751 }
752 }
753 Some(Ok(Frame::PortForwardClose { channel_id })) => {
755 self.pf.channels.remove(&channel_id);
756 }
757 Some(Ok(Frame::PortForwardStop { forward_id })) => {
759 if let Some(fwd) = self.pf.forwards.remove(&forward_id) {
760 if let Some(h) = fwd.listener_handle {
761 h.abort();
762 }
763 }
764 self.pf.channels.retain(|_, (fid, _)| *fid != forward_id);
766 }
767 Some(Ok(_)) => {} Some(Err(e)) => {
769 debug!("server connection error: {e}");
770 return Ok(ControlFlow::Break(RelayExit::Disconnected));
771 }
772 None => {
773 debug!("server disconnected");
774 return Ok(ControlFlow::Break(RelayExit::Disconnected));
775 }
776 }
777 Ok(ControlFlow::Continue(()))
778 }
779
780 async fn handle_agent_event(
781 &mut self,
782 framed: &mut Framed<UnixStream, FrameCodec>,
783 event: Option<AgentEvent>,
784 ) -> bool {
785 match event {
786 Some(AgentEvent::Data { channel_id, data }) => {
787 if self.agent.channels.contains_key(&channel_id)
788 && !timed_send(framed, Frame::AgentData { channel_id, data }).await
789 {
790 return false;
791 }
792 }
793 Some(AgentEvent::Closed { channel_id }) => {
794 if self.agent.channels.remove(&channel_id).is_some()
795 && !timed_send(framed, Frame::AgentClose { channel_id }).await
796 {
797 return false;
798 }
799 }
800 None => {} }
802 true
803 }
804
805 async fn handle_tunnel_event(
806 &mut self,
807 framed: &mut Framed<UnixStream, FrameCodec>,
808 event: Option<ClientTunnelEvent>,
809 ) -> bool {
810 match event {
811 Some(ClientTunnelEvent::Accepted { channel_id, writer_tx }) => {
812 self.tunnel.channels.insert(channel_id, writer_tx);
813 if !timed_send(framed, Frame::TunnelOpen { channel_id }).await {
814 return false;
815 }
816 }
817 Some(ClientTunnelEvent::Data { channel_id, data }) => {
818 if !timed_send(framed, Frame::TunnelData { channel_id, data }).await {
819 return false;
820 }
821 }
822 Some(ClientTunnelEvent::Closed { channel_id }) => {
823 self.tunnel.channels.remove(&channel_id);
824 if !timed_send(framed, Frame::TunnelClose { channel_id }).await {
825 return false;
826 }
827 }
828 None => {}
829 }
830 true
831 }
832
833 async fn handle_pf_event(
834 &mut self,
835 framed: &mut Framed<UnixStream, FrameCodec>,
836 event: Option<ClientPortForwardEvent>,
837 ) -> bool {
838 match event {
839 Some(ClientPortForwardEvent::Accepted { forward_id, channel_id, writer_tx }) => {
840 if let Some(fwd) = self.pf.forwards.get(&forward_id) {
841 let target_port = fwd.target_port;
842 self.pf.channels.insert(channel_id, (forward_id, writer_tx));
843 if !timed_send(
844 framed,
845 Frame::PortForwardOpen { forward_id, channel_id, target_port },
846 )
847 .await
848 {
849 return false;
850 }
851 }
852 }
853 Some(ClientPortForwardEvent::Data { channel_id, data }) => {
854 if self.pf.channels.contains_key(&channel_id)
855 && !timed_send(framed, Frame::PortForwardData { channel_id, data }).await
856 {
857 return false;
858 }
859 }
860 Some(ClientPortForwardEvent::Closed { channel_id }) => {
861 if self.pf.channels.remove(&channel_id).is_some()
862 && !timed_send(framed, Frame::PortForwardClose { channel_id }).await
863 {
864 return false;
865 }
866 }
867 None => {}
868 }
869 true
870 }
871}
872
873#[allow(clippy::too_many_arguments)]
875async fn relay(
876 framed: &mut Framed<UnixStream, FrameCodec>,
877 async_stdin: &AsyncFd<io::Stdin>,
878 async_stdout: &AsyncFd<std::os::fd::OwnedFd>,
879 sigwinch: &mut tokio::signal::unix::Signal,
880 buf: &mut [u8],
881 mut escape: Option<&mut EscapeProcessor>,
882 raw_guard: &RawModeGuard,
883 nb_guard: &NonBlockGuard,
884 agent_socket: Option<&str>,
885 oauth_redirect: bool,
886 oauth_timeout: u64,
887 session: &str,
888 hb_interval: Duration,
889 hb_timeout: Duration,
890) -> anyhow::Result<RelayExit> {
891 let mut sigterm = signal(SignalKind::terminate())?;
892 let mut sighup = signal(SignalKind::hangup())?;
893
894 let mut heartbeat_interval = tokio::time::interval(hb_interval);
895 heartbeat_interval.reset(); let mut last_pong = Instant::now();
897 let mut last_ping_sent = Instant::now();
898 let mut last_rtt: Option<Duration> = None;
899
900 let mut agent = ClientAgentState::new();
902 let (agent_event_tx, mut agent_event_rx) = mpsc::unbounded_channel::<AgentEvent>();
903
904 let mut tunnel = ClientTunnelState::new();
906 let (tunnel_event_tx, mut tunnel_event_rx) = mpsc::unbounded_channel::<ClientTunnelEvent>();
907
908 let (pf_event_tx, mut pf_event_rx) = mpsc::unbounded_channel::<ClientPortForwardEvent>();
910 let mut pf = ClientPortForwardTable::new();
911
912 let mut bytes_relayed = 0u64;
913 let mut relay = ClientRelay {
914 async_stdout,
915 agent: &mut agent,
916 agent_event_tx: &agent_event_tx,
917 agent_socket,
918 tunnel: &mut tunnel,
919 tunnel_event_tx: &tunnel_event_tx,
920 oauth_redirect,
921 oauth_timeout,
922 pf: &mut pf,
923 pf_event_tx: &pf_event_tx,
924 last_pong: &mut last_pong,
925 last_ping_sent: &mut last_ping_sent,
926 last_rtt: &mut last_rtt,
927 connected_at: Instant::now(),
928 bytes_relayed: &mut bytes_relayed,
929 };
930
931 loop {
932 tokio::select! {
933 ready = async_stdin.readable() => {
934 let mut guard = ready?;
935 match guard.try_io(|inner| inner.get_ref().read(buf)) {
936 Ok(Ok(0)) => {
937 debug!("stdin EOF");
938 return Ok(RelayExit::Exit(0));
939 }
940 Ok(Ok(n)) => {
941 debug!(len = n, "stdin → socket");
942 if let Some(ref mut esc) = escape {
943 for action in esc.process(&buf[..n]) {
944 match action {
945 EscapeAction::Data(data) => {
946 if !timed_send(framed, Frame::Data(Bytes::from(data))).await {
947 return Ok(RelayExit::Disconnected);
948 }
949 }
950 EscapeAction::Detach => {
951 write_stdout_async(async_stdout, status_msg("detached").as_bytes()).await?;
952 return Ok(RelayExit::Exit(0));
953 }
954 EscapeAction::Reconnect => {
955 write_stdout_async(async_stdout, status_msg("force reconnect").as_bytes()).await?;
956 return Ok(RelayExit::Disconnected);
957 }
958 EscapeAction::Suspend => {
959 suspend(raw_guard, nb_guard)?;
960 let (cols, rows) = get_terminal_size();
962 if !timed_send(framed, Frame::Resize { cols, rows }).await {
963 return Ok(RelayExit::Disconnected);
964 }
965 }
966 EscapeAction::Status => {
967 let rtt_str = match *relay.last_rtt {
968 Some(d) => format!("{:.1}ms", d.as_secs_f64() * 1000.0),
969 None => "n/a".to_string(),
970 };
971 let uptime = relay.connected_at.elapsed();
972 let uptime_str = if uptime.as_secs() >= 3600 {
973 format!(
974 "{}h {}m {}s",
975 uptime.as_secs() / 3600,
976 (uptime.as_secs() % 3600) / 60,
977 uptime.as_secs() % 60,
978 )
979 } else if uptime.as_secs() >= 60 {
980 format!(
981 "{}m {}s",
982 uptime.as_secs() / 60,
983 uptime.as_secs() % 60,
984 )
985 } else {
986 format!("{}s", uptime.as_secs())
987 };
988 let bytes_str = format_size(*relay.bytes_relayed);
989 let agent_info = if relay.agent_socket.is_some() {
990 format!(
991 "on ({} channels)",
992 relay.agent.channels.len()
993 )
994 } else {
995 "off".to_string()
996 };
997 let open_str = if relay.oauth_redirect { "on" } else { "off" };
998 let mut pf_lines = Vec::new();
999 for (&fwd_id, fwd) in &relay.pf.forwards {
1000 let ch_count = relay.pf.channels.values()
1001 .filter(|(fid, _)| *fid == fwd_id)
1002 .count();
1003 pf_lines.push(format!(
1004 " :{} ({} connections)",
1005 fwd.target_port,
1006 ch_count,
1007 ));
1008 }
1009 let tunnel_str = if !relay.tunnel.channels.is_empty() {
1010 format!("active ({} channels)", relay.tunnel.channels.len())
1011 } else if relay.tunnel.listener.is_some() {
1012 "listening".to_string()
1013 } else {
1014 "idle".to_string()
1015 };
1016 let mut status = format!(
1017 "\r\n\x1b[2;33m[gritty status]\r\n\
1018 \x1b[0m\x1b[2m session: {session}\r\n\
1019 \x1b[0m\x1b[2m rtt: {rtt_str}\r\n\
1020 \x1b[0m\x1b[2m connected: {uptime_str}\r\n\
1021 \x1b[0m\x1b[2m bytes relayed: {bytes_str}\r\n\
1022 \x1b[0m\x1b[2m agent forwarding: {agent_info}\r\n\
1023 \x1b[0m\x1b[2m open forwarding: {open_str}\r\n\
1024 \x1b[0m\x1b[2m oauth tunnel: {tunnel_str}\r\n",
1025 );
1026 for line in &pf_lines {
1027 status.push_str(&format!(
1028 "\x1b[0m\x1b[2m port forward{line}\r\n"
1029 ));
1030 }
1031 status.push_str("\x1b[0m");
1032 write_stdout_async(
1033 async_stdout,
1034 status.as_bytes(),
1035 ).await?;
1036 }
1037 EscapeAction::Help => {
1038 write_stdout_async(async_stdout, ESCAPE_HELP).await?;
1039 }
1040 }
1041 }
1042 } else if !timed_send(framed, Frame::Data(Bytes::copy_from_slice(&buf[..n]))).await {
1043 return Ok(RelayExit::Disconnected);
1044 }
1045 }
1046 Ok(Err(e)) => return Err(e.into()),
1047 Err(_would_block) => continue,
1048 }
1049 }
1050
1051 frame = framed.next() => {
1052 if let ControlFlow::Break(exit) = relay.handle_server_frame(framed, frame).await? {
1053 return Ok(exit);
1054 }
1055 }
1056
1057 event = agent_event_rx.recv() => {
1058 if !relay.handle_agent_event(framed, event).await {
1059 return Ok(RelayExit::Disconnected);
1060 }
1061 }
1062
1063 event = tunnel_event_rx.recv() => {
1064 if !relay.handle_tunnel_event(framed, event).await {
1065 return Ok(RelayExit::Disconnected);
1066 }
1067 }
1068
1069 event = pf_event_rx.recv() => {
1070 if !relay.handle_pf_event(framed, event).await {
1071 return Ok(RelayExit::Disconnected);
1072 }
1073 }
1074
1075 _ = sigwinch.recv() => {
1076 let (cols, rows) = get_terminal_size();
1077 debug!(cols, rows, "SIGWINCH → resize");
1078 if !timed_send(framed, Frame::Resize { cols, rows }).await {
1079 return Ok(RelayExit::Disconnected);
1080 }
1081 }
1082
1083 _ = heartbeat_interval.tick() => {
1084 if relay.last_pong.elapsed() > hb_timeout {
1085 debug!("heartbeat timeout");
1086 return Ok(RelayExit::Disconnected);
1087 }
1088 *relay.last_ping_sent = Instant::now();
1089 if !timed_send(framed, Frame::Ping).await {
1090 return Ok(RelayExit::Disconnected);
1091 }
1092 }
1093
1094 _ = sigterm.recv() => {
1095 debug!("SIGTERM received, exiting");
1096 return Ok(RelayExit::Exit(1));
1097 }
1098
1099 _ = sighup.recv() => {
1100 debug!("SIGHUP received, exiting");
1101 return Ok(RelayExit::Exit(1));
1102 }
1103 }
1104 }
1105}
1106
1107#[allow(clippy::too_many_arguments)]
1108pub async fn run(
1109 session: &str,
1110 mut framed: Framed<UnixStream, FrameCodec>,
1111 redraw: bool,
1112 ctl_path: &Path,
1113 env_vars: Vec<(String, String)>,
1114 no_escape: bool,
1115 forward_agent: bool,
1116 forward_open: bool,
1117 oauth_redirect: bool,
1118 oauth_timeout: u64,
1119 heartbeat_interval: u64,
1120 heartbeat_timeout: u64,
1121) -> anyhow::Result<i32> {
1122 let stdin = io::stdin();
1123 let stdin_fd = stdin.as_fd();
1124 let stdin_borrowed: BorrowedFd<'static> =
1126 unsafe { BorrowedFd::borrow_raw(stdin_fd.as_raw_fd()) };
1127 let raw_guard = RawModeGuard::enter(stdin_borrowed)?;
1128
1129 let nb_guard = NonBlockGuard::set(stdin_borrowed)?;
1132 let async_stdin = AsyncFd::new(io::stdin())?;
1133 let stdout_fd = crate::security::checked_dup(io::stdout().as_raw_fd())?;
1135 let async_stdout = AsyncFd::new(stdout_fd)?;
1136 let mut sigwinch = signal(SignalKind::window_change())?;
1137 let mut buf = vec![0u8; 4096];
1138 let mut current_redraw = redraw;
1139 let mut current_env = env_vars;
1140 let mut escape = if no_escape { None } else { Some(EscapeProcessor::new()) };
1141 let agent_socket = if forward_agent { std::env::var("SSH_AUTH_SOCK").ok() } else { None };
1142
1143 loop {
1144 let result = if send_init_frames(
1145 &mut framed,
1146 ¤t_env,
1147 forward_agent,
1148 agent_socket.as_deref(),
1149 forward_open,
1150 current_redraw,
1151 )
1152 .await
1153 {
1154 relay(
1155 &mut framed,
1156 &async_stdin,
1157 &async_stdout,
1158 &mut sigwinch,
1159 &mut buf,
1160 escape.as_mut(),
1161 &raw_guard,
1162 &nb_guard,
1163 agent_socket.as_deref(),
1164 oauth_redirect,
1165 oauth_timeout,
1166 session,
1167 Duration::from_secs(heartbeat_interval),
1168 Duration::from_secs(heartbeat_timeout),
1169 )
1170 .await?
1171 } else {
1172 RelayExit::Disconnected
1173 };
1174 match result {
1175 RelayExit::Exit(code) => return Ok(code),
1176 RelayExit::Disconnected => {
1177 current_env.clear();
1179 let reconnect_started = Instant::now();
1181 write_stdout_async(&async_stdout, status_msg("reconnecting...").as_bytes()).await?;
1182
1183 loop {
1184 tokio::select! {
1186 _ = tokio::time::sleep(Duration::from_secs(1)) => {}
1187 _ = async_stdin.readable() => {
1188 let mut peek = [0u8; 1];
1189 match async_stdin.get_ref().read(&mut peek) {
1190 Ok(1) if peek[0] == 0x03 => {
1191 write_stdout_async(&async_stdout, b"\r\n").await?;
1192 return Ok(1);
1193 }
1194 _ => {}
1195 }
1196 continue;
1197 }
1198 }
1199
1200 let elapsed = reconnect_started.elapsed().as_secs();
1202 write_stdout_async(
1203 &async_stdout,
1204 format!("\r{}", status_msg(&format!("reconnecting... {elapsed}s")))
1205 .as_bytes(),
1206 )
1207 .await?;
1208
1209 let stream = match UnixStream::connect(ctl_path).await {
1210 Ok(s) => s,
1211 Err(_) => continue,
1212 };
1213
1214 let mut new_framed = Framed::new(stream, FrameCodec);
1215 if crate::handshake(&mut new_framed).await.is_err() {
1216 continue;
1217 }
1218 if new_framed
1219 .send(Frame::Attach { session: session.to_string() })
1220 .await
1221 .is_err()
1222 {
1223 continue;
1224 }
1225
1226 match new_framed.next().await {
1227 Some(Ok(Frame::Ok)) => {
1228 write_stdout_async(
1229 &async_stdout,
1230 success_msg("reconnected").as_bytes(),
1231 )
1232 .await?;
1233 framed = new_framed;
1234 current_redraw = true;
1235 break;
1236 }
1237 Some(Ok(Frame::Error { message })) => {
1238 write_stdout_async(
1239 &async_stdout,
1240 error_msg(&format!("session gone: {message}")).as_bytes(),
1241 )
1242 .await?;
1243 return Ok(1);
1244 }
1245 _ => continue,
1246 }
1247 }
1248 }
1249 }
1250 }
1251}
1252
1253pub async fn tail(
1257 session: &str,
1258 mut framed: Framed<UnixStream, FrameCodec>,
1259 ctl_path: &Path,
1260) -> anyhow::Result<i32> {
1261 let stdin_fd = unsafe { BorrowedFd::borrow_raw(libc::STDIN_FILENO) };
1263 let _input_guard = SuppressInputGuard::enter(stdin_fd).ok();
1264
1265 tokio::task::spawn_blocking(|| {
1267 let mut buf = [0u8; 64];
1268 let mut belled = false;
1269 loop {
1270 match io::stdin().read(&mut buf) {
1271 Ok(0) | Err(_) => break,
1272 Ok(_) if !belled => {
1273 let _ = io::stderr().write_all(b"\x07");
1274 let _ = io::stderr().flush();
1275 belled = true;
1276 }
1277 _ => {}
1278 }
1279 }
1280 });
1281
1282 let mut heartbeat_interval = tokio::time::interval(DEFAULT_HEARTBEAT_INTERVAL);
1283 heartbeat_interval.reset();
1284 let mut last_pong = Instant::now();
1285 let mut sigint = signal(SignalKind::interrupt())?;
1286 let mut sigterm = signal(SignalKind::terminate())?;
1287 let mut sighup = signal(SignalKind::hangup())?;
1288 let mut stdout = tokio::io::stdout();
1289
1290 let code = 'outer: loop {
1291 let result = 'relay: loop {
1292 tokio::select! {
1293 frame = framed.next() => {
1294 match frame {
1295 Some(Ok(Frame::Data(data))) => {
1296 use tokio::io::AsyncWriteExt;
1297 stdout.write_all(&data).await?;
1298 }
1299 Some(Ok(Frame::Pong)) => {
1300 last_pong = Instant::now();
1301 }
1302 Some(Ok(Frame::Exit { code })) => {
1303 break 'relay Some(code);
1304 }
1305 Some(Ok(_)) => {}
1306 Some(Err(e)) => {
1307 debug!("tail connection error: {e}");
1308 break 'relay None;
1309 }
1310 None => {
1311 debug!("tail server disconnected");
1312 break 'relay None;
1313 }
1314 }
1315 }
1316 _ = heartbeat_interval.tick() => {
1317 if last_pong.elapsed() > DEFAULT_HEARTBEAT_TIMEOUT {
1318 debug!("tail heartbeat timeout");
1319 break 'relay None;
1320 }
1321 if framed.send(Frame::Ping).await.is_err() {
1322 break 'relay None;
1323 }
1324 }
1325 _ = sigint.recv() => {
1326 break 'outer 0;
1327 }
1328 _ = sigterm.recv() => {
1329 break 'outer 1;
1330 }
1331 _ = sighup.recv() => {
1332 break 'outer 1;
1333 }
1334 }
1335 };
1336
1337 match result {
1338 Some(code) => break code,
1339 None => {
1340 let reconnect_started = Instant::now();
1341 eprintln!("\x1b[2;33m[reconnecting...]\x1b[0m");
1342 loop {
1343 tokio::time::sleep(Duration::from_secs(1)).await;
1344 let elapsed = reconnect_started.elapsed().as_secs();
1345 eprint!("\r\x1b[2;33m[reconnecting... {elapsed}s]\x1b[0m");
1346
1347 let stream = match UnixStream::connect(ctl_path).await {
1348 Ok(s) => s,
1349 Err(_) => continue,
1350 };
1351
1352 let mut new_framed = Framed::new(stream, FrameCodec);
1353 if crate::handshake(&mut new_framed).await.is_err() {
1354 continue;
1355 }
1356 if new_framed.send(Frame::Tail { session: session.to_string() }).await.is_err()
1357 {
1358 continue;
1359 }
1360
1361 match new_framed.next().await {
1362 Some(Ok(Frame::Ok)) => {
1363 eprintln!("\x1b[32m[reconnected]\x1b[0m");
1364 framed = new_framed;
1365 heartbeat_interval.reset();
1366 last_pong = Instant::now();
1367 break;
1368 }
1369 Some(Ok(Frame::Error { message })) => {
1370 eprintln!("\x1b[31m[session gone: {message}]\x1b[0m");
1371 break 'outer 1;
1372 }
1373 _ => continue,
1374 }
1375 }
1376 }
1377 }
1378 };
1379
1380 {
1383 use tokio::io::AsyncWriteExt;
1384 let _ = stdout.write_all(b"\x1b[0m\x1b[?25h").await;
1385 }
1386 Ok(code)
1387}
1388
1389#[cfg(test)]
1390mod tests {
1391 use super::*;
1392
1393 #[test]
1394 fn normal_passthrough() {
1395 let mut ep = EscapeProcessor::new();
1396 let actions = ep.process(b"hello");
1398 assert_eq!(actions, vec![EscapeAction::Data(b"hello".to_vec())]);
1399 }
1400
1401 #[test]
1402 fn tilde_after_newline_detach() {
1403 let mut ep = EscapeProcessor { state: EscapeState::Normal };
1404 let actions = ep.process(b"\n~.");
1405 assert_eq!(actions, vec![EscapeAction::Data(b"\n".to_vec()), EscapeAction::Detach,]);
1406 }
1407
1408 #[test]
1409 fn tilde_after_cr_detach() {
1410 let mut ep = EscapeProcessor { state: EscapeState::Normal };
1411 let actions = ep.process(b"\r~.");
1412 assert_eq!(actions, vec![EscapeAction::Data(b"\r".to_vec()), EscapeAction::Detach,]);
1413 }
1414
1415 #[test]
1416 fn tilde_not_after_newline() {
1417 let mut ep = EscapeProcessor { state: EscapeState::Normal };
1418 let actions = ep.process(b"a~.");
1419 assert_eq!(actions, vec![EscapeAction::Data(b"a~.".to_vec())]);
1420 }
1421
1422 #[test]
1423 fn initial_state_detach() {
1424 let mut ep = EscapeProcessor::new();
1425 let actions = ep.process(b"~.");
1426 assert_eq!(actions, vec![EscapeAction::Detach]);
1427 }
1428
1429 #[test]
1430 fn tilde_suspend() {
1431 let mut ep = EscapeProcessor { state: EscapeState::Normal };
1432 let actions = ep.process(b"\n~\x1a");
1433 assert_eq!(actions, vec![EscapeAction::Data(b"\n".to_vec()), EscapeAction::Suspend,]);
1434 }
1435
1436 #[test]
1437 fn tilde_status() {
1438 let mut ep = EscapeProcessor { state: EscapeState::Normal };
1439 let actions = ep.process(b"\n~#");
1440 assert_eq!(actions, vec![EscapeAction::Data(b"\n".to_vec()), EscapeAction::Status,]);
1441 }
1442
1443 #[test]
1444 fn tilde_reconnect() {
1445 let mut ep = EscapeProcessor { state: EscapeState::Normal };
1446 let actions = ep.process(b"\n~R");
1447 assert_eq!(actions, vec![EscapeAction::Data(b"\n".to_vec()), EscapeAction::Reconnect,]);
1448 }
1449
1450 #[test]
1451 fn tilde_reconnect_stops_processing() {
1452 let mut ep = EscapeProcessor { state: EscapeState::Normal };
1453 let actions = ep.process(b"\n~Rremaining");
1454 assert_eq!(actions, vec![EscapeAction::Data(b"\n".to_vec()), EscapeAction::Reconnect,]);
1455 }
1456
1457 #[test]
1458 fn tilde_help() {
1459 let mut ep = EscapeProcessor { state: EscapeState::Normal };
1460 let actions = ep.process(b"\n~?");
1461 assert_eq!(actions, vec![EscapeAction::Data(b"\n".to_vec()), EscapeAction::Help,]);
1462 }
1463
1464 #[test]
1465 fn double_tilde() {
1466 let mut ep = EscapeProcessor { state: EscapeState::Normal };
1467 let actions = ep.process(b"\n~~");
1468 assert_eq!(
1469 actions,
1470 vec![EscapeAction::Data(b"\n".to_vec()), EscapeAction::Data(b"~".to_vec()),]
1471 );
1472 assert_eq!(ep.state, EscapeState::Normal);
1473 }
1474
1475 #[test]
1476 fn tilde_unknown_char() {
1477 let mut ep = EscapeProcessor { state: EscapeState::Normal };
1478 let actions = ep.process(b"\n~x");
1479 assert_eq!(
1480 actions,
1481 vec![EscapeAction::Data(b"\n".to_vec()), EscapeAction::Data(b"~x".to_vec()),]
1482 );
1483 }
1484
1485 #[test]
1486 fn split_across_reads() {
1487 let mut ep = EscapeProcessor { state: EscapeState::Normal };
1488 let a1 = ep.process(b"\n");
1489 assert_eq!(a1, vec![EscapeAction::Data(b"\n".to_vec())]);
1490 let a2 = ep.process(b"~");
1491 assert_eq!(a2, vec![]); let a3 = ep.process(b".");
1493 assert_eq!(a3, vec![EscapeAction::Detach]);
1494 }
1495
1496 #[test]
1497 fn split_tilde_then_normal() {
1498 let mut ep = EscapeProcessor { state: EscapeState::Normal };
1499 let a1 = ep.process(b"\n");
1500 assert_eq!(a1, vec![EscapeAction::Data(b"\n".to_vec())]);
1501 let a2 = ep.process(b"~");
1502 assert_eq!(a2, vec![]);
1503 let a3 = ep.process(b"a");
1504 assert_eq!(a3, vec![EscapeAction::Data(b"~a".to_vec())]);
1505 }
1506
1507 #[test]
1508 fn multiple_escapes_one_buffer() {
1509 let mut ep = EscapeProcessor { state: EscapeState::Normal };
1510 let actions = ep.process(b"\n~?\n~.");
1511 assert_eq!(
1512 actions,
1513 vec![
1514 EscapeAction::Data(b"\n".to_vec()),
1515 EscapeAction::Help,
1516 EscapeAction::Data(b"\n".to_vec()),
1517 EscapeAction::Detach,
1518 ]
1519 );
1520 }
1521
1522 #[test]
1523 fn consecutive_newlines() {
1524 let mut ep = EscapeProcessor { state: EscapeState::Normal };
1525 let actions = ep.process(b"\n\n\n~.");
1526 assert_eq!(actions, vec![EscapeAction::Data(b"\n\n\n".to_vec()), EscapeAction::Detach,]);
1527 }
1528
1529 #[test]
1530 fn detach_stops_processing() {
1531 let mut ep = EscapeProcessor { state: EscapeState::Normal };
1532 let actions = ep.process(b"\n~.remaining");
1533 assert_eq!(actions, vec![EscapeAction::Data(b"\n".to_vec()), EscapeAction::Detach,]);
1534 }
1535
1536 #[test]
1537 fn tilde_then_newline() {
1538 let mut ep = EscapeProcessor { state: EscapeState::Normal };
1539 let actions = ep.process(b"\n~\n");
1540 assert_eq!(
1541 actions,
1542 vec![EscapeAction::Data(b"\n".to_vec()), EscapeAction::Data(b"~\n".to_vec()),]
1543 );
1544 assert_eq!(ep.state, EscapeState::AfterNewline);
1545 }
1546
1547 #[test]
1548 fn empty_input() {
1549 let mut ep = EscapeProcessor::new();
1550 let actions = ep.process(b"");
1551 assert_eq!(actions, vec![]);
1552 }
1553
1554 #[test]
1555 fn only_tilde_buffered() {
1556 let mut ep = EscapeProcessor { state: EscapeState::Normal };
1557 let a1 = ep.process(b"\n~");
1558 assert_eq!(a1, vec![EscapeAction::Data(b"\n".to_vec())]);
1559 assert_eq!(ep.state, EscapeState::AfterTilde);
1560 let a2 = ep.process(b".");
1561 assert_eq!(a2, vec![EscapeAction::Detach]);
1562 }
1563}