Skip to main content

rmux_client/
control.rs

1//! Blocking tmux-compatible control-mode client transport.
2
3use std::io::{self, Read, Write};
4use std::sync::mpsc;
5use std::thread;
6
7use rmux_ipc::BlockingLocalStream;
8#[cfg(windows)]
9use rmux_proto::CONTROL_STDIN_EOF_MARKER;
10use rmux_proto::{
11    ClientTerminalContext, ControlMode, ControlModeRequest, Request, Response, CONTROL_CONTROL_END,
12    CONTROL_CONTROL_START,
13};
14#[cfg(windows)]
15use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
16#[cfg(windows)]
17use tokio::sync::mpsc as tokio_mpsc;
18
19use crate::{
20    connection::{read_response_frame_exact, Connection, ControlModeUpgrade, ControlTransition},
21    ClientError,
22};
23
24impl Connection {
25    /// Requests a control-mode upgrade and, on success, yields the raw local
26    /// stream for tmux-compatible text control traffic.
27    pub fn begin_control_mode(
28        mut self,
29        mode: ControlMode,
30        client_terminal: ClientTerminalContext,
31    ) -> Result<ControlTransition, ClientError> {
32        self.write_request(&Request::ControlMode(ControlModeRequest {
33            mode,
34            client_terminal,
35        }))?;
36        let response = read_response_frame_exact(self.stream_mut())?;
37
38        match response {
39            Response::ControlMode(response) => Ok(ControlTransition::Upgraded(
40                self.into_control_upgrade(response)?,
41            )),
42            other => Ok(ControlTransition::Rejected(other)),
43        }
44    }
45}
46
47/// Drives a control-mode session using the process stdio streams.
48pub fn drive_control_mode(
49    upgrade: ControlModeUpgrade,
50    initial_commands: &[String],
51) -> Result<(), ClientError> {
52    let stdin = io::stdin();
53    let stdout = io::stdout();
54    drive_control_mode_with_stdio(upgrade, initial_commands, stdin, stdout)
55}
56
57/// Drives a control-mode session using explicit input and output streams.
58pub fn drive_control_mode_with_stdio<R, W>(
59    upgrade: ControlModeUpgrade,
60    initial_commands: &[String],
61    input: R,
62    mut output: W,
63) -> Result<(), ClientError>
64where
65    R: Read + Send + 'static,
66    W: Write + Send,
67{
68    let mode = upgrade.mode();
69    if mode.is_control_control() {
70        output
71            .write_all(CONTROL_CONTROL_START.as_bytes())
72            .map_err(ClientError::Io)?;
73        output.flush().map_err(ClientError::Io)?;
74    }
75
76    let stream = upgrade.into_stream();
77    let copy_result = drive_control_stream(stream, initial_commands, input, &mut output);
78    if copy_result.is_ok() && output_needs_suffix(mode) {
79        output
80            .write_all(CONTROL_CONTROL_END.as_bytes())
81            .map_err(ClientError::Io)?;
82        output.flush().map_err(ClientError::Io)?;
83    }
84
85    copy_result
86}
87
88#[cfg(unix)]
89fn drive_control_stream<R, W>(
90    stream: BlockingLocalStream,
91    initial_commands: &[String],
92    mut input: R,
93    output: &mut W,
94) -> Result<(), ClientError>
95where
96    R: Read + Send + 'static,
97    W: Write + Send,
98{
99    write_initial_commands(&stream, initial_commands)?;
100    ensure_blocking(&stream).map_err(ClientError::Io)?;
101    let mut writer = stream.try_clone().map_err(ClientError::Io)?;
102    let (stdin_done_tx, stdin_done_rx) = mpsc::sync_channel(1);
103    let stdin_thread = thread::spawn(move || {
104        let result = io::copy(&mut input, &mut writer).map(|_| ());
105        let _ = shutdown_write(&writer);
106        let _ = stdin_done_tx.send(result);
107    });
108
109    let copy_result = copy_control_output(stream, output).map_err(ClientError::Io);
110    let stdin_result = poll_input_thread(&stdin_done_rx)?;
111    if stdin_result.is_some() {
112        stdin_thread
113            .join()
114            .map_err(|_| ClientError::Io(io::Error::other("control input thread panicked")))?;
115    }
116
117    copy_result?;
118    if let Some(stdin_result) = stdin_result {
119        stdin_result.map_err(ClientError::Io)?;
120    }
121    Ok(())
122}
123
124#[cfg(windows)]
125const CONTROL_STDIN_QUEUE_CAPACITY: usize = 256;
126#[cfg(windows)]
127const CONTROL_STDOUT_QUEUE_CAPACITY: usize = 256;
128
129#[cfg(windows)]
130fn drive_control_stream<R, W>(
131    stream: BlockingLocalStream,
132    initial_commands: &[String],
133    input: R,
134    output: &mut W,
135) -> Result<(), ClientError>
136where
137    R: Read + Send + 'static,
138    W: Write + Send,
139{
140    let (input_tx, input_rx) = tokio_mpsc::channel(CONTROL_STDIN_QUEUE_CAPACITY);
141    let (output_tx, output_rx) = tokio_mpsc::channel(CONTROL_STDOUT_QUEUE_CAPACITY);
142    let (stdin_done_tx, stdin_done_rx) = mpsc::sync_channel(1);
143    let stdin_thread = thread::spawn(move || {
144        let result = copy_control_input(input, input_tx);
145        let _ = stdin_done_tx.send(result);
146    });
147
148    let (pipe, runtime) = stream.into_async_parts();
149    let copy_result = thread::scope(|scope| {
150        let output_thread = scope.spawn(move || write_queued_control_output(output, output_rx));
151        let copy_result = runtime
152            .block_on(drive_async_control(
153                pipe,
154                initial_commands,
155                input_rx,
156                output_tx,
157            ))
158            .map_err(ClientError::Io);
159        let output_result = output_thread
160            .join()
161            .map_err(|_| ClientError::Io(io::Error::other("control output thread panicked")))?;
162
163        copy_result?;
164        output_result.map_err(ClientError::Io)
165    });
166    let stdin_result = poll_input_thread(&stdin_done_rx)?;
167
168    if stdin_result.is_some() {
169        stdin_thread
170            .join()
171            .map_err(|_| ClientError::Io(io::Error::other("control input thread panicked")))?;
172    }
173
174    copy_result?;
175    if let Some(stdin_result) = stdin_result {
176        stdin_result.map_err(ClientError::Io)?;
177    }
178    Ok(())
179}
180
181fn output_needs_suffix(mode: ControlMode) -> bool {
182    mode.is_control_control()
183}
184
185fn poll_input_thread(
186    stdin_done_rx: &mpsc::Receiver<io::Result<()>>,
187) -> Result<Option<io::Result<()>>, ClientError> {
188    match stdin_done_rx.try_recv() {
189        Ok(result) => Ok(Some(result)),
190        Err(mpsc::TryRecvError::Empty) => Ok(None),
191        Err(mpsc::TryRecvError::Disconnected) => Err(ClientError::Io(io::Error::other(
192            "control input thread terminated unexpectedly",
193        ))),
194    }
195}
196
197#[cfg(unix)]
198fn write_initial_commands(
199    stream: &BlockingLocalStream,
200    initial_commands: &[String],
201) -> Result<(), ClientError> {
202    if initial_commands.is_empty() {
203        return Ok(());
204    }
205
206    let mut writer = stream.try_clone().map_err(ClientError::Io)?;
207    for command in initial_commands {
208        writer
209            .write_all(command.as_bytes())
210            .and_then(|()| writer.write_all(b"\n"))
211            .map_err(ClientError::Io)?;
212    }
213    Ok(())
214}
215
216#[cfg(unix)]
217fn copy_control_output(mut stream: BlockingLocalStream, output: &mut impl Write) -> io::Result<()> {
218    let mut buffer = [0_u8; 8192];
219
220    loop {
221        let bytes_read = stream.read(&mut buffer)?;
222        if bytes_read == 0 {
223            return Ok(());
224        }
225        output.write_all(&buffer[..bytes_read])?;
226        output.flush()?;
227    }
228}
229
230#[cfg(unix)]
231fn ensure_blocking(stream: &BlockingLocalStream) -> io::Result<()> {
232    stream.set_nonblocking(false)
233}
234
235#[cfg(unix)]
236fn shutdown_write(stream: &BlockingLocalStream) -> io::Result<()> {
237    stream.shutdown(std::net::Shutdown::Write)
238}
239
240#[cfg(windows)]
241fn copy_control_input<R>(mut input: R, input_tx: tokio_mpsc::Sender<Vec<u8>>) -> io::Result<()>
242where
243    R: Read,
244{
245    let mut buffer = [0_u8; 8192];
246    loop {
247        let bytes_read = match input.read(&mut buffer) {
248            Ok(0) => return Ok(()),
249            Ok(bytes_read) => bytes_read,
250            Err(error) if error.kind() == io::ErrorKind::Interrupted => continue,
251            Err(error) => return Err(error),
252        };
253
254        if input_tx
255            .blocking_send(buffer[..bytes_read].to_vec())
256            .is_err()
257        {
258            return Ok(());
259        }
260    }
261}
262
263#[cfg(windows)]
264async fn drive_async_control<Stream>(
265    stream: Stream,
266    initial_commands: &[String],
267    mut input_rx: tokio_mpsc::Receiver<Vec<u8>>,
268    output_tx: tokio_mpsc::Sender<Vec<u8>>,
269) -> io::Result<()>
270where
271    Stream: AsyncRead + AsyncWrite + Unpin,
272{
273    let mut completion_tracker = ControlCompletionTracker::default();
274    let mut input_closed = false;
275    let (mut reader, mut writer) = tokio::io::split(stream);
276    write_async_initial_commands(&mut writer, initial_commands).await?;
277    let mut buffer = [0_u8; 8192];
278
279    loop {
280        tokio::select! {
281            input = input_rx.recv(), if !input_closed => {
282                match input {
283                    Some(bytes) => {
284                        writer.write_all(&bytes).await?;
285                    }
286                    None => {
287                        writer.write_all(CONTROL_STDIN_EOF_MARKER.as_bytes()).await?;
288                        writer.write_all(b"\n").await?;
289                        writer.flush().await?;
290                        writer.shutdown().await?;
291                        input_closed = true;
292                    }
293                }
294            }
295            bytes_read = reader.read(&mut buffer) => {
296                let bytes_read = bytes_read?;
297                if bytes_read == 0 {
298                    return Ok(());
299                }
300                let observed = completion_tracker.observe(&buffer[..bytes_read]);
301                send_control_output(&output_tx, &buffer[..bytes_read]).await?;
302                if observed.exited {
303                    return Ok(());
304                }
305            }
306        }
307    }
308}
309
310#[cfg(windows)]
311fn write_queued_control_output<W>(
312    output: &mut W,
313    mut output_rx: tokio_mpsc::Receiver<Vec<u8>>,
314) -> io::Result<()>
315where
316    W: Write,
317{
318    while let Some(bytes) = output_rx.blocking_recv() {
319        output.write_all(&bytes)?;
320        output.flush()?;
321    }
322    Ok(())
323}
324
325#[cfg(windows)]
326async fn send_control_output(
327    output_tx: &tokio_mpsc::Sender<Vec<u8>>,
328    bytes: &[u8],
329) -> io::Result<()> {
330    output_tx
331        .send(bytes.to_vec())
332        .await
333        .map_err(|_| io::Error::new(io::ErrorKind::BrokenPipe, "control output writer stopped"))
334}
335
336#[cfg(windows)]
337#[derive(Debug, Default)]
338struct ControlCompletionTracker {
339    pending: Vec<u8>,
340}
341
342#[cfg(windows)]
343impl ControlCompletionTracker {
344    fn observe(&mut self, bytes: &[u8]) -> ControlOutputObservation {
345        self.pending.extend_from_slice(bytes);
346        let mut observation = ControlOutputObservation::default();
347        while let Some(position) = self.pending.iter().position(|byte| *byte == b'\n') {
348            let line = self.pending.drain(..=position).collect::<Vec<_>>();
349            if is_control_exit_line(&line) {
350                observation.exited = true;
351            }
352        }
353        observation
354    }
355}
356
357#[cfg(windows)]
358#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
359struct ControlOutputObservation {
360    exited: bool,
361}
362
363#[cfg(windows)]
364fn is_control_exit_line(line: &[u8]) -> bool {
365    line == b"%exit\n" || line.starts_with(b"%exit ")
366}
367
368#[cfg(windows)]
369async fn write_async_initial_commands<Writer>(
370    writer: &mut Writer,
371    initial_commands: &[String],
372) -> io::Result<()>
373where
374    Writer: AsyncWrite + Unpin,
375{
376    for command in initial_commands {
377        writer.write_all(command.as_bytes()).await?;
378        writer.write_all(b"\n").await?;
379    }
380    writer.flush().await?;
381    Ok(())
382}
383
384#[cfg(all(test, unix))]
385mod tests {
386    use std::io::{Cursor, Write};
387    use std::sync::mpsc;
388    use std::time::Duration;
389
390    use rmux_proto::{ControlMode, ControlModeResponse};
391
392    use super::drive_control_mode_with_stdio;
393    use crate::connection::ControlModeUpgrade;
394
395    #[test]
396    fn control_control_mode_wraps_output_with_dcs_sequences() {
397        let (left, right) = std::os::unix::net::UnixStream::pair().expect("socket pair");
398        let writer = std::thread::spawn(move || {
399            let mut right = right;
400            right.write_all(b"%exit\n").expect("write output");
401        });
402
403        let mut output = Vec::new();
404        drive_control_mode_with_stdio(
405            ControlModeUpgrade {
406                response: ControlModeResponse {
407                    mode: ControlMode::ControlControl,
408                },
409                stream: left,
410            },
411            &[],
412            Cursor::new(Vec::<u8>::new()),
413            &mut output,
414        )
415        .expect("control mode succeeds");
416        writer.join().expect("writer thread");
417
418        let rendered = String::from_utf8(output).expect("utf8");
419        assert!(rendered.starts_with(rmux_proto::CONTROL_CONTROL_START));
420        assert!(rendered.contains("%exit\n"));
421        assert!(rendered.ends_with(rmux_proto::CONTROL_CONTROL_END));
422    }
423
424    #[test]
425    fn control_mode_returns_after_server_exit_without_waiting_for_input_eof() {
426        let (left, right) = std::os::unix::net::UnixStream::pair().expect("socket pair");
427        let (input_reader, input_writer) =
428            std::os::unix::net::UnixStream::pair().expect("input socket pair");
429        let server = std::thread::spawn(move || {
430            let mut right = right;
431            right.write_all(b"%exit\n").expect("write exit");
432        });
433        let (done_tx, done_rx) = mpsc::channel();
434        let worker = std::thread::spawn(move || {
435            let mut output = Vec::new();
436            let result = drive_control_mode_with_stdio(
437                ControlModeUpgrade {
438                    response: ControlModeResponse {
439                        mode: ControlMode::Plain,
440                    },
441                    stream: left,
442                },
443                &[],
444                input_reader,
445                &mut output,
446            );
447            done_tx
448                .send((result, output))
449                .expect("report control mode result");
450        });
451
452        let done = done_rx.recv_timeout(Duration::from_secs(1));
453        drop(input_writer);
454        worker.join().expect("worker thread");
455        server.join().expect("server thread");
456
457        let (result, output) = done.expect("control mode should exit promptly");
458        result.expect("control mode succeeds");
459        assert_eq!(String::from_utf8(output).expect("utf8"), "%exit\n");
460    }
461}
462
463#[cfg(all(test, windows))]
464mod windows_tests {
465    use super::drive_async_control;
466    use rmux_proto::CONTROL_STDIN_EOF_MARKER;
467    use tokio::io::{AsyncReadExt, AsyncWriteExt};
468    use tokio::sync::mpsc as tokio_mpsc;
469
470    #[tokio::test]
471    async fn control_input_eof_shutdowns_writer_and_waits_for_exit() -> std::io::Result<()> {
472        let (client, mut server) = tokio::io::duplex(4096);
473        let (input_tx, input_rx) = tokio_mpsc::channel::<Vec<u8>>(1);
474        input_tx
475            .send(b"list-sessions\n".to_vec())
476            .await
477            .expect("send input");
478        drop(input_tx);
479        let (output_tx, output_rx) = tokio_mpsc::channel::<Vec<u8>>(4);
480
481        let drive = drive_async_control(client, &[], input_rx, output_tx);
482        let server_peer = async {
483            let expected_input = format!("list-sessions\n{CONTROL_STDIN_EOF_MARKER}\n");
484            let mut received = Vec::new();
485            let mut buffer = [0_u8; 32];
486            while received.len() < expected_input.len() {
487                let bytes_read = server.read(&mut buffer).await?;
488                assert_ne!(bytes_read, 0, "client closed before sending command");
489                received.extend_from_slice(&buffer[..bytes_read]);
490            }
491            assert_eq!(received, expected_input.as_bytes());
492            server
493                .write_all(b"%begin 1 1 1\n%end 1 1 1\n%exit\n")
494                .await?;
495            Ok::<(), std::io::Error>(())
496        };
497        let output = collect_control_output(output_rx);
498
499        let (_, _, output) = tokio::try_join!(drive, server_peer, output)?;
500        assert_eq!(output, b"%begin 1 1 1\n%end 1 1 1\n%exit\n");
501        Ok(())
502    }
503
504    #[tokio::test]
505    async fn control_input_eof_drains_exit_after_completed_command() -> std::io::Result<()> {
506        let (client, mut server) = tokio::io::duplex(4096);
507        let (input_tx, input_rx) = tokio_mpsc::channel::<Vec<u8>>(1);
508        input_tx
509            .send(b"list-sessions\n".to_vec())
510            .await
511            .expect("send input");
512        drop(input_tx);
513        let (output_tx, output_rx) = tokio_mpsc::channel::<Vec<u8>>(4);
514
515        let drive = drive_async_control(client, &[], input_rx, output_tx);
516        let server_peer = async {
517            let mut received = Vec::new();
518            let mut buffer = [0_u8; 32];
519            while !received.ends_with(b"\n") {
520                let bytes_read = server.read(&mut buffer).await?;
521                assert_ne!(bytes_read, 0, "client closed before sending command");
522                received.extend_from_slice(&buffer[..bytes_read]);
523            }
524            server.write_all(b"%begin 1 1 1\n%end 1 1 1\n").await?;
525            tokio::task::yield_now().await;
526            server.write_all(b"%exit\n").await?;
527            Ok::<(), std::io::Error>(())
528        };
529        let output = collect_control_output(output_rx);
530
531        let (_, _, output) = tokio::try_join!(drive, server_peer, output)?;
532        assert_eq!(output, b"%begin 1 1 1\n%end 1 1 1\n%exit\n");
533        Ok(())
534    }
535
536    async fn collect_control_output(
537        mut output_rx: tokio_mpsc::Receiver<Vec<u8>>,
538    ) -> std::io::Result<Vec<u8>> {
539        let mut output = Vec::new();
540        while let Some(bytes) = output_rx.recv().await {
541            output.extend_from_slice(&bytes);
542        }
543        Ok(output)
544    }
545}