use crate::capability::{Capability, Context, Output};
use crate::processes::ProcessSnapshot;
use crate::{Error, Result};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::time::Duration;
#[cfg(test)]
use std::process::Command;
fn get_process_start_time(pid: u32) -> Option<u64> {
let stat_path = format!("/proc/{}/stat", pid);
let content = std::fs::read_to_string(&stat_path).ok()?;
let last_paren = content.rfind(')')?;
let fields: Vec<&str> = content[last_paren + 2..].split_whitespace().collect();
fields.get(19)?.parse::<u64>().ok()
}
fn get_process_start_time_retry(pid: u32) -> Option<u64> {
for attempt in 0..3 {
if attempt > 0 {
std::thread::sleep(std::time::Duration::from_millis(10 * (1 << attempt)));
}
if let Some(start_time) = get_process_start_time(pid) {
return Some(start_time);
}
}
None
}
fn get_process_cgroup(pid: u32) -> Option<String> {
std::fs::read_to_string(format!("/proc/{}/cgroup", pid)).ok()
}
fn is_systemd_service(cgroup: &str) -> bool {
cgroup.contains("/system.slice/")
|| cgroup.contains("/init.scope")
|| cgroup.contains("systemd")
}
fn protected_pids() -> Vec<u32> {
let mut pids = vec![1, 2];
let self_pid = std::process::id();
pids.push(self_pid);
if let Ok(status) = std::fs::read_to_string(format!("/proc/{}/status", self_pid)) {
if let Some(ppid_str) = status
.lines()
.find(|l| l.starts_with("PPid:"))
.and_then(|l| l.split_whitespace().nth(1))
{
if let Ok(ppid) = ppid_str.parse::<u32>() {
pids.push(ppid);
}
}
}
if let Ok(status) = std::fs::read_to_string(format!("/proc/{}/status", self_pid)) {
if let Some(sid_str) = status
.lines()
.find(|l| l.starts_with("Sid:"))
.and_then(|l| l.split_whitespace().nth(1))
{
if let Ok(sid) = sid_str.parse::<u32>() {
if sid != 0 {
pids.push(sid);
}
}
}
}
if let Ok(status) = std::fs::read_to_string(format!("/proc/{}/status", self_pid)) {
if let Some(pgid_str) = status
.lines()
.find(|l| l.starts_with("NSpgid:"))
.and_then(|l| l.split_whitespace().nth(1))
{
if let Ok(pgid) = pgid_str.parse::<u32>() {
if pgid != 0 {
pids.push(pgid);
}
}
}
}
if let Ok(entries) = std::fs::read_dir("/proc") {
for entry in entries.flatten() {
if let Ok(name) = entry.file_name().into_string() {
if let Ok(pid) = name.parse::<u32>() {
if let Some(cgroup) = get_process_cgroup(pid) {
if is_systemd_service(&cgroup) {
pids.push(pid);
}
}
}
}
}
}
pids.sort();
pids.dedup();
pids
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KillArgs {
pub pid: u32,
pub signal: Option<i32>,
}
pub struct Kill;
impl Capability for Kill {
fn name(&self) -> &'static str {
"Kill"
}
fn description(&self) -> &'static str {
"Terminate a process by PID. Protects critical system processes (init, kthreadd, self). Supports custom signals."
}
fn schema(&self) -> Value {
serde_json::json!({
"type": "object",
"properties": {
"pid": { "type": "integer", "minimum": 1 },
"signal": {
"type": "integer",
"anyOf": [
{ "minimum": 1, "maximum": 31 },
{ "enum": [64] }
]
}
},
"required": ["pid"]
})
}
fn validate(&self, args: &Value) -> Result<()> {
let args: KillArgs = serde_json::from_value(args.clone())
.map_err(|e| Error::SchemaValidationFailed(e.to_string()))?;
if let Some(signal) = args.signal {
if !(1..=31).contains(&signal) && signal != 64 {
return Err(Error::SchemaValidationFailed(format!(
"Invalid signal {}: must be 1-31 or 64 (POSIX signals)",
signal
)));
}
}
Ok(())
}
fn execute(&self, args: &Value, ctx: &Context) -> Result<Output> {
let args: KillArgs = serde_json::from_value(args.clone())
.map_err(|e| Error::ExecutionFailed(e.to_string()))?;
let protected = protected_pids();
if protected.contains(&args.pid) {
return Err(Error::ExecutionFailed(format!(
"PID {} is a protected system process (protected: {:?})",
args.pid, protected
)));
}
if ctx.dry_run {
return Ok(Output {
success: true,
data: serde_json::json!({
"pid": args.pid,
"killed": false,
"dry_run": true,
"signal": args.signal.unwrap_or(15),
}),
message: Some(format!("DRY RUN: would kill PID {}", args.pid)),
});
}
let process_before = ProcessSnapshot::capture();
let process_exists = process_before.processes.iter().any(|p| p.pid == args.pid);
if !process_exists {
return Ok(Output {
success: false,
data: serde_json::json!({
"pid": args.pid,
"killed": false,
"reason": "Process not found"
}),
message: Some(format!("Process {} not found", args.pid)),
});
}
let process_info: Option<(String, String)> = process_before
.processes
.iter()
.find(|p| p.pid == args.pid)
.map(|p| (p.command.clone(), p.user.clone()));
let start_time_before = get_process_start_time_retry(args.pid);
let signal = args.signal.unwrap_or(15);
let kill_result = unsafe { libc::kill(args.pid as libc::pid_t, signal) };
let success = kill_result == 0;
let stderr_str = if !success {
std::io::Error::last_os_error().to_string()
} else {
String::new()
};
std::thread::sleep(Duration::from_millis(500));
ProcessSnapshot::clear_cache();
let process_after = ProcessSnapshot::capture();
let process_still_exists = process_after
.processes
.iter()
.any(|p| p.pid == args.pid && !p.stat.starts_with('Z'));
let pid_reused = match (start_time_before, get_process_start_time_retry(args.pid)) {
(Some(before_time), Some(after_time)) => before_time != after_time,
(None, _) => false,
(Some(_), None) => true,
};
let killed_success = success && !process_still_exists && !pid_reused;
let message = if killed_success {
format!("Killed process {} (signal {})", args.pid, signal)
} else if pid_reused {
format!(
"PID {} was reused by a different process (start time changed)",
args.pid
)
} else if !success {
format!("Failed to kill process {}: {}", args.pid, stderr_str)
} else {
format!("Process {} still exists after signal {}", args.pid, signal)
};
Ok(Output {
success: killed_success,
data: serde_json::json!({
"pid": args.pid,
"killed": killed_success,
"signal": signal,
"command": process_info.as_ref().map(|(cmd, _)| cmd),
"user": process_info.as_ref().map(|(_, user)| user),
"stderr": if !success { stderr_str.clone() } else { String::new() },
"pid_reused": pid_reused,
"process_before": {
"count": process_before.summary.total_processes,
"zombies": process_before.summary.zombie_count
},
"process_after": {
"count": process_after.summary.total_processes,
"zombies": process_after.summary.zombie_count
}
}),
message: Some(message),
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::capability::Capability;
use std::thread;
use std::time::Duration;
#[test]
fn test_kill_schema() {
let cap = Kill;
let _schema = cap.schema();
let mut child = Command::new("sleep").arg("60").spawn().unwrap();
let pid = child.id();
let result = get_process_start_time_retry(pid);
assert!(result.is_some(), "Should read start time for running process");
child.kill().ok();
let _ = child.wait();
let result = get_process_start_time_retry(999999);
assert!(result.is_none(), "Non-existent PID should return None");
}
#[test]
fn test_kill_protected_pid() {
let cap = Kill;
let result = cap.execute(
&serde_json::json!({ "pid": 1 }),
&Context {
dry_run: false,
job_id: "test".into(),
working_dir: std::env::current_dir().unwrap(),
},
);
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("protected system process"));
}
#[test]
fn test_kill_self_protected() {
let cap = Kill;
let self_pid = std::process::id();
let result = cap.execute(
&serde_json::json!({ "pid": self_pid }),
&Context {
dry_run: false,
job_id: "test".into(),
working_dir: std::env::current_dir().unwrap(),
},
);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("protected"));
}
#[test]
fn test_kill_nonexistent() {
let cap = Kill;
let result = cap
.execute(
&serde_json::json!({ "pid": 999999 }),
&Context {
dry_run: false,
job_id: "test".into(),
working_dir: std::env::current_dir().unwrap(),
},
)
.unwrap();
assert!(!result.success);
assert!(result.data["killed"].as_bool() == Some(false));
}
#[test]
fn test_kill_dry_run() {
let cap = Kill;
let result = cap
.execute(
&serde_json::json!({ "pid": 999998 }),
&Context {
dry_run: true,
job_id: "test".into(),
working_dir: std::env::current_dir().unwrap(),
},
)
.unwrap();
assert!(result.success);
assert!(result.data["dry_run"].as_bool() == Some(true));
assert!(result.data["killed"].as_bool() == Some(false));
}
#[test]
fn test_kill_actual_process() {
let mut child = Command::new("sleep").arg("60").spawn().unwrap();
let pid = child.id();
thread::sleep(Duration::from_millis(100));
let pre_check = Command::new("kill").arg("-0").arg(pid.to_string()).output();
assert!(
pre_check.unwrap().status.success(),
"Process should exist before kill"
);
ProcessSnapshot::clear_cache();
let cap = Kill;
let result = cap
.execute(
&serde_json::json!({ "pid": pid, "signal": 9 }),
&Context {
dry_run: false,
job_id: "test".into(),
working_dir: std::env::current_dir().unwrap(),
},
)
.unwrap();
assert!(
result.data["killed"].as_bool() == Some(true),
"Kill failed: {:?}",
result.data
);
assert!(
result.data["signal"].as_i64() == Some(9),
"Should use SIGKILL"
);
let _ = child.wait();
let post_check = Command::new("kill").arg("-0").arg(pid.to_string()).output();
let still_alive = post_check.map(|o| o.status.success()).unwrap_or(false);
assert!(
!still_alive,
"Process {} should be dead after kill and reap",
pid
);
}
#[test]
fn test_get_process_start_time() {
let mut child = Command::new("sleep").arg("60").spawn().unwrap();
let pid = child.id();
let start_time = get_process_start_time(pid);
assert!(
start_time.is_some(),
"Should be able to read start time for running process"
);
let start_time2 = get_process_start_time(pid);
assert_eq!(start_time, start_time2, "Start time should be stable");
child.kill().ok();
let _ = child.wait();
}
#[test]
fn test_get_process_start_time_nonexistent() {
let result = get_process_start_time(999999);
assert!(result.is_none(), "Non-existent PID should return None");
}
#[test]
fn test_signal_validation_rejects_negative() {
let cap = Kill;
let result = cap.validate(&serde_json::json!({ "pid": 999998, "signal": -1 }));
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("Invalid signal"));
}
#[test]
fn test_signal_validation_rejects_zero() {
let cap = Kill;
let result = cap.validate(&serde_json::json!({ "pid": 999998, "signal": 0 }));
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("Invalid signal"));
}
#[test]
fn test_signal_validation_rejects_out_of_range() {
let cap = Kill;
let result = cap.validate(&serde_json::json!({ "pid": 999998, "signal": 32 }));
assert!(result.is_err());
}
#[test]
fn test_signal_validation_accepts_valid_signals() {
let cap = Kill;
for sig in [1, 9, 15, 31, 64] {
let result = cap.validate(&serde_json::json!({ "pid": 999998, "signal": sig }));
assert!(result.is_ok(), "Signal {} should be valid", sig);
}
}
#[test]
fn test_dry_run_hides_process_info() {
let cap = Kill;
let result = cap
.execute(
&serde_json::json!({ "pid": 999998 }),
&Context {
dry_run: true,
job_id: "test".into(),
working_dir: std::env::current_dir().unwrap(),
},
)
.unwrap();
assert!(result.success);
assert!(result.data["dry_run"].as_bool() == Some(true));
assert!(result.data.get("command").is_none(), "dry-run must not expose command");
assert!(result.data.get("user").is_none(), "dry-run must not expose user");
assert!(result.data.get("process_exists").is_none(), "dry-run must not expose process_exists");
}
#[test]
fn test_protected_pids_includes_self_and_parent() {
let protected = protected_pids();
let self_pid = std::process::id();
assert!(protected.contains(&1), "PID 1 should be protected");
assert!(protected.contains(&2), "PID 2 should be protected");
assert!(protected.contains(&self_pid), "self PID should be protected");
}
#[test]
fn test_get_process_start_time_retry() {
let mut child = Command::new("sleep").arg("60").spawn().unwrap();
let pid = child.id();
let result = get_process_start_time_retry(pid);
assert!(result.is_some(), "Should read start time for running process");
child.kill().ok();
let _ = child.wait();
let result = get_process_start_time_retry(999999);
assert!(result.is_none(), "Non-existent PID should return None");
}
}