use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::Mutex;
use tokio::time::Instant;
use crate::constants::MONITOR_CHECK_INTERVAL_SECONDS;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ProcessKind {
Orchestrator,
Worker,
}
impl std::fmt::Display for ProcessKind {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ProcessKind::Orchestrator => write!(f, "orchestrator"),
ProcessKind::Worker => write!(f, "worker"),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ProcessState {
Running,
Finished,
Crashed,
TimedOut,
Killed,
}
pub struct ManagedProcess {
pub pid: u32,
pub kind: ProcessKind,
pub repo: String,
pub label: String,
pub started_at: Instant,
pub timeout_seconds: u64,
pub state: ProcessState,
pub exit_code: Option<i32>,
}
const READ_ONLY_AGENT_TYPES: &[&str] = &["security_reviewer"];
pub fn check_fork_pr_gate(event_payload: &serde_json::Value, agent_type: &str) -> bool {
let pr = match event_payload.get("pull_request") {
Some(pr) if !pr.is_null() => pr,
_ => return true, };
let is_fork = pr
.get("head")
.and_then(|h| h.get("repo"))
.and_then(|r| r.get("fork"))
.and_then(|f| f.as_bool())
.unwrap_or(false);
if !is_fork {
return true; }
if let Some(labels) = pr.get("labels").and_then(|l| l.as_array()) {
for label in labels {
if let Some(name) = label.get("name").and_then(|n| n.as_str()) {
if name == "githubclaw-approved" {
return true;
}
}
}
}
READ_ONLY_AGENT_TYPES.contains(&agent_type)
}
pub struct ProcessManager {
pub max_concurrent_agents: usize,
pub max_concurrent_orchestrators: usize,
pub max_concurrent_workers: usize,
processes: Arc<Mutex<HashMap<u32, ManagedProcess>>>,
}
impl ProcessManager {
pub fn new(max_concurrent_agents: usize) -> Self {
Self {
max_concurrent_agents,
max_concurrent_orchestrators: 4,
max_concurrent_workers: max_concurrent_agents,
processes: Arc::new(Mutex::new(HashMap::new())),
}
}
pub fn with_limits(max_orchestrators: usize, max_workers: usize) -> Self {
Self {
max_concurrent_agents: max_orchestrators + max_workers,
max_concurrent_orchestrators: max_orchestrators,
max_concurrent_workers: max_workers,
processes: Arc::new(Mutex::new(HashMap::new())),
}
}
pub async fn active_count(&self) -> usize {
let procs = self.processes.lock().await;
procs
.values()
.filter(|p| p.state == ProcessState::Running)
.count()
}
pub async fn active_orchestrator_count(&self) -> usize {
let procs = self.processes.lock().await;
procs
.values()
.filter(|p| p.state == ProcessState::Running && p.kind == ProcessKind::Orchestrator)
.count()
}
pub async fn active_worker_count(&self) -> usize {
let procs = self.processes.lock().await;
procs
.values()
.filter(|p| p.state == ProcessState::Running && p.kind == ProcessKind::Worker)
.count()
}
pub async fn has_capacity(&self) -> bool {
self.active_count().await < self.max_concurrent_agents
}
pub async fn has_capacity_for(&self, kind: ProcessKind) -> bool {
match kind {
ProcessKind::Orchestrator => {
self.active_orchestrator_count().await < self.max_concurrent_orchestrators
}
ProcessKind::Worker => self.active_worker_count().await < self.max_concurrent_workers,
}
}
pub async fn all_processes(&self) -> Vec<(u32, ProcessState)> {
let procs = self.processes.lock().await;
procs.iter().map(|(&pid, p)| (pid, p.state)).collect()
}
pub async fn register(
&self,
pid: u32,
kind: ProcessKind,
repo: &str,
label: &str,
timeout_seconds: u64,
) {
let managed = ManagedProcess {
pid,
kind,
repo: repo.to_string(),
label: label.to_string(),
started_at: Instant::now(),
timeout_seconds,
state: ProcessState::Running,
exit_code: None,
};
self.processes.lock().await.insert(pid, managed);
}
pub async fn report_exit(&self, pid: u32, exit_code: i32) {
if let Some(proc) = self.processes.lock().await.get_mut(&pid) {
proc.exit_code = Some(exit_code);
proc.state = if exit_code == 0 {
ProcessState::Finished
} else {
ProcessState::Crashed
};
}
}
pub fn start_monitor(self: &Arc<Self>) -> tokio::task::JoinHandle<()> {
let pm = self.clone();
tokio::spawn(async move {
loop {
tokio::time::sleep(Duration::from_secs(MONITOR_CHECK_INTERVAL_SECONDS)).await;
let mut procs = pm.processes.lock().await;
let now = Instant::now();
for proc in procs.values_mut() {
if proc.state == ProcessState::Running {
let elapsed = now.duration_since(proc.started_at).as_secs();
if elapsed > proc.timeout_seconds {
proc.state = ProcessState::TimedOut;
tracing::warn!(
"Process [{}] pid={} timed out after {}s",
proc.label,
proc.pid,
elapsed
);
}
}
}
}
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::constants::DEFAULT_MAX_CONCURRENT_AGENTS;
use serde_json::json;
#[test]
fn fork_gate_not_a_pr_event_returns_true() {
let payload = json!({"action": "opened", "issue": {"number": 1}});
assert!(check_fork_pr_gate(&payload, "coder"));
}
#[test]
fn fork_gate_non_fork_pr_returns_true() {
let payload = json!({
"pull_request": {
"head": {"repo": {"fork": false}},
"labels": []
}
});
assert!(check_fork_pr_gate(&payload, "coder"));
}
#[test]
fn fork_gate_fork_pr_with_approved_label_returns_true() {
let payload = json!({
"pull_request": {
"head": {"repo": {"fork": true}},
"labels": [{"name": "githubclaw-approved"}]
}
});
assert!(check_fork_pr_gate(&payload, "coder"));
}
#[test]
fn fork_gate_fork_pr_without_label_blocks() {
let payload = json!({
"pull_request": {
"head": {"repo": {"fork": true}},
"labels": []
}
});
assert!(!check_fork_pr_gate(&payload, "coder"));
}
#[test]
fn fork_gate_fork_pr_security_reviewer_always_allowed() {
let payload = json!({
"pull_request": {
"head": {"repo": {"fork": true}},
"labels": []
}
});
assert!(check_fork_pr_gate(&payload, "security_reviewer"));
}
#[test]
fn fork_gate_fork_pr_coder_blocked() {
let payload = json!({
"pull_request": {
"head": {"repo": {"fork": true}},
"labels": [{"name": "some-other-label"}]
}
});
assert!(!check_fork_pr_gate(&payload, "coder"));
}
#[test]
fn process_manager_new_with_capacity() {
let pm = ProcessManager::new(4);
assert_eq!(pm.max_concurrent_agents, 4);
}
#[tokio::test]
async fn active_count_starts_at_zero() {
let pm = ProcessManager::new(DEFAULT_MAX_CONCURRENT_AGENTS);
assert_eq!(pm.active_count().await, 0);
}
#[tokio::test]
async fn has_capacity_returns_true_when_empty() {
let pm = ProcessManager::new(DEFAULT_MAX_CONCURRENT_AGENTS);
assert!(pm.has_capacity().await);
}
#[tokio::test]
async fn register_adds_process_to_tracking() {
let pm = ProcessManager::new(4);
assert_eq!(pm.active_count().await, 0);
pm.register(1001, ProcessKind::Worker, "owner/repo", "test-worker", 3600)
.await;
assert_eq!(pm.active_count().await, 1);
let procs = pm.all_processes().await;
assert_eq!(procs.len(), 1);
assert_eq!(procs[0], (1001, ProcessState::Running));
}
#[tokio::test]
async fn report_exit_updates_state_success() {
let pm = ProcessManager::new(4);
pm.register(2001, ProcessKind::Orchestrator, "owner/repo", "orch", 3600)
.await;
pm.report_exit(2001, 0).await;
let procs = pm.all_processes().await;
assert_eq!(procs[0].1, ProcessState::Finished);
assert_eq!(pm.active_count().await, 0);
}
#[tokio::test]
async fn report_exit_updates_state_crash() {
let pm = ProcessManager::new(4);
pm.register(3001, ProcessKind::Worker, "owner/repo", "worker", 3600)
.await;
pm.report_exit(3001, 1).await;
let procs = pm.all_processes().await;
assert_eq!(procs[0].1, ProcessState::Crashed);
}
#[tokio::test]
async fn has_capacity_respects_registered_processes() {
let pm = ProcessManager::new(2);
assert!(pm.has_capacity().await);
pm.register(100, ProcessKind::Worker, "r", "w1", 3600).await;
assert!(pm.has_capacity().await);
pm.register(101, ProcessKind::Worker, "r", "w2", 3600).await;
assert!(!pm.has_capacity().await);
pm.report_exit(100, 0).await;
assert!(pm.has_capacity().await); }
}