distant_core/client/
process.rs

1use std::path::PathBuf;
2use std::sync::Arc;
3
4use distant_net::client::Mailbox;
5use distant_net::common::{Request, Response};
6use log::*;
7use tokio::io;
8use tokio::sync::mpsc::error::{TryRecvError, TrySendError};
9use tokio::sync::{mpsc, RwLock};
10use tokio::task::JoinHandle;
11
12use crate::client::DistantChannel;
13use crate::constants::CLIENT_PIPE_CAPACITY;
14use crate::protocol::{self, Cmd, Environment, ProcessId, PtySize};
15
16#[derive(Clone, Debug, PartialEq, Eq)]
17pub struct RemoteOutput {
18    pub success: bool,
19    pub code: Option<i32>,
20    pub stdout: Vec<u8>,
21    pub stderr: Vec<u8>,
22}
23
24#[derive(Copy, Clone, Debug, PartialEq, Eq)]
25pub struct RemoteStatus {
26    pub success: bool,
27    pub code: Option<i32>,
28}
29
30impl From<(bool, Option<i32>)> for RemoteStatus {
31    fn from((success, code): (bool, Option<i32>)) -> Self {
32        Self { success, code }
33    }
34}
35
36type StatusResult = io::Result<RemoteStatus>;
37
38/// A [`RemoteProcess`] builder providing support to configure
39/// before spawning the process on a remote machine
40pub struct RemoteCommand {
41    pty: Option<PtySize>,
42    environment: Environment,
43    current_dir: Option<PathBuf>,
44}
45
46impl Default for RemoteCommand {
47    fn default() -> Self {
48        Self::new()
49    }
50}
51
52impl RemoteCommand {
53    /// Creates a new set of options for a remote process
54    pub fn new() -> Self {
55        Self {
56            pty: None,
57            environment: Environment::new(),
58            current_dir: None,
59        }
60    }
61
62    /// Configures the process to leverage a PTY with the specified size
63    pub fn pty(&mut self, pty: Option<PtySize>) -> &mut Self {
64        self.pty = pty;
65        self
66    }
67
68    /// Replaces the existing environment variables with the given collection
69    pub fn environment(&mut self, environment: Environment) -> &mut Self {
70        self.environment = environment;
71        self
72    }
73
74    /// Configures the process with an alternative current directory
75    pub fn current_dir(&mut self, current_dir: Option<PathBuf>) -> &mut Self {
76        self.current_dir = current_dir;
77        self
78    }
79
80    /// Spawns the specified process on the remote machine using the given `channel` and `cmd`
81    pub async fn spawn(
82        &mut self,
83        mut channel: DistantChannel,
84        cmd: impl Into<String>,
85    ) -> io::Result<RemoteProcess> {
86        let cmd = cmd.into();
87
88        // Submit our run request and get back a mailbox for responses
89        let mut mailbox = channel
90            .mail(Request::new(protocol::Msg::Single(
91                protocol::Request::ProcSpawn {
92                    cmd: Cmd::from(cmd),
93                    pty: self.pty,
94                    environment: self.environment.clone(),
95                    current_dir: self.current_dir.clone(),
96                },
97            )))
98            .await?;
99
100        // Wait until we get the first response, and get id from proc started
101        let (id, origin_id) = match mailbox.next().await {
102            Some(res) => {
103                let origin_id = res.origin_id;
104                match res.payload {
105                    protocol::Msg::Single(protocol::Response::ProcSpawned { id }) => {
106                        (id, origin_id)
107                    }
108                    protocol::Msg::Single(protocol::Response::Error(x)) => return Err(x.into()),
109                    protocol::Msg::Single(x) => {
110                        return Err(io::Error::new(
111                            io::ErrorKind::InvalidData,
112                            format!("Got response type of {}", x.as_ref()),
113                        ))
114                    }
115                    protocol::Msg::Batch(_) => {
116                        return Err(io::Error::new(
117                            io::ErrorKind::InvalidData,
118                            "Got batch instead of single response",
119                        ));
120                    }
121                }
122            }
123            None => return Err(io::Error::from(io::ErrorKind::ConnectionAborted)),
124        };
125
126        // Create channels for our stdin/stdout/stderr
127        let (stdin_tx, stdin_rx) = mpsc::channel(CLIENT_PIPE_CAPACITY);
128        let (stdout_tx, stdout_rx) = mpsc::channel(CLIENT_PIPE_CAPACITY);
129        let (stderr_tx, stderr_rx) = mpsc::channel(CLIENT_PIPE_CAPACITY);
130        let (resize_tx, resize_rx) = mpsc::channel(1);
131
132        // Used to terminate request task, either explicitly by the process or internally
133        // by the response task when it terminates
134        let (kill_tx, kill_rx) = mpsc::channel(1);
135        let kill_tx_2 = kill_tx.clone();
136
137        // Now we spawn a task to handle future responses that are async
138        // such as ProcStdout, ProcStderr, and ProcDone
139        let (abort_res_task_tx, mut abort_res_task_rx) = mpsc::channel::<()>(1);
140        let res_task = tokio::spawn(async move {
141            tokio::select! {
142                _ = abort_res_task_rx.recv() => {
143                    panic!("killed");
144                }
145                res = process_incoming_responses(id, mailbox, stdout_tx, stderr_tx, kill_tx_2) => {
146                    res
147                }
148            }
149        });
150
151        // Spawn a task that takes stdin from our channel and forwards it to the remote process
152        let (abort_req_task_tx, mut abort_req_task_rx) = mpsc::channel::<()>(1);
153        let req_task = tokio::spawn(async move {
154            tokio::select! {
155                _ = abort_req_task_rx.recv() => {
156                    panic!("killed");
157                }
158                res = process_outgoing_requests( id, channel, stdin_rx, resize_rx, kill_rx) => {
159                    res
160                }
161            }
162        });
163
164        let status = Arc::new(RwLock::new(None));
165        let status_2 = Arc::clone(&status);
166        let wait_task = tokio::spawn(async move {
167            let res = match tokio::try_join!(req_task, res_task) {
168                Ok((_, res)) => res.map(RemoteStatus::from),
169                Err(x) => Err(io::Error::new(io::ErrorKind::Interrupted, x)),
170            };
171            status_2.write().await.replace(res);
172        });
173
174        Ok(RemoteProcess {
175            id,
176            origin_id,
177            abort_req_task_tx,
178            abort_res_task_tx,
179            stdin: Some(RemoteStdin(stdin_tx)),
180            stdout: Some(RemoteStdout(stdout_rx)),
181            stderr: Some(RemoteStderr(stderr_rx)),
182            resizer: RemoteProcessResizer(resize_tx),
183            killer: RemoteProcessKiller(kill_tx),
184            wait_task,
185            status,
186        })
187    }
188}
189
190/// Represents a process on a remote machine
191#[derive(Debug)]
192pub struct RemoteProcess {
193    /// Id of the process
194    id: ProcessId,
195
196    /// Id used to map back to mailbox
197    origin_id: String,
198
199    // Sender to abort req task
200    abort_req_task_tx: mpsc::Sender<()>,
201
202    // Sender to abort res task
203    abort_res_task_tx: mpsc::Sender<()>,
204
205    /// Sender for stdin
206    pub stdin: Option<RemoteStdin>,
207
208    /// Receiver for stdout
209    pub stdout: Option<RemoteStdout>,
210
211    /// Receiver for stderr
212    pub stderr: Option<RemoteStderr>,
213
214    /// Sender for resize events
215    resizer: RemoteProcessResizer,
216
217    /// Sender for kill events
218    killer: RemoteProcessKiller,
219
220    /// Task that waits for the process to complete
221    wait_task: JoinHandle<()>,
222
223    /// Handles the success and exit code for a completed process
224    status: Arc<RwLock<Option<StatusResult>>>,
225}
226
227impl RemoteProcess {
228    /// Returns the id of the running process
229    pub fn id(&self) -> ProcessId {
230        self.id
231    }
232
233    /// Returns the id of the request that spawned this process
234    pub fn origin_id(&self) -> &str {
235        &self.origin_id
236    }
237
238    /// Checks if the process has completed, returning the exit status if it has, without
239    /// consuming the process itself. Note that this does not include join errors that can
240    /// occur when aborting and instead converts any error to a status of false. To acquire
241    /// the actual error, you must call `wait`
242    pub async fn status(&self) -> Option<RemoteStatus> {
243        self.status.read().await.as_ref().map(|x| match x {
244            Ok(status) => *status,
245            Err(_) => RemoteStatus {
246                success: false,
247                code: None,
248            },
249        })
250    }
251
252    /// Waits for the process to terminate, returning the success status and an optional exit code
253    pub async fn wait(self) -> io::Result<RemoteStatus> {
254        // Wait for the process to complete before we try to get the status
255        let _ = self.wait_task.await;
256
257        // NOTE: If we haven't received an exit status, this lines up with the UnexpectedEof error
258        self.status
259            .write()
260            .await
261            .take()
262            .unwrap_or_else(|| Err(errors::unexpected_eof()))
263    }
264
265    /// Waits for the process to terminate, returning the success status, an optional exit code,
266    /// and any remaining stdout and stderr (if still attached to the process)
267    pub async fn output(mut self) -> io::Result<RemoteOutput> {
268        let maybe_stdout = self.stdout.take();
269        let maybe_stderr = self.stderr.take();
270
271        let status = self.wait().await?;
272
273        let mut stdout = Vec::new();
274        if let Some(mut reader) = maybe_stdout {
275            while let Ok(data) = reader.read().await {
276                stdout.extend(&data);
277            }
278        }
279
280        let mut stderr = Vec::new();
281        if let Some(mut reader) = maybe_stderr {
282            while let Ok(data) = reader.read().await {
283                stderr.extend(&data);
284            }
285        }
286
287        Ok(RemoteOutput {
288            success: status.success,
289            code: status.code,
290            stdout,
291            stderr,
292        })
293    }
294
295    /// Resizes the pty of the remote process if it is attached to one
296    pub async fn resize(&self, size: PtySize) -> io::Result<()> {
297        self.resizer.resize(size).await
298    }
299
300    /// Clones a copy of the remote process pty resizer
301    pub fn clone_resizer(&self) -> RemoteProcessResizer {
302        self.resizer.clone()
303    }
304
305    /// Submits a kill request for the running process
306    pub async fn kill(&mut self) -> io::Result<()> {
307        self.killer.kill().await
308    }
309
310    /// Clones a copy of the remote process killer
311    pub fn clone_killer(&self) -> RemoteProcessKiller {
312        self.killer.clone()
313    }
314
315    /// Aborts the process by forcing its response task to shutdown, which means that a call
316    /// to `wait` will return an error. Note that this does **not** send a kill request, so if
317    /// you want to be nice you should send the request before aborting.
318    pub fn abort(&self) {
319        let _ = self.abort_req_task_tx.try_send(());
320        let _ = self.abort_res_task_tx.try_send(());
321    }
322}
323
324/// A handle to the channel to kill a remote process
325#[derive(Clone, Debug)]
326pub struct RemoteProcessResizer(mpsc::Sender<PtySize>);
327
328impl RemoteProcessResizer {
329    /// Resizes the pty of the remote process if it is attached to one
330    pub async fn resize(&self, size: PtySize) -> io::Result<()> {
331        self.0
332            .send(size)
333            .await
334            .map_err(|_| errors::dead_channel())?;
335        Ok(())
336    }
337}
338
339/// A handle to the channel to kill a remote process
340#[derive(Clone, Debug)]
341pub struct RemoteProcessKiller(mpsc::Sender<()>);
342
343impl RemoteProcessKiller {
344    /// Submits a kill request for the running process
345    pub async fn kill(&mut self) -> io::Result<()> {
346        self.0.send(()).await.map_err(|_| errors::dead_channel())?;
347        Ok(())
348    }
349}
350
351/// A handle to a remote process' standard input (stdin)
352#[derive(Clone, Debug)]
353pub struct RemoteStdin(mpsc::Sender<Vec<u8>>);
354
355impl RemoteStdin {
356    /// Creates a disconnected remote stdin
357    pub fn disconnected() -> Self {
358        Self(mpsc::channel(1).0)
359    }
360
361    /// Tries to write to the stdin of the remote process, returning ok if immediately
362    /// successful, `WouldBlock` if would need to wait to send data, and `BrokenPipe`
363    /// if stdin has been closed
364    pub fn try_write(&mut self, data: impl Into<Vec<u8>>) -> io::Result<()> {
365        match self.0.try_send(data.into()) {
366            Ok(data) => Ok(data),
367            Err(TrySendError::Full(_)) => Err(io::Error::from(io::ErrorKind::WouldBlock)),
368            Err(TrySendError::Closed(_)) => Err(io::Error::from(io::ErrorKind::BrokenPipe)),
369        }
370    }
371
372    /// Same as `try_write`, but with a string
373    pub fn try_write_str(&mut self, data: impl Into<String>) -> io::Result<()> {
374        self.try_write(data.into().into_bytes())
375    }
376
377    /// Writes data to the stdin of a specific remote process
378    pub async fn write(&mut self, data: impl Into<Vec<u8>>) -> io::Result<()> {
379        self.0
380            .send(data.into())
381            .await
382            .map_err(|x| io::Error::new(io::ErrorKind::BrokenPipe, x))
383    }
384
385    /// Same as `write`, but with a string
386    pub async fn write_str(&mut self, data: impl Into<String>) -> io::Result<()> {
387        self.write(data.into().into_bytes()).await
388    }
389
390    /// Checks if stdin has been closed
391    pub fn is_closed(&self) -> bool {
392        self.0.is_closed()
393    }
394}
395
396/// A handle to a remote process' standard output (stdout)
397#[derive(Debug)]
398pub struct RemoteStdout(mpsc::Receiver<Vec<u8>>);
399
400impl RemoteStdout {
401    /// Tries to receive latest stdout for a remote process, yielding `None`
402    /// if no stdout is available, and `BrokenPipe` if stdout has been closed
403    pub fn try_read(&mut self) -> io::Result<Option<Vec<u8>>> {
404        match self.0.try_recv() {
405            Ok(data) => Ok(Some(data)),
406            Err(TryRecvError::Empty) => Ok(None),
407            Err(TryRecvError::Disconnected) => Err(io::Error::from(io::ErrorKind::BrokenPipe)),
408        }
409    }
410
411    /// Same as `try_read`, but returns a string
412    pub fn try_read_string(&mut self) -> io::Result<Option<String>> {
413        self.try_read().and_then(|x| match x {
414            Some(data) => String::from_utf8(data)
415                .map(Some)
416                .map_err(|x| io::Error::new(io::ErrorKind::InvalidData, x)),
417            None => Ok(None),
418        })
419    }
420
421    /// Retrieves the latest stdout for a specific remote process, and `BrokenPipe` if stdout has
422    /// been closed
423    pub async fn read(&mut self) -> io::Result<Vec<u8>> {
424        self.0
425            .recv()
426            .await
427            .ok_or_else(|| io::Error::from(io::ErrorKind::BrokenPipe))
428    }
429
430    /// Same as `read`, but returns a string
431    pub async fn read_string(&mut self) -> io::Result<String> {
432        self.read().await.and_then(|data| {
433            String::from_utf8(data).map_err(|x| io::Error::new(io::ErrorKind::InvalidData, x))
434        })
435    }
436}
437
438/// A handle to a remote process' stderr
439#[derive(Debug)]
440pub struct RemoteStderr(mpsc::Receiver<Vec<u8>>);
441
442impl RemoteStderr {
443    /// Tries to receive latest stderr for a remote process, yielding `None`
444    /// if no stderr is available, and `BrokenPipe` if stderr has been closed
445    pub fn try_read(&mut self) -> io::Result<Option<Vec<u8>>> {
446        match self.0.try_recv() {
447            Ok(data) => Ok(Some(data)),
448            Err(TryRecvError::Empty) => Ok(None),
449            Err(TryRecvError::Disconnected) => Err(io::Error::from(io::ErrorKind::BrokenPipe)),
450        }
451    }
452
453    /// Same as `try_read`, but returns a string
454    pub fn try_read_string(&mut self) -> io::Result<Option<String>> {
455        self.try_read().and_then(|x| match x {
456            Some(data) => String::from_utf8(data)
457                .map(Some)
458                .map_err(|x| io::Error::new(io::ErrorKind::InvalidData, x)),
459            None => Ok(None),
460        })
461    }
462
463    /// Retrieves the latest stderr for a specific remote process, and `BrokenPipe` if stderr has
464    /// been closed
465    pub async fn read(&mut self) -> io::Result<Vec<u8>> {
466        self.0
467            .recv()
468            .await
469            .ok_or_else(|| io::Error::from(io::ErrorKind::BrokenPipe))
470    }
471
472    /// Same as `read`, but returns a string
473    pub async fn read_string(&mut self) -> io::Result<String> {
474        self.read().await.and_then(|data| {
475            String::from_utf8(data).map_err(|x| io::Error::new(io::ErrorKind::InvalidData, x))
476        })
477    }
478}
479
480/// Helper function that loops, processing outgoing stdin requests to a remote process as well as
481/// supporting a kill request to terminate the remote process
482async fn process_outgoing_requests(
483    id: ProcessId,
484    mut channel: DistantChannel,
485    mut stdin_rx: mpsc::Receiver<Vec<u8>>,
486    mut resize_rx: mpsc::Receiver<PtySize>,
487    mut kill_rx: mpsc::Receiver<()>,
488) -> io::Result<()> {
489    let result = loop {
490        tokio::select! {
491            data = stdin_rx.recv() => {
492                match data {
493                    Some(data) => channel.fire(
494                        Request::new(
495                            protocol::Msg::Single(protocol::Request::ProcStdin { id, data })
496                        )
497                    ).await?,
498                    None => break Err(errors::dead_channel()),
499                }
500            }
501            size = resize_rx.recv() => {
502                match size {
503                    Some(size) => channel.fire(
504                        Request::new(
505                            protocol::Msg::Single(protocol::Request::ProcResizePty { id, size })
506                        )
507                    ).await?,
508                    None => break Err(errors::dead_channel()),
509                }
510            }
511            msg = kill_rx.recv() => {
512                if msg.is_some() {
513                    channel.fire(Request::new(
514                        protocol::Msg::Single(protocol::Request::ProcKill { id })
515                    )).await?;
516                    break Ok(());
517                } else {
518                    break Err(errors::dead_channel());
519                }
520            }
521        }
522    };
523
524    trace!("Process outgoing channel closed");
525    result
526}
527
528/// Helper function that loops, processing incoming stdout & stderr requests from a remote process
529async fn process_incoming_responses(
530    proc_id: ProcessId,
531    mut mailbox: Mailbox<Response<protocol::Msg<protocol::Response>>>,
532    stdout_tx: mpsc::Sender<Vec<u8>>,
533    stderr_tx: mpsc::Sender<Vec<u8>>,
534    kill_tx: mpsc::Sender<()>,
535) -> io::Result<(bool, Option<i32>)> {
536    while let Some(res) = mailbox.next().await {
537        let payload = res.payload.into_vec();
538
539        // Check if any of the payload data is the termination
540        let exit_status = payload.iter().find_map(|data| match data {
541            protocol::Response::ProcDone { id, success, code } if *id == proc_id => {
542                Some((*success, *code))
543            }
544            _ => None,
545        });
546
547        // Next, check for stdout/stderr and send them along our channels
548        // TODO: What should we do about unexpected data? For now, just ignore
549        for data in payload {
550            match data {
551                protocol::Response::ProcStdout { id, data } if id == proc_id => {
552                    let _ = stdout_tx.send(data).await;
553                }
554                protocol::Response::ProcStderr { id, data } if id == proc_id => {
555                    let _ = stderr_tx.send(data).await;
556                }
557                _ => {}
558            }
559        }
560
561        // If we got a termination, then exit accordingly
562        if let Some((success, code)) = exit_status {
563            // Flag that the other task should conclude
564            let _ = kill_tx.try_send(());
565
566            return Ok((success, code));
567        }
568    }
569
570    // Flag that the other task should conclude
571    let _ = kill_tx.try_send(());
572
573    trace!("Process incoming channel closed");
574    Err(errors::unexpected_eof())
575}
576
577mod errors {
578    use std::io;
579
580    pub fn dead_channel() -> io::Error {
581        io::Error::new(io::ErrorKind::BrokenPipe, "Channel is dead")
582    }
583
584    pub fn unexpected_eof() -> io::Error {
585        io::Error::from(io::ErrorKind::UnexpectedEof)
586    }
587}
588
589#[cfg(test)]
590mod tests {
591    use std::time::Duration;
592
593    use distant_net::common::{FramedTransport, InmemoryTransport, Response};
594    use distant_net::Client;
595    use test_log::test;
596
597    use super::*;
598    use crate::client::DistantClient;
599    use crate::protocol::{Error, ErrorKind};
600
601    fn make_session() -> (FramedTransport<InmemoryTransport>, DistantClient) {
602        let (t1, t2) = FramedTransport::pair(100);
603        (t1, Client::spawn_inmemory(t2, Default::default()))
604    }
605
606    #[test(tokio::test)]
607    async fn spawn_should_return_invalid_data_if_received_batch_response() {
608        let (mut transport, session) = make_session();
609
610        // Create a task for process spawning as we need to handle the request and a response
611        // in a separate async block
612        let spawn_task = tokio::spawn(async move {
613            RemoteCommand::new()
614                .spawn(session.clone_channel(), String::from("cmd arg"))
615                .await
616        });
617
618        // Wait until we get the request from the session
619        let req: Request<protocol::Msg<protocol::Request>> =
620            transport.read_frame_as().await.unwrap().unwrap();
621
622        // Send back a response through the session
623        transport
624            .write_frame_for(&Response::new(
625                req.id,
626                protocol::Msg::Batch(vec![protocol::Response::ProcSpawned { id: 1 }]),
627            ))
628            .await
629            .unwrap();
630
631        // Get the spawn result and verify
632        match spawn_task.await.unwrap() {
633            Err(x) if x.kind() == io::ErrorKind::InvalidData => {}
634            x => panic!("Unexpected result: {:?}", x),
635        }
636    }
637
638    #[test(tokio::test)]
639    async fn spawn_should_return_invalid_data_if_did_not_get_a_indicator_that_process_started() {
640        let (mut transport, session) = make_session();
641
642        // Create a task for process spawning as we need to handle the request and a response
643        // in a separate async block
644        let spawn_task = tokio::spawn(async move {
645            RemoteCommand::new()
646                .spawn(session.clone_channel(), String::from("cmd arg"))
647                .await
648        });
649
650        // Wait until we get the request from the session
651        let req: Request<protocol::Msg<protocol::Request>> =
652            transport.read_frame_as().await.unwrap().unwrap();
653
654        // Send back a response through the session
655        transport
656            .write_frame_for(&Response::new(
657                req.id,
658                protocol::Msg::Single(protocol::Response::Error(Error {
659                    kind: ErrorKind::BrokenPipe,
660                    description: String::from("some error"),
661                })),
662            ))
663            .await
664            .unwrap();
665
666        // Get the spawn result and verify
667        match spawn_task.await.unwrap() {
668            Err(x) if x.kind() == io::ErrorKind::BrokenPipe => {}
669            x => panic!("Unexpected result: {:?}", x),
670        }
671    }
672
673    #[test(tokio::test)]
674    async fn kill_should_return_error_if_internal_tasks_already_completed() {
675        let (mut transport, session) = make_session();
676
677        // Create a task for process spawning as we need to handle the request and a response
678        // in a separate async block
679        let spawn_task = tokio::spawn(async move {
680            RemoteCommand::new()
681                .spawn(session.clone_channel(), String::from("cmd arg"))
682                .await
683        });
684
685        // Wait until we get the request from the session
686        let req: Request<protocol::Msg<protocol::Request>> =
687            transport.read_frame_as().await.unwrap().unwrap();
688
689        // Send back a response through the session
690        let id = 12345;
691        transport
692            .write_frame_for(&Response::new(
693                req.id,
694                protocol::Msg::Single(protocol::Response::ProcSpawned { id }),
695            ))
696            .await
697            .unwrap();
698
699        // Receive the process and then abort it to make kill fail
700        let mut proc = spawn_task.await.unwrap().unwrap();
701        proc.abort();
702
703        // Ensure that the other tasks are aborted before continuing
704        tokio::task::yield_now().await;
705
706        match proc.kill().await {
707            Err(x) if x.kind() == io::ErrorKind::BrokenPipe => {}
708            x => panic!("Unexpected result: {:?}", x),
709        }
710    }
711
712    #[test(tokio::test)]
713    async fn kill_should_send_proc_kill_request_and_then_cause_stdin_forwarding_to_close() {
714        let (mut transport, session) = make_session();
715
716        // Create a task for process spawning as we need to handle the request and a response
717        // in a separate async block
718        let spawn_task = tokio::spawn(async move {
719            RemoteCommand::new()
720                .spawn(session.clone_channel(), String::from("cmd arg"))
721                .await
722        });
723
724        // Wait until we get the request from the session
725        let req: Request<protocol::Msg<protocol::Request>> =
726            transport.read_frame_as().await.unwrap().unwrap();
727
728        // Send back a response through the session
729        let id = 12345;
730        transport
731            .write_frame_for(&Response::new(
732                req.id,
733                protocol::Msg::Single(protocol::Response::ProcSpawned { id }),
734            ))
735            .await
736            .unwrap();
737
738        // Receive the process and then kill it
739        let mut proc = spawn_task.await.unwrap().unwrap();
740        assert!(proc.kill().await.is_ok(), "Failed to send kill request");
741
742        // Verify the kill request was sent
743        let req: Request<protocol::Msg<protocol::Request>> =
744            transport.read_frame_as().await.unwrap().unwrap();
745        match req.payload {
746            protocol::Msg::Single(protocol::Request::ProcKill { id: proc_id }) => {
747                assert_eq!(proc_id, id)
748            }
749            x => panic!("Unexpected request: {:?}", x),
750        }
751
752        // Verify we can no longer write to stdin anymore
753        assert_eq!(
754            proc.stdin
755                .as_mut()
756                .unwrap()
757                .write("some stdin")
758                .await
759                .unwrap_err()
760                .kind(),
761            io::ErrorKind::BrokenPipe
762        );
763    }
764
765    #[test(tokio::test)]
766    async fn stdin_should_be_forwarded_from_receiver_field() {
767        let (mut transport, session) = make_session();
768
769        // Create a task for process spawning as we need to handle the request and a response
770        // in a separate async block
771        let spawn_task = tokio::spawn(async move {
772            RemoteCommand::new()
773                .spawn(session.clone_channel(), String::from("cmd arg"))
774                .await
775        });
776
777        // Wait until we get the request from the session
778        let req: Request<protocol::Msg<protocol::Request>> =
779            transport.read_frame_as().await.unwrap().unwrap();
780
781        // Send back a response through the session
782        let id = 12345;
783        transport
784            .write_frame_for(&Response::new(
785                req.id,
786                protocol::Msg::Single(protocol::Response::ProcSpawned { id }),
787            ))
788            .await
789            .unwrap();
790
791        // Receive the process and then send stdin
792        let mut proc = spawn_task.await.unwrap().unwrap();
793        proc.stdin
794            .as_mut()
795            .unwrap()
796            .write("some input")
797            .await
798            .unwrap();
799
800        // Verify that a request is made through the session
801        let req: Request<protocol::Msg<protocol::Request>> =
802            transport.read_frame_as().await.unwrap().unwrap();
803        match req.payload {
804            protocol::Msg::Single(protocol::Request::ProcStdin { id, data }) => {
805                assert_eq!(id, 12345);
806                assert_eq!(data, b"some input");
807            }
808            x => panic!("Unexpected request: {:?}", x),
809        }
810    }
811
812    #[test(tokio::test)]
813    async fn stdout_should_be_forwarded_to_receiver_field() {
814        let (mut transport, session) = make_session();
815
816        // Create a task for process spawning as we need to handle the request and a response
817        // in a separate async block
818        let spawn_task = tokio::spawn(async move {
819            RemoteCommand::new()
820                .spawn(session.clone_channel(), String::from("cmd arg"))
821                .await
822        });
823
824        // Wait until we get the request from the session
825        let req: Request<protocol::Msg<protocol::Request>> =
826            transport.read_frame_as().await.unwrap().unwrap();
827
828        // Send back a response through the session
829        let id = 12345;
830        transport
831            .write_frame_for(&Response::new(
832                req.id.clone(),
833                protocol::Msg::Single(protocol::Response::ProcSpawned { id }),
834            ))
835            .await
836            .unwrap();
837
838        // Receive the process and then read stdout
839        let mut proc = spawn_task.await.unwrap().unwrap();
840
841        transport
842            .write_frame_for(&Response::new(
843                req.id,
844                protocol::Msg::Single(protocol::Response::ProcStdout {
845                    id,
846                    data: b"some out".to_vec(),
847                }),
848            ))
849            .await
850            .unwrap();
851
852        let out = proc.stdout.as_mut().unwrap().read().await.unwrap();
853        assert_eq!(out, b"some out");
854    }
855
856    #[test(tokio::test)]
857    async fn stderr_should_be_forwarded_to_receiver_field() {
858        let (mut transport, session) = make_session();
859
860        // Create a task for process spawning as we need to handle the request and a response
861        // in a separate async block
862        let spawn_task = tokio::spawn(async move {
863            RemoteCommand::new()
864                .spawn(session.clone_channel(), String::from("cmd arg"))
865                .await
866        });
867
868        // Wait until we get the request from the session
869        let req: Request<protocol::Msg<protocol::Request>> =
870            transport.read_frame_as().await.unwrap().unwrap();
871
872        // Send back a response through the session
873        let id = 12345;
874        transport
875            .write_frame_for(&Response::new(
876                req.id.clone(),
877                protocol::Msg::Single(protocol::Response::ProcSpawned { id }),
878            ))
879            .await
880            .unwrap();
881
882        // Receive the process and then read stderr
883        let mut proc = spawn_task.await.unwrap().unwrap();
884
885        transport
886            .write_frame_for(&Response::new(
887                req.id,
888                protocol::Msg::Single(protocol::Response::ProcStderr {
889                    id,
890                    data: b"some err".to_vec(),
891                }),
892            ))
893            .await
894            .unwrap();
895
896        let out = proc.stderr.as_mut().unwrap().read().await.unwrap();
897        assert_eq!(out, b"some err");
898    }
899
900    #[test(tokio::test)]
901    async fn status_should_return_none_if_not_done() {
902        let (mut transport, session) = make_session();
903
904        // Create a task for process spawning as we need to handle the request and a response
905        // in a separate async block
906        let spawn_task = tokio::spawn(async move {
907            RemoteCommand::new()
908                .spawn(session.clone_channel(), String::from("cmd arg"))
909                .await
910        });
911
912        // Wait until we get the request from the session
913        let req: Request<protocol::Msg<protocol::Request>> =
914            transport.read_frame_as().await.unwrap().unwrap();
915
916        // Send back a response through the session
917        let id = 12345;
918        transport
919            .write_frame_for(&Response::new(
920                req.id,
921                protocol::Msg::Single(protocol::Response::ProcSpawned { id }),
922            ))
923            .await
924            .unwrap();
925
926        // Receive the process and then check its status
927        let proc = spawn_task.await.unwrap().unwrap();
928
929        let result = proc.status().await;
930        assert_eq!(result, None, "Unexpectedly got proc status: {:?}", result);
931    }
932
933    #[test(tokio::test)]
934    async fn status_should_return_false_for_success_if_internal_tasks_fail() {
935        let (mut transport, session) = make_session();
936
937        // Create a task for process spawning as we need to handle the request and a response
938        // in a separate async block
939        let spawn_task = tokio::spawn(async move {
940            RemoteCommand::new()
941                .spawn(session.clone_channel(), String::from("cmd arg"))
942                .await
943        });
944
945        // Wait until we get the request from the session
946        let req: Request<protocol::Msg<protocol::Request>> =
947            transport.read_frame_as().await.unwrap().unwrap();
948
949        // Send back a response through the session
950        let id = 12345;
951        transport
952            .write_frame_for(&Response::new(
953                req.id,
954                protocol::Msg::Single(protocol::Response::ProcSpawned { id }),
955            ))
956            .await
957            .unwrap();
958
959        // Receive the process and then abort it to make internal tasks fail
960        let proc = spawn_task.await.unwrap().unwrap();
961        proc.abort();
962
963        // Wait a bit to ensure the other tasks abort
964        tokio::time::sleep(Duration::from_millis(100)).await;
965
966        // Peek at the status to confirm the result
967        let result = proc.status().await;
968        match result {
969            Some(status) => {
970                assert!(!status.success, "Status unexpectedly reported success");
971                assert!(
972                    status.code.is_none(),
973                    "Status unexpectedly reported exit code"
974                );
975            }
976            x => panic!("Unexpected result: {:?}", x),
977        }
978    }
979
980    #[test(tokio::test)]
981    async fn status_should_return_process_status_when_done() {
982        let (mut transport, session) = make_session();
983
984        // Create a task for process spawning as we need to handle the request and a response
985        // in a separate async block
986        let spawn_task = tokio::spawn(async move {
987            RemoteCommand::new()
988                .spawn(session.clone_channel(), String::from("cmd arg"))
989                .await
990        });
991
992        // Wait until we get the request from the session
993        let req: Request<protocol::Msg<protocol::Request>> =
994            transport.read_frame_as().await.unwrap().unwrap();
995
996        // Send back a response through the session
997        let id = 12345;
998        transport
999            .write_frame_for(&Response::new(
1000                req.id.clone(),
1001                protocol::Msg::Single(protocol::Response::ProcSpawned { id }),
1002            ))
1003            .await
1004            .unwrap();
1005
1006        // Receive the process and then spawn a task for it to complete
1007        let proc = spawn_task.await.unwrap().unwrap();
1008
1009        // Send a process completion response to pass along exit status and conclude wait
1010        transport
1011            .write_frame_for(&Response::new(
1012                req.id,
1013                protocol::Msg::Single(protocol::Response::ProcDone {
1014                    id,
1015                    success: true,
1016                    code: Some(123),
1017                }),
1018            ))
1019            .await
1020            .unwrap();
1021
1022        // Wait a bit to ensure the status gets transmitted
1023        tokio::time::sleep(Duration::from_millis(100)).await;
1024
1025        // Finally, verify that we complete and get the expected results
1026        assert_eq!(
1027            proc.status().await,
1028            Some(RemoteStatus {
1029                success: true,
1030                code: Some(123)
1031            })
1032        );
1033    }
1034
1035    #[test(tokio::test)]
1036    async fn wait_should_return_error_if_internal_tasks_fail() {
1037        let (mut transport, session) = make_session();
1038
1039        // Create a task for process spawning as we need to handle the request and a response
1040        // in a separate async block
1041        let spawn_task = tokio::spawn(async move {
1042            RemoteCommand::new()
1043                .spawn(session.clone_channel(), String::from("cmd arg"))
1044                .await
1045        });
1046
1047        // Wait until we get the request from the session
1048        let req: Request<protocol::Msg<protocol::Request>> =
1049            transport.read_frame_as().await.unwrap().unwrap();
1050
1051        // Send back a response through the session
1052        let id = 12345;
1053        transport
1054            .write_frame_for(&Response::new(
1055                req.id,
1056                protocol::Msg::Single(protocol::Response::ProcSpawned { id }),
1057            ))
1058            .await
1059            .unwrap();
1060
1061        // Receive the process and then abort it to make internal tasks fail
1062        let proc = spawn_task.await.unwrap().unwrap();
1063        proc.abort();
1064
1065        match proc.wait().await {
1066            Err(x) if x.kind() == io::ErrorKind::Interrupted => {}
1067            x => panic!("Unexpected result: {:?}", x),
1068        }
1069    }
1070
1071    #[test(tokio::test)]
1072    async fn wait_should_return_error_if_connection_terminates_before_receiving_done_response() {
1073        let (mut transport, session) = make_session();
1074
1075        // Create a task for process spawning as we need to handle the request and a response
1076        // in a separate async block
1077        let spawn_task = tokio::spawn(async move {
1078            RemoteCommand::new()
1079                .spawn(session.clone_channel(), String::from("cmd arg"))
1080                .await
1081        });
1082
1083        // Wait until we get the request from the session
1084        let req: Request<protocol::Msg<protocol::Request>> =
1085            transport.read_frame_as().await.unwrap().unwrap();
1086
1087        // Send back a response through the session
1088        let id = 12345;
1089        transport
1090            .write_frame_for(&Response::new(
1091                req.id,
1092                protocol::Msg::Single(protocol::Response::ProcSpawned { id }),
1093            ))
1094            .await
1095            .unwrap();
1096
1097        // Receive the process and then terminate session connection
1098        let proc = spawn_task.await.unwrap().unwrap();
1099
1100        // Ensure that the spawned task gets a chance to wait on stdout/stderr
1101        tokio::task::yield_now().await;
1102
1103        drop(transport);
1104
1105        // Ensure that the other tasks are cancelled before continuing
1106        tokio::task::yield_now().await;
1107
1108        match proc.wait().await {
1109            Err(x) if x.kind() == io::ErrorKind::UnexpectedEof => {}
1110            x => panic!("Unexpected result: {:?}", x),
1111        }
1112    }
1113
1114    #[test(tokio::test)]
1115    async fn receiving_done_response_should_result_in_wait_returning_exit_information() {
1116        let (mut transport, session) = make_session();
1117
1118        // Create a task for process spawning as we need to handle the request and a response
1119        // in a separate async block
1120        let spawn_task = tokio::spawn(async move {
1121            RemoteCommand::new()
1122                .spawn(session.clone_channel(), String::from("cmd arg"))
1123                .await
1124        });
1125
1126        // Wait until we get the request from the session
1127        let req: Request<protocol::Msg<protocol::Request>> =
1128            transport.read_frame_as().await.unwrap().unwrap();
1129
1130        // Send back a response through the session
1131        let id = 12345;
1132        transport
1133            .write_frame_for(&Response::new(
1134                req.id.clone(),
1135                protocol::Msg::Single(protocol::Response::ProcSpawned { id }),
1136            ))
1137            .await
1138            .unwrap();
1139
1140        // Receive the process and then spawn a task for it to complete
1141        let proc = spawn_task.await.unwrap().unwrap();
1142        let proc_wait_task = tokio::spawn(proc.wait());
1143
1144        // Send a process completion response to pass along exit status and conclude wait
1145        transport
1146            .write_frame_for(&Response::new(
1147                req.id,
1148                protocol::Msg::Single(protocol::Response::ProcDone {
1149                    id,
1150                    success: false,
1151                    code: Some(123),
1152                }),
1153            ))
1154            .await
1155            .unwrap();
1156
1157        // Finally, verify that we complete and get the expected results
1158        assert_eq!(
1159            proc_wait_task.await.unwrap().unwrap(),
1160            RemoteStatus {
1161                success: false,
1162                code: Some(123)
1163            }
1164        );
1165    }
1166
1167    #[test(tokio::test)]
1168    async fn receiving_done_response_should_result_in_output_returning_exit_information() {
1169        let (mut transport, session) = make_session();
1170
1171        // Create a task for process spawning as we need to handle the request and a response
1172        // in a separate async block
1173        let spawn_task = tokio::spawn(async move {
1174            RemoteCommand::new()
1175                .spawn(session.clone_channel(), String::from("cmd arg"))
1176                .await
1177        });
1178
1179        // Wait until we get the request from the session
1180        let req: Request<protocol::Msg<protocol::Request>> =
1181            transport.read_frame_as().await.unwrap().unwrap();
1182
1183        // Send back a response through the session
1184        let id = 12345;
1185        transport
1186            .write_frame_for(&Response::new(
1187                req.id.clone(),
1188                protocol::Msg::Single(protocol::Response::ProcSpawned { id }),
1189            ))
1190            .await
1191            .unwrap();
1192
1193        // Receive the process and then spawn a task for it to complete
1194        let proc = spawn_task.await.unwrap().unwrap();
1195        let proc_output_task = tokio::spawn(proc.output());
1196
1197        // Send some stdout
1198        transport
1199            .write_frame_for(&Response::new(
1200                req.id.clone(),
1201                protocol::Msg::Single(protocol::Response::ProcStdout {
1202                    id,
1203                    data: b"some out".to_vec(),
1204                }),
1205            ))
1206            .await
1207            .unwrap();
1208
1209        // Send some stderr
1210        transport
1211            .write_frame_for(&Response::new(
1212                req.id.clone(),
1213                protocol::Msg::Single(protocol::Response::ProcStderr {
1214                    id,
1215                    data: b"some err".to_vec(),
1216                }),
1217            ))
1218            .await
1219            .unwrap();
1220
1221        // Send a process completion response to pass along exit status and conclude wait
1222        transport
1223            .write_frame_for(&Response::new(
1224                req.id,
1225                protocol::Msg::Single(protocol::Response::ProcDone {
1226                    id,
1227                    success: false,
1228                    code: Some(123),
1229                }),
1230            ))
1231            .await
1232            .unwrap();
1233
1234        // Finally, verify that we complete and get the expected results
1235        assert_eq!(
1236            proc_output_task.await.unwrap().unwrap(),
1237            RemoteOutput {
1238                success: false,
1239                code: Some(123),
1240                stdout: b"some out".to_vec(),
1241                stderr: b"some err".to_vec(),
1242            }
1243        );
1244    }
1245}