Skip to main content

openhawk_core/
agent_manager.rs

1use std::collections::HashMap;
2use std::path::PathBuf;
3use std::sync::{Arc, Mutex};
4use std::time::{Duration, Instant};
5
6use rusqlite::{params, Connection};
7use tokio::process::{Child, Command};
8use uuid::Uuid;
9
10use crate::error::HawkError;
11use crate::manifest::AgentManifest;
12use crate::types::{AgentStatus, LifecycleState, ProcessId};
13
14pub type Result<T> = std::result::Result<T, HawkError>;
15
16#[derive(Debug)]
17pub struct AgentRecord {
18    pub pid: ProcessId,
19    pub name: String,
20    pub state: LifecycleState,
21    pub started_at: Instant,
22    pub session_id: String,
23    pub manifest: AgentManifest,
24}
25
26pub struct AgentManager {
27    pub db: Arc<Mutex<Connection>>,
28    agents: Arc<Mutex<HashMap<ProcessId, (Child, AgentRecord)>>>,
29}
30
31impl AgentManager {
32    pub fn new(db: Connection) -> Self {
33        Self {
34            db: Arc::new(Mutex::new(db)),
35            agents: Arc::new(Mutex::new(HashMap::new())),
36        }
37    }
38
39    pub async fn spawn(&self, manifest: AgentManifest) -> Result<ProcessId> {
40        if manifest.info.entry_command.trim().is_empty() {
41            return Err(HawkError::InvalidManifest("entry_command must not be empty".to_string()));
42        }
43
44        let session_id = Uuid::new_v4().to_string();
45        let parts: Vec<&str> = manifest.info.entry_command.split_whitespace().collect();
46        let (program, args) = parts.split_first().ok_or_else(|| {
47            HawkError::InvalidManifest("entry_command is empty".to_string())
48        })?;
49
50        let child = Command::new(program)
51            .args(args)
52            .spawn()
53            .map_err(HawkError::Io)?;
54
55        let pid = child.id().ok_or_else(|| {
56            HawkError::Io(std::io::Error::new(std::io::ErrorKind::Other, "could not get child PID"))
57        })?;
58
59        let record = AgentRecord {
60            pid,
61            name: manifest.info.name.clone(),
62            state: LifecycleState::Running,
63            started_at: Instant::now(),
64            session_id: session_id.clone(),
65            manifest: manifest.clone(),
66        };
67
68        {
69            let db = self.db.lock().unwrap();
70            db.execute(
71                "INSERT OR IGNORE INTO sessions (id, started_at, status) VALUES (?1, datetime('now'), 'Active')",
72                params![session_id],
73            )
74            .map_err(|e| HawkError::Database(e.to_string()))?;
75
76            db.execute(
77                "INSERT INTO agents (pid, name, state, manifest_path, started_at, session_id) \
78                 VALUES (?1, ?2, ?3, ?4, datetime('now'), ?5)",
79                params![pid, manifest.info.name, "Running", manifest.info.entry_command, session_id],
80            )
81            .map_err(|e| HawkError::Database(e.to_string()))?;
82        }
83
84        self.agents.lock().unwrap().insert(pid, (child, record));
85        Ok(pid)
86    }
87
88    pub async fn stop(&self, pid: ProcessId) -> Result<StopResult> {
89        if !self.agents.lock().unwrap().contains_key(&pid) {
90            return Err(HawkError::NotFound(format!("agent {pid}")));
91        }
92
93        self.update_state(pid, LifecycleState::Stopping)?;
94        send_term(pid)?;
95
96        let deadline = Instant::now() + Duration::from_secs(5);
97        let mut graceful = false;
98        while Instant::now() < deadline {
99            tokio::time::sleep(Duration::from_millis(100)).await;
100            let exited = {
101                let mut agents = self.agents.lock().unwrap();
102                if let Some((child, _)) = agents.get_mut(&pid) {
103                    child.try_wait().map_err(HawkError::Io)?.is_some()
104                } else {
105                    true
106                }
107            };
108            if exited {
109                graceful = true;
110                break;
111            }
112        }
113
114        if !graceful {
115            let mut agents = self.agents.lock().unwrap();
116            if let Some((child, _)) = agents.get_mut(&pid) {
117                child.kill().await.map_err(HawkError::Io)?;
118            }
119            drop(agents);
120            self.log_forced_termination(pid)?;
121        }
122
123        self.agents.lock().unwrap().remove(&pid);
124        self.update_state_db(pid, LifecycleState::Stopped)?;
125
126        Ok(StopResult { pid, forced: !graceful })
127    }
128
129    pub fn pause(&self, pid: ProcessId) -> Result<()> {
130        self.require_agent(pid)?;
131        send_stop(pid)?;
132        self.update_state(pid, LifecycleState::Paused)
133    }
134
135    pub fn resume(&self, pid: ProcessId) -> Result<()> {
136        self.require_agent(pid)?;
137        send_cont(pid)?;
138        self.update_state(pid, LifecycleState::Running)
139    }
140
141    pub fn list(&self) -> Vec<AgentStatus> {
142        self.agents
143            .lock()
144            .unwrap()
145            .values()
146            .map(|(_, rec)| AgentStatus {
147                pid: rec.pid,
148                name: rec.name.clone(),
149                state: rec.state.clone(),
150                uptime: rec.started_at.elapsed(),
151                cpu_percent: 0.0,
152                memory_bytes: 0,
153                open_fds: 0,
154            })
155            .collect()
156    }
157
158    pub fn get_state(&self, pid: ProcessId) -> Option<LifecycleState> {
159        self.agents.lock().unwrap().get(&pid).map(|(_, rec)| rec.state.clone())
160    }
161
162    /// Check if agent has exceeded its token budget; if so, pause it and return true.
163    pub fn enforce_budget(&self, pid: ProcessId) -> Result<bool> {
164        let budget = {
165            let agents = self.agents.lock().unwrap();
166            agents.get(&pid).map(|(_, rec)| rec.manifest.llm.budget_tokens)
167        };
168        let Some(budget) = budget else { return Ok(false) };
169        if budget == 0 {
170            return Ok(false);
171        }
172        let db = self.db.lock().unwrap();
173        let total: i64 = db
174            .query_row(
175                "SELECT COALESCE(SUM(prompt_tokens + completion_tokens), 0) \
176                 FROM token_usage WHERE agent_pid = ?1",
177                params![pid],
178                |row| row.get(0),
179            )
180            .map_err(|e| HawkError::Database(e.to_string()))?;
181        drop(db);
182        if total as u64 > budget {
183            send_stop(pid)?;
184            self.update_state(pid, LifecycleState::Paused)?;
185            return Ok(true);
186        }
187        Ok(false)
188    }
189
190    fn require_agent(&self, pid: ProcessId) -> Result<()> {
191        if self.agents.lock().unwrap().contains_key(&pid) {
192            Ok(())
193        } else {
194            Err(HawkError::NotFound(format!("agent {pid}")))
195        }
196    }
197
198    fn update_state(&self, pid: ProcessId, state: LifecycleState) -> Result<()> {
199        if let Some((_, rec)) = self.agents.lock().unwrap().get_mut(&pid) {
200            rec.state = state.clone();
201        }
202        self.update_state_db(pid, state)
203    }
204
205    fn update_state_db(&self, pid: ProcessId, state: LifecycleState) -> Result<()> {
206        let state_str = lifecycle_str(&state);
207        self.db
208            .lock()
209            .unwrap()
210            .execute("UPDATE agents SET state = ?1 WHERE pid = ?2", params![state_str, pid])
211            .map_err(|e| HawkError::Database(e.to_string()))?;
212        Ok(())
213    }
214
215    fn log_forced_termination(&self, pid: ProcessId) -> Result<()> {
216        self.db
217            .lock()
218            .unwrap()
219            .execute(
220                "INSERT INTO healing_events \
221                 (agent_pid, timestamp, original_error, adjustment, outcome, attempt_number) \
222                 VALUES (?1, datetime('now'), 'graceful stop timeout', 'force kill', 'Success', 1)",
223                params![pid],
224            )
225            .map_err(|e| HawkError::Database(e.to_string()))?;
226        Ok(())
227    }
228}
229
230fn lifecycle_str(s: &LifecycleState) -> &'static str {
231    match s {
232        LifecycleState::Starting => "Starting",
233        LifecycleState::Running => "Running",
234        LifecycleState::Paused => "Paused",
235        LifecycleState::Stopping => "Stopping",
236        LifecycleState::Stopped => "Stopped",
237        LifecycleState::Failed => "Failed",
238    }
239}
240
241#[derive(Debug)]
242pub struct StopResult {
243    pub pid: ProcessId,
244    pub forced: bool,
245}
246
247// ── platform signal helpers ───────────────────────────────────────────────────
248
249#[cfg(target_family = "unix")]
250fn send_term(pid: ProcessId) -> Result<()> {
251    use nix::sys::signal::{kill, Signal};
252    use nix::unistd::Pid;
253    kill(Pid::from_raw(pid as i32), Signal::SIGTERM)
254        .map_err(|e| HawkError::Io(std::io::Error::new(std::io::ErrorKind::Other, e.to_string())))
255}
256
257#[cfg(target_family = "unix")]
258fn send_stop(pid: ProcessId) -> Result<()> {
259    use nix::sys::signal::{kill, Signal};
260    use nix::unistd::Pid;
261    kill(Pid::from_raw(pid as i32), Signal::SIGSTOP)
262        .map_err(|e| HawkError::Io(std::io::Error::new(std::io::ErrorKind::Other, e.to_string())))
263}
264
265#[cfg(target_family = "unix")]
266fn send_cont(pid: ProcessId) -> Result<()> {
267    use nix::sys::signal::{kill, Signal};
268    use nix::unistd::Pid;
269    kill(Pid::from_raw(pid as i32), Signal::SIGCONT)
270        .map_err(|e| HawkError::Io(std::io::Error::new(std::io::ErrorKind::Other, e.to_string())))
271}
272
273#[cfg(not(target_family = "unix"))]
274fn send_term(_pid: ProcessId) -> Result<()> { Ok(()) }
275#[cfg(not(target_family = "unix"))]
276fn send_stop(_pid: ProcessId) -> Result<()> { Ok(()) }
277#[cfg(not(target_family = "unix"))]
278fn send_cont(_pid: ProcessId) -> Result<()> { Ok(()) }
279
280pub fn snapshot_dir() -> PathBuf {
281    dirs_next::data_local_dir()
282        .unwrap_or_else(|| PathBuf::from("."))
283        .join("hawk")
284        .join("snapshots")
285}