1use std::collections::HashMap;
2use std::path::PathBuf;
3use std::sync::{Arc, Mutex};
4use std::time::{Duration, Instant};
5
6use rusqlite::{params, Connection};
7use tokio::process::{Child, Command};
8use uuid::Uuid;
9
10use crate::error::HawkError;
11use crate::manifest::AgentManifest;
12use crate::types::{AgentStatus, LifecycleState, ProcessId};
13
14pub type Result<T> = std::result::Result<T, HawkError>;
15
16#[derive(Debug)]
17pub struct AgentRecord {
18 pub pid: ProcessId,
19 pub name: String,
20 pub state: LifecycleState,
21 pub started_at: Instant,
22 pub session_id: String,
23 pub manifest: AgentManifest,
24}
25
26pub struct AgentManager {
27 pub db: Arc<Mutex<Connection>>,
28 agents: Arc<Mutex<HashMap<ProcessId, (Child, AgentRecord)>>>,
29}
30
31impl AgentManager {
32 pub fn new(db: Connection) -> Self {
33 Self {
34 db: Arc::new(Mutex::new(db)),
35 agents: Arc::new(Mutex::new(HashMap::new())),
36 }
37 }
38
39 pub async fn spawn(&self, manifest: AgentManifest) -> Result<ProcessId> {
40 if manifest.info.entry_command.trim().is_empty() {
41 return Err(HawkError::InvalidManifest("entry_command must not be empty".to_string()));
42 }
43
44 let session_id = Uuid::new_v4().to_string();
45 let parts: Vec<&str> = manifest.info.entry_command.split_whitespace().collect();
46 let (program, args) = parts.split_first().ok_or_else(|| {
47 HawkError::InvalidManifest("entry_command is empty".to_string())
48 })?;
49
50 let child = Command::new(program)
51 .args(args)
52 .spawn()
53 .map_err(HawkError::Io)?;
54
55 let pid = child.id().ok_or_else(|| {
56 HawkError::Io(std::io::Error::new(std::io::ErrorKind::Other, "could not get child PID"))
57 })?;
58
59 let record = AgentRecord {
60 pid,
61 name: manifest.info.name.clone(),
62 state: LifecycleState::Running,
63 started_at: Instant::now(),
64 session_id: session_id.clone(),
65 manifest: manifest.clone(),
66 };
67
68 {
69 let db = self.db.lock().unwrap();
70 db.execute(
71 "INSERT OR IGNORE INTO sessions (id, started_at, status) VALUES (?1, datetime('now'), 'Active')",
72 params![session_id],
73 )
74 .map_err(|e| HawkError::Database(e.to_string()))?;
75
76 db.execute(
77 "INSERT INTO agents (pid, name, state, manifest_path, started_at, session_id) \
78 VALUES (?1, ?2, ?3, ?4, datetime('now'), ?5)",
79 params![pid, manifest.info.name, "Running", manifest.info.entry_command, session_id],
80 )
81 .map_err(|e| HawkError::Database(e.to_string()))?;
82 }
83
84 self.agents.lock().unwrap().insert(pid, (child, record));
85 Ok(pid)
86 }
87
88 pub async fn stop(&self, pid: ProcessId) -> Result<StopResult> {
89 if !self.agents.lock().unwrap().contains_key(&pid) {
90 return Err(HawkError::NotFound(format!("agent {pid}")));
91 }
92
93 self.update_state(pid, LifecycleState::Stopping)?;
94 send_term(pid)?;
95
96 let deadline = Instant::now() + Duration::from_secs(5);
97 let mut graceful = false;
98 while Instant::now() < deadline {
99 tokio::time::sleep(Duration::from_millis(100)).await;
100 let exited = {
101 let mut agents = self.agents.lock().unwrap();
102 if let Some((child, _)) = agents.get_mut(&pid) {
103 child.try_wait().map_err(HawkError::Io)?.is_some()
104 } else {
105 true
106 }
107 };
108 if exited {
109 graceful = true;
110 break;
111 }
112 }
113
114 if !graceful {
115 let mut agents = self.agents.lock().unwrap();
116 if let Some((child, _)) = agents.get_mut(&pid) {
117 child.kill().await.map_err(HawkError::Io)?;
118 }
119 drop(agents);
120 self.log_forced_termination(pid)?;
121 }
122
123 self.agents.lock().unwrap().remove(&pid);
124 self.update_state_db(pid, LifecycleState::Stopped)?;
125
126 Ok(StopResult { pid, forced: !graceful })
127 }
128
129 pub fn pause(&self, pid: ProcessId) -> Result<()> {
130 self.require_agent(pid)?;
131 send_stop(pid)?;
132 self.update_state(pid, LifecycleState::Paused)
133 }
134
135 pub fn resume(&self, pid: ProcessId) -> Result<()> {
136 self.require_agent(pid)?;
137 send_cont(pid)?;
138 self.update_state(pid, LifecycleState::Running)
139 }
140
141 pub fn list(&self) -> Vec<AgentStatus> {
142 self.agents
143 .lock()
144 .unwrap()
145 .values()
146 .map(|(_, rec)| AgentStatus {
147 pid: rec.pid,
148 name: rec.name.clone(),
149 state: rec.state.clone(),
150 uptime: rec.started_at.elapsed(),
151 cpu_percent: 0.0,
152 memory_bytes: 0,
153 open_fds: 0,
154 })
155 .collect()
156 }
157
158 pub fn get_state(&self, pid: ProcessId) -> Option<LifecycleState> {
159 self.agents.lock().unwrap().get(&pid).map(|(_, rec)| rec.state.clone())
160 }
161
162 pub fn enforce_budget(&self, pid: ProcessId) -> Result<bool> {
164 let budget = {
165 let agents = self.agents.lock().unwrap();
166 agents.get(&pid).map(|(_, rec)| rec.manifest.llm.budget_tokens)
167 };
168 let Some(budget) = budget else { return Ok(false) };
169 if budget == 0 {
170 return Ok(false);
171 }
172 let db = self.db.lock().unwrap();
173 let total: i64 = db
174 .query_row(
175 "SELECT COALESCE(SUM(prompt_tokens + completion_tokens), 0) \
176 FROM token_usage WHERE agent_pid = ?1",
177 params![pid],
178 |row| row.get(0),
179 )
180 .map_err(|e| HawkError::Database(e.to_string()))?;
181 drop(db);
182 if total as u64 > budget {
183 send_stop(pid)?;
184 self.update_state(pid, LifecycleState::Paused)?;
185 return Ok(true);
186 }
187 Ok(false)
188 }
189
190 fn require_agent(&self, pid: ProcessId) -> Result<()> {
191 if self.agents.lock().unwrap().contains_key(&pid) {
192 Ok(())
193 } else {
194 Err(HawkError::NotFound(format!("agent {pid}")))
195 }
196 }
197
198 fn update_state(&self, pid: ProcessId, state: LifecycleState) -> Result<()> {
199 if let Some((_, rec)) = self.agents.lock().unwrap().get_mut(&pid) {
200 rec.state = state.clone();
201 }
202 self.update_state_db(pid, state)
203 }
204
205 fn update_state_db(&self, pid: ProcessId, state: LifecycleState) -> Result<()> {
206 let state_str = lifecycle_str(&state);
207 self.db
208 .lock()
209 .unwrap()
210 .execute("UPDATE agents SET state = ?1 WHERE pid = ?2", params![state_str, pid])
211 .map_err(|e| HawkError::Database(e.to_string()))?;
212 Ok(())
213 }
214
215 fn log_forced_termination(&self, pid: ProcessId) -> Result<()> {
216 self.db
217 .lock()
218 .unwrap()
219 .execute(
220 "INSERT INTO healing_events \
221 (agent_pid, timestamp, original_error, adjustment, outcome, attempt_number) \
222 VALUES (?1, datetime('now'), 'graceful stop timeout', 'force kill', 'Success', 1)",
223 params![pid],
224 )
225 .map_err(|e| HawkError::Database(e.to_string()))?;
226 Ok(())
227 }
228}
229
230fn lifecycle_str(s: &LifecycleState) -> &'static str {
231 match s {
232 LifecycleState::Starting => "Starting",
233 LifecycleState::Running => "Running",
234 LifecycleState::Paused => "Paused",
235 LifecycleState::Stopping => "Stopping",
236 LifecycleState::Stopped => "Stopped",
237 LifecycleState::Failed => "Failed",
238 }
239}
240
241#[derive(Debug)]
242pub struct StopResult {
243 pub pid: ProcessId,
244 pub forced: bool,
245}
246
247#[cfg(target_family = "unix")]
250fn send_term(pid: ProcessId) -> Result<()> {
251 use nix::sys::signal::{kill, Signal};
252 use nix::unistd::Pid;
253 kill(Pid::from_raw(pid as i32), Signal::SIGTERM)
254 .map_err(|e| HawkError::Io(std::io::Error::new(std::io::ErrorKind::Other, e.to_string())))
255}
256
257#[cfg(target_family = "unix")]
258fn send_stop(pid: ProcessId) -> Result<()> {
259 use nix::sys::signal::{kill, Signal};
260 use nix::unistd::Pid;
261 kill(Pid::from_raw(pid as i32), Signal::SIGSTOP)
262 .map_err(|e| HawkError::Io(std::io::Error::new(std::io::ErrorKind::Other, e.to_string())))
263}
264
265#[cfg(target_family = "unix")]
266fn send_cont(pid: ProcessId) -> Result<()> {
267 use nix::sys::signal::{kill, Signal};
268 use nix::unistd::Pid;
269 kill(Pid::from_raw(pid as i32), Signal::SIGCONT)
270 .map_err(|e| HawkError::Io(std::io::Error::new(std::io::ErrorKind::Other, e.to_string())))
271}
272
273#[cfg(not(target_family = "unix"))]
274fn send_term(_pid: ProcessId) -> Result<()> { Ok(()) }
275#[cfg(not(target_family = "unix"))]
276fn send_stop(_pid: ProcessId) -> Result<()> { Ok(()) }
277#[cfg(not(target_family = "unix"))]
278fn send_cont(_pid: ProcessId) -> Result<()> { Ok(()) }
279
280pub fn snapshot_dir() -> PathBuf {
281 dirs_next::data_local_dir()
282 .unwrap_or_else(|| PathBuf::from("."))
283 .join("hawk")
284 .join("snapshots")
285}