Skip to main content

aft/bash_background/
watches.rs

1use std::collections::{HashMap, HashSet};
2use std::fs::File;
3use std::io::{Read, Seek, SeekFrom};
4use std::path::Path;
5
6use regex::{Regex, RegexBuilder};
7use serde::Serialize;
8
9const MAX_WATCHES_PER_TASK: usize = 8;
10const CONTEXT_BEFORE: usize = 100;
11const CONTEXT_AFTER: usize = 500;
12const SCAN_OVERLAP_BYTES: usize = 8 * 1024;
13
14#[derive(Debug, Clone)]
15pub struct WatchSpec {
16    pub watch_id: String,
17    pub task_id: String,
18    pub pattern: WatchPattern,
19    pub once: bool,
20}
21
22#[derive(Debug, Clone)]
23pub enum WatchPattern {
24    Substring(String),
25    Regex(Regex),
26}
27
28impl WatchPattern {
29    pub fn regex(pattern: &str) -> Result<Self, regex::Error> {
30        RegexBuilder::new(pattern)
31            .multi_line(true)
32            .build()
33            .map(Self::Regex)
34    }
35}
36
37#[derive(Debug, Clone, Serialize)]
38pub struct PatternMatch {
39    pub watch_id: String,
40    pub task_id: String,
41    pub match_text: String,
42    pub match_offset: u64,
43    pub context: String,
44    pub once: bool,
45}
46
47#[derive(Debug, Default)]
48pub struct WatchRegistry {
49    watches: HashMap<String, Vec<WatchSpec>>,
50    scan_cursors: HashMap<String, u64>,
51    scan_overlaps: HashMap<String, Vec<u8>>,
52    controlled_tasks: HashSet<String>,
53    matched_tasks: HashSet<String>,
54    next_watch: u64,
55}
56
57impl WatchRegistry {
58    pub fn register(
59        &mut self,
60        task_id: String,
61        pattern: WatchPattern,
62        once: bool,
63    ) -> Result<String, &'static str> {
64        let watches = self.watches.entry(task_id.clone()).or_default();
65        if watches.len() >= MAX_WATCHES_PER_TASK {
66            return Err("too_many_watches");
67        }
68        self.controlled_tasks.insert(task_id.clone());
69        self.next_watch = self.next_watch.wrapping_add(1);
70        let watch_id = format!("watch-{:08x}", self.next_watch);
71        watches.push(WatchSpec {
72            watch_id: watch_id.clone(),
73            task_id,
74            pattern,
75            once,
76        });
77        Ok(watch_id)
78    }
79
80    pub fn unregister(&mut self, task_id: &str, watch_id: &str) {
81        if let Some(watches) = self.watches.get_mut(task_id) {
82            watches.retain(|watch| watch.watch_id != watch_id);
83            if watches.is_empty() {
84                self.watches.remove(task_id);
85            }
86        }
87    }
88
89    pub fn clear_task(&mut self, task_id: &str) {
90        self.watches.remove(task_id);
91        self.controlled_tasks.remove(task_id);
92        self.matched_tasks.remove(task_id);
93        let prefix = format!("{task_id}:");
94        self.scan_cursors
95            .retain(|key, _| key != task_id && !key.starts_with(&prefix));
96        self.scan_overlaps
97            .retain(|key, _| key != task_id && !key.starts_with(&prefix));
98    }
99
100    pub fn has_controlled_task(&self, task_id: &str) -> bool {
101        self.controlled_tasks.contains(task_id)
102    }
103
104    pub fn has_matched_task(&self, task_id: &str) -> bool {
105        self.matched_tasks.contains(task_id)
106    }
107
108    pub fn active_count(&self, task_id: &str) -> usize {
109        self.watches.get(task_id).map_or(0, Vec::len)
110    }
111
112    pub fn prime_file_cursor(&mut self, cursor_key: &str, path: &Path) {
113        if self.scan_cursors.contains_key(cursor_key) {
114            return;
115        }
116        let len = File::open(path)
117            .and_then(|file| file.metadata())
118            .map(|metadata| metadata.len())
119            .unwrap_or(0);
120        self.scan_cursors.insert(cursor_key.to_string(), len);
121    }
122
123    pub fn set_file_cursor(&mut self, cursor_key: &str, offset: u64) {
124        self.scan_cursors.insert(cursor_key.to_string(), offset);
125        self.scan_overlaps.remove(cursor_key);
126    }
127
128    pub fn scan_file_new_bytes(
129        &mut self,
130        cursor_key: &str,
131        task_id: &str,
132        path: &Path,
133    ) -> Vec<PatternMatch> {
134        if self.active_count(task_id) == 0 {
135            return Vec::new();
136        }
137        let Ok(mut file) = File::open(path) else {
138            return Vec::new();
139        };
140        let cursor = self
141            .scan_cursors
142            .get(cursor_key)
143            .copied()
144            .unwrap_or_else(|| {
145                // Start at current EOF so a newly registered watch does not match old spill content.
146                file.metadata().map(|m| m.len()).unwrap_or(0)
147            });
148        if file.seek(SeekFrom::Start(cursor)).is_err() {
149            return Vec::new();
150        }
151        let mut bytes = Vec::new();
152        if file.read_to_end(&mut bytes).is_err() || bytes.is_empty() {
153            self.scan_cursors.insert(cursor_key.to_string(), cursor);
154            return Vec::new();
155        }
156        let next = cursor.saturating_add(bytes.len() as u64);
157        self.scan_cursors.insert(cursor_key.to_string(), next);
158        self.scan_new_bytes_at(cursor_key, task_id, &bytes, cursor)
159    }
160
161    pub fn scan_new_bytes(&mut self, task_id: &str, bytes: &[u8]) -> Vec<PatternMatch> {
162        let base = self.scan_cursors.get(task_id).copied().unwrap_or(0);
163        self.scan_cursors
164            .insert(task_id.to_string(), base.saturating_add(bytes.len() as u64));
165        self.scan_new_bytes_at(task_id, task_id, bytes, base)
166    }
167
168    fn scan_new_bytes_at(
169        &mut self,
170        cursor_key: &str,
171        task_id: &str,
172        bytes: &[u8],
173        base_offset: u64,
174    ) -> Vec<PatternMatch> {
175        let Some(watches) = self.watches.get(task_id).cloned() else {
176            return Vec::new();
177        };
178        let overlap = self
179            .scan_overlaps
180            .get(cursor_key)
181            .cloned()
182            .unwrap_or_default();
183        let prefix_len = overlap.len();
184        let mut scan_bytes = Vec::with_capacity(prefix_len.saturating_add(bytes.len()));
185        scan_bytes.extend_from_slice(&overlap);
186        scan_bytes.extend_from_slice(bytes);
187        let text = String::from_utf8_lossy(&scan_bytes);
188        let scan_base_offset = base_offset.saturating_sub(prefix_len as u64);
189        let mut matches = Vec::new();
190        let mut remove_once = Vec::new();
191        for watch in watches {
192            if let Some((start, end, matched)) = find_match(&watch.pattern, &text, prefix_len) {
193                self.matched_tasks.insert(task_id.to_string());
194                matches.push(PatternMatch {
195                    watch_id: watch.watch_id.clone(),
196                    task_id: watch.task_id.clone(),
197                    match_text: matched,
198                    match_offset: scan_base_offset.saturating_add(start as u64),
199                    context: context_snippet(&text, start, end),
200                    once: watch.once,
201                });
202                if watch.once {
203                    remove_once.push(watch.watch_id);
204                }
205            }
206        }
207        for watch_id in remove_once {
208            self.unregister(task_id, &watch_id);
209        }
210        let keep = scan_bytes.len().min(SCAN_OVERLAP_BYTES);
211        self.scan_overlaps.insert(
212            cursor_key.to_string(),
213            scan_bytes[scan_bytes.len().saturating_sub(keep)..].to_vec(),
214        );
215        matches
216    }
217}
218
219fn find_match(
220    pattern: &WatchPattern,
221    text: &str,
222    min_end_exclusive: usize,
223) -> Option<(usize, usize, String)> {
224    match pattern {
225        WatchPattern::Substring(needle) => {
226            if needle.is_empty() {
227                return None;
228            }
229            let mut search_start = min_end_exclusive.saturating_sub(needle.len().saturating_sub(1));
230            while search_start > 0 && !text.is_char_boundary(search_start) {
231                search_start -= 1;
232            }
233            text.get(search_start..).and_then(|tail| {
234                tail.find(needle).and_then(|relative_start| {
235                    let start = search_start + relative_start;
236                    let end = start + needle.len();
237                    (end > min_end_exclusive).then(|| (start, end, needle.clone()))
238                })
239            })
240        }
241        WatchPattern::Regex(regex) => regex
242            .find_iter(text)
243            .find(|m| m.end() > min_end_exclusive)
244            .map(|m| (m.start(), m.end(), m.as_str().to_string())),
245    }
246}
247
248fn context_snippet(text: &str, start: usize, end: usize) -> String {
249    let before_start = text[..start]
250        .char_indices()
251        .rev()
252        .nth(CONTEXT_BEFORE)
253        .map(|(idx, _)| idx)
254        .unwrap_or(0);
255    let after_end = text[end..]
256        .char_indices()
257        .nth(CONTEXT_AFTER)
258        .map(|(idx, _)| end + idx)
259        .unwrap_or(text.len());
260    text[before_start..after_end].replace('\r', "")
261}
262
263#[cfg(test)]
264mod tests {
265    use super::*;
266
267    #[test]
268    fn once_watch_self_removes_after_match() {
269        let mut registry = WatchRegistry::default();
270        let task_id = "bash-1".to_string();
271        registry
272            .register(
273                task_id.clone(),
274                WatchPattern::Substring("READY".into()),
275                true,
276            )
277            .unwrap();
278        assert_eq!(registry.scan_new_bytes(&task_id, b"READY\n").len(), 1);
279        assert_eq!(registry.active_count(&task_id), 0);
280    }
281
282    #[test]
283    fn sticky_watch_fires_multiple_times() {
284        let mut registry = WatchRegistry::default();
285        let task_id = "bash-1".to_string();
286        registry
287            .register(
288                task_id.clone(),
289                WatchPattern::Substring("READY".into()),
290                false,
291            )
292            .unwrap();
293        assert_eq!(registry.scan_new_bytes(&task_id, b"READY\n").len(), 1);
294        assert_eq!(registry.scan_new_bytes(&task_id, b"READY\n").len(), 1);
295        assert_eq!(registry.active_count(&task_id), 1);
296    }
297
298    #[test]
299    fn cap_8_watches_per_task_rejects_9th() {
300        let mut registry = WatchRegistry::default();
301        for _ in 0..8 {
302            registry
303                .register("bash-1".into(), WatchPattern::Substring("x".into()), true)
304                .unwrap();
305        }
306        assert_eq!(
307            registry.register("bash-1".into(), WatchPattern::Substring("x".into()), true),
308            Err("too_many_watches")
309        );
310    }
311
312    #[test]
313    fn regex_pattern_matches_with_capture() {
314        let mut registry = WatchRegistry::default();
315        let task_id = "bash-1".to_string();
316        registry
317            .register(
318                task_id.clone(),
319                WatchPattern::regex("port (\\d+)").unwrap(),
320                true,
321            )
322            .unwrap();
323        let hits = registry.scan_new_bytes(&task_id, b"listening on port 3000\n");
324        assert_eq!(hits[0].match_text, "port 3000");
325    }
326
327    #[test]
328    fn substring_pattern_can_span_scans() {
329        let mut registry = WatchRegistry::default();
330        let task_id = "bash-1".to_string();
331        registry
332            .register(
333                task_id.clone(),
334                WatchPattern::Substring("READY".into()),
335                true,
336            )
337            .unwrap();
338
339        assert!(registry.scan_new_bytes(&task_id, b"RE").is_empty());
340        let hits = registry.scan_new_bytes(&task_id, b"ADY\n");
341
342        assert_eq!(hits.len(), 1);
343        assert_eq!(hits[0].match_text, "READY");
344        assert_eq!(hits[0].match_offset, 0);
345    }
346
347    #[test]
348    fn regex_pattern_can_span_scans() {
349        let mut registry = WatchRegistry::default();
350        let task_id = "bash-1".to_string();
351        registry
352            .register(
353                task_id.clone(),
354                WatchPattern::regex("ready: \\d{4}").unwrap(),
355                true,
356            )
357            .unwrap();
358
359        assert!(registry
360            .scan_new_bytes(&task_id, b"prefix ready: 4")
361            .is_empty());
362        let hits = registry.scan_new_bytes(&task_id, b"242\n");
363
364        assert_eq!(hits.len(), 1);
365        assert_eq!(hits[0].match_text, "ready: 4242");
366        assert_eq!(hits[0].match_offset, 7);
367    }
368
369    #[test]
370    fn overlap_does_not_repeat_fully_previous_match() {
371        let mut registry = WatchRegistry::default();
372        let task_id = "bash-1".to_string();
373        registry
374            .register(
375                task_id.clone(),
376                WatchPattern::Substring("READY".into()),
377                false,
378            )
379            .unwrap();
380
381        assert_eq!(registry.scan_new_bytes(&task_id, b"READY").len(), 1);
382        assert!(registry.scan_new_bytes(&task_id, b"\n").is_empty());
383    }
384}