oclean 0.1.1

Process-cleanup wrapper for opencode sessions
use std::collections::HashSet;
use std::thread;
use std::time::Duration;

use nix::sys::signal::{self, Signal};
use nix::unistd::{getpgid, getpid, Pid};

use crate::process_tree::{discover_descendants, pgid_exists, pid_exists, process_groups};

const POLL_INTERVAL: Duration = Duration::from_millis(150);
const TERM_GRACE: Duration = Duration::from_millis(1_200);
const SWEEP_PASSES: usize = 3;

pub struct Tracker {
    root_pid: i32,
    wrapper_pgid: i32,
    known_pids: HashSet<i32>,
    known_pgids: HashSet<i32>,
}

impl Tracker {
    pub fn new(root_pid: i32) -> Self {
        let wrapper_pgid = getpgid(Some(getpid())).map_or(-1, Pid::as_raw);

        Self {
            root_pid,
            wrapper_pgid,
            known_pids: HashSet::new(),
            known_pgids: HashSet::new(),
        }
    }

    pub fn observe(&mut self) {
        let descendants = discover_descendants(self.root_pid);
        self.known_pids.extend(descendants.iter().copied());

        for group in process_groups(&descendants) {
            if group <= 1 || group == self.wrapper_pgid {
                continue;
            }
            self.known_pgids.insert(group);
        }
    }

    fn live_pids(&self) -> HashSet<i32> {
        self.known_pids
            .iter()
            .copied()
            .filter(|pid| pid_exists(*pid))
            .collect()
    }

    fn live_pgids(&self) -> HashSet<i32> {
        self.known_pgids
            .iter()
            .copied()
            .filter(|group| pgid_exists(*group))
            .collect()
    }
}

pub fn post_exit_cleanup(tracker: &mut Tracker) {
    cleanup_tree_with_signal(tracker, Signal::SIGTERM, false);
    thread::sleep(POLL_INTERVAL);
    cleanup_tree_with_signal(tracker, Signal::SIGKILL, false);
}

pub fn cleanup_tree(tracker: &mut Tracker) {
    cleanup_tree_with_signal(tracker, Signal::SIGTERM, true);
    thread::sleep(TERM_GRACE);
    cleanup_tree_with_signal(tracker, Signal::SIGKILL, true);
}

fn cleanup_tree_with_signal(tracker: &mut Tracker, first_signal: Signal, kill_groups: bool) {
    for pass in 0..SWEEP_PASSES {
        tracker.observe();
        let pids = tracker.live_pids();
        let groups = tracker.live_pgids();
        if pids.is_empty() && groups.is_empty() {
            return;
        }

        let signal_kind = if pass == 0 {
            first_signal
        } else {
            Signal::SIGKILL
        };
        terminate_targets(&pids, &groups, signal_kind, kill_groups);
        thread::sleep(POLL_INTERVAL);
    }
}

fn terminate_targets(
    pids: &HashSet<i32>,
    groups: &HashSet<i32>,
    signal_kind: Signal,
    kill_groups: bool,
) {
    if kill_groups {
        for pgid in groups {
            if *pgid <= 1 {
                continue;
            }
            let _ = signal::kill(Pid::from_raw(-*pgid), signal_kind);
        }
    }

    for pid in pids {
        if *pid <= 1 {
            continue;
        }
        let _ = signal::kill(Pid::from_raw(*pid), signal_kind);
    }
}