Skip to main content

nucleus/container/
lifecycle.rs

1use crate::container::{ContainerState, ContainerStateManager};
2use crate::error::{NucleusError, Result};
3use nix::sys::signal::{kill, Signal};
4use nix::unistd::Pid;
5use nix::unistd::Uid;
6use std::thread;
7use std::time::Duration;
8use tracing::{info, warn};
9
10/// Container lifecycle operations (stop, kill, delete)
11pub struct ContainerLifecycle;
12
13impl ContainerLifecycle {
14    fn ensure_container_access(state: &ContainerState) -> Result<()> {
15        let current_uid = Uid::effective().as_raw();
16        if current_uid == 0 || current_uid == state.creator_uid {
17            return Ok(());
18        }
19
20        Err(NucleusError::PermissionDenied(format!(
21            "container {} owned by UID {}, caller is UID {}",
22            state.id, state.creator_uid, current_uid
23        )))
24    }
25
26    /// Stop a container gracefully: SIGTERM, wait for timeout, then SIGKILL
27    pub fn stop(state: &ContainerState, timeout_secs: u64) -> Result<()> {
28        Self::ensure_container_access(state)?;
29
30        if !state.is_running() {
31            info!("Container {} is already stopped", state.id);
32            return Ok(());
33        }
34
35        let pid = Pid::from_raw(state.pid as i32);
36
37        // Verify PID is still alive before sending signal.
38        // kill(pid, None) sends signal 0 — a no-op that returns ESRCH if the
39        // PID doesn't exist, protecting against PID recycling TOCTOU races.
40        if let Err(e) = kill(pid, None) {
41            if e == nix::errno::Errno::ESRCH {
42                info!("Process already exited");
43                return Ok(());
44            }
45        }
46
47        // Send SIGTERM
48        info!(
49            "Sending SIGTERM to container {} (PID {})",
50            state.id, state.pid
51        );
52        if let Err(e) = kill(pid, Signal::SIGTERM) {
53            if e == nix::errno::Errno::ESRCH {
54                info!("Process already exited");
55                return Ok(());
56            }
57            return Err(NucleusError::ExecError(format!(
58                "Failed to send SIGTERM: {}",
59                e
60            )));
61        }
62
63        // Wait for process to exit
64        let poll_interval = Duration::from_millis(100);
65        let deadline = Duration::from_secs(timeout_secs);
66        let mut elapsed = Duration::ZERO;
67
68        while elapsed < deadline {
69            if !state.is_running() {
70                info!("Container {} stopped gracefully", state.id);
71                return Ok(());
72            }
73            thread::sleep(poll_interval);
74            elapsed += poll_interval;
75        }
76
77        // Force kill
78        warn!(
79            "Container {} did not stop after {}s, sending SIGKILL",
80            state.id, timeout_secs
81        );
82        if let Err(e) = kill(pid, Signal::SIGKILL) {
83            if e == nix::errno::Errno::ESRCH {
84                return Ok(());
85            }
86            return Err(NucleusError::ExecError(format!(
87                "Failed to send SIGKILL: {}",
88                e
89            )));
90        }
91
92        Ok(())
93    }
94
95    /// Send an arbitrary signal to a container
96    pub fn kill_container(state: &ContainerState, signal: Signal) -> Result<()> {
97        Self::ensure_container_access(state)?;
98
99        if !state.is_running() {
100            return Err(NucleusError::ContainerNotRunning(format!(
101                "Container {} is not running",
102                state.id
103            )));
104        }
105
106        let pid = Pid::from_raw(state.pid as i32);
107        info!(
108            "Sending {:?} to container {} (PID {})",
109            signal, state.id, state.pid
110        );
111
112        kill(pid, signal).map_err(|e| {
113            NucleusError::ExecError(format!("Failed to send signal {:?}: {}", signal, e))
114        })?;
115
116        Ok(())
117    }
118
119    /// Remove a stopped container's state
120    pub fn remove(
121        state_mgr: &ContainerStateManager,
122        state: &ContainerState,
123        force: bool,
124    ) -> Result<()> {
125        Self::ensure_container_access(state)?;
126
127        if state.is_running() {
128            if force {
129                info!("Force removing running container {}", state.id);
130                Self::stop(state, 5)?;
131            } else {
132                return Err(NucleusError::ExecError(format!(
133                    "Container {} is still running. Stop it first or use --force",
134                    state.id
135                )));
136            }
137        }
138
139        // Clean up cgroup directory if present
140        if let Some(ref cgroup_path) = state.cgroup_path {
141            let cgroup = std::path::Path::new(cgroup_path);
142            if cgroup.exists() {
143                if let Err(e) = std::fs::remove_dir_all(cgroup) {
144                    warn!(
145                        "Failed to remove cgroup {}: {} (may still have processes)",
146                        cgroup_path, e
147                    );
148                } else {
149                    info!("Removed cgroup {}", cgroup_path);
150                }
151            }
152        }
153
154        state_mgr.delete_state(&state.id)?;
155        info!("Removed container {}", state.id);
156        Ok(())
157    }
158}
159
160/// Parse a signal name or number string into a Signal
161pub fn parse_signal(s: &str) -> Result<Signal> {
162    // Try numeric
163    if let Ok(num) = s.parse::<i32>() {
164        return Signal::try_from(num)
165            .map_err(|_| NucleusError::ConfigError(format!("Invalid signal number: {}", num)));
166    }
167
168    // Normalize: uppercase and strip optional "SIG" prefix
169    let upper = s.to_ascii_uppercase();
170    let normalized = upper.strip_prefix("SIG").unwrap_or(&upper);
171
172    match normalized {
173        "ABRT" | "IOT" => Ok(Signal::SIGABRT),
174        "ALRM" => Ok(Signal::SIGALRM),
175        "BUS" => Ok(Signal::SIGBUS),
176        "CHLD" | "CLD" => Ok(Signal::SIGCHLD),
177        "CONT" => Ok(Signal::SIGCONT),
178        "FPE" => Ok(Signal::SIGFPE),
179        "HUP" => Ok(Signal::SIGHUP),
180        "ILL" => Ok(Signal::SIGILL),
181        "INT" => Ok(Signal::SIGINT),
182        "IO" | "POLL" => Ok(Signal::SIGIO),
183        "KILL" => Ok(Signal::SIGKILL),
184        "PIPE" => Ok(Signal::SIGPIPE),
185        "PROF" => Ok(Signal::SIGPROF),
186        "PWR" => Ok(Signal::SIGPWR),
187        "QUIT" => Ok(Signal::SIGQUIT),
188        "SEGV" => Ok(Signal::SIGSEGV),
189        "STKFLT" => Ok(Signal::SIGSTKFLT),
190        "STOP" => Ok(Signal::SIGSTOP),
191        "SYS" => Ok(Signal::SIGSYS),
192        "TERM" => Ok(Signal::SIGTERM),
193        "TRAP" => Ok(Signal::SIGTRAP),
194        "TSTP" => Ok(Signal::SIGTSTP),
195        "TTIN" => Ok(Signal::SIGTTIN),
196        "TTOU" => Ok(Signal::SIGTTOU),
197        "URG" => Ok(Signal::SIGURG),
198        "USR1" => Ok(Signal::SIGUSR1),
199        "USR2" => Ok(Signal::SIGUSR2),
200        "VTALRM" => Ok(Signal::SIGVTALRM),
201        "WINCH" => Ok(Signal::SIGWINCH),
202        "XCPU" => Ok(Signal::SIGXCPU),
203        "XFSZ" => Ok(Signal::SIGXFSZ),
204        _ => Err(NucleusError::ConfigError(format!("Unknown signal: {}", s))),
205    }
206}
207
208#[cfg(test)]
209mod tests {
210    use super::*;
211    use crate::container::ContainerStateParams;
212
213    #[test]
214    fn test_parse_signal_by_name() {
215        assert_eq!(parse_signal("TERM").unwrap(), Signal::SIGTERM);
216        assert_eq!(parse_signal("SIGTERM").unwrap(), Signal::SIGTERM);
217        assert_eq!(parse_signal("KILL").unwrap(), Signal::SIGKILL);
218        assert_eq!(parse_signal("SIGKILL").unwrap(), Signal::SIGKILL);
219        assert_eq!(parse_signal("INT").unwrap(), Signal::SIGINT);
220        assert_eq!(parse_signal("HUP").unwrap(), Signal::SIGHUP);
221    }
222
223    #[test]
224    fn test_parse_signal_by_number() {
225        assert_eq!(parse_signal("15").unwrap(), Signal::SIGTERM);
226        assert_eq!(parse_signal("9").unwrap(), Signal::SIGKILL);
227        assert_eq!(parse_signal("2").unwrap(), Signal::SIGINT);
228    }
229
230    #[test]
231    fn test_parse_signal_case_insensitive() {
232        assert_eq!(parse_signal("term").unwrap(), Signal::SIGTERM);
233        assert_eq!(parse_signal("sigterm").unwrap(), Signal::SIGTERM);
234        assert_eq!(parse_signal("Term").unwrap(), Signal::SIGTERM);
235    }
236
237    #[test]
238    fn test_parse_signal_all_standard_names() {
239        let cases = vec![
240            ("ABRT", Signal::SIGABRT),
241            ("IOT", Signal::SIGABRT),
242            ("ALRM", Signal::SIGALRM),
243            ("BUS", Signal::SIGBUS),
244            ("CHLD", Signal::SIGCHLD),
245            ("CLD", Signal::SIGCHLD),
246            ("FPE", Signal::SIGFPE),
247            ("ILL", Signal::SIGILL),
248            ("IO", Signal::SIGIO),
249            ("POLL", Signal::SIGIO),
250            ("PIPE", Signal::SIGPIPE),
251            ("PROF", Signal::SIGPROF),
252            ("PWR", Signal::SIGPWR),
253            ("SEGV", Signal::SIGSEGV),
254            ("STKFLT", Signal::SIGSTKFLT),
255            ("SYS", Signal::SIGSYS),
256            ("TRAP", Signal::SIGTRAP),
257            ("TSTP", Signal::SIGTSTP),
258            ("TTIN", Signal::SIGTTIN),
259            ("TTOU", Signal::SIGTTOU),
260            ("URG", Signal::SIGURG),
261            ("VTALRM", Signal::SIGVTALRM),
262            ("WINCH", Signal::SIGWINCH),
263            ("XCPU", Signal::SIGXCPU),
264            ("XFSZ", Signal::SIGXFSZ),
265        ];
266        for (name, expected) in cases {
267            assert_eq!(
268                parse_signal(name).unwrap(),
269                expected,
270                "parse_signal({name}) failed"
271            );
272            // Also with SIG prefix
273            let prefixed = format!("SIG{name}");
274            assert_eq!(
275                parse_signal(&prefixed).unwrap(),
276                expected,
277                "parse_signal({prefixed}) failed"
278            );
279        }
280    }
281
282    #[test]
283    fn test_parse_signal_invalid() {
284        assert!(parse_signal("INVALID").is_err());
285        assert!(parse_signal("999").is_err());
286    }
287
288    #[test]
289    fn test_access_check_owner_allowed() {
290        let uid = Uid::effective().as_raw();
291        let state = ContainerState::new(ContainerStateParams {
292            id: "testid".to_string(),
293            name: "testname".to_string(),
294            pid: 12345,
295            command: vec!["/bin/true".to_string()],
296            memory_limit: None,
297            cpu_limit: None,
298            using_gvisor: false,
299            rootless: true,
300            cgroup_path: None,
301        });
302        // Override creator to match current caller for this test.
303        let mut state = state;
304        state.creator_uid = uid;
305        assert!(ContainerLifecycle::ensure_container_access(&state).is_ok());
306    }
307}