1use std::collections::HashMap;
2use std::io::{ErrorKind, Write};
3use std::os::unix::prelude::CommandExt;
4use std::path::Path;
5use std::{process, thread, time};
6
7use nix::sys::signal;
8use nix::unistd::Pid;
9use oci_spec::runtime::{Hook, State as OciState};
10
11use crate::container::{State, StateConversionError};
12use crate::utils;
13
14#[derive(Debug, thiserror::Error)]
15pub enum HookError {
16 #[error("failed to execute hook command")]
17 CommandExecute(#[source] std::io::Error),
18 #[error("failed to encode container state")]
19 EncodeContainerState(#[source] serde_json::Error),
20 #[error("hook command exited with non-zero exit code: {0}")]
21 NonZeroExitCode(i32),
22 #[error("hook command was killed by a signal")]
23 Killed,
24 #[error("failed to execute hook command due to a timeout")]
25 Timeout,
26 #[error("container state is required to run hook")]
27 MissingContainerState,
28 #[error("failed to write container state to stdin")]
29 WriteContainerState(#[source] std::io::Error),
30 #[error("failed to convert state to OCI format")]
31 StateConversion(#[from] StateConversionError),
32}
33
34type Result<T> = std::result::Result<T, HookError>;
35
36pub fn run_hooks(
37 hooks: Option<&Vec<Hook>>,
38 state: Option<&State>,
39 cwd: Option<&Path>,
41 pid: Option<Pid>,
42) -> Result<()> {
43 let base_state = state.ok_or(HookError::MissingContainerState)?;
44
45 let mut oci_state = OciState::try_from(base_state)?;
49
50 if let Some(override_pid) = pid {
54 oci_state.set_pid(Some(override_pid.as_raw()));
55 }
56
57 if let Some(hooks) = hooks {
58 for hook in hooks {
59 let mut hook_command = process::Command::new(hook.path());
60
61 if let Some(cwd) = cwd {
62 hook_command.current_dir(cwd);
63 }
64
65 if let Some((arg0, args)) = hook.args().as_ref().and_then(|a| a.split_first()) {
72 tracing::debug!("run_hooks arg0: {:?}, args: {:?}", arg0, args);
73 hook_command.arg0(arg0).args(args)
74 } else {
75 hook_command.arg0(hook.path().display().to_string())
76 };
77
78 let envs: HashMap<String, String> = if let Some(env) = hook.env() {
79 utils::parse_env(env)
80 } else {
81 HashMap::new()
82 };
83 tracing::debug!("run_hooks envs: {:?}", envs);
84
85 let mut hook_process = hook_command
86 .env_clear()
87 .envs(envs)
88 .stdin(process::Stdio::piped())
89 .stdout(std::process::Stdio::null())
90 .stderr(process::Stdio::inherit())
91 .spawn()
92 .map_err(HookError::CommandExecute)?;
93 let hook_process_pid = Pid::from_raw(hook_process.id() as i32);
94 if let Some(stdin) = &mut hook_process.stdin {
97 let encoded_state =
106 serde_json::to_string(&oci_state).map_err(HookError::EncodeContainerState)?;
107 if let Err(e) = stdin.write_all(encoded_state.as_bytes()) {
108 if e.kind() != ErrorKind::BrokenPipe {
109 let _ = signal::kill(hook_process_pid, signal::Signal::SIGKILL);
112 return Err(HookError::WriteContainerState(e));
113 }
114 }
115 }
116
117 let res = if let Some(timeout_sec) = hook.timeout() {
118 let (s, r) = std::sync::mpsc::channel();
129 thread::spawn(move || {
130 let res = hook_process.wait();
131 let _ = s.send(res);
132 });
133 match r.recv_timeout(time::Duration::from_secs(timeout_sec as u64)) {
134 Ok(res) => res,
135 Err(std::sync::mpsc::RecvTimeoutError::Timeout) => {
136 let _ = signal::kill(hook_process_pid, signal::Signal::SIGKILL);
139 return Err(HookError::Timeout);
140 }
141 Err(_) => {
142 unreachable!();
143 }
144 }
145 } else {
146 hook_process.wait()
147 };
148
149 match res {
150 Ok(exit_status) => match exit_status.code() {
151 Some(0) => Ok(()),
152 Some(exit_code) => Err(HookError::NonZeroExitCode(exit_code)),
153 None => Err(HookError::Killed),
154 },
155 Err(e) => Err(HookError::CommandExecute(e)),
156 }?;
157 }
158 }
159
160 Ok(())
161}
162
163#[cfg(test)]
164mod test {
165 use std::{env, fs};
166
167 use anyhow::{Context, Result, bail};
168 use oci_spec::runtime::HookBuilder;
169 use serial_test::serial;
170
171 use super::*;
172 use crate::container::Container;
173
174 fn is_command_in_path(program: &str) -> bool {
175 if let Ok(path) = env::var("PATH") {
176 for p in path.split(':') {
177 let p_str = format!("{p}/{program}");
178 if fs::metadata(p_str).is_ok() {
179 return true;
180 }
181 }
182 }
183 false
184 }
185
186 #[test]
194 #[serial]
195 fn test_run_hook() -> Result<()> {
196 {
197 let default_container: Container = Default::default();
198 run_hooks(None, Some(&default_container.state), None, None)
199 .context("Failed simple test")?;
200 }
201
202 {
203 assert!(is_command_in_path("true"), "The true was not found.");
204 let default_container: Container = Default::default();
205
206 let hook = HookBuilder::default().path("true").build()?;
207 let hooks = Some(vec![hook]);
208 run_hooks(hooks.as_ref(), Some(&default_container.state), None, None)
209 .context("Failed true")?;
210 }
211
212 {
213 assert!(
214 is_command_in_path("printenv"),
215 "The printenv was not found."
216 );
217 let default_container: Container = Default::default();
219 let hook = HookBuilder::default()
220 .path("bash")
221 .args(vec![
222 String::from("bash"),
223 String::from("-c"),
224 String::from("printenv key > /dev/null"),
225 ])
226 .env(vec![String::from("key=value")])
227 .build()?;
228 let hooks = Some(vec![hook]);
229 run_hooks(hooks.as_ref(), Some(&default_container.state), None, None)
230 .context("Failed printenv test")?;
231 }
232
233 {
234 assert!(is_command_in_path("pwd"), "The pwd was not found.");
235
236 let tmp = tempfile::tempdir()?;
237
238 let default_container: Container = Default::default();
239 let hook = HookBuilder::default()
240 .path("bash")
241 .args(vec![
242 String::from("bash"),
243 String::from("-c"),
244 format!("test $(pwd) = {:?}", tmp.path()),
245 ])
246 .build()?;
247 let hooks = Some(vec![hook]);
248 run_hooks(
249 hooks.as_ref(),
250 Some(&default_container.state),
251 Some(tmp.path()),
252 None,
253 )
254 .context("Failed pwd test")?;
255 }
256
257 {
258 let default_container: Container = Default::default();
259 let expected_pid = Pid::from_raw(1000);
260
261 let hook = HookBuilder::default()
262 .path("bash")
263 .args(vec![
264 String::from("bash"),
265 String::from("-c"),
266 format!("cat | grep '\"pid\":{}'", expected_pid),
267 ])
268 .build()?;
269 let hooks = Some(vec![hook]);
270 run_hooks(
271 hooks.as_ref(),
272 Some(&default_container.state),
273 None,
274 Some(expected_pid),
275 )
276 .context("Failed pid test")?;
277 }
278
279 Ok(())
280 }
281
282 #[test]
283 #[serial]
284 fn test_run_hook_timeout() -> Result<()> {
287 let default_container: Container = Default::default();
288 let hook = HookBuilder::default()
290 .path("tail")
291 .args(vec![
292 String::from("tail"),
293 String::from("-f"),
294 String::from("/dev/null"),
295 ])
296 .timeout(1)
297 .build()?;
298 let hooks = Some(vec![hook]);
299 match run_hooks(hooks.as_ref(), Some(&default_container.state), None, None) {
300 Ok(_) => {
301 bail!(
302 "The test expects the hook to error out with timeout. Should not execute cleanly"
303 );
304 }
305 Err(HookError::Timeout) => {}
306 Err(err) => {
307 bail!(
308 "The test expects the hook to error out with timeout. Got error: {}",
309 err
310 );
311 }
312 };
313
314 Ok(())
315 }
316}