1use anyhow::Result;
2use rusqlite::Connection;
3use std::collections::HashMap;
4use std::sync::Arc;
5use tokio::sync::{broadcast, Mutex};
6
7use crate::db::models::{Event, HookPayload, Snapshot};
8use crate::db::queries;
9#[cfg(test)]
10use crate::db::schema;
11
12const FILE_MODIFYING_TOOLS: &[&str] = &["Edit", "Write"];
13
14struct PendingSnapshot {
15 file_path: String,
16 content_before: Option<Vec<u8>>,
17 event_id: i64,
18 inserted_at: std::time::Instant,
19}
20
21pub struct EventProcessor {
22 conn: Arc<Mutex<Connection>>,
23 pending_pre: HashMap<String, PendingSnapshot>,
24 broadcast_tx: broadcast::Sender<Event>,
25}
26
27impl EventProcessor {
28 pub fn new(conn: Arc<Mutex<Connection>>, broadcast_tx: broadcast::Sender<Event>) -> Self {
29 Self {
30 conn,
31 pending_pre: HashMap::new(),
32 broadcast_tx,
33 }
34 }
35
36 pub async fn process(&mut self, payload: HookPayload) -> Result<()> {
37 self.flush_pending_snapshots().await?;
38
39 let event_type = payload.hook_event_name.as_deref().unwrap_or("Unknown");
40 let session_id = payload.session_id.as_deref().unwrap_or("unknown");
41 let now = chrono::Utc::now().timestamp_millis();
42
43 match event_type {
44 "SessionStart" => {
45 let session = crate::db::models::Session {
46 id: session_id.to_string(),
47 started_at: now,
48 ended_at: None,
49 cwd: payload.cwd.unwrap_or_default(),
50 model: payload.model,
51 permission_mode: payload.permission_mode,
52 };
53 let db = self.conn.lock().await;
54 queries::upsert_session(&db, &session)?;
55 }
56 "SessionEnd" => {
57 let db = self.conn.lock().await;
58 db.execute(
59 "UPDATE sessions SET ended_at = ?1 WHERE id = ?2",
60 rusqlite::params![now, session_id],
61 )?;
62 }
63 "PreToolUse" => {
64 let tool_name = payload.tool_name.as_deref().unwrap_or("");
65 let tool_use_id = payload.tool_use_id.clone();
66
67 let file_path = if FILE_MODIFYING_TOOLS.contains(&tool_name) {
69 self.extract_file_path(&payload)
70 } else {
71 None
72 };
73 let content_before = file_path.as_ref().and_then(|p| std::fs::read(p).ok());
74
75 let event = self.payload_to_event(&payload, session_id, now);
76 let db = self.conn.lock().await;
77 let event_id = queries::insert_event(&db, &event)?;
78 drop(db);
79
80 if let (Some(tuid), Some(path)) = (tool_use_id, file_path) {
82 self.pending_pre.insert(
83 tuid,
84 PendingSnapshot {
85 file_path: path,
86 content_before,
87 event_id,
88 inserted_at: std::time::Instant::now(),
89 },
90 );
91 }
92
93 let mut stored = event;
94 stored.id = event_id;
95 let _ = self.broadcast_tx.send(stored);
96 }
97 "PostToolUse" => {
98 let tool_use_id = payload.tool_use_id.clone();
99
100 let snapshot = if let Some(ref tuid) = tool_use_id {
102 if let Some(pending) = self.pending_pre.remove(tuid) {
103 let content_after = std::fs::read(&pending.file_path).ok();
104 let diff = self.compute_diff(
105 pending.content_before.as_deref(),
106 content_after.as_deref(),
107 );
108 let compressed_before = pending
109 .content_before
110 .map(|c| zstd::encode_all(c.as_slice(), 3))
111 .transpose()?;
112 let compressed_after = content_after
113 .map(|c| zstd::encode_all(c.as_slice(), 3))
114 .transpose()?;
115 Some((pending.file_path, compressed_before, compressed_after, diff))
116 } else {
117 None
118 }
119 } else {
120 None
121 };
122
123 let event = self.payload_to_event(&payload, session_id, now);
124 let db = self.conn.lock().await;
125 let event_id = queries::insert_event(&db, &event)?;
126
127 if let Some((path, compressed_before, compressed_after, diff)) = snapshot {
128 let snap = Snapshot {
129 id: 0,
130 event_id,
131 file_path: path,
132 content_before: compressed_before,
133 content_after: compressed_after,
134 diff_unified: diff,
135 };
136 queries::insert_snapshot(&db, &snap)?;
137 }
138 drop(db);
139
140 let mut stored = event;
141 stored.id = event_id;
142 let _ = self.broadcast_tx.send(stored);
143 }
144 "PostToolUseFailure" => {
145 if let Some(ref tuid) = payload.tool_use_id {
146 self.pending_pre.remove(tuid);
147 }
148 let event = self.payload_to_event(&payload, session_id, now);
149 let db = self.conn.lock().await;
150 let event_id = queries::insert_event(&db, &event)?;
151 drop(db);
152 let mut stored = event;
153 stored.id = event_id;
154 let _ = self.broadcast_tx.send(stored);
155 }
156 _ => {
157 let session = crate::db::models::Session {
158 id: session_id.to_string(),
159 started_at: now,
160 ended_at: None,
161 cwd: payload.cwd.clone().unwrap_or_default(),
162 model: None,
163 permission_mode: None,
164 };
165 let event = self.payload_to_event(&payload, session_id, now);
166 let db = self.conn.lock().await;
167 queries::upsert_session(&db, &session)?;
168 let event_id = queries::insert_event(&db, &event)?;
169 drop(db);
170 let mut stored = event;
171 stored.id = event_id;
172 let _ = self.broadcast_tx.send(stored);
173 }
174 }
175
176 Ok(())
177 }
178
179 fn payload_to_event(&self, payload: &HookPayload, session_id: &str, timestamp: i64) -> Event {
180 Event {
181 id: 0,
182 session_id: session_id.to_string(),
183 timestamp,
184 event_type: payload.hook_event_name.clone().unwrap_or_default(),
185 tool_name: payload.tool_name.clone(),
186 tool_use_id: payload.tool_use_id.clone(),
187 agent_id: payload.agent_id.clone(),
188 agent_type: payload.agent_type.clone(),
189 input_json: payload
190 .tool_input
191 .as_ref()
192 .map(|v| serde_json::to_vec(v))
193 .transpose()
194 .ok()
195 .flatten(),
196 output_json: payload
197 .tool_response
198 .as_ref()
199 .map(|s| s.as_bytes().to_vec())
200 .or_else(|| payload.tool_error.as_ref().map(|s| s.as_bytes().to_vec()))
201 .or_else(|| payload.prompt.as_ref().map(|s| s.as_bytes().to_vec()))
202 .or_else(|| {
203 payload
204 .last_assistant_message
205 .as_ref()
206 .map(|s| s.as_bytes().to_vec())
207 }),
208 }
209 }
210
211 fn extract_file_path(&self, payload: &HookPayload) -> Option<String> {
212 payload
213 .tool_input
214 .as_ref()
215 .and_then(|v| v.get("file_path"))
216 .and_then(|v| v.as_str())
217 .map(|s| s.to_string())
218 }
219
220 fn compute_diff(&self, before: Option<&[u8]>, after: Option<&[u8]>) -> String {
221 let before_str = before
222 .map(|b| String::from_utf8_lossy(b).to_string())
223 .unwrap_or_default();
224 let after_str = after
225 .map(|b| String::from_utf8_lossy(b).to_string())
226 .unwrap_or_default();
227 let diff = similar::TextDiff::from_lines(&before_str, &after_str);
228 diff.unified_diff().header("a", "b").to_string()
229 }
230
231 async fn flush_pending_snapshots(&mut self) -> Result<()> {
232 let flushed: Vec<String> = self
233 .pending_pre
234 .iter()
235 .filter_map(|(tuid, pending)| {
236 let content_after = std::fs::read(&pending.file_path).ok();
237 if content_after != pending.content_before {
238 Some(tuid.clone())
239 } else {
240 None
241 }
242 })
243 .collect();
244
245 for tuid in flushed {
246 let pending = self.pending_pre.remove(&tuid).unwrap();
247 let content_after = std::fs::read(&pending.file_path).ok();
248 let diff = self.compute_diff(
249 pending.content_before.as_deref(),
250 content_after.as_deref(),
251 );
252 let compressed_before = pending
253 .content_before
254 .map(|c| zstd::encode_all(c.as_slice(), 3))
255 .transpose()?;
256 let compressed_after = content_after
257 .map(|c| zstd::encode_all(c.as_slice(), 3))
258 .transpose()?;
259
260 let db = self.conn.lock().await;
261 let snap = Snapshot {
262 id: 0,
263 event_id: pending.event_id,
264 file_path: pending.file_path,
265 content_before: compressed_before,
266 content_after: compressed_after,
267 diff_unified: diff,
268 };
269 queries::insert_snapshot(&db, &snap)?;
270 }
271
272 Ok(())
273 }
274
275 pub fn evict_stale_entries(&mut self) {
276 let ttl = std::time::Duration::from_secs(10 * 60);
277 let now = std::time::Instant::now();
278 self.pending_pre
279 .retain(|_id, pending| now.duration_since(pending.inserted_at) < ttl);
280 }
281
282 pub fn clear_pending(&mut self) {
283 self.pending_pre.clear();
284 }
285}
286
287#[cfg(test)]
288mod tests {
289 use super::*;
290
291 fn setup() -> (Arc<Mutex<Connection>>, broadcast::Sender<Event>) {
292 let conn = Connection::open_in_memory().unwrap();
293 schema::initialize(&conn).unwrap();
294 let conn = Arc::new(Mutex::new(conn));
295 let (tx, _rx) = broadcast::channel(100);
296 (conn, tx)
297 }
298
299 #[tokio::test]
300 async fn test_session_start_creates_session() {
301 let (conn, tx) = setup();
302 let mut proc = EventProcessor::new(conn.clone(), tx);
303 let payload = HookPayload {
304 session_id: Some("sess1".into()),
305 hook_event_name: Some("SessionStart".into()),
306 cwd: Some("/tmp/project".into()),
307 model: Some("claude-sonnet".into()),
308 permission_mode: Some("default".into()),
309 ..default_payload()
310 };
311 proc.process(payload).await.unwrap();
312 let db = conn.lock().await;
313 let sessions = queries::list_sessions(&db).unwrap();
314 assert_eq!(sessions.len(), 1);
315 assert_eq!(sessions[0].id, "sess1");
316 }
317
318 #[tokio::test]
319 async fn test_post_tool_use_creates_event() {
320 let (conn, tx) = setup();
321 let mut proc = EventProcessor::new(conn.clone(), tx);
322 let start = HookPayload {
323 session_id: Some("sess1".into()),
324 hook_event_name: Some("SessionStart".into()),
325 cwd: Some("/tmp".into()),
326 ..default_payload()
327 };
328 proc.process(start).await.unwrap();
329 let payload = HookPayload {
330 session_id: Some("sess1".into()),
331 hook_event_name: Some("PostToolUse".into()),
332 tool_name: Some("Read".into()),
333 tool_use_id: Some("tu1".into()),
334 ..default_payload()
335 };
336 proc.process(payload).await.unwrap();
337 let db = conn.lock().await;
338 let events = queries::list_events_for_session(&db, "sess1").unwrap();
339 assert_eq!(events.len(), 1);
340 assert_eq!(events[0].tool_name.as_deref(), Some("Read"));
341 }
342
343 fn default_payload() -> HookPayload {
344 HookPayload {
345 session_id: None,
346 hook_event_name: None,
347 cwd: None,
348 permission_mode: None,
349 model: None,
350 agent_id: None,
351 agent_type: None,
352 tool_name: None,
353 tool_input: None,
354 tool_use_id: None,
355 tool_response: None,
356 tool_error: None,
357 prompt: None,
358 last_assistant_message: None,
359 source: None,
360 }
361 }
362}