1use crate::protocol::{Frame, FrameCodec};
2use bytes::Bytes;
3use futures_util::{SinkExt, StreamExt};
4use nix::sys::termios::{self, FlushArg, LocalFlags, SetArg, SpecialCharacterIndices, Termios};
5use std::collections::HashMap;
6use std::io::{self, Read, Write};
7use std::ops::ControlFlow;
8use std::os::fd::{AsFd, AsRawFd, BorrowedFd};
9use std::path::Path;
10use std::sync::Arc;
11use std::sync::atomic::{AtomicU32, Ordering};
12use std::time::Duration;
13use tokio::io::unix::AsyncFd;
14use tokio::net::UnixStream;
15use tokio::signal::unix::{SignalKind, signal};
16use tokio::sync::mpsc;
17use tokio::time::Instant;
18
19enum RelayExit {
21 Exit(i32),
23 Disconnected,
25}
26use tokio_util::codec::Framed;
27use tracing::{debug, info};
28
29const ESCAPE_HELP: &[u8] = b"\r\nSupported escape sequences:\r\n\
32 ~. - detach from session\r\n\
33 ~R - force reconnect\r\n\
34 ~^Z - suspend client\r\n\
35 ~# - session status and RTT\r\n\
36 ~? - this message\r\n\
37 ~~ - send the escape character by typing it twice\r\n\
38(Note that escapes are only recognized immediately after newline.)\r\n";
39
40#[derive(Debug, Clone, Copy, PartialEq, Eq)]
41enum EscapeState {
42 Normal,
43 AfterNewline,
44 AfterTilde,
45}
46
47#[derive(Debug, PartialEq, Eq)]
48enum EscapeAction {
49 Data(Vec<u8>),
50 Detach,
51 Reconnect,
52 Suspend,
53 Status,
54 Help,
55}
56
57struct EscapeProcessor {
58 state: EscapeState,
59}
60
61impl EscapeProcessor {
62 fn new() -> Self {
63 Self { state: EscapeState::AfterNewline }
64 }
65
66 fn process(&mut self, input: &[u8]) -> Vec<EscapeAction> {
67 let mut actions = Vec::new();
68 let mut data_buf = Vec::new();
69
70 for &b in input {
71 match self.state {
72 EscapeState::Normal => {
73 if b == b'\n' || b == b'\r' {
74 self.state = EscapeState::AfterNewline;
75 }
76 data_buf.push(b);
77 }
78 EscapeState::AfterNewline => {
79 if b == b'~' {
80 self.state = EscapeState::AfterTilde;
81 if !data_buf.is_empty() {
83 actions.push(EscapeAction::Data(std::mem::take(&mut data_buf)));
84 }
85 } else if b == b'\n' || b == b'\r' {
86 data_buf.push(b);
88 } else {
89 self.state = EscapeState::Normal;
90 data_buf.push(b);
91 }
92 }
93 EscapeState::AfterTilde => {
94 match b {
95 b'.' => {
96 if !data_buf.is_empty() {
97 actions.push(EscapeAction::Data(std::mem::take(&mut data_buf)));
98 }
99 actions.push(EscapeAction::Detach);
100 return actions; }
102 b'R' => {
103 if !data_buf.is_empty() {
104 actions.push(EscapeAction::Data(std::mem::take(&mut data_buf)));
105 }
106 actions.push(EscapeAction::Reconnect);
107 return actions; }
109 0x1a => {
110 if !data_buf.is_empty() {
112 actions.push(EscapeAction::Data(std::mem::take(&mut data_buf)));
113 }
114 actions.push(EscapeAction::Suspend);
115 self.state = EscapeState::Normal;
116 }
117 b'#' => {
118 if !data_buf.is_empty() {
119 actions.push(EscapeAction::Data(std::mem::take(&mut data_buf)));
120 }
121 actions.push(EscapeAction::Status);
122 self.state = EscapeState::Normal;
123 }
124 b'?' => {
125 if !data_buf.is_empty() {
126 actions.push(EscapeAction::Data(std::mem::take(&mut data_buf)));
127 }
128 actions.push(EscapeAction::Help);
129 self.state = EscapeState::Normal;
130 }
131 b'~' => {
132 data_buf.push(b'~');
134 self.state = EscapeState::Normal;
135 }
136 b'\n' | b'\r' => {
137 data_buf.push(b'~');
139 data_buf.push(b);
140 self.state = EscapeState::AfterNewline;
141 }
142 _ => {
143 data_buf.push(b'~');
145 data_buf.push(b);
146 self.state = EscapeState::Normal;
147 }
148 }
149 }
150 }
151 }
152
153 if !data_buf.is_empty() {
154 actions.push(EscapeAction::Data(data_buf));
155 }
156 actions
157 }
158}
159
160fn suspend(raw_guard: &RawModeGuard, nb_guard: &NonBlockGuard) -> anyhow::Result<()> {
161 termios::tcsetattr(raw_guard.fd, SetArg::TCSAFLUSH, &raw_guard.original)?;
163 let _ = nix::fcntl::fcntl(nb_guard.fd, nix::fcntl::FcntlArg::F_SETFL(nb_guard.original_flags));
164
165 nix::sys::signal::kill(nix::unistd::Pid::from_raw(0), nix::sys::signal::Signal::SIGTSTP)?;
166
167 let _ = nix::fcntl::fcntl(
169 nb_guard.fd,
170 nix::fcntl::FcntlArg::F_SETFL(nb_guard.original_flags | nix::fcntl::OFlag::O_NONBLOCK),
171 );
172 let mut raw = raw_guard.original.clone();
173 termios::cfmakeraw(&mut raw);
174 termios::tcsetattr(raw_guard.fd, SetArg::TCSAFLUSH, &raw)?;
175 Ok(())
176}
177
178const SEND_TIMEOUT: Duration = Duration::from_secs(5);
179
180struct NonBlockGuard {
181 fd: BorrowedFd<'static>,
182 original_flags: nix::fcntl::OFlag,
183}
184
185impl NonBlockGuard {
186 fn set(fd: BorrowedFd<'static>) -> nix::Result<Self> {
187 let flags = nix::fcntl::fcntl(fd, nix::fcntl::FcntlArg::F_GETFL)?;
188 let original_flags = nix::fcntl::OFlag::from_bits_truncate(flags);
189 nix::fcntl::fcntl(
190 fd,
191 nix::fcntl::FcntlArg::F_SETFL(original_flags | nix::fcntl::OFlag::O_NONBLOCK),
192 )?;
193 Ok(Self { fd, original_flags })
194 }
195}
196
197impl Drop for NonBlockGuard {
198 fn drop(&mut self) {
199 let _ = nix::fcntl::fcntl(self.fd, nix::fcntl::FcntlArg::F_SETFL(self.original_flags));
200 }
201}
202
203struct RawModeGuard {
204 fd: BorrowedFd<'static>,
205 original: Termios,
206}
207
208impl RawModeGuard {
209 fn enter(fd: BorrowedFd<'static>) -> nix::Result<Self> {
210 let original = termios::tcgetattr(fd)?;
211 let mut raw = original.clone();
212 termios::cfmakeraw(&mut raw);
213 termios::tcsetattr(fd, SetArg::TCSAFLUSH, &raw)?;
214 Ok(Self { fd, original })
215 }
216}
217
218impl Drop for RawModeGuard {
219 fn drop(&mut self) {
220 let _ = termios::tcsetattr(self.fd, SetArg::TCSAFLUSH, &self.original);
221 }
222}
223
224struct SuppressInputGuard {
227 fd: BorrowedFd<'static>,
228 original: Termios,
229}
230
231impl SuppressInputGuard {
232 fn enter(fd: BorrowedFd<'static>) -> nix::Result<Self> {
233 let original = termios::tcgetattr(fd)?;
234 let mut modified = original.clone();
235 modified.local_flags.remove(LocalFlags::ECHO | LocalFlags::ICANON);
236 modified.control_chars[SpecialCharacterIndices::VMIN as usize] = 1;
237 modified.control_chars[SpecialCharacterIndices::VTIME as usize] = 0;
238 termios::tcsetattr(fd, SetArg::TCSAFLUSH, &modified)?;
239 Ok(Self { fd, original })
240 }
241}
242
243impl Drop for SuppressInputGuard {
244 fn drop(&mut self) {
245 let _ = termios::tcflush(self.fd, FlushArg::TCIFLUSH);
246 let _ = termios::tcsetattr(self.fd, SetArg::TCSAFLUSH, &self.original);
247 }
248}
249
250async fn write_stdout_async(fd: &AsyncFd<std::os::fd::OwnedFd>, data: &[u8]) -> io::Result<()> {
253 let mut written = 0;
254 while written < data.len() {
255 let mut guard = fd.writable().await?;
256 match guard
257 .try_io(|inner| nix::unistd::write(inner, &data[written..]).map_err(io::Error::from))
258 {
259 Ok(Ok(n)) => written += n,
260 Ok(Err(e)) => return Err(e),
261 Err(_would_block) => continue,
262 }
263 }
264 Ok(())
265}
266
267pub fn format_size(bytes: u64) -> String {
269 humansize::format_size(bytes, humansize::BINARY)
270}
271
272fn status_msg(text: &str) -> String {
273 format!("\r\n\x1b[2;33m[{text}]\x1b[0m\r\n")
274}
275
276fn success_msg(text: &str) -> String {
277 format!("\r\n\x1b[32m[{text}]\x1b[0m\r\n")
278}
279
280fn error_msg(text: &str) -> String {
281 format!("\r\n\x1b[31m[{text}]\x1b[0m\r\n")
282}
283
284fn get_terminal_size() -> (u16, u16) {
285 terminal_size::terminal_size().map(|(w, h)| (w.0, h.0)).unwrap_or((80, 24))
286}
287
288async fn timed_send(framed: &mut Framed<UnixStream, FrameCodec>, frame: Frame) -> bool {
290 match tokio::time::timeout(SEND_TIMEOUT, framed.send(frame)).await {
291 Ok(Ok(())) => true,
292 Ok(Err(e)) => {
293 debug!("send error: {e}");
294 false
295 }
296 Err(_) => {
297 debug!("send timed out");
298 false
299 }
300 }
301}
302
303const DEFAULT_HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5);
304const DEFAULT_HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(15);
305
306enum AgentEvent {
308 Data { channel_id: u32, data: Bytes },
309 Closed { channel_id: u32 },
310}
311
312enum ClientTunnelEvent {
314 Accepted { channel_id: u32, writer_tx: mpsc::Sender<Bytes> },
315 Data { channel_id: u32, data: Bytes },
316 Closed { channel_id: u32 },
317}
318
319enum ClientPortForwardEvent {
321 Accepted { forward_id: u32, channel_id: u32, writer_tx: mpsc::Sender<Bytes> },
322 Data { channel_id: u32, data: Bytes },
323 Closed { channel_id: u32 },
324}
325
326struct ClientPortForwardState {
328 listener_handle: Option<tokio::task::JoinHandle<()>>,
329 target_port: u16,
330}
331
332struct ClientAgentState {
334 channels: HashMap<u32, mpsc::Sender<Bytes>>,
335}
336
337impl ClientAgentState {
338 fn new() -> Self {
339 Self { channels: HashMap::new() }
340 }
341
342 fn teardown(&mut self) {
343 self.channels.clear();
344 }
345}
346
347struct ClientTunnelState {
349 listener: Option<tokio::task::JoinHandle<()>>,
350 channels: HashMap<u32, mpsc::Sender<Bytes>>,
351 next_channel_id: Arc<AtomicU32>,
352}
353
354impl ClientTunnelState {
355 fn new() -> Self {
356 Self {
357 listener: None,
358 channels: HashMap::new(),
359 next_channel_id: Arc::new(AtomicU32::new(0)),
360 }
361 }
362
363 fn teardown(&mut self) {
364 self.channels.clear();
365 if let Some(handle) = self.listener.take() {
366 handle.abort();
367 }
368 }
369}
370
371struct ClientPortForwardTable {
373 forwards: HashMap<u32, ClientPortForwardState>,
374 channels: HashMap<u32, (u32, mpsc::Sender<Bytes>)>,
375 next_channel_id: std::sync::Arc<std::sync::atomic::AtomicU32>,
376}
377
378impl ClientPortForwardTable {
379 fn new() -> Self {
380 Self {
381 forwards: HashMap::new(),
382 channels: HashMap::new(),
383 next_channel_id: std::sync::Arc::new(std::sync::atomic::AtomicU32::new(0)),
384 }
385 }
386
387 fn teardown(&mut self) {
388 for (_, fwd) in self.forwards.drain() {
389 if let Some(h) = fwd.listener_handle {
390 h.abort();
391 }
392 }
393 self.channels.clear();
394 }
395}
396
397async fn send_init_frames(
400 framed: &mut Framed<UnixStream, FrameCodec>,
401 env_vars: &[(String, String)],
402 forward_agent: bool,
403 agent_socket: Option<&str>,
404 forward_open: bool,
405 redraw: bool,
406) -> bool {
407 if !env_vars.is_empty() && !timed_send(framed, Frame::Env { vars: env_vars.to_vec() }).await {
408 return false;
409 }
410 if forward_agent && agent_socket.is_some() && !timed_send(framed, Frame::AgentForward).await {
411 return false;
412 }
413 if forward_open && !timed_send(framed, Frame::OpenForward).await {
414 return false;
415 }
416 let (cols, rows) = get_terminal_size();
417 if !timed_send(framed, Frame::Resize { cols, rows }).await {
418 return false;
419 }
420 if redraw && !timed_send(framed, Frame::Data(Bytes::from_static(b"\x0c"))).await {
421 return false;
422 }
423 true
424}
425
426struct ClientRelay<'a> {
429 async_stdout: &'a AsyncFd<std::os::fd::OwnedFd>,
430 agent: &'a mut ClientAgentState,
431 agent_event_tx: &'a mpsc::UnboundedSender<AgentEvent>,
432 agent_socket: Option<&'a str>,
433 tunnel: &'a mut ClientTunnelState,
434 tunnel_event_tx: &'a mpsc::UnboundedSender<ClientTunnelEvent>,
435 oauth_redirect: bool,
436 oauth_timeout: u64,
437 pf: &'a mut ClientPortForwardTable,
438 pf_event_tx: &'a mpsc::UnboundedSender<ClientPortForwardEvent>,
439 last_pong: &'a mut Instant,
440 last_ping_sent: &'a mut Instant,
441 last_rtt: &'a mut Option<Duration>,
442 connected_at: Instant,
443 bytes_relayed: &'a mut u64,
444}
445
446impl ClientRelay<'_> {
447 async fn handle_server_frame(
448 &mut self,
449 framed: &mut Framed<UnixStream, FrameCodec>,
450 frame: Option<Result<Frame, io::Error>>,
451 ) -> Result<ControlFlow<RelayExit>, anyhow::Error> {
452 match frame {
453 Some(Ok(Frame::Data(data))) => {
454 debug!(len = data.len(), "socket → stdout");
455 *self.bytes_relayed += data.len() as u64;
456 write_stdout_async(self.async_stdout, &data).await?;
457 }
458 Some(Ok(Frame::Pong)) => {
459 *self.last_rtt = Some(self.last_ping_sent.elapsed());
460 debug!(rtt_ms = self.last_rtt.unwrap().as_secs_f64() * 1000.0, "pong received");
461 *self.last_pong = Instant::now();
462 }
463 Some(Ok(Frame::Exit { code })) => {
464 debug!(code, "server sent exit");
465 return Ok(ControlFlow::Break(RelayExit::Exit(code)));
466 }
467 Some(Ok(Frame::Detached)) => {
468 info!("detached by another client");
469 self.agent.teardown();
470 self.tunnel.teardown();
471 self.pf.teardown();
472 write_stdout_async(self.async_stdout, status_msg("detached").as_bytes()).await?;
473 return Ok(ControlFlow::Break(RelayExit::Exit(0)));
474 }
475 Some(Ok(Frame::AgentOpen { channel_id })) => {
476 if let Some(sock_path) = self.agent_socket {
477 match tokio::net::UnixStream::connect(sock_path).await {
478 Ok(stream) => {
479 let (read_half, write_half) = stream.into_split();
480 let data_tx = self.agent_event_tx.clone();
481 let close_tx = self.agent_event_tx.clone();
482 let writer_tx = crate::spawn_channel_relay(
483 channel_id,
484 read_half,
485 write_half,
486 move |id, data| {
487 data_tx.send(AgentEvent::Data { channel_id: id, data }).is_ok()
488 },
489 move |id| {
490 let _ = close_tx.send(AgentEvent::Closed { channel_id: id });
491 },
492 );
493 self.agent.channels.insert(channel_id, writer_tx);
494 }
495 Err(e) => {
496 debug!("failed to connect to local agent: {e}");
497 let _ = timed_send(framed, Frame::AgentClose { channel_id }).await;
498 }
499 }
500 } else {
501 let _ = timed_send(framed, Frame::AgentClose { channel_id }).await;
502 }
503 }
504 Some(Ok(Frame::AgentData { channel_id, data })) => {
505 if let Some(tx) = self.agent.channels.get(&channel_id) {
506 let _ = tx.send(data).await;
507 }
508 }
509 Some(Ok(Frame::AgentClose { channel_id })) => {
510 self.agent.channels.remove(&channel_id);
511 }
512 Some(Ok(Frame::OpenUrl { url })) => {
513 if url.starts_with("http://") || url.starts_with("https://") {
514 debug!("opening URL locally: {url}");
515 tokio::task::spawn_blocking(move || {
516 let _ = opener::open(&url);
517 });
518 } else {
519 debug!("rejected non-http(s) URL: {url}");
520 }
521 }
522 Some(Ok(Frame::TunnelListen { port })) => {
523 if !self.oauth_redirect {
524 debug!(port, "tunnel: oauth-redirect disabled, declining");
525 let _ = timed_send(framed, Frame::TunnelClose { channel_id: 0 }).await;
526 } else {
527 match std::net::TcpListener::bind(("127.0.0.1", port)) {
529 Ok(std_listener) => {
530 debug!(port, "tunnel: bound local port");
531 std_listener.set_nonblocking(true).ok();
532 let listener = tokio::net::TcpListener::from_std(std_listener).unwrap();
533 let tx = self.tunnel_event_tx.clone();
534 let timeout = self.oauth_timeout;
535 let next_id = Arc::clone(&self.tunnel.next_channel_id);
536 self.tunnel.listener = Some(tokio::spawn(async move {
537 let deadline =
538 tokio::time::Instant::now() + Duration::from_secs(timeout);
539 loop {
540 let accept =
541 tokio::time::timeout_at(deadline, listener.accept()).await;
542 match accept {
543 Ok(Ok((stream, _))) => {
544 let channel_id =
545 next_id.fetch_add(1, Ordering::Relaxed);
546 let (read_half, write_half) = stream.into_split();
547 let (writer_tx, mut writer_rx) =
548 mpsc::channel::<Bytes>(crate::CHANNEL_RELAY_BUFFER);
549
550 tokio::spawn(async move {
552 use tokio::io::AsyncWriteExt;
553 let mut writer = write_half;
554 while let Some(data) = writer_rx.recv().await {
555 if writer.write_all(&data).await.is_err() {
556 break;
557 }
558 }
559 });
560
561 let _ = tx.send(ClientTunnelEvent::Accepted {
562 channel_id,
563 writer_tx,
564 });
565
566 let reader_tx = tx.clone();
569 tokio::spawn(async move {
570 use tokio::io::AsyncReadExt;
571 let mut read_half = read_half;
572 let mut buf = vec![0u8; 4096];
573 loop {
574 match read_half.read(&mut buf).await {
575 Ok(0) | Err(_) => {
576 let _ = reader_tx.send(
577 ClientTunnelEvent::Closed {
578 channel_id,
579 },
580 );
581 break;
582 }
583 Ok(n) => {
584 let data =
585 Bytes::copy_from_slice(&buf[..n]);
586 if reader_tx
587 .send(ClientTunnelEvent::Data {
588 channel_id,
589 data,
590 })
591 .is_err()
592 {
593 break;
594 }
595 }
596 }
597 }
598 });
599 }
600 _ => {
601 debug!(port, "tunnel: accept timed out or failed");
602 break;
603 }
604 }
605 }
606 }));
607 }
608 Err(e) => {
609 debug!(port, "tunnel: bind failed: {e}");
610 let _ = timed_send(framed, Frame::TunnelClose { channel_id: 0 }).await;
611 }
612 }
613 }
614 }
615 Some(Ok(Frame::SendOffer { file_count, total_bytes })) => {
616 let size_str = format_size(total_bytes);
617 let s = if file_count == 1 { "" } else { "s" };
618 write_stdout_async(
619 self.async_stdout,
620 status_msg(&format!("gritty: receiving {file_count} file{s} ({size_str})"))
621 .as_bytes(),
622 )
623 .await?;
624 }
625 Some(Ok(Frame::SendDone)) => {
626 write_stdout_async(
627 self.async_stdout,
628 success_msg("gritty: transfer complete").as_bytes(),
629 )
630 .await?;
631 }
632 Some(Ok(Frame::SendCancel { reason })) => {
633 write_stdout_async(
634 self.async_stdout,
635 error_msg(&format!("gritty: transfer cancelled: {reason}")).as_bytes(),
636 )
637 .await?;
638 }
639 Some(Ok(Frame::TunnelData { channel_id, data })) => {
640 if let Some(tx) = self.tunnel.channels.get(&channel_id) {
641 let _ = tx.send(data).await;
642 }
643 }
644 Some(Ok(Frame::TunnelClose { channel_id })) => {
645 self.tunnel.channels.remove(&channel_id);
646 }
647 Some(Ok(Frame::PortForwardListen { forward_id, listen_port, target_port })) => {
649 match std::net::TcpListener::bind(("127.0.0.1", listen_port)) {
650 Ok(std_listener) => {
651 debug!(forward_id, listen_port, "port forward: bound local port");
652 std_listener.set_nonblocking(true).ok();
653 let listener = tokio::net::TcpListener::from_std(std_listener).unwrap();
654 let tx = self.pf_event_tx.clone();
655 let nid = self.pf.next_channel_id.clone();
656 let handle = tokio::spawn(async move {
657 loop {
658 let (stream, _) = match listener.accept().await {
659 Ok(conn) => conn,
660 Err(_) => break,
661 };
662 let channel_id =
663 nid.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
664 let (read_half, write_half) = stream.into_split();
665 let data_tx = tx.clone();
666 let close_tx = tx.clone();
667 let writer_tx = crate::spawn_channel_relay(
668 channel_id,
669 read_half,
670 write_half,
671 move |id, data| {
672 data_tx
673 .send(ClientPortForwardEvent::Data {
674 channel_id: id,
675 data,
676 })
677 .is_ok()
678 },
679 move |id| {
680 let _ = close_tx.send(ClientPortForwardEvent::Closed {
681 channel_id: id,
682 });
683 },
684 );
685 if tx
686 .send(ClientPortForwardEvent::Accepted {
687 forward_id,
688 channel_id,
689 writer_tx,
690 })
691 .is_err()
692 {
693 break;
694 }
695 }
696 });
697 self.pf.forwards.insert(
698 forward_id,
699 ClientPortForwardState { listener_handle: Some(handle), target_port },
700 );
701 if !timed_send(framed, Frame::PortForwardReady { forward_id }).await {
702 return Ok(ControlFlow::Break(RelayExit::Disconnected));
703 }
704 }
705 Err(e) => {
706 debug!(forward_id, listen_port, "port forward: bind failed: {e}");
707 let _ = timed_send(framed, Frame::PortForwardStop { forward_id }).await;
708 }
709 }
710 }
711 Some(Ok(Frame::PortForwardOpen { forward_id, channel_id, target_port })) => {
713 if self.pf.forwards.contains_key(&forward_id) || forward_id == u32::MAX {
714 match tokio::net::TcpStream::connect(("127.0.0.1", target_port)).await {
716 Ok(stream) => {
717 let (read_half, write_half) = stream.into_split();
718 let data_tx = self.pf_event_tx.clone();
719 let close_tx = self.pf_event_tx.clone();
720 let writer_tx = crate::spawn_channel_relay(
721 channel_id,
722 read_half,
723 write_half,
724 move |id, data| {
725 data_tx
726 .send(ClientPortForwardEvent::Data { channel_id: id, data })
727 .is_ok()
728 },
729 move |id| {
730 let _ = close_tx
731 .send(ClientPortForwardEvent::Closed { channel_id: id });
732 },
733 );
734 self.pf.channels.insert(channel_id, (forward_id, writer_tx));
735 }
736 Err(e) => {
737 debug!(channel_id, target_port, "pf connect failed: {e}");
738 let _ =
739 timed_send(framed, Frame::PortForwardClose { channel_id }).await;
740 }
741 }
742 }
743 }
744 Some(Ok(Frame::PortForwardData { channel_id, data })) => {
746 if let Some((_, tx)) = self.pf.channels.get(&channel_id) {
747 let _ = tx.send(data).await;
748 }
749 }
750 Some(Ok(Frame::PortForwardClose { channel_id })) => {
752 self.pf.channels.remove(&channel_id);
753 }
754 Some(Ok(Frame::PortForwardStop { forward_id })) => {
756 if let Some(fwd) = self.pf.forwards.remove(&forward_id) {
757 if let Some(h) = fwd.listener_handle {
758 h.abort();
759 }
760 }
761 self.pf.channels.retain(|_, (fid, _)| *fid != forward_id);
763 }
764 Some(Ok(_)) => {} Some(Err(e)) => {
766 debug!("server connection error: {e}");
767 return Ok(ControlFlow::Break(RelayExit::Disconnected));
768 }
769 None => {
770 debug!("server disconnected");
771 return Ok(ControlFlow::Break(RelayExit::Disconnected));
772 }
773 }
774 Ok(ControlFlow::Continue(()))
775 }
776
777 async fn handle_agent_event(
778 &mut self,
779 framed: &mut Framed<UnixStream, FrameCodec>,
780 event: Option<AgentEvent>,
781 ) -> bool {
782 match event {
783 Some(AgentEvent::Data { channel_id, data }) => {
784 if self.agent.channels.contains_key(&channel_id)
785 && !timed_send(framed, Frame::AgentData { channel_id, data }).await
786 {
787 return false;
788 }
789 }
790 Some(AgentEvent::Closed { channel_id }) => {
791 if self.agent.channels.remove(&channel_id).is_some()
792 && !timed_send(framed, Frame::AgentClose { channel_id }).await
793 {
794 return false;
795 }
796 }
797 None => {} }
799 true
800 }
801
802 async fn handle_tunnel_event(
803 &mut self,
804 framed: &mut Framed<UnixStream, FrameCodec>,
805 event: Option<ClientTunnelEvent>,
806 ) -> bool {
807 match event {
808 Some(ClientTunnelEvent::Accepted { channel_id, writer_tx }) => {
809 self.tunnel.channels.insert(channel_id, writer_tx);
810 if !timed_send(framed, Frame::TunnelOpen { channel_id }).await {
811 return false;
812 }
813 }
814 Some(ClientTunnelEvent::Data { channel_id, data }) => {
815 if !timed_send(framed, Frame::TunnelData { channel_id, data }).await {
816 return false;
817 }
818 }
819 Some(ClientTunnelEvent::Closed { channel_id }) => {
820 self.tunnel.channels.remove(&channel_id);
821 if !timed_send(framed, Frame::TunnelClose { channel_id }).await {
822 return false;
823 }
824 }
825 None => {}
826 }
827 true
828 }
829
830 async fn handle_pf_event(
831 &mut self,
832 framed: &mut Framed<UnixStream, FrameCodec>,
833 event: Option<ClientPortForwardEvent>,
834 ) -> bool {
835 match event {
836 Some(ClientPortForwardEvent::Accepted { forward_id, channel_id, writer_tx }) => {
837 if let Some(fwd) = self.pf.forwards.get(&forward_id) {
838 let target_port = fwd.target_port;
839 self.pf.channels.insert(channel_id, (forward_id, writer_tx));
840 if !timed_send(
841 framed,
842 Frame::PortForwardOpen { forward_id, channel_id, target_port },
843 )
844 .await
845 {
846 return false;
847 }
848 }
849 }
850 Some(ClientPortForwardEvent::Data { channel_id, data }) => {
851 if self.pf.channels.contains_key(&channel_id)
852 && !timed_send(framed, Frame::PortForwardData { channel_id, data }).await
853 {
854 return false;
855 }
856 }
857 Some(ClientPortForwardEvent::Closed { channel_id }) => {
858 if self.pf.channels.remove(&channel_id).is_some()
859 && !timed_send(framed, Frame::PortForwardClose { channel_id }).await
860 {
861 return false;
862 }
863 }
864 None => {}
865 }
866 true
867 }
868}
869
870#[allow(clippy::too_many_arguments)]
872async fn relay(
873 framed: &mut Framed<UnixStream, FrameCodec>,
874 async_stdin: &AsyncFd<io::Stdin>,
875 async_stdout: &AsyncFd<std::os::fd::OwnedFd>,
876 sigwinch: &mut tokio::signal::unix::Signal,
877 buf: &mut [u8],
878 mut escape: Option<&mut EscapeProcessor>,
879 raw_guard: &RawModeGuard,
880 nb_guard: &NonBlockGuard,
881 agent_socket: Option<&str>,
882 oauth_redirect: bool,
883 oauth_timeout: u64,
884 session: &str,
885 hb_interval: Duration,
886 hb_timeout: Duration,
887) -> anyhow::Result<RelayExit> {
888 let mut sigterm = signal(SignalKind::terminate())?;
889 let mut sighup = signal(SignalKind::hangup())?;
890
891 let mut heartbeat_interval = tokio::time::interval(hb_interval);
892 heartbeat_interval.reset(); let mut last_pong = Instant::now();
894 let mut last_ping_sent = Instant::now();
895 let mut last_rtt: Option<Duration> = None;
896
897 let mut agent = ClientAgentState::new();
899 let (agent_event_tx, mut agent_event_rx) = mpsc::unbounded_channel::<AgentEvent>();
900
901 let mut tunnel = ClientTunnelState::new();
903 let (tunnel_event_tx, mut tunnel_event_rx) = mpsc::unbounded_channel::<ClientTunnelEvent>();
904
905 let (pf_event_tx, mut pf_event_rx) = mpsc::unbounded_channel::<ClientPortForwardEvent>();
907 let mut pf = ClientPortForwardTable::new();
908
909 let mut bytes_relayed = 0u64;
910 let mut relay = ClientRelay {
911 async_stdout,
912 agent: &mut agent,
913 agent_event_tx: &agent_event_tx,
914 agent_socket,
915 tunnel: &mut tunnel,
916 tunnel_event_tx: &tunnel_event_tx,
917 oauth_redirect,
918 oauth_timeout,
919 pf: &mut pf,
920 pf_event_tx: &pf_event_tx,
921 last_pong: &mut last_pong,
922 last_ping_sent: &mut last_ping_sent,
923 last_rtt: &mut last_rtt,
924 connected_at: Instant::now(),
925 bytes_relayed: &mut bytes_relayed,
926 };
927
928 loop {
929 tokio::select! {
930 ready = async_stdin.readable() => {
931 let mut guard = ready?;
932 match guard.try_io(|inner| inner.get_ref().read(buf)) {
933 Ok(Ok(0)) => {
934 debug!("stdin EOF");
935 return Ok(RelayExit::Exit(0));
936 }
937 Ok(Ok(n)) => {
938 debug!(len = n, "stdin → socket");
939 if let Some(ref mut esc) = escape {
940 for action in esc.process(&buf[..n]) {
941 match action {
942 EscapeAction::Data(data) => {
943 if !timed_send(framed, Frame::Data(Bytes::from(data))).await {
944 return Ok(RelayExit::Disconnected);
945 }
946 }
947 EscapeAction::Detach => {
948 write_stdout_async(async_stdout, status_msg("detached").as_bytes()).await?;
949 return Ok(RelayExit::Exit(0));
950 }
951 EscapeAction::Reconnect => {
952 write_stdout_async(async_stdout, status_msg("force reconnect").as_bytes()).await?;
953 return Ok(RelayExit::Disconnected);
954 }
955 EscapeAction::Suspend => {
956 suspend(raw_guard, nb_guard)?;
957 let (cols, rows) = get_terminal_size();
959 if !timed_send(framed, Frame::Resize { cols, rows }).await {
960 return Ok(RelayExit::Disconnected);
961 }
962 }
963 EscapeAction::Status => {
964 let rtt_str = match *relay.last_rtt {
965 Some(d) => format!("{:.1}ms", d.as_secs_f64() * 1000.0),
966 None => "n/a".to_string(),
967 };
968 let uptime = relay.connected_at.elapsed();
969 let uptime_str = if uptime.as_secs() >= 3600 {
970 format!(
971 "{}h {}m {}s",
972 uptime.as_secs() / 3600,
973 (uptime.as_secs() % 3600) / 60,
974 uptime.as_secs() % 60,
975 )
976 } else if uptime.as_secs() >= 60 {
977 format!(
978 "{}m {}s",
979 uptime.as_secs() / 60,
980 uptime.as_secs() % 60,
981 )
982 } else {
983 format!("{}s", uptime.as_secs())
984 };
985 let bytes_str = format_size(*relay.bytes_relayed);
986 let agent_info = if relay.agent_socket.is_some() {
987 format!(
988 "on ({} channels)",
989 relay.agent.channels.len()
990 )
991 } else {
992 "off".to_string()
993 };
994 let open_str = if relay.oauth_redirect { "on" } else { "off" };
995 let mut pf_lines = Vec::new();
996 for (&fwd_id, fwd) in &relay.pf.forwards {
997 let ch_count = relay.pf.channels.values()
998 .filter(|(fid, _)| *fid == fwd_id)
999 .count();
1000 pf_lines.push(format!(
1001 " :{} ({} connections)",
1002 fwd.target_port,
1003 ch_count,
1004 ));
1005 }
1006 let tunnel_str = if !relay.tunnel.channels.is_empty() {
1007 format!("active ({} channels)", relay.tunnel.channels.len())
1008 } else if relay.tunnel.listener.is_some() {
1009 "listening".to_string()
1010 } else {
1011 "idle".to_string()
1012 };
1013 let mut status = format!(
1014 "\r\n\x1b[2;33m[gritty status]\r\n\
1015 \x1b[0m\x1b[2m session: {session}\r\n\
1016 \x1b[0m\x1b[2m rtt: {rtt_str}\r\n\
1017 \x1b[0m\x1b[2m connected: {uptime_str}\r\n\
1018 \x1b[0m\x1b[2m bytes relayed: {bytes_str}\r\n\
1019 \x1b[0m\x1b[2m agent forwarding: {agent_info}\r\n\
1020 \x1b[0m\x1b[2m open forwarding: {open_str}\r\n\
1021 \x1b[0m\x1b[2m oauth tunnel: {tunnel_str}\r\n",
1022 );
1023 for line in &pf_lines {
1024 status.push_str(&format!(
1025 "\x1b[0m\x1b[2m port forward{line}\r\n"
1026 ));
1027 }
1028 status.push_str("\x1b[0m");
1029 write_stdout_async(
1030 async_stdout,
1031 status.as_bytes(),
1032 ).await?;
1033 }
1034 EscapeAction::Help => {
1035 write_stdout_async(async_stdout, ESCAPE_HELP).await?;
1036 }
1037 }
1038 }
1039 } else if !timed_send(framed, Frame::Data(Bytes::copy_from_slice(&buf[..n]))).await {
1040 return Ok(RelayExit::Disconnected);
1041 }
1042 }
1043 Ok(Err(e)) => return Err(e.into()),
1044 Err(_would_block) => continue,
1045 }
1046 }
1047
1048 frame = framed.next() => {
1049 if let ControlFlow::Break(exit) = relay.handle_server_frame(framed, frame).await? {
1050 return Ok(exit);
1051 }
1052 }
1053
1054 event = agent_event_rx.recv() => {
1055 if !relay.handle_agent_event(framed, event).await {
1056 return Ok(RelayExit::Disconnected);
1057 }
1058 }
1059
1060 event = tunnel_event_rx.recv() => {
1061 if !relay.handle_tunnel_event(framed, event).await {
1062 return Ok(RelayExit::Disconnected);
1063 }
1064 }
1065
1066 event = pf_event_rx.recv() => {
1067 if !relay.handle_pf_event(framed, event).await {
1068 return Ok(RelayExit::Disconnected);
1069 }
1070 }
1071
1072 _ = sigwinch.recv() => {
1073 let (cols, rows) = get_terminal_size();
1074 debug!(cols, rows, "SIGWINCH → resize");
1075 if !timed_send(framed, Frame::Resize { cols, rows }).await {
1076 return Ok(RelayExit::Disconnected);
1077 }
1078 }
1079
1080 _ = heartbeat_interval.tick() => {
1081 if relay.last_pong.elapsed() > hb_timeout {
1082 debug!("heartbeat timeout");
1083 return Ok(RelayExit::Disconnected);
1084 }
1085 *relay.last_ping_sent = Instant::now();
1086 if !timed_send(framed, Frame::Ping).await {
1087 return Ok(RelayExit::Disconnected);
1088 }
1089 }
1090
1091 _ = sigterm.recv() => {
1092 debug!("SIGTERM received, exiting");
1093 return Ok(RelayExit::Exit(1));
1094 }
1095
1096 _ = sighup.recv() => {
1097 debug!("SIGHUP received, exiting");
1098 return Ok(RelayExit::Exit(1));
1099 }
1100 }
1101 }
1102}
1103
1104#[allow(clippy::too_many_arguments)]
1105pub async fn run(
1106 session: &str,
1107 mut framed: Framed<UnixStream, FrameCodec>,
1108 redraw: bool,
1109 ctl_path: &Path,
1110 env_vars: Vec<(String, String)>,
1111 no_escape: bool,
1112 forward_agent: bool,
1113 forward_open: bool,
1114 oauth_redirect: bool,
1115 oauth_timeout: u64,
1116 heartbeat_interval: u64,
1117 heartbeat_timeout: u64,
1118) -> anyhow::Result<i32> {
1119 let stdin = io::stdin();
1120 let stdin_fd = stdin.as_fd();
1121 let stdin_borrowed: BorrowedFd<'static> =
1123 unsafe { BorrowedFd::borrow_raw(stdin_fd.as_raw_fd()) };
1124 let raw_guard = RawModeGuard::enter(stdin_borrowed)?;
1125
1126 let nb_guard = NonBlockGuard::set(stdin_borrowed)?;
1129 let async_stdin = AsyncFd::new(io::stdin())?;
1130 let stdout_fd = crate::security::checked_dup(io::stdout().as_raw_fd())?;
1132 let async_stdout = AsyncFd::new(stdout_fd)?;
1133 let mut sigwinch = signal(SignalKind::window_change())?;
1134 let mut buf = vec![0u8; 4096];
1135 let mut current_redraw = redraw;
1136 let mut current_env = env_vars;
1137 let mut escape = if no_escape { None } else { Some(EscapeProcessor::new()) };
1138 let agent_socket = if forward_agent { std::env::var("SSH_AUTH_SOCK").ok() } else { None };
1139
1140 loop {
1141 let result = if send_init_frames(
1142 &mut framed,
1143 ¤t_env,
1144 forward_agent,
1145 agent_socket.as_deref(),
1146 forward_open,
1147 current_redraw,
1148 )
1149 .await
1150 {
1151 relay(
1152 &mut framed,
1153 &async_stdin,
1154 &async_stdout,
1155 &mut sigwinch,
1156 &mut buf,
1157 escape.as_mut(),
1158 &raw_guard,
1159 &nb_guard,
1160 agent_socket.as_deref(),
1161 oauth_redirect,
1162 oauth_timeout,
1163 session,
1164 Duration::from_secs(heartbeat_interval),
1165 Duration::from_secs(heartbeat_timeout),
1166 )
1167 .await?
1168 } else {
1169 RelayExit::Disconnected
1170 };
1171 match result {
1172 RelayExit::Exit(code) => return Ok(code),
1173 RelayExit::Disconnected => {
1174 current_env.clear();
1176 let reconnect_started = Instant::now();
1178 write_stdout_async(&async_stdout, status_msg("reconnecting...").as_bytes()).await?;
1179
1180 loop {
1181 tokio::select! {
1183 _ = tokio::time::sleep(Duration::from_secs(1)) => {}
1184 _ = async_stdin.readable() => {
1185 let mut peek = [0u8; 1];
1186 match async_stdin.get_ref().read(&mut peek) {
1187 Ok(1) if peek[0] == 0x03 => {
1188 write_stdout_async(&async_stdout, b"\r\n").await?;
1189 return Ok(1);
1190 }
1191 _ => {}
1192 }
1193 continue;
1194 }
1195 }
1196
1197 let elapsed = reconnect_started.elapsed().as_secs();
1199 write_stdout_async(
1200 &async_stdout,
1201 format!("\r{}", status_msg(&format!("reconnecting... {elapsed}s")))
1202 .as_bytes(),
1203 )
1204 .await?;
1205
1206 let stream = match UnixStream::connect(ctl_path).await {
1207 Ok(s) => s,
1208 Err(_) => continue,
1209 };
1210
1211 let mut new_framed = Framed::new(stream, FrameCodec);
1212 if crate::handshake(&mut new_framed).await.is_err() {
1213 continue;
1214 }
1215 if new_framed
1216 .send(Frame::Attach { session: session.to_string() })
1217 .await
1218 .is_err()
1219 {
1220 continue;
1221 }
1222
1223 match new_framed.next().await {
1224 Some(Ok(Frame::Ok)) => {
1225 write_stdout_async(
1226 &async_stdout,
1227 success_msg("reconnected").as_bytes(),
1228 )
1229 .await?;
1230 framed = new_framed;
1231 current_redraw = true;
1232 break;
1233 }
1234 Some(Ok(Frame::Error { message })) => {
1235 write_stdout_async(
1236 &async_stdout,
1237 error_msg(&format!("session gone: {message}")).as_bytes(),
1238 )
1239 .await?;
1240 return Ok(1);
1241 }
1242 _ => continue,
1243 }
1244 }
1245 }
1246 }
1247 }
1248}
1249
1250pub async fn tail(
1254 session: &str,
1255 mut framed: Framed<UnixStream, FrameCodec>,
1256 ctl_path: &Path,
1257) -> anyhow::Result<i32> {
1258 let stdin_fd = unsafe { BorrowedFd::borrow_raw(libc::STDIN_FILENO) };
1260 let _input_guard = SuppressInputGuard::enter(stdin_fd).ok();
1261
1262 tokio::task::spawn_blocking(|| {
1264 let mut buf = [0u8; 64];
1265 let mut belled = false;
1266 loop {
1267 match io::stdin().read(&mut buf) {
1268 Ok(0) | Err(_) => break,
1269 Ok(_) if !belled => {
1270 let _ = io::stderr().write_all(b"\x07");
1271 let _ = io::stderr().flush();
1272 belled = true;
1273 }
1274 _ => {}
1275 }
1276 }
1277 });
1278
1279 let mut heartbeat_interval = tokio::time::interval(DEFAULT_HEARTBEAT_INTERVAL);
1280 heartbeat_interval.reset();
1281 let mut last_pong = Instant::now();
1282 let mut sigint = signal(SignalKind::interrupt())?;
1283 let mut sigterm = signal(SignalKind::terminate())?;
1284 let mut sighup = signal(SignalKind::hangup())?;
1285 let mut stdout = tokio::io::stdout();
1286
1287 let code = 'outer: loop {
1288 let result = 'relay: loop {
1289 tokio::select! {
1290 frame = framed.next() => {
1291 match frame {
1292 Some(Ok(Frame::Data(data))) => {
1293 use tokio::io::AsyncWriteExt;
1294 stdout.write_all(&data).await?;
1295 }
1296 Some(Ok(Frame::Pong)) => {
1297 last_pong = Instant::now();
1298 }
1299 Some(Ok(Frame::Exit { code })) => {
1300 break 'relay Some(code);
1301 }
1302 Some(Ok(_)) => {}
1303 Some(Err(e)) => {
1304 debug!("tail connection error: {e}");
1305 break 'relay None;
1306 }
1307 None => {
1308 debug!("tail server disconnected");
1309 break 'relay None;
1310 }
1311 }
1312 }
1313 _ = heartbeat_interval.tick() => {
1314 if last_pong.elapsed() > DEFAULT_HEARTBEAT_TIMEOUT {
1315 debug!("tail heartbeat timeout");
1316 break 'relay None;
1317 }
1318 if framed.send(Frame::Ping).await.is_err() {
1319 break 'relay None;
1320 }
1321 }
1322 _ = sigint.recv() => {
1323 break 'outer 0;
1324 }
1325 _ = sigterm.recv() => {
1326 break 'outer 1;
1327 }
1328 _ = sighup.recv() => {
1329 break 'outer 1;
1330 }
1331 }
1332 };
1333
1334 match result {
1335 Some(code) => break code,
1336 None => {
1337 let reconnect_started = Instant::now();
1338 eprintln!("\x1b[2;33m[reconnecting...]\x1b[0m");
1339 loop {
1340 tokio::time::sleep(Duration::from_secs(1)).await;
1341 let elapsed = reconnect_started.elapsed().as_secs();
1342 eprint!("\r\x1b[2;33m[reconnecting... {elapsed}s]\x1b[0m");
1343
1344 let stream = match UnixStream::connect(ctl_path).await {
1345 Ok(s) => s,
1346 Err(_) => continue,
1347 };
1348
1349 let mut new_framed = Framed::new(stream, FrameCodec);
1350 if crate::handshake(&mut new_framed).await.is_err() {
1351 continue;
1352 }
1353 if new_framed.send(Frame::Tail { session: session.to_string() }).await.is_err()
1354 {
1355 continue;
1356 }
1357
1358 match new_framed.next().await {
1359 Some(Ok(Frame::Ok)) => {
1360 eprintln!("\x1b[32m[reconnected]\x1b[0m");
1361 framed = new_framed;
1362 heartbeat_interval.reset();
1363 last_pong = Instant::now();
1364 break;
1365 }
1366 Some(Ok(Frame::Error { message })) => {
1367 eprintln!("\x1b[31m[session gone: {message}]\x1b[0m");
1368 break 'outer 1;
1369 }
1370 _ => continue,
1371 }
1372 }
1373 }
1374 }
1375 };
1376
1377 {
1380 use tokio::io::AsyncWriteExt;
1381 let _ = stdout.write_all(b"\x1b[0m\x1b[?25h").await;
1382 }
1383 Ok(code)
1384}
1385
1386#[cfg(test)]
1387mod tests {
1388 use super::*;
1389
1390 #[test]
1391 fn normal_passthrough() {
1392 let mut ep = EscapeProcessor::new();
1393 let actions = ep.process(b"hello");
1395 assert_eq!(actions, vec![EscapeAction::Data(b"hello".to_vec())]);
1396 }
1397
1398 #[test]
1399 fn tilde_after_newline_detach() {
1400 let mut ep = EscapeProcessor { state: EscapeState::Normal };
1401 let actions = ep.process(b"\n~.");
1402 assert_eq!(actions, vec![EscapeAction::Data(b"\n".to_vec()), EscapeAction::Detach,]);
1403 }
1404
1405 #[test]
1406 fn tilde_after_cr_detach() {
1407 let mut ep = EscapeProcessor { state: EscapeState::Normal };
1408 let actions = ep.process(b"\r~.");
1409 assert_eq!(actions, vec![EscapeAction::Data(b"\r".to_vec()), EscapeAction::Detach,]);
1410 }
1411
1412 #[test]
1413 fn tilde_not_after_newline() {
1414 let mut ep = EscapeProcessor { state: EscapeState::Normal };
1415 let actions = ep.process(b"a~.");
1416 assert_eq!(actions, vec![EscapeAction::Data(b"a~.".to_vec())]);
1417 }
1418
1419 #[test]
1420 fn initial_state_detach() {
1421 let mut ep = EscapeProcessor::new();
1422 let actions = ep.process(b"~.");
1423 assert_eq!(actions, vec![EscapeAction::Detach]);
1424 }
1425
1426 #[test]
1427 fn tilde_suspend() {
1428 let mut ep = EscapeProcessor { state: EscapeState::Normal };
1429 let actions = ep.process(b"\n~\x1a");
1430 assert_eq!(actions, vec![EscapeAction::Data(b"\n".to_vec()), EscapeAction::Suspend,]);
1431 }
1432
1433 #[test]
1434 fn tilde_status() {
1435 let mut ep = EscapeProcessor { state: EscapeState::Normal };
1436 let actions = ep.process(b"\n~#");
1437 assert_eq!(actions, vec![EscapeAction::Data(b"\n".to_vec()), EscapeAction::Status,]);
1438 }
1439
1440 #[test]
1441 fn tilde_reconnect() {
1442 let mut ep = EscapeProcessor { state: EscapeState::Normal };
1443 let actions = ep.process(b"\n~R");
1444 assert_eq!(actions, vec![EscapeAction::Data(b"\n".to_vec()), EscapeAction::Reconnect,]);
1445 }
1446
1447 #[test]
1448 fn tilde_reconnect_stops_processing() {
1449 let mut ep = EscapeProcessor { state: EscapeState::Normal };
1450 let actions = ep.process(b"\n~Rremaining");
1451 assert_eq!(actions, vec![EscapeAction::Data(b"\n".to_vec()), EscapeAction::Reconnect,]);
1452 }
1453
1454 #[test]
1455 fn tilde_help() {
1456 let mut ep = EscapeProcessor { state: EscapeState::Normal };
1457 let actions = ep.process(b"\n~?");
1458 assert_eq!(actions, vec![EscapeAction::Data(b"\n".to_vec()), EscapeAction::Help,]);
1459 }
1460
1461 #[test]
1462 fn double_tilde() {
1463 let mut ep = EscapeProcessor { state: EscapeState::Normal };
1464 let actions = ep.process(b"\n~~");
1465 assert_eq!(
1466 actions,
1467 vec![EscapeAction::Data(b"\n".to_vec()), EscapeAction::Data(b"~".to_vec()),]
1468 );
1469 assert_eq!(ep.state, EscapeState::Normal);
1470 }
1471
1472 #[test]
1473 fn tilde_unknown_char() {
1474 let mut ep = EscapeProcessor { state: EscapeState::Normal };
1475 let actions = ep.process(b"\n~x");
1476 assert_eq!(
1477 actions,
1478 vec![EscapeAction::Data(b"\n".to_vec()), EscapeAction::Data(b"~x".to_vec()),]
1479 );
1480 }
1481
1482 #[test]
1483 fn split_across_reads() {
1484 let mut ep = EscapeProcessor { state: EscapeState::Normal };
1485 let a1 = ep.process(b"\n");
1486 assert_eq!(a1, vec![EscapeAction::Data(b"\n".to_vec())]);
1487 let a2 = ep.process(b"~");
1488 assert_eq!(a2, vec![]); let a3 = ep.process(b".");
1490 assert_eq!(a3, vec![EscapeAction::Detach]);
1491 }
1492
1493 #[test]
1494 fn split_tilde_then_normal() {
1495 let mut ep = EscapeProcessor { state: EscapeState::Normal };
1496 let a1 = ep.process(b"\n");
1497 assert_eq!(a1, vec![EscapeAction::Data(b"\n".to_vec())]);
1498 let a2 = ep.process(b"~");
1499 assert_eq!(a2, vec![]);
1500 let a3 = ep.process(b"a");
1501 assert_eq!(a3, vec![EscapeAction::Data(b"~a".to_vec())]);
1502 }
1503
1504 #[test]
1505 fn multiple_escapes_one_buffer() {
1506 let mut ep = EscapeProcessor { state: EscapeState::Normal };
1507 let actions = ep.process(b"\n~?\n~.");
1508 assert_eq!(
1509 actions,
1510 vec![
1511 EscapeAction::Data(b"\n".to_vec()),
1512 EscapeAction::Help,
1513 EscapeAction::Data(b"\n".to_vec()),
1514 EscapeAction::Detach,
1515 ]
1516 );
1517 }
1518
1519 #[test]
1520 fn consecutive_newlines() {
1521 let mut ep = EscapeProcessor { state: EscapeState::Normal };
1522 let actions = ep.process(b"\n\n\n~.");
1523 assert_eq!(actions, vec![EscapeAction::Data(b"\n\n\n".to_vec()), EscapeAction::Detach,]);
1524 }
1525
1526 #[test]
1527 fn detach_stops_processing() {
1528 let mut ep = EscapeProcessor { state: EscapeState::Normal };
1529 let actions = ep.process(b"\n~.remaining");
1530 assert_eq!(actions, vec![EscapeAction::Data(b"\n".to_vec()), EscapeAction::Detach,]);
1531 }
1532
1533 #[test]
1534 fn tilde_then_newline() {
1535 let mut ep = EscapeProcessor { state: EscapeState::Normal };
1536 let actions = ep.process(b"\n~\n");
1537 assert_eq!(
1538 actions,
1539 vec![EscapeAction::Data(b"\n".to_vec()), EscapeAction::Data(b"~\n".to_vec()),]
1540 );
1541 assert_eq!(ep.state, EscapeState::AfterNewline);
1542 }
1543
1544 #[test]
1545 fn empty_input() {
1546 let mut ep = EscapeProcessor::new();
1547 let actions = ep.process(b"");
1548 assert_eq!(actions, vec![]);
1549 }
1550
1551 #[test]
1552 fn only_tilde_buffered() {
1553 let mut ep = EscapeProcessor { state: EscapeState::Normal };
1554 let a1 = ep.process(b"\n~");
1555 assert_eq!(a1, vec![EscapeAction::Data(b"\n".to_vec())]);
1556 assert_eq!(ep.state, EscapeState::AfterTilde);
1557 let a2 = ep.process(b".");
1558 assert_eq!(a2, vec![EscapeAction::Detach]);
1559 }
1560}