1use std::path::PathBuf;
2use std::sync::Arc;
3
4use distant_net::client::Mailbox;
5use distant_net::common::{Request, Response};
6use log::*;
7use tokio::io;
8use tokio::sync::mpsc::error::{TryRecvError, TrySendError};
9use tokio::sync::{mpsc, RwLock};
10use tokio::task::JoinHandle;
11
12use crate::client::DistantChannel;
13use crate::constants::CLIENT_PIPE_CAPACITY;
14use crate::protocol::{self, Cmd, Environment, ProcessId, PtySize};
15
16#[derive(Clone, Debug, PartialEq, Eq)]
17pub struct RemoteOutput {
18 pub success: bool,
19 pub code: Option<i32>,
20 pub stdout: Vec<u8>,
21 pub stderr: Vec<u8>,
22}
23
24#[derive(Copy, Clone, Debug, PartialEq, Eq)]
25pub struct RemoteStatus {
26 pub success: bool,
27 pub code: Option<i32>,
28}
29
30impl From<(bool, Option<i32>)> for RemoteStatus {
31 fn from((success, code): (bool, Option<i32>)) -> Self {
32 Self { success, code }
33 }
34}
35
36type StatusResult = io::Result<RemoteStatus>;
37
38pub struct RemoteCommand {
41 pty: Option<PtySize>,
42 environment: Environment,
43 current_dir: Option<PathBuf>,
44}
45
46impl Default for RemoteCommand {
47 fn default() -> Self {
48 Self::new()
49 }
50}
51
52impl RemoteCommand {
53 pub fn new() -> Self {
55 Self {
56 pty: None,
57 environment: Environment::new(),
58 current_dir: None,
59 }
60 }
61
62 pub fn pty(&mut self, pty: Option<PtySize>) -> &mut Self {
64 self.pty = pty;
65 self
66 }
67
68 pub fn environment(&mut self, environment: Environment) -> &mut Self {
70 self.environment = environment;
71 self
72 }
73
74 pub fn current_dir(&mut self, current_dir: Option<PathBuf>) -> &mut Self {
76 self.current_dir = current_dir;
77 self
78 }
79
80 pub async fn spawn(
82 &mut self,
83 mut channel: DistantChannel,
84 cmd: impl Into<String>,
85 ) -> io::Result<RemoteProcess> {
86 let cmd = cmd.into();
87
88 let mut mailbox = channel
90 .mail(Request::new(protocol::Msg::Single(
91 protocol::Request::ProcSpawn {
92 cmd: Cmd::from(cmd),
93 pty: self.pty,
94 environment: self.environment.clone(),
95 current_dir: self.current_dir.clone(),
96 },
97 )))
98 .await?;
99
100 let (id, origin_id) = match mailbox.next().await {
102 Some(res) => {
103 let origin_id = res.origin_id;
104 match res.payload {
105 protocol::Msg::Single(protocol::Response::ProcSpawned { id }) => {
106 (id, origin_id)
107 }
108 protocol::Msg::Single(protocol::Response::Error(x)) => return Err(x.into()),
109 protocol::Msg::Single(x) => {
110 return Err(io::Error::new(
111 io::ErrorKind::InvalidData,
112 format!("Got response type of {}", x.as_ref()),
113 ))
114 }
115 protocol::Msg::Batch(_) => {
116 return Err(io::Error::new(
117 io::ErrorKind::InvalidData,
118 "Got batch instead of single response",
119 ));
120 }
121 }
122 }
123 None => return Err(io::Error::from(io::ErrorKind::ConnectionAborted)),
124 };
125
126 let (stdin_tx, stdin_rx) = mpsc::channel(CLIENT_PIPE_CAPACITY);
128 let (stdout_tx, stdout_rx) = mpsc::channel(CLIENT_PIPE_CAPACITY);
129 let (stderr_tx, stderr_rx) = mpsc::channel(CLIENT_PIPE_CAPACITY);
130 let (resize_tx, resize_rx) = mpsc::channel(1);
131
132 let (kill_tx, kill_rx) = mpsc::channel(1);
135 let kill_tx_2 = kill_tx.clone();
136
137 let (abort_res_task_tx, mut abort_res_task_rx) = mpsc::channel::<()>(1);
140 let res_task = tokio::spawn(async move {
141 tokio::select! {
142 _ = abort_res_task_rx.recv() => {
143 panic!("killed");
144 }
145 res = process_incoming_responses(id, mailbox, stdout_tx, stderr_tx, kill_tx_2) => {
146 res
147 }
148 }
149 });
150
151 let (abort_req_task_tx, mut abort_req_task_rx) = mpsc::channel::<()>(1);
153 let req_task = tokio::spawn(async move {
154 tokio::select! {
155 _ = abort_req_task_rx.recv() => {
156 panic!("killed");
157 }
158 res = process_outgoing_requests( id, channel, stdin_rx, resize_rx, kill_rx) => {
159 res
160 }
161 }
162 });
163
164 let status = Arc::new(RwLock::new(None));
165 let status_2 = Arc::clone(&status);
166 let wait_task = tokio::spawn(async move {
167 let res = match tokio::try_join!(req_task, res_task) {
168 Ok((_, res)) => res.map(RemoteStatus::from),
169 Err(x) => Err(io::Error::new(io::ErrorKind::Interrupted, x)),
170 };
171 status_2.write().await.replace(res);
172 });
173
174 Ok(RemoteProcess {
175 id,
176 origin_id,
177 abort_req_task_tx,
178 abort_res_task_tx,
179 stdin: Some(RemoteStdin(stdin_tx)),
180 stdout: Some(RemoteStdout(stdout_rx)),
181 stderr: Some(RemoteStderr(stderr_rx)),
182 resizer: RemoteProcessResizer(resize_tx),
183 killer: RemoteProcessKiller(kill_tx),
184 wait_task,
185 status,
186 })
187 }
188}
189
190#[derive(Debug)]
192pub struct RemoteProcess {
193 id: ProcessId,
195
196 origin_id: String,
198
199 abort_req_task_tx: mpsc::Sender<()>,
201
202 abort_res_task_tx: mpsc::Sender<()>,
204
205 pub stdin: Option<RemoteStdin>,
207
208 pub stdout: Option<RemoteStdout>,
210
211 pub stderr: Option<RemoteStderr>,
213
214 resizer: RemoteProcessResizer,
216
217 killer: RemoteProcessKiller,
219
220 wait_task: JoinHandle<()>,
222
223 status: Arc<RwLock<Option<StatusResult>>>,
225}
226
227impl RemoteProcess {
228 pub fn id(&self) -> ProcessId {
230 self.id
231 }
232
233 pub fn origin_id(&self) -> &str {
235 &self.origin_id
236 }
237
238 pub async fn status(&self) -> Option<RemoteStatus> {
243 self.status.read().await.as_ref().map(|x| match x {
244 Ok(status) => *status,
245 Err(_) => RemoteStatus {
246 success: false,
247 code: None,
248 },
249 })
250 }
251
252 pub async fn wait(self) -> io::Result<RemoteStatus> {
254 let _ = self.wait_task.await;
256
257 self.status
259 .write()
260 .await
261 .take()
262 .unwrap_or_else(|| Err(errors::unexpected_eof()))
263 }
264
265 pub async fn output(mut self) -> io::Result<RemoteOutput> {
268 let maybe_stdout = self.stdout.take();
269 let maybe_stderr = self.stderr.take();
270
271 let status = self.wait().await?;
272
273 let mut stdout = Vec::new();
274 if let Some(mut reader) = maybe_stdout {
275 while let Ok(data) = reader.read().await {
276 stdout.extend(&data);
277 }
278 }
279
280 let mut stderr = Vec::new();
281 if let Some(mut reader) = maybe_stderr {
282 while let Ok(data) = reader.read().await {
283 stderr.extend(&data);
284 }
285 }
286
287 Ok(RemoteOutput {
288 success: status.success,
289 code: status.code,
290 stdout,
291 stderr,
292 })
293 }
294
295 pub async fn resize(&self, size: PtySize) -> io::Result<()> {
297 self.resizer.resize(size).await
298 }
299
300 pub fn clone_resizer(&self) -> RemoteProcessResizer {
302 self.resizer.clone()
303 }
304
305 pub async fn kill(&mut self) -> io::Result<()> {
307 self.killer.kill().await
308 }
309
310 pub fn clone_killer(&self) -> RemoteProcessKiller {
312 self.killer.clone()
313 }
314
315 pub fn abort(&self) {
319 let _ = self.abort_req_task_tx.try_send(());
320 let _ = self.abort_res_task_tx.try_send(());
321 }
322}
323
324#[derive(Clone, Debug)]
326pub struct RemoteProcessResizer(mpsc::Sender<PtySize>);
327
328impl RemoteProcessResizer {
329 pub async fn resize(&self, size: PtySize) -> io::Result<()> {
331 self.0
332 .send(size)
333 .await
334 .map_err(|_| errors::dead_channel())?;
335 Ok(())
336 }
337}
338
339#[derive(Clone, Debug)]
341pub struct RemoteProcessKiller(mpsc::Sender<()>);
342
343impl RemoteProcessKiller {
344 pub async fn kill(&mut self) -> io::Result<()> {
346 self.0.send(()).await.map_err(|_| errors::dead_channel())?;
347 Ok(())
348 }
349}
350
351#[derive(Clone, Debug)]
353pub struct RemoteStdin(mpsc::Sender<Vec<u8>>);
354
355impl RemoteStdin {
356 pub fn disconnected() -> Self {
358 Self(mpsc::channel(1).0)
359 }
360
361 pub fn try_write(&mut self, data: impl Into<Vec<u8>>) -> io::Result<()> {
365 match self.0.try_send(data.into()) {
366 Ok(data) => Ok(data),
367 Err(TrySendError::Full(_)) => Err(io::Error::from(io::ErrorKind::WouldBlock)),
368 Err(TrySendError::Closed(_)) => Err(io::Error::from(io::ErrorKind::BrokenPipe)),
369 }
370 }
371
372 pub fn try_write_str(&mut self, data: impl Into<String>) -> io::Result<()> {
374 self.try_write(data.into().into_bytes())
375 }
376
377 pub async fn write(&mut self, data: impl Into<Vec<u8>>) -> io::Result<()> {
379 self.0
380 .send(data.into())
381 .await
382 .map_err(|x| io::Error::new(io::ErrorKind::BrokenPipe, x))
383 }
384
385 pub async fn write_str(&mut self, data: impl Into<String>) -> io::Result<()> {
387 self.write(data.into().into_bytes()).await
388 }
389
390 pub fn is_closed(&self) -> bool {
392 self.0.is_closed()
393 }
394}
395
396#[derive(Debug)]
398pub struct RemoteStdout(mpsc::Receiver<Vec<u8>>);
399
400impl RemoteStdout {
401 pub fn try_read(&mut self) -> io::Result<Option<Vec<u8>>> {
404 match self.0.try_recv() {
405 Ok(data) => Ok(Some(data)),
406 Err(TryRecvError::Empty) => Ok(None),
407 Err(TryRecvError::Disconnected) => Err(io::Error::from(io::ErrorKind::BrokenPipe)),
408 }
409 }
410
411 pub fn try_read_string(&mut self) -> io::Result<Option<String>> {
413 self.try_read().and_then(|x| match x {
414 Some(data) => String::from_utf8(data)
415 .map(Some)
416 .map_err(|x| io::Error::new(io::ErrorKind::InvalidData, x)),
417 None => Ok(None),
418 })
419 }
420
421 pub async fn read(&mut self) -> io::Result<Vec<u8>> {
424 self.0
425 .recv()
426 .await
427 .ok_or_else(|| io::Error::from(io::ErrorKind::BrokenPipe))
428 }
429
430 pub async fn read_string(&mut self) -> io::Result<String> {
432 self.read().await.and_then(|data| {
433 String::from_utf8(data).map_err(|x| io::Error::new(io::ErrorKind::InvalidData, x))
434 })
435 }
436}
437
438#[derive(Debug)]
440pub struct RemoteStderr(mpsc::Receiver<Vec<u8>>);
441
442impl RemoteStderr {
443 pub fn try_read(&mut self) -> io::Result<Option<Vec<u8>>> {
446 match self.0.try_recv() {
447 Ok(data) => Ok(Some(data)),
448 Err(TryRecvError::Empty) => Ok(None),
449 Err(TryRecvError::Disconnected) => Err(io::Error::from(io::ErrorKind::BrokenPipe)),
450 }
451 }
452
453 pub fn try_read_string(&mut self) -> io::Result<Option<String>> {
455 self.try_read().and_then(|x| match x {
456 Some(data) => String::from_utf8(data)
457 .map(Some)
458 .map_err(|x| io::Error::new(io::ErrorKind::InvalidData, x)),
459 None => Ok(None),
460 })
461 }
462
463 pub async fn read(&mut self) -> io::Result<Vec<u8>> {
466 self.0
467 .recv()
468 .await
469 .ok_or_else(|| io::Error::from(io::ErrorKind::BrokenPipe))
470 }
471
472 pub async fn read_string(&mut self) -> io::Result<String> {
474 self.read().await.and_then(|data| {
475 String::from_utf8(data).map_err(|x| io::Error::new(io::ErrorKind::InvalidData, x))
476 })
477 }
478}
479
480async fn process_outgoing_requests(
483 id: ProcessId,
484 mut channel: DistantChannel,
485 mut stdin_rx: mpsc::Receiver<Vec<u8>>,
486 mut resize_rx: mpsc::Receiver<PtySize>,
487 mut kill_rx: mpsc::Receiver<()>,
488) -> io::Result<()> {
489 let result = loop {
490 tokio::select! {
491 data = stdin_rx.recv() => {
492 match data {
493 Some(data) => channel.fire(
494 Request::new(
495 protocol::Msg::Single(protocol::Request::ProcStdin { id, data })
496 )
497 ).await?,
498 None => break Err(errors::dead_channel()),
499 }
500 }
501 size = resize_rx.recv() => {
502 match size {
503 Some(size) => channel.fire(
504 Request::new(
505 protocol::Msg::Single(protocol::Request::ProcResizePty { id, size })
506 )
507 ).await?,
508 None => break Err(errors::dead_channel()),
509 }
510 }
511 msg = kill_rx.recv() => {
512 if msg.is_some() {
513 channel.fire(Request::new(
514 protocol::Msg::Single(protocol::Request::ProcKill { id })
515 )).await?;
516 break Ok(());
517 } else {
518 break Err(errors::dead_channel());
519 }
520 }
521 }
522 };
523
524 trace!("Process outgoing channel closed");
525 result
526}
527
528async fn process_incoming_responses(
530 proc_id: ProcessId,
531 mut mailbox: Mailbox<Response<protocol::Msg<protocol::Response>>>,
532 stdout_tx: mpsc::Sender<Vec<u8>>,
533 stderr_tx: mpsc::Sender<Vec<u8>>,
534 kill_tx: mpsc::Sender<()>,
535) -> io::Result<(bool, Option<i32>)> {
536 while let Some(res) = mailbox.next().await {
537 let payload = res.payload.into_vec();
538
539 let exit_status = payload.iter().find_map(|data| match data {
541 protocol::Response::ProcDone { id, success, code } if *id == proc_id => {
542 Some((*success, *code))
543 }
544 _ => None,
545 });
546
547 for data in payload {
550 match data {
551 protocol::Response::ProcStdout { id, data } if id == proc_id => {
552 let _ = stdout_tx.send(data).await;
553 }
554 protocol::Response::ProcStderr { id, data } if id == proc_id => {
555 let _ = stderr_tx.send(data).await;
556 }
557 _ => {}
558 }
559 }
560
561 if let Some((success, code)) = exit_status {
563 let _ = kill_tx.try_send(());
565
566 return Ok((success, code));
567 }
568 }
569
570 let _ = kill_tx.try_send(());
572
573 trace!("Process incoming channel closed");
574 Err(errors::unexpected_eof())
575}
576
577mod errors {
578 use std::io;
579
580 pub fn dead_channel() -> io::Error {
581 io::Error::new(io::ErrorKind::BrokenPipe, "Channel is dead")
582 }
583
584 pub fn unexpected_eof() -> io::Error {
585 io::Error::from(io::ErrorKind::UnexpectedEof)
586 }
587}
588
589#[cfg(test)]
590mod tests {
591 use std::time::Duration;
592
593 use distant_net::common::{FramedTransport, InmemoryTransport, Response};
594 use distant_net::Client;
595 use test_log::test;
596
597 use super::*;
598 use crate::client::DistantClient;
599 use crate::protocol::{Error, ErrorKind};
600
601 fn make_session() -> (FramedTransport<InmemoryTransport>, DistantClient) {
602 let (t1, t2) = FramedTransport::pair(100);
603 (t1, Client::spawn_inmemory(t2, Default::default()))
604 }
605
606 #[test(tokio::test)]
607 async fn spawn_should_return_invalid_data_if_received_batch_response() {
608 let (mut transport, session) = make_session();
609
610 let spawn_task = tokio::spawn(async move {
613 RemoteCommand::new()
614 .spawn(session.clone_channel(), String::from("cmd arg"))
615 .await
616 });
617
618 let req: Request<protocol::Msg<protocol::Request>> =
620 transport.read_frame_as().await.unwrap().unwrap();
621
622 transport
624 .write_frame_for(&Response::new(
625 req.id,
626 protocol::Msg::Batch(vec![protocol::Response::ProcSpawned { id: 1 }]),
627 ))
628 .await
629 .unwrap();
630
631 match spawn_task.await.unwrap() {
633 Err(x) if x.kind() == io::ErrorKind::InvalidData => {}
634 x => panic!("Unexpected result: {:?}", x),
635 }
636 }
637
638 #[test(tokio::test)]
639 async fn spawn_should_return_invalid_data_if_did_not_get_a_indicator_that_process_started() {
640 let (mut transport, session) = make_session();
641
642 let spawn_task = tokio::spawn(async move {
645 RemoteCommand::new()
646 .spawn(session.clone_channel(), String::from("cmd arg"))
647 .await
648 });
649
650 let req: Request<protocol::Msg<protocol::Request>> =
652 transport.read_frame_as().await.unwrap().unwrap();
653
654 transport
656 .write_frame_for(&Response::new(
657 req.id,
658 protocol::Msg::Single(protocol::Response::Error(Error {
659 kind: ErrorKind::BrokenPipe,
660 description: String::from("some error"),
661 })),
662 ))
663 .await
664 .unwrap();
665
666 match spawn_task.await.unwrap() {
668 Err(x) if x.kind() == io::ErrorKind::BrokenPipe => {}
669 x => panic!("Unexpected result: {:?}", x),
670 }
671 }
672
673 #[test(tokio::test)]
674 async fn kill_should_return_error_if_internal_tasks_already_completed() {
675 let (mut transport, session) = make_session();
676
677 let spawn_task = tokio::spawn(async move {
680 RemoteCommand::new()
681 .spawn(session.clone_channel(), String::from("cmd arg"))
682 .await
683 });
684
685 let req: Request<protocol::Msg<protocol::Request>> =
687 transport.read_frame_as().await.unwrap().unwrap();
688
689 let id = 12345;
691 transport
692 .write_frame_for(&Response::new(
693 req.id,
694 protocol::Msg::Single(protocol::Response::ProcSpawned { id }),
695 ))
696 .await
697 .unwrap();
698
699 let mut proc = spawn_task.await.unwrap().unwrap();
701 proc.abort();
702
703 tokio::task::yield_now().await;
705
706 match proc.kill().await {
707 Err(x) if x.kind() == io::ErrorKind::BrokenPipe => {}
708 x => panic!("Unexpected result: {:?}", x),
709 }
710 }
711
712 #[test(tokio::test)]
713 async fn kill_should_send_proc_kill_request_and_then_cause_stdin_forwarding_to_close() {
714 let (mut transport, session) = make_session();
715
716 let spawn_task = tokio::spawn(async move {
719 RemoteCommand::new()
720 .spawn(session.clone_channel(), String::from("cmd arg"))
721 .await
722 });
723
724 let req: Request<protocol::Msg<protocol::Request>> =
726 transport.read_frame_as().await.unwrap().unwrap();
727
728 let id = 12345;
730 transport
731 .write_frame_for(&Response::new(
732 req.id,
733 protocol::Msg::Single(protocol::Response::ProcSpawned { id }),
734 ))
735 .await
736 .unwrap();
737
738 let mut proc = spawn_task.await.unwrap().unwrap();
740 assert!(proc.kill().await.is_ok(), "Failed to send kill request");
741
742 let req: Request<protocol::Msg<protocol::Request>> =
744 transport.read_frame_as().await.unwrap().unwrap();
745 match req.payload {
746 protocol::Msg::Single(protocol::Request::ProcKill { id: proc_id }) => {
747 assert_eq!(proc_id, id)
748 }
749 x => panic!("Unexpected request: {:?}", x),
750 }
751
752 assert_eq!(
754 proc.stdin
755 .as_mut()
756 .unwrap()
757 .write("some stdin")
758 .await
759 .unwrap_err()
760 .kind(),
761 io::ErrorKind::BrokenPipe
762 );
763 }
764
765 #[test(tokio::test)]
766 async fn stdin_should_be_forwarded_from_receiver_field() {
767 let (mut transport, session) = make_session();
768
769 let spawn_task = tokio::spawn(async move {
772 RemoteCommand::new()
773 .spawn(session.clone_channel(), String::from("cmd arg"))
774 .await
775 });
776
777 let req: Request<protocol::Msg<protocol::Request>> =
779 transport.read_frame_as().await.unwrap().unwrap();
780
781 let id = 12345;
783 transport
784 .write_frame_for(&Response::new(
785 req.id,
786 protocol::Msg::Single(protocol::Response::ProcSpawned { id }),
787 ))
788 .await
789 .unwrap();
790
791 let mut proc = spawn_task.await.unwrap().unwrap();
793 proc.stdin
794 .as_mut()
795 .unwrap()
796 .write("some input")
797 .await
798 .unwrap();
799
800 let req: Request<protocol::Msg<protocol::Request>> =
802 transport.read_frame_as().await.unwrap().unwrap();
803 match req.payload {
804 protocol::Msg::Single(protocol::Request::ProcStdin { id, data }) => {
805 assert_eq!(id, 12345);
806 assert_eq!(data, b"some input");
807 }
808 x => panic!("Unexpected request: {:?}", x),
809 }
810 }
811
812 #[test(tokio::test)]
813 async fn stdout_should_be_forwarded_to_receiver_field() {
814 let (mut transport, session) = make_session();
815
816 let spawn_task = tokio::spawn(async move {
819 RemoteCommand::new()
820 .spawn(session.clone_channel(), String::from("cmd arg"))
821 .await
822 });
823
824 let req: Request<protocol::Msg<protocol::Request>> =
826 transport.read_frame_as().await.unwrap().unwrap();
827
828 let id = 12345;
830 transport
831 .write_frame_for(&Response::new(
832 req.id.clone(),
833 protocol::Msg::Single(protocol::Response::ProcSpawned { id }),
834 ))
835 .await
836 .unwrap();
837
838 let mut proc = spawn_task.await.unwrap().unwrap();
840
841 transport
842 .write_frame_for(&Response::new(
843 req.id,
844 protocol::Msg::Single(protocol::Response::ProcStdout {
845 id,
846 data: b"some out".to_vec(),
847 }),
848 ))
849 .await
850 .unwrap();
851
852 let out = proc.stdout.as_mut().unwrap().read().await.unwrap();
853 assert_eq!(out, b"some out");
854 }
855
856 #[test(tokio::test)]
857 async fn stderr_should_be_forwarded_to_receiver_field() {
858 let (mut transport, session) = make_session();
859
860 let spawn_task = tokio::spawn(async move {
863 RemoteCommand::new()
864 .spawn(session.clone_channel(), String::from("cmd arg"))
865 .await
866 });
867
868 let req: Request<protocol::Msg<protocol::Request>> =
870 transport.read_frame_as().await.unwrap().unwrap();
871
872 let id = 12345;
874 transport
875 .write_frame_for(&Response::new(
876 req.id.clone(),
877 protocol::Msg::Single(protocol::Response::ProcSpawned { id }),
878 ))
879 .await
880 .unwrap();
881
882 let mut proc = spawn_task.await.unwrap().unwrap();
884
885 transport
886 .write_frame_for(&Response::new(
887 req.id,
888 protocol::Msg::Single(protocol::Response::ProcStderr {
889 id,
890 data: b"some err".to_vec(),
891 }),
892 ))
893 .await
894 .unwrap();
895
896 let out = proc.stderr.as_mut().unwrap().read().await.unwrap();
897 assert_eq!(out, b"some err");
898 }
899
900 #[test(tokio::test)]
901 async fn status_should_return_none_if_not_done() {
902 let (mut transport, session) = make_session();
903
904 let spawn_task = tokio::spawn(async move {
907 RemoteCommand::new()
908 .spawn(session.clone_channel(), String::from("cmd arg"))
909 .await
910 });
911
912 let req: Request<protocol::Msg<protocol::Request>> =
914 transport.read_frame_as().await.unwrap().unwrap();
915
916 let id = 12345;
918 transport
919 .write_frame_for(&Response::new(
920 req.id,
921 protocol::Msg::Single(protocol::Response::ProcSpawned { id }),
922 ))
923 .await
924 .unwrap();
925
926 let proc = spawn_task.await.unwrap().unwrap();
928
929 let result = proc.status().await;
930 assert_eq!(result, None, "Unexpectedly got proc status: {:?}", result);
931 }
932
933 #[test(tokio::test)]
934 async fn status_should_return_false_for_success_if_internal_tasks_fail() {
935 let (mut transport, session) = make_session();
936
937 let spawn_task = tokio::spawn(async move {
940 RemoteCommand::new()
941 .spawn(session.clone_channel(), String::from("cmd arg"))
942 .await
943 });
944
945 let req: Request<protocol::Msg<protocol::Request>> =
947 transport.read_frame_as().await.unwrap().unwrap();
948
949 let id = 12345;
951 transport
952 .write_frame_for(&Response::new(
953 req.id,
954 protocol::Msg::Single(protocol::Response::ProcSpawned { id }),
955 ))
956 .await
957 .unwrap();
958
959 let proc = spawn_task.await.unwrap().unwrap();
961 proc.abort();
962
963 tokio::time::sleep(Duration::from_millis(100)).await;
965
966 let result = proc.status().await;
968 match result {
969 Some(status) => {
970 assert!(!status.success, "Status unexpectedly reported success");
971 assert!(
972 status.code.is_none(),
973 "Status unexpectedly reported exit code"
974 );
975 }
976 x => panic!("Unexpected result: {:?}", x),
977 }
978 }
979
980 #[test(tokio::test)]
981 async fn status_should_return_process_status_when_done() {
982 let (mut transport, session) = make_session();
983
984 let spawn_task = tokio::spawn(async move {
987 RemoteCommand::new()
988 .spawn(session.clone_channel(), String::from("cmd arg"))
989 .await
990 });
991
992 let req: Request<protocol::Msg<protocol::Request>> =
994 transport.read_frame_as().await.unwrap().unwrap();
995
996 let id = 12345;
998 transport
999 .write_frame_for(&Response::new(
1000 req.id.clone(),
1001 protocol::Msg::Single(protocol::Response::ProcSpawned { id }),
1002 ))
1003 .await
1004 .unwrap();
1005
1006 let proc = spawn_task.await.unwrap().unwrap();
1008
1009 transport
1011 .write_frame_for(&Response::new(
1012 req.id,
1013 protocol::Msg::Single(protocol::Response::ProcDone {
1014 id,
1015 success: true,
1016 code: Some(123),
1017 }),
1018 ))
1019 .await
1020 .unwrap();
1021
1022 tokio::time::sleep(Duration::from_millis(100)).await;
1024
1025 assert_eq!(
1027 proc.status().await,
1028 Some(RemoteStatus {
1029 success: true,
1030 code: Some(123)
1031 })
1032 );
1033 }
1034
1035 #[test(tokio::test)]
1036 async fn wait_should_return_error_if_internal_tasks_fail() {
1037 let (mut transport, session) = make_session();
1038
1039 let spawn_task = tokio::spawn(async move {
1042 RemoteCommand::new()
1043 .spawn(session.clone_channel(), String::from("cmd arg"))
1044 .await
1045 });
1046
1047 let req: Request<protocol::Msg<protocol::Request>> =
1049 transport.read_frame_as().await.unwrap().unwrap();
1050
1051 let id = 12345;
1053 transport
1054 .write_frame_for(&Response::new(
1055 req.id,
1056 protocol::Msg::Single(protocol::Response::ProcSpawned { id }),
1057 ))
1058 .await
1059 .unwrap();
1060
1061 let proc = spawn_task.await.unwrap().unwrap();
1063 proc.abort();
1064
1065 match proc.wait().await {
1066 Err(x) if x.kind() == io::ErrorKind::Interrupted => {}
1067 x => panic!("Unexpected result: {:?}", x),
1068 }
1069 }
1070
1071 #[test(tokio::test)]
1072 async fn wait_should_return_error_if_connection_terminates_before_receiving_done_response() {
1073 let (mut transport, session) = make_session();
1074
1075 let spawn_task = tokio::spawn(async move {
1078 RemoteCommand::new()
1079 .spawn(session.clone_channel(), String::from("cmd arg"))
1080 .await
1081 });
1082
1083 let req: Request<protocol::Msg<protocol::Request>> =
1085 transport.read_frame_as().await.unwrap().unwrap();
1086
1087 let id = 12345;
1089 transport
1090 .write_frame_for(&Response::new(
1091 req.id,
1092 protocol::Msg::Single(protocol::Response::ProcSpawned { id }),
1093 ))
1094 .await
1095 .unwrap();
1096
1097 let proc = spawn_task.await.unwrap().unwrap();
1099
1100 tokio::task::yield_now().await;
1102
1103 drop(transport);
1104
1105 tokio::task::yield_now().await;
1107
1108 match proc.wait().await {
1109 Err(x) if x.kind() == io::ErrorKind::UnexpectedEof => {}
1110 x => panic!("Unexpected result: {:?}", x),
1111 }
1112 }
1113
1114 #[test(tokio::test)]
1115 async fn receiving_done_response_should_result_in_wait_returning_exit_information() {
1116 let (mut transport, session) = make_session();
1117
1118 let spawn_task = tokio::spawn(async move {
1121 RemoteCommand::new()
1122 .spawn(session.clone_channel(), String::from("cmd arg"))
1123 .await
1124 });
1125
1126 let req: Request<protocol::Msg<protocol::Request>> =
1128 transport.read_frame_as().await.unwrap().unwrap();
1129
1130 let id = 12345;
1132 transport
1133 .write_frame_for(&Response::new(
1134 req.id.clone(),
1135 protocol::Msg::Single(protocol::Response::ProcSpawned { id }),
1136 ))
1137 .await
1138 .unwrap();
1139
1140 let proc = spawn_task.await.unwrap().unwrap();
1142 let proc_wait_task = tokio::spawn(proc.wait());
1143
1144 transport
1146 .write_frame_for(&Response::new(
1147 req.id,
1148 protocol::Msg::Single(protocol::Response::ProcDone {
1149 id,
1150 success: false,
1151 code: Some(123),
1152 }),
1153 ))
1154 .await
1155 .unwrap();
1156
1157 assert_eq!(
1159 proc_wait_task.await.unwrap().unwrap(),
1160 RemoteStatus {
1161 success: false,
1162 code: Some(123)
1163 }
1164 );
1165 }
1166
1167 #[test(tokio::test)]
1168 async fn receiving_done_response_should_result_in_output_returning_exit_information() {
1169 let (mut transport, session) = make_session();
1170
1171 let spawn_task = tokio::spawn(async move {
1174 RemoteCommand::new()
1175 .spawn(session.clone_channel(), String::from("cmd arg"))
1176 .await
1177 });
1178
1179 let req: Request<protocol::Msg<protocol::Request>> =
1181 transport.read_frame_as().await.unwrap().unwrap();
1182
1183 let id = 12345;
1185 transport
1186 .write_frame_for(&Response::new(
1187 req.id.clone(),
1188 protocol::Msg::Single(protocol::Response::ProcSpawned { id }),
1189 ))
1190 .await
1191 .unwrap();
1192
1193 let proc = spawn_task.await.unwrap().unwrap();
1195 let proc_output_task = tokio::spawn(proc.output());
1196
1197 transport
1199 .write_frame_for(&Response::new(
1200 req.id.clone(),
1201 protocol::Msg::Single(protocol::Response::ProcStdout {
1202 id,
1203 data: b"some out".to_vec(),
1204 }),
1205 ))
1206 .await
1207 .unwrap();
1208
1209 transport
1211 .write_frame_for(&Response::new(
1212 req.id.clone(),
1213 protocol::Msg::Single(protocol::Response::ProcStderr {
1214 id,
1215 data: b"some err".to_vec(),
1216 }),
1217 ))
1218 .await
1219 .unwrap();
1220
1221 transport
1223 .write_frame_for(&Response::new(
1224 req.id,
1225 protocol::Msg::Single(protocol::Response::ProcDone {
1226 id,
1227 success: false,
1228 code: Some(123),
1229 }),
1230 ))
1231 .await
1232 .unwrap();
1233
1234 assert_eq!(
1236 proc_output_task.await.unwrap().unwrap(),
1237 RemoteOutput {
1238 success: false,
1239 code: Some(123),
1240 stdout: b"some out".to_vec(),
1241 stderr: b"some err".to_vec(),
1242 }
1243 );
1244 }
1245}