Skip to main content

kratazone/
exec.rs

1use std::{collections::HashMap, process::Stdio};
2
3use crate::childwait::ChildWait;
4use anyhow::{anyhow, Result};
5use krata::idm::{
6    client::IdmClientStreamResponseHandle,
7    internal::{
8        exec_stream_request_update::Update, request::Request as RequestType,
9        ExecStreamResponseUpdate,
10    },
11    internal::{response::Response as ResponseType, Request, Response},
12};
13use libc::c_int;
14use pty_process::{Pty, Size};
15use tokio::process::Child;
16use tokio::{
17    io::{AsyncReadExt, AsyncWriteExt},
18    join,
19    process::Command,
20    select,
21};
22use tokio_util::sync::CancellationToken;
23
24pub struct ZoneExecTask {
25    pub wait: ChildWait,
26    pub handle: IdmClientStreamResponseHandle<Request>,
27}
28
29impl ZoneExecTask {
30    pub async fn run(&self) -> Result<()> {
31        let mut receiver = self.handle.take().await?;
32
33        let Some(ref request) = self.handle.initial.request else {
34            return Err(anyhow!("request was empty"));
35        };
36
37        let RequestType::ExecStream(update) = request else {
38            return Err(anyhow!("request was not an exec update"));
39        };
40
41        let Some(Update::Start(ref start)) = update.update else {
42            return Err(anyhow!("first request did not contain a start update"));
43        };
44
45        let mut cmd = start.command.clone();
46        if cmd.is_empty() {
47            return Err(anyhow!("command line was empty"));
48        }
49        let exe = cmd.remove(0);
50        let mut env = HashMap::new();
51        for entry in &start.environment {
52            env.insert(entry.key.clone(), entry.value.clone());
53        }
54
55        if !env.contains_key("PATH") {
56            env.insert(
57                "PATH".to_string(),
58                "/bin:/usr/bin:/usr/local/bin:/sbin:/usr/sbin".to_string(),
59            );
60        }
61
62        let dir = if start.working_directory.is_empty() {
63            "/".to_string()
64        } else {
65            start.working_directory.clone()
66        };
67
68        let mut wait_subscription = self.wait.subscribe().await?;
69
70        let code: c_int;
71        if start.tty {
72            let pty = Pty::new().map_err(|error| anyhow!("unable to allocate pty: {}", error))?;
73            let size = start
74                .terminal_size
75                .map(|x| Size::new(x.rows as u16, x.columns as u16))
76                .unwrap_or_else(|| Size::new(24, 80));
77            pty.resize(size)?;
78            let pts = pty
79                .pts()
80                .map_err(|error| anyhow!("unable to allocate pts: {}", error))?;
81            let child = std::panic::catch_unwind(move || {
82                let pts = pts;
83                pty_process::Command::new(exe)
84                    .args(cmd)
85                    .envs(env)
86                    .current_dir(dir)
87                    .spawn(&pts)
88            })
89            .map_err(|_| anyhow!("internal error"))
90            .map_err(|error| anyhow!("failed to spawn: {}", error))??;
91            let mut child = ChildDropGuard {
92                inner: child,
93                kill: true,
94            };
95            let pid = child
96                .inner
97                .id()
98                .ok_or_else(|| anyhow!("pid is not provided"))?;
99            let (mut read, mut write) = pty.into_split();
100            let pty_read_handle = self.handle.clone();
101            let pty_read_task = tokio::task::spawn(async move {
102                let mut stdout_buffer = vec![0u8; 8 * 1024];
103                loop {
104                    let Ok(size) = read.read(&mut stdout_buffer).await else {
105                        break;
106                    };
107                    if size > 0 {
108                        let response = Response {
109                            response: Some(ResponseType::ExecStream(ExecStreamResponseUpdate {
110                                exited: false,
111                                exit_code: 0,
112                                error: String::new(),
113                                stdout: stdout_buffer[0..size].to_vec(),
114                                stderr: vec![],
115                            })),
116                        };
117                        let _ = pty_read_handle.respond(response).await;
118                    } else {
119                        break;
120                    }
121                }
122            });
123
124            let cancel = CancellationToken::new();
125            let stdin_cancel = cancel.clone();
126            let stdin_task = tokio::task::spawn(async move {
127                loop {
128                    let Some(request) = receiver.recv().await else {
129                        stdin_cancel.cancel();
130                        break;
131                    };
132
133                    let Some(RequestType::ExecStream(update)) = request.request else {
134                        continue;
135                    };
136
137                    match update.update {
138                        Some(Update::Stdin(update)) => {
139                            if !update.data.is_empty()
140                                && write.write_all(&update.data).await.is_err()
141                            {
142                                break;
143                            }
144
145                            if update.closed {
146                                break;
147                            }
148                        }
149                        Some(Update::TerminalResize(size)) => {
150                            let _ = write.resize(Size::new(size.rows as u16, size.columns as u16));
151                        }
152                        _ => {
153                            continue;
154                        }
155                    }
156                }
157            });
158
159            code = loop {
160                select! {
161                    result = wait_subscription.recv() => match result {
162                        Ok(event) => {
163                            if event.pid.as_raw() as u32 == pid {
164                                child.kill = false;
165                                break event.status;
166                            }
167                        }
168                        _ => {
169                            child.inner.start_kill()?;
170                            child.kill = false;
171                            break -1;
172                        }
173                    },
174                    _ = cancel.cancelled() => {
175                        child.inner.start_kill()?;
176                        child.kill = false;
177                        break -1;
178                    }
179                }
180            };
181
182            let _ = join!(pty_read_task);
183            stdin_task.abort();
184        } else {
185            let mut child = std::panic::catch_unwind(|| {
186                Command::new(exe)
187                    .args(cmd)
188                    .envs(env)
189                    .current_dir(dir)
190                    .stdin(Stdio::piped())
191                    .stdout(Stdio::piped())
192                    .stderr(Stdio::piped())
193                    .kill_on_drop(true)
194                    .spawn()
195            })
196            .map_err(|_| anyhow!("internal error"))
197            .map_err(|error| anyhow!("failed to spawn: {}", error))??;
198
199            let pid = child.id().ok_or_else(|| anyhow!("pid is not provided"))?;
200            let mut stdin = child
201                .stdin
202                .take()
203                .ok_or_else(|| anyhow!("stdin was missing"))?;
204            let mut stdout = child
205                .stdout
206                .take()
207                .ok_or_else(|| anyhow!("stdout was missing"))?;
208            let mut stderr = child
209                .stderr
210                .take()
211                .ok_or_else(|| anyhow!("stderr was missing"))?;
212
213            let stdout_handle = self.handle.clone();
214            let stdout_task = tokio::task::spawn(async move {
215                let mut stdout_buffer = vec![0u8; 8 * 1024];
216                loop {
217                    let Ok(size) = stdout.read(&mut stdout_buffer).await else {
218                        break;
219                    };
220                    if size > 0 {
221                        let response = Response {
222                            response: Some(ResponseType::ExecStream(ExecStreamResponseUpdate {
223                                exited: false,
224                                exit_code: 0,
225                                error: String::new(),
226                                stdout: stdout_buffer[0..size].to_vec(),
227                                stderr: vec![],
228                            })),
229                        };
230                        let _ = stdout_handle.respond(response).await;
231                    } else {
232                        break;
233                    }
234                }
235            });
236
237            let stderr_handle = self.handle.clone();
238            let stderr_task = tokio::task::spawn(async move {
239                let mut stderr_buffer = vec![0u8; 8 * 1024];
240                loop {
241                    let Ok(size) = stderr.read(&mut stderr_buffer).await else {
242                        break;
243                    };
244                    if size > 0 {
245                        let response = Response {
246                            response: Some(ResponseType::ExecStream(ExecStreamResponseUpdate {
247                                exited: false,
248                                exit_code: 0,
249                                error: String::new(),
250                                stdout: vec![],
251                                stderr: stderr_buffer[0..size].to_vec(),
252                            })),
253                        };
254                        let _ = stderr_handle.respond(response).await;
255                    } else {
256                        break;
257                    }
258                }
259            });
260
261            let cancel = CancellationToken::new();
262            let stdin_cancel = cancel.clone();
263            let stdin_task = tokio::task::spawn(async move {
264                loop {
265                    let Some(request) = receiver.recv().await else {
266                        stdin_cancel.cancel();
267                        break;
268                    };
269
270                    let Some(RequestType::ExecStream(update)) = request.request else {
271                        continue;
272                    };
273
274                    let Some(Update::Stdin(update)) = update.update else {
275                        continue;
276                    };
277
278                    if stdin.write_all(&update.data).await.is_err() {
279                        break;
280                    }
281                }
282            });
283
284            let data_task = tokio::task::spawn(async move {
285                let _ = join!(stdout_task, stderr_task);
286                stdin_task.abort();
287            });
288
289            code = loop {
290                select! {
291                    result = wait_subscription.recv() => match result {
292                        Ok(event) => {
293                            if event.pid.as_raw() as u32 == pid {
294                                break event.status;
295                            }
296                        }
297                        _ => {
298                            child.start_kill()?;
299                            break -1;
300                        }
301                    },
302                    _ = cancel.cancelled() => {
303                        child.start_kill()?;
304                        break -1;
305                    }
306                }
307            };
308            data_task.await?;
309        }
310        let response = Response {
311            response: Some(ResponseType::ExecStream(ExecStreamResponseUpdate {
312                exited: true,
313                exit_code: code,
314                error: String::new(),
315                stdout: vec![],
316                stderr: vec![],
317            })),
318        };
319        self.handle.respond(response).await?;
320
321        Ok(())
322    }
323}
324
325struct ChildDropGuard {
326    pub inner: Child,
327    pub kill: bool,
328}
329
330impl Drop for ChildDropGuard {
331    fn drop(&mut self) {
332        if self.kill {
333            drop(self.inner.start_kill());
334        }
335    }
336}