Skip to main content

windows_erg/process/
tree.rs

1//! Process tree operations (kill tree, ancestors, etc.).
2
3use std::collections::{HashMap, HashSet};
4
5use super::processes::Process;
6use super::types::{ProcessId, ProcessInfo};
7use crate::error::Result;
8
9impl Process {
10    /// Kill this process and all its descendants.
11    ///
12    /// This will recursively kill all child processes first, then the parent.
13    pub fn kill_tree(&self) -> Result<()> {
14        Self::kill_tree_by_id(self.id())
15    }
16
17    /// Kill a process and all its descendants by process ID.
18    pub fn kill_tree_by_id(pid: ProcessId) -> Result<()> {
19        let mut buffer = Vec::with_capacity(8192);
20        Self::kill_tree_by_id_with_buffer(pid, &mut buffer)
21    }
22
23    /// Kill a process tree using a reusable output buffer.
24    pub fn kill_tree_by_id_with_buffer(
25        pid: ProcessId,
26        out_processes: &mut Vec<ProcessInfo>,
27    ) -> Result<()> {
28        Self::list_with_buffer(out_processes)?;
29        let tree = build_process_tree(out_processes);
30
31        // Collect all descendants
32        let mut to_kill = HashSet::new();
33        collect_descendants(pid, &tree, &mut to_kill);
34        to_kill.insert(pid);
35
36        // Kill in reverse order (children before parents)
37        let mut kill_order: Vec<_> = to_kill.into_iter().collect();
38        kill_order.sort_by_key(|&pid| std::cmp::Reverse(tree_depth(pid, &tree)));
39
40        for kill_pid in kill_order {
41            // Ignore errors - process might have already exited
42            let _ = Self::kill_by_id(kill_pid);
43        }
44
45        Ok(())
46    }
47
48    /// Find the root ancestor of a process and kill the entire tree.
49    ///
50    /// This walks up the parent chain to find the topmost process,
51    /// then kills that entire tree.
52    pub fn kill_tree_from_root(pid: ProcessId) -> Result<()> {
53        let mut buffer = Vec::with_capacity(8192);
54        Self::kill_tree_from_root_with_buffer(pid, &mut buffer)
55    }
56
57    /// Kill tree from root using a reusable output buffer.
58    pub fn kill_tree_from_root_with_buffer(
59        pid: ProcessId,
60        out_processes: &mut Vec<ProcessInfo>,
61    ) -> Result<()> {
62        Self::list_with_buffer(out_processes)?;
63        let tree = build_process_tree(out_processes);
64
65        // Find root ancestor
66        let root = find_root_ancestor(pid, &tree);
67
68        // Kill entire tree from root
69        Self::kill_tree_by_id_with_buffer(root, out_processes)
70    }
71}
72
73/// Build a parent->children mapping.
74fn build_process_tree(processes: &[ProcessInfo]) -> HashMap<ProcessId, Vec<ProcessId>> {
75    let mut tree: HashMap<ProcessId, Vec<ProcessId>> = HashMap::new();
76
77    for proc in processes {
78        if let Some(parent_pid) = proc.parent_pid {
79            tree.entry(parent_pid).or_default().push(proc.pid);
80        }
81    }
82
83    tree
84}
85
86/// Recursively collect all descendants of a process.
87fn collect_descendants(
88    pid: ProcessId,
89    tree: &HashMap<ProcessId, Vec<ProcessId>>,
90    result: &mut HashSet<ProcessId>,
91) {
92    if let Some(children) = tree.get(&pid) {
93        for &child in children {
94            if result.insert(child) {
95                collect_descendants(child, tree, result);
96            }
97        }
98    }
99}
100
101/// Calculate depth of a process in the tree (for kill ordering).
102fn tree_depth(pid: ProcessId, tree: &HashMap<ProcessId, Vec<ProcessId>>) -> usize {
103    if let Some(children) = tree.get(&pid) {
104        1 + children
105            .iter()
106            .map(|&child| tree_depth(child, tree))
107            .max()
108            .unwrap_or(0)
109    } else {
110        0
111    }
112}
113
114/// Find the root ancestor by walking up the parent chain.
115fn find_root_ancestor(
116    mut pid: ProcessId,
117    parent_map: &HashMap<ProcessId, Vec<ProcessId>>,
118) -> ProcessId {
119    // Build reverse map (child -> parent)
120    let mut child_to_parent: HashMap<ProcessId, ProcessId> = HashMap::new();
121    for (&parent, children) in parent_map {
122        for &child in children {
123            child_to_parent.insert(child, parent);
124        }
125    }
126
127    // Walk up to root
128    let mut visited = HashSet::new();
129    while let Some(&parent) = child_to_parent.get(&pid) {
130        if !visited.insert(pid) {
131            // Cycle detected, return current
132            break;
133        }
134        pid = parent;
135    }
136
137    pid
138}