1use std::io::{self, Read, Write};
4use std::sync::mpsc;
5use std::thread;
6
7use rmux_ipc::BlockingLocalStream;
8#[cfg(windows)]
9use rmux_proto::CONTROL_STDIN_EOF_MARKER;
10use rmux_proto::{
11 ClientTerminalContext, ControlMode, ControlModeRequest, Request, Response, CONTROL_CONTROL_END,
12 CONTROL_CONTROL_START,
13};
14#[cfg(windows)]
15use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
16#[cfg(windows)]
17use tokio::sync::mpsc as tokio_mpsc;
18
19use crate::{
20 connection::{read_response_frame_exact, Connection, ControlModeUpgrade, ControlTransition},
21 ClientError,
22};
23
24impl Connection {
25 pub fn begin_control_mode(
28 mut self,
29 mode: ControlMode,
30 client_terminal: ClientTerminalContext,
31 ) -> Result<ControlTransition, ClientError> {
32 self.write_request(&Request::ControlMode(ControlModeRequest {
33 mode,
34 client_terminal,
35 }))?;
36 let response = read_response_frame_exact(self.stream_mut())?;
37
38 match response {
39 Response::ControlMode(response) => Ok(ControlTransition::Upgraded(
40 self.into_control_upgrade(response)?,
41 )),
42 other => Ok(ControlTransition::Rejected(other)),
43 }
44 }
45}
46
47pub fn drive_control_mode(
49 upgrade: ControlModeUpgrade,
50 initial_commands: &[String],
51) -> Result<(), ClientError> {
52 let stdin = io::stdin();
53 let stdout = io::stdout();
54 drive_control_mode_with_stdio(upgrade, initial_commands, stdin, stdout)
55}
56
57pub fn drive_control_mode_with_stdio<R, W>(
59 upgrade: ControlModeUpgrade,
60 initial_commands: &[String],
61 input: R,
62 mut output: W,
63) -> Result<(), ClientError>
64where
65 R: Read + Send + 'static,
66 W: Write + Send,
67{
68 let mode = upgrade.mode();
69 if mode.is_control_control() {
70 output
71 .write_all(CONTROL_CONTROL_START.as_bytes())
72 .map_err(ClientError::Io)?;
73 output.flush().map_err(ClientError::Io)?;
74 }
75
76 let stream = upgrade.into_stream();
77 let copy_result = drive_control_stream(stream, initial_commands, input, &mut output);
78 if copy_result.is_ok() && output_needs_suffix(mode) {
79 output
80 .write_all(CONTROL_CONTROL_END.as_bytes())
81 .map_err(ClientError::Io)?;
82 output.flush().map_err(ClientError::Io)?;
83 }
84
85 copy_result
86}
87
88#[cfg(unix)]
89fn drive_control_stream<R, W>(
90 stream: BlockingLocalStream,
91 initial_commands: &[String],
92 mut input: R,
93 output: &mut W,
94) -> Result<(), ClientError>
95where
96 R: Read + Send + 'static,
97 W: Write + Send,
98{
99 write_initial_commands(&stream, initial_commands)?;
100 ensure_blocking(&stream).map_err(ClientError::Io)?;
101 let mut writer = stream.try_clone().map_err(ClientError::Io)?;
102 let (stdin_done_tx, stdin_done_rx) = mpsc::sync_channel(1);
103 let stdin_thread = thread::spawn(move || {
104 let result = io::copy(&mut input, &mut writer).map(|_| ());
105 let _ = shutdown_write(&writer);
106 let _ = stdin_done_tx.send(result);
107 });
108
109 let copy_result = copy_control_output(stream, output).map_err(ClientError::Io);
110 let stdin_result = poll_input_thread(&stdin_done_rx)?;
111 if stdin_result.is_some() {
112 stdin_thread
113 .join()
114 .map_err(|_| ClientError::Io(io::Error::other("control input thread panicked")))?;
115 }
116
117 copy_result?;
118 if let Some(stdin_result) = stdin_result {
119 stdin_result.map_err(ClientError::Io)?;
120 }
121 Ok(())
122}
123
124#[cfg(windows)]
125const CONTROL_STDIN_QUEUE_CAPACITY: usize = 256;
126#[cfg(windows)]
127const CONTROL_STDOUT_QUEUE_CAPACITY: usize = 256;
128
129#[cfg(windows)]
130fn drive_control_stream<R, W>(
131 stream: BlockingLocalStream,
132 initial_commands: &[String],
133 input: R,
134 output: &mut W,
135) -> Result<(), ClientError>
136where
137 R: Read + Send + 'static,
138 W: Write + Send,
139{
140 let (input_tx, input_rx) = tokio_mpsc::channel(CONTROL_STDIN_QUEUE_CAPACITY);
141 let (output_tx, output_rx) = tokio_mpsc::channel(CONTROL_STDOUT_QUEUE_CAPACITY);
142 let (stdin_done_tx, stdin_done_rx) = mpsc::sync_channel(1);
143 let stdin_thread = thread::spawn(move || {
144 let result = copy_control_input(input, input_tx);
145 let _ = stdin_done_tx.send(result);
146 });
147
148 let (pipe, runtime) = stream.into_async_parts();
149 let copy_result = thread::scope(|scope| {
150 let output_thread = scope.spawn(move || write_queued_control_output(output, output_rx));
151 let copy_result = runtime
152 .block_on(drive_async_control(
153 pipe,
154 initial_commands,
155 input_rx,
156 output_tx,
157 ))
158 .map_err(ClientError::Io);
159 let output_result = output_thread
160 .join()
161 .map_err(|_| ClientError::Io(io::Error::other("control output thread panicked")))?;
162
163 copy_result?;
164 output_result.map_err(ClientError::Io)
165 });
166 let stdin_result = poll_input_thread(&stdin_done_rx)?;
167
168 if stdin_result.is_some() {
169 stdin_thread
170 .join()
171 .map_err(|_| ClientError::Io(io::Error::other("control input thread panicked")))?;
172 }
173
174 copy_result?;
175 if let Some(stdin_result) = stdin_result {
176 stdin_result.map_err(ClientError::Io)?;
177 }
178 Ok(())
179}
180
181fn output_needs_suffix(mode: ControlMode) -> bool {
182 mode.is_control_control()
183}
184
185fn poll_input_thread(
186 stdin_done_rx: &mpsc::Receiver<io::Result<()>>,
187) -> Result<Option<io::Result<()>>, ClientError> {
188 match stdin_done_rx.try_recv() {
189 Ok(result) => Ok(Some(result)),
190 Err(mpsc::TryRecvError::Empty) => Ok(None),
191 Err(mpsc::TryRecvError::Disconnected) => Err(ClientError::Io(io::Error::other(
192 "control input thread terminated unexpectedly",
193 ))),
194 }
195}
196
197#[cfg(unix)]
198fn write_initial_commands(
199 stream: &BlockingLocalStream,
200 initial_commands: &[String],
201) -> Result<(), ClientError> {
202 if initial_commands.is_empty() {
203 return Ok(());
204 }
205
206 let mut writer = stream.try_clone().map_err(ClientError::Io)?;
207 for command in initial_commands {
208 writer
209 .write_all(command.as_bytes())
210 .and_then(|()| writer.write_all(b"\n"))
211 .map_err(ClientError::Io)?;
212 }
213 Ok(())
214}
215
216#[cfg(unix)]
217fn copy_control_output(mut stream: BlockingLocalStream, output: &mut impl Write) -> io::Result<()> {
218 let mut buffer = [0_u8; 8192];
219
220 loop {
221 let bytes_read = stream.read(&mut buffer)?;
222 if bytes_read == 0 {
223 return Ok(());
224 }
225 output.write_all(&buffer[..bytes_read])?;
226 output.flush()?;
227 }
228}
229
230#[cfg(unix)]
231fn ensure_blocking(stream: &BlockingLocalStream) -> io::Result<()> {
232 stream.set_nonblocking(false)
233}
234
235#[cfg(unix)]
236fn shutdown_write(stream: &BlockingLocalStream) -> io::Result<()> {
237 stream.shutdown(std::net::Shutdown::Write)
238}
239
240#[cfg(windows)]
241fn copy_control_input<R>(mut input: R, input_tx: tokio_mpsc::Sender<Vec<u8>>) -> io::Result<()>
242where
243 R: Read,
244{
245 let mut buffer = [0_u8; 8192];
246 loop {
247 let bytes_read = match input.read(&mut buffer) {
248 Ok(0) => return Ok(()),
249 Ok(bytes_read) => bytes_read,
250 Err(error) if error.kind() == io::ErrorKind::Interrupted => continue,
251 Err(error) => return Err(error),
252 };
253
254 if input_tx
255 .blocking_send(buffer[..bytes_read].to_vec())
256 .is_err()
257 {
258 return Ok(());
259 }
260 }
261}
262
263#[cfg(windows)]
264async fn drive_async_control<Stream>(
265 stream: Stream,
266 initial_commands: &[String],
267 mut input_rx: tokio_mpsc::Receiver<Vec<u8>>,
268 output_tx: tokio_mpsc::Sender<Vec<u8>>,
269) -> io::Result<()>
270where
271 Stream: AsyncRead + AsyncWrite + Unpin,
272{
273 let mut completion_tracker = ControlCompletionTracker::default();
274 let mut input_closed = false;
275 let (mut reader, mut writer) = tokio::io::split(stream);
276 write_async_initial_commands(&mut writer, initial_commands).await?;
277 let mut buffer = [0_u8; 8192];
278
279 loop {
280 tokio::select! {
281 input = input_rx.recv(), if !input_closed => {
282 match input {
283 Some(bytes) => {
284 writer.write_all(&bytes).await?;
285 }
286 None => {
287 writer.write_all(CONTROL_STDIN_EOF_MARKER.as_bytes()).await?;
288 writer.write_all(b"\n").await?;
289 writer.flush().await?;
290 writer.shutdown().await?;
291 input_closed = true;
292 }
293 }
294 }
295 bytes_read = reader.read(&mut buffer) => {
296 let bytes_read = bytes_read?;
297 if bytes_read == 0 {
298 return Ok(());
299 }
300 let observed = completion_tracker.observe(&buffer[..bytes_read]);
301 send_control_output(&output_tx, &buffer[..bytes_read]).await?;
302 if observed.exited {
303 return Ok(());
304 }
305 }
306 }
307 }
308}
309
310#[cfg(windows)]
311fn write_queued_control_output<W>(
312 output: &mut W,
313 mut output_rx: tokio_mpsc::Receiver<Vec<u8>>,
314) -> io::Result<()>
315where
316 W: Write,
317{
318 while let Some(bytes) = output_rx.blocking_recv() {
319 output.write_all(&bytes)?;
320 output.flush()?;
321 }
322 Ok(())
323}
324
325#[cfg(windows)]
326async fn send_control_output(
327 output_tx: &tokio_mpsc::Sender<Vec<u8>>,
328 bytes: &[u8],
329) -> io::Result<()> {
330 output_tx
331 .send(bytes.to_vec())
332 .await
333 .map_err(|_| io::Error::new(io::ErrorKind::BrokenPipe, "control output writer stopped"))
334}
335
336#[cfg(windows)]
337#[derive(Debug, Default)]
338struct ControlCompletionTracker {
339 pending: Vec<u8>,
340}
341
342#[cfg(windows)]
343impl ControlCompletionTracker {
344 fn observe(&mut self, bytes: &[u8]) -> ControlOutputObservation {
345 self.pending.extend_from_slice(bytes);
346 let mut observation = ControlOutputObservation::default();
347 while let Some(position) = self.pending.iter().position(|byte| *byte == b'\n') {
348 let line = self.pending.drain(..=position).collect::<Vec<_>>();
349 if is_control_exit_line(&line) {
350 observation.exited = true;
351 }
352 }
353 observation
354 }
355}
356
357#[cfg(windows)]
358#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
359struct ControlOutputObservation {
360 exited: bool,
361}
362
363#[cfg(windows)]
364fn is_control_exit_line(line: &[u8]) -> bool {
365 line == b"%exit\n" || line.starts_with(b"%exit ")
366}
367
368#[cfg(windows)]
369async fn write_async_initial_commands<Writer>(
370 writer: &mut Writer,
371 initial_commands: &[String],
372) -> io::Result<()>
373where
374 Writer: AsyncWrite + Unpin,
375{
376 for command in initial_commands {
377 writer.write_all(command.as_bytes()).await?;
378 writer.write_all(b"\n").await?;
379 }
380 writer.flush().await?;
381 Ok(())
382}
383
384#[cfg(all(test, unix))]
385mod tests {
386 use std::io::{Cursor, Write};
387 use std::sync::mpsc;
388 use std::time::Duration;
389
390 use rmux_proto::{ControlMode, ControlModeResponse};
391
392 use super::drive_control_mode_with_stdio;
393 use crate::connection::ControlModeUpgrade;
394
395 #[test]
396 fn control_control_mode_wraps_output_with_dcs_sequences() {
397 let (left, right) = std::os::unix::net::UnixStream::pair().expect("socket pair");
398 let writer = std::thread::spawn(move || {
399 let mut right = right;
400 right.write_all(b"%exit\n").expect("write output");
401 });
402
403 let mut output = Vec::new();
404 drive_control_mode_with_stdio(
405 ControlModeUpgrade {
406 response: ControlModeResponse {
407 mode: ControlMode::ControlControl,
408 },
409 stream: left,
410 },
411 &[],
412 Cursor::new(Vec::<u8>::new()),
413 &mut output,
414 )
415 .expect("control mode succeeds");
416 writer.join().expect("writer thread");
417
418 let rendered = String::from_utf8(output).expect("utf8");
419 assert!(rendered.starts_with(rmux_proto::CONTROL_CONTROL_START));
420 assert!(rendered.contains("%exit\n"));
421 assert!(rendered.ends_with(rmux_proto::CONTROL_CONTROL_END));
422 }
423
424 #[test]
425 fn control_mode_returns_after_server_exit_without_waiting_for_input_eof() {
426 let (left, right) = std::os::unix::net::UnixStream::pair().expect("socket pair");
427 let (input_reader, input_writer) =
428 std::os::unix::net::UnixStream::pair().expect("input socket pair");
429 let server = std::thread::spawn(move || {
430 let mut right = right;
431 right.write_all(b"%exit\n").expect("write exit");
432 });
433 let (done_tx, done_rx) = mpsc::channel();
434 let worker = std::thread::spawn(move || {
435 let mut output = Vec::new();
436 let result = drive_control_mode_with_stdio(
437 ControlModeUpgrade {
438 response: ControlModeResponse {
439 mode: ControlMode::Plain,
440 },
441 stream: left,
442 },
443 &[],
444 input_reader,
445 &mut output,
446 );
447 done_tx
448 .send((result, output))
449 .expect("report control mode result");
450 });
451
452 let done = done_rx.recv_timeout(Duration::from_secs(1));
453 drop(input_writer);
454 worker.join().expect("worker thread");
455 server.join().expect("server thread");
456
457 let (result, output) = done.expect("control mode should exit promptly");
458 result.expect("control mode succeeds");
459 assert_eq!(String::from_utf8(output).expect("utf8"), "%exit\n");
460 }
461}
462
463#[cfg(all(test, windows))]
464mod windows_tests {
465 use super::drive_async_control;
466 use rmux_proto::CONTROL_STDIN_EOF_MARKER;
467 use tokio::io::{AsyncReadExt, AsyncWriteExt};
468 use tokio::sync::mpsc as tokio_mpsc;
469
470 #[tokio::test]
471 async fn control_input_eof_shutdowns_writer_and_waits_for_exit() -> std::io::Result<()> {
472 let (client, mut server) = tokio::io::duplex(4096);
473 let (input_tx, input_rx) = tokio_mpsc::channel::<Vec<u8>>(1);
474 input_tx
475 .send(b"list-sessions\n".to_vec())
476 .await
477 .expect("send input");
478 drop(input_tx);
479 let (output_tx, output_rx) = tokio_mpsc::channel::<Vec<u8>>(4);
480
481 let drive = drive_async_control(client, &[], input_rx, output_tx);
482 let server_peer = async {
483 let expected_input = format!("list-sessions\n{CONTROL_STDIN_EOF_MARKER}\n");
484 let mut received = Vec::new();
485 let mut buffer = [0_u8; 32];
486 while received.len() < expected_input.len() {
487 let bytes_read = server.read(&mut buffer).await?;
488 assert_ne!(bytes_read, 0, "client closed before sending command");
489 received.extend_from_slice(&buffer[..bytes_read]);
490 }
491 assert_eq!(received, expected_input.as_bytes());
492 server
493 .write_all(b"%begin 1 1 1\n%end 1 1 1\n%exit\n")
494 .await?;
495 Ok::<(), std::io::Error>(())
496 };
497 let output = collect_control_output(output_rx);
498
499 let (_, _, output) = tokio::try_join!(drive, server_peer, output)?;
500 assert_eq!(output, b"%begin 1 1 1\n%end 1 1 1\n%exit\n");
501 Ok(())
502 }
503
504 #[tokio::test]
505 async fn control_input_eof_drains_exit_after_completed_command() -> std::io::Result<()> {
506 let (client, mut server) = tokio::io::duplex(4096);
507 let (input_tx, input_rx) = tokio_mpsc::channel::<Vec<u8>>(1);
508 input_tx
509 .send(b"list-sessions\n".to_vec())
510 .await
511 .expect("send input");
512 drop(input_tx);
513 let (output_tx, output_rx) = tokio_mpsc::channel::<Vec<u8>>(4);
514
515 let drive = drive_async_control(client, &[], input_rx, output_tx);
516 let server_peer = async {
517 let mut received = Vec::new();
518 let mut buffer = [0_u8; 32];
519 while !received.ends_with(b"\n") {
520 let bytes_read = server.read(&mut buffer).await?;
521 assert_ne!(bytes_read, 0, "client closed before sending command");
522 received.extend_from_slice(&buffer[..bytes_read]);
523 }
524 server.write_all(b"%begin 1 1 1\n%end 1 1 1\n").await?;
525 tokio::task::yield_now().await;
526 server.write_all(b"%exit\n").await?;
527 Ok::<(), std::io::Error>(())
528 };
529 let output = collect_control_output(output_rx);
530
531 let (_, _, output) = tokio::try_join!(drive, server_peer, output)?;
532 assert_eq!(output, b"%begin 1 1 1\n%end 1 1 1\n%exit\n");
533 Ok(())
534 }
535
536 async fn collect_control_output(
537 mut output_rx: tokio_mpsc::Receiver<Vec<u8>>,
538 ) -> std::io::Result<Vec<u8>> {
539 let mut output = Vec::new();
540 while let Some(bytes) = output_rx.recv().await {
541 output.extend_from_slice(&bytes);
542 }
543 Ok(output)
544 }
545}