1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
use std::fmt::Debug;
use std::future::Future;
use std::io;
use std::sync::{Arc, Weak};
use std::time::Duration;

use futures::channel::mpsc;
use futures::future::{AbortHandle, Abortable, Aborted, BoxFuture};
use futures::stream::BoxStream;
use futures::{FutureExt, Sink, SinkExt, Stream, StreamExt, TryFutureExt};
use log::{debug, trace};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::TcpStream;
use tokio::sync::Mutex;
use tokio_util::codec::{FramedRead, FramedWrite};

use crate::codec::{ClientCodec, InputChunk, OutputChunk};
use crate::execution::{child_channel, send_to_io, ChildInput, ChildOutput, Command, ExitCode};
use crate::Config;

pub struct Child {
    ///
    /// A stream of outputs from the remote child process.
    ///
    /// Similar to `std::process::Child`, you should `take` this instance to avoid partial moves:
    ///   let output_stream = child.output_stream.take().unwrap();
    ///
    pub output_stream: Option<BoxStream<'static, ChildOutput>>,
    ///
    /// A future for the exit code of the remote process.
    ///
    exit_code: Option<BoxFuture<'static, Result<ExitCode, io::Error>>>,
    ///
    /// A callable to shut down the write half of the connection upon request.
    ///
    shutdown: Option<BoxFuture<'static, ()>>,
    ///
    /// A handle to cancel the background task managing the connection when the Child is dropped.
    ///
    abort_handle: AbortHandle,
}

impl Child {
    ///
    /// Closes the write half of the connection to the server, which will trigger cancellation in
    /// well behaved servers. Because the read half of the connection will still be open, a well
    /// behaved server/Nail will render teardown information before exiting.
    ///
    /// Dropping the Child instance also triggers cancellation, but closes both the read and write
    /// halves of the connection at the same time (which does not allow for orderly shutdown of
    /// the server).
    ///
    pub async fn shutdown(&mut self) {
        if let Some(shutdown) = self.shutdown.take() {
            shutdown.await;
        }
    }

    ///
    /// Wait for the Child to have exited, and return an ExitCode.
    ///
    pub async fn wait(mut self) -> Result<ExitCode, io::Error> {
        // This method may only be called once, so it's safe to take the exit code unconditionally.
        self.exit_code.take().unwrap().await
    }
}

impl Drop for Child {
    fn drop(&mut self) {
        self.abort_handle.abort();
    }
}

///
/// Implements the client side of a single connection on the given socket.
///
/// The `input_stream` is lazily instantiated because servers only optionally accept input, and
/// clients should not begin reading stdin from their callers unless the server will accept it.
///
pub async fn handle_connection(
    config: Config,
    socket: TcpStream,
    cmd: Command,
    open_input_stream: impl Future<Output = mpsc::Receiver<ChildInput>> + Send + 'static,
) -> Result<Child, io::Error> {
    socket.set_nodelay(true)?;
    let (read, write) = socket.into_split();
    execute(config, read, write, cmd, open_input_stream).await
}

///
/// Converts a Command into the initialize chunks for the nailgun protocol. Note: order matters.
///
fn command_as_chunks(cmd: Command) -> Vec<InputChunk> {
    let Command {
        command,
        args,
        env,
        working_dir,
    } = cmd;

    let mut chunks = Vec::new();
    chunks.extend(args.into_iter().map(InputChunk::Argument));
    chunks.extend(
        env.into_iter()
            .map(|(key, val)| InputChunk::Environment { key, val }),
    );
    chunks.push(InputChunk::WorkingDir(working_dir));
    chunks.push(InputChunk::Command(command));
    chunks
}

async fn execute<R, W>(
    config: Config,
    read: R,
    write: W,
    cmd: Command,
    open_cli_read: impl Future<Output = mpsc::Receiver<ChildInput>> + Send + 'static,
) -> Result<Child, io::Error>
where
    R: AsyncRead + Debug + Unpin + Send + 'static,
    W: AsyncWrite + Debug + Unpin + Send + 'static,
{
    let server_read = FramedRead::new(read, ClientCodec);
    let mut server_write = FramedWrite::new(write, ClientCodec);

    // Send all of the init chunks.
    let mut init_chunks = futures::stream::iter(command_as_chunks(cmd).into_iter().map(Ok))
        .inspect(|i| debug!("nails client sending initialization chunk {:?}", i));
    server_write
        .send_all(&mut init_chunks)
        .map_err(|e| {
            io_err(&format!(
                "Could not send initial chunks to the server. Got: {}",
                e
            ))
        })
        .await?;
    let server_write = Arc::new(Mutex::new(Some(server_write)));

    // Calls to shutdown will drop the write half of the socket.
    let shutdown = {
        let server_write = server_write.clone();
        async move {
            // Take and drop the write half of the connection (if it has not already been dropped).
            let _ = server_write.lock().await.take();
        }
    };

    // If configured, spawn a task to send heartbeats.
    if let Some(heartbeat_frequency) = config.heartbeat_frequency {
        let _join = tokio::spawn(heartbeat_sender(
            Arc::downgrade(&server_write),
            heartbeat_frequency,
        ));
    }

    // Then handle stdio until we receive an ExitCode, or until the Child is dropped.
    let (cli_write, cli_read) = child_channel::<ChildOutput>();
    let (abort_handle, exit_code) = {
        // We spawn the execution of the process onto a background task to ensure that it starts
        // running even if a consumer of the Child instance chooses to completely consume the stdio
        // output_stream before interacting with the exit code (rathering than `join`ing them).
        //
        // We wrap in Abortable so that dropping the Child instance cancels the background task.
        let (abort_handle, abort_registration) = AbortHandle::new_pair();
        let stdio_task = handle_stdio(server_read, server_write.clone(), cli_write, open_cli_read);
        let exit_code_result = tokio::spawn(Abortable::new(stdio_task, abort_registration));
        let exit_code = async move {
            match exit_code_result.await.unwrap() {
                Ok(res) => res,
                Err(Aborted) => Err(io::Error::new(
                    io::ErrorKind::ConnectionAborted,
                    "The connection was canceled because the Child was dropped",
                )),
            }
        }
        .boxed();
        (abort_handle, exit_code)
    };
    Ok(Child {
        output_stream: Some(cli_read.boxed()),
        exit_code: Some(exit_code),
        shutdown: Some(shutdown.boxed()),
        abort_handle,
    })
}

async fn handle_stdio<S: ServerSink>(
    mut server_read: impl Stream<Item = Result<OutputChunk, io::Error>> + Unpin,
    server_write: Arc<Mutex<Option<S>>>,
    mut cli_write: mpsc::Sender<ChildOutput>,
    open_cli_read: impl Future<Output = mpsc::Receiver<ChildInput>>,
) -> Result<ExitCode, io::Error> {
    let mut stdin_inputs = Some((server_write, open_cli_read));
    while let Some(output_chunk) = server_read.next().await {
        match output_chunk? {
            OutputChunk::Stderr(bytes) => {
                trace!("nails client got {} bytes of stderr.", bytes.len());
                cli_write
                    .send(ChildOutput::Stderr(bytes))
                    .map_err(|e| send_to_io(e))
                    .await?;
            }
            OutputChunk::Stdout(bytes) => {
                trace!("nails client got {} bytes of stdout.", bytes.len());
                cli_write
                    .send(ChildOutput::Stdout(bytes))
                    .map_err(|e| send_to_io(e))
                    .await?;
            }
            OutputChunk::StartReadingStdin => {
                // We spawn a task to send stdin after receiving `StartReadingStdin`, but only
                // once: some servers (ours included, optionally) have a `noisy_stdin` behaviour
                // where they ask for more input after every Stdin chunk.
                if let Some((server_write, open_cli_read)) = stdin_inputs.take() {
                    debug!("nails client will start sending stdin.");
                    let _join = tokio::spawn(stdin_sender(server_write, open_cli_read.await));
                }
            }
            OutputChunk::Exit(code) => {
                trace!("nails client got exit code: {}", code);
                return Ok(ExitCode(code));
            }
        }
    }
    Err(io_err(
        "Client exited before the server's result could be returned.",
    ))
}

async fn stdin_sender<S: ServerSink>(
    server_write: Arc<Mutex<Option<S>>>,
    mut cli_read: mpsc::Receiver<ChildInput>,
) -> Result<(), io::Error> {
    while let Some(input_chunk) = cli_read.next().await {
        if let Some(ref mut server_write) = *server_write.lock().await {
            match input_chunk {
                ChildInput::Stdin(bytes) => {
                    trace!("nails client sending {} bytes of stdin.", bytes.len());
                    server_write.send(InputChunk::Stdin(bytes)).await?;
                }
            }
        } else {
            break;
        };
    }

    if let Some(ref mut server_write) = *server_write.lock().await {
        server_write.send(InputChunk::StdinEOF).await?;
    }
    Ok(())
}

async fn heartbeat_sender<S: ServerSink>(
    server_write: Weak<Mutex<Option<S>>>,
    heartbeat_frequency: Duration,
) -> Result<(), io::Error> {
    loop {
        // Wait a fraction of the desired frequency (which from a client's perspective is a
        // minimum: more frequent is fine).
        tokio::time::sleep(heartbeat_frequency / 4).await;

        // Then, if the connection might still be alive...
        if let Some(server_write) = server_write.upgrade() {
            let mut server_write = server_write.lock().await;
            if let Some(ref mut server_write) = *server_write {
                server_write.send(InputChunk::Heartbeat).await?;
            } else {
                break Ok(());
            }
        } else {
            break Ok(());
        };
    }
}

fn io_err(e: &str) -> io::Error {
    io::Error::new(io::ErrorKind::Other, e)
}

///
///TODO: See https://users.rust-lang.org/t/why-cant-type-aliases-be-used-for-traits/10002/4
///
 #[cfg_attr(rustfmt, rustfmt_skip)]
trait ServerSink: Debug + Sink<InputChunk, Error = io::Error> + Unpin + Send + 'static {}
#[cfg_attr(rustfmt, rustfmt_skip)]
impl<T> ServerSink for T where T: Debug + Sink<InputChunk, Error = io::Error> + Unpin + Send + 'static {}