Skip to main content

aft/bash_background/
watches.rs

1use std::collections::{HashMap, HashSet};
2use std::fs::File;
3use std::io::{Read, Seek, SeekFrom};
4use std::path::Path;
5
6use regex::{Regex, RegexBuilder};
7use serde::Serialize;
8
9const MAX_WATCHES_PER_TASK: usize = 8;
10const CONTEXT_BEFORE: usize = 100;
11const CONTEXT_AFTER: usize = 500;
12
13#[derive(Debug, Clone)]
14pub struct WatchSpec {
15    pub watch_id: String,
16    pub task_id: String,
17    pub pattern: WatchPattern,
18    pub once: bool,
19}
20
21#[derive(Debug, Clone)]
22pub enum WatchPattern {
23    Substring(String),
24    Regex(Regex),
25}
26
27impl WatchPattern {
28    pub fn regex(pattern: &str) -> Result<Self, regex::Error> {
29        RegexBuilder::new(pattern)
30            .multi_line(true)
31            .build()
32            .map(Self::Regex)
33    }
34}
35
36#[derive(Debug, Clone, Serialize)]
37pub struct PatternMatch {
38    pub watch_id: String,
39    pub task_id: String,
40    pub match_text: String,
41    pub match_offset: u64,
42    pub context: String,
43    pub once: bool,
44}
45
46#[derive(Debug, Default)]
47pub struct WatchRegistry {
48    watches: HashMap<String, Vec<WatchSpec>>,
49    scan_cursors: HashMap<String, u64>,
50    controlled_tasks: HashSet<String>,
51    matched_tasks: HashSet<String>,
52    next_watch: u64,
53}
54
55impl WatchRegistry {
56    pub fn register(
57        &mut self,
58        task_id: String,
59        pattern: WatchPattern,
60        once: bool,
61    ) -> Result<String, &'static str> {
62        let watches = self.watches.entry(task_id.clone()).or_default();
63        if watches.len() >= MAX_WATCHES_PER_TASK {
64            return Err("too_many_watches");
65        }
66        self.controlled_tasks.insert(task_id.clone());
67        self.next_watch = self.next_watch.wrapping_add(1);
68        let watch_id = format!("watch-{:08x}", self.next_watch);
69        watches.push(WatchSpec {
70            watch_id: watch_id.clone(),
71            task_id,
72            pattern,
73            once,
74        });
75        Ok(watch_id)
76    }
77
78    pub fn unregister(&mut self, task_id: &str, watch_id: &str) {
79        if let Some(watches) = self.watches.get_mut(task_id) {
80            watches.retain(|watch| watch.watch_id != watch_id);
81            if watches.is_empty() {
82                self.watches.remove(task_id);
83            }
84        }
85    }
86
87    pub fn clear_task(&mut self, task_id: &str) {
88        self.watches.remove(task_id);
89        self.controlled_tasks.remove(task_id);
90        self.matched_tasks.remove(task_id);
91        let prefix = format!("{task_id}:");
92        self.scan_cursors
93            .retain(|key, _| key != task_id && !key.starts_with(&prefix));
94    }
95
96    pub fn has_controlled_task(&self, task_id: &str) -> bool {
97        self.controlled_tasks.contains(task_id)
98    }
99
100    pub fn has_matched_task(&self, task_id: &str) -> bool {
101        self.matched_tasks.contains(task_id)
102    }
103
104    pub fn active_count(&self, task_id: &str) -> usize {
105        self.watches.get(task_id).map_or(0, Vec::len)
106    }
107
108    pub fn prime_file_cursor(&mut self, cursor_key: &str, path: &Path) {
109        if self.scan_cursors.contains_key(cursor_key) {
110            return;
111        }
112        let len = File::open(path)
113            .and_then(|file| file.metadata())
114            .map(|metadata| metadata.len())
115            .unwrap_or(0);
116        self.scan_cursors.insert(cursor_key.to_string(), len);
117    }
118
119    pub fn set_file_cursor(&mut self, cursor_key: &str, offset: u64) {
120        self.scan_cursors.insert(cursor_key.to_string(), offset);
121    }
122
123    pub fn scan_file_new_bytes(
124        &mut self,
125        cursor_key: &str,
126        task_id: &str,
127        path: &Path,
128    ) -> Vec<PatternMatch> {
129        if self.active_count(task_id) == 0 {
130            return Vec::new();
131        }
132        let Ok(mut file) = File::open(path) else {
133            return Vec::new();
134        };
135        let cursor = self
136            .scan_cursors
137            .get(cursor_key)
138            .copied()
139            .unwrap_or_else(|| {
140                // Start at current EOF so a newly registered watch does not match old spill content.
141                file.metadata().map(|m| m.len()).unwrap_or(0)
142            });
143        if file.seek(SeekFrom::Start(cursor)).is_err() {
144            return Vec::new();
145        }
146        let mut bytes = Vec::new();
147        if file.read_to_end(&mut bytes).is_err() || bytes.is_empty() {
148            self.scan_cursors.insert(cursor_key.to_string(), cursor);
149            return Vec::new();
150        }
151        let next = cursor.saturating_add(bytes.len() as u64);
152        self.scan_cursors.insert(cursor_key.to_string(), next);
153        self.scan_new_bytes_at(task_id, &bytes, cursor)
154    }
155
156    pub fn scan_new_bytes(&mut self, task_id: &str, bytes: &[u8]) -> Vec<PatternMatch> {
157        let base = self.scan_cursors.get(task_id).copied().unwrap_or(0);
158        self.scan_cursors
159            .insert(task_id.to_string(), base.saturating_add(bytes.len() as u64));
160        self.scan_new_bytes_at(task_id, bytes, base)
161    }
162
163    fn scan_new_bytes_at(
164        &mut self,
165        task_id: &str,
166        bytes: &[u8],
167        base_offset: u64,
168    ) -> Vec<PatternMatch> {
169        let Some(watches) = self.watches.get(task_id).cloned() else {
170            return Vec::new();
171        };
172        let text = String::from_utf8_lossy(bytes);
173        let mut matches = Vec::new();
174        let mut remove_once = Vec::new();
175        for watch in watches {
176            if let Some((start, end, matched)) = find_match(&watch.pattern, &text) {
177                self.matched_tasks.insert(task_id.to_string());
178                matches.push(PatternMatch {
179                    watch_id: watch.watch_id.clone(),
180                    task_id: watch.task_id.clone(),
181                    match_text: matched,
182                    match_offset: base_offset.saturating_add(start as u64),
183                    context: context_snippet(&text, start, end),
184                    once: watch.once,
185                });
186                if watch.once {
187                    remove_once.push(watch.watch_id);
188                }
189            }
190        }
191        for watch_id in remove_once {
192            self.unregister(task_id, &watch_id);
193        }
194        matches
195    }
196}
197
198fn find_match(pattern: &WatchPattern, text: &str) -> Option<(usize, usize, String)> {
199    match pattern {
200        WatchPattern::Substring(needle) => text.find(needle).map(|start| {
201            let end = start + needle.len();
202            (start, end, needle.clone())
203        }),
204        WatchPattern::Regex(regex) => regex
205            .find(text)
206            .map(|m| (m.start(), m.end(), m.as_str().to_string())),
207    }
208}
209
210fn context_snippet(text: &str, start: usize, end: usize) -> String {
211    let before_start = text[..start]
212        .char_indices()
213        .rev()
214        .nth(CONTEXT_BEFORE)
215        .map(|(idx, _)| idx)
216        .unwrap_or(0);
217    let after_end = text[end..]
218        .char_indices()
219        .nth(CONTEXT_AFTER)
220        .map(|(idx, _)| end + idx)
221        .unwrap_or(text.len());
222    text[before_start..after_end].replace('\r', "")
223}
224
225#[cfg(test)]
226mod tests {
227    use super::*;
228
229    #[test]
230    fn once_watch_self_removes_after_match() {
231        let mut registry = WatchRegistry::default();
232        let task_id = "bash-1".to_string();
233        registry
234            .register(
235                task_id.clone(),
236                WatchPattern::Substring("READY".into()),
237                true,
238            )
239            .unwrap();
240        assert_eq!(registry.scan_new_bytes(&task_id, b"READY\n").len(), 1);
241        assert_eq!(registry.active_count(&task_id), 0);
242    }
243
244    #[test]
245    fn sticky_watch_fires_multiple_times() {
246        let mut registry = WatchRegistry::default();
247        let task_id = "bash-1".to_string();
248        registry
249            .register(
250                task_id.clone(),
251                WatchPattern::Substring("READY".into()),
252                false,
253            )
254            .unwrap();
255        assert_eq!(registry.scan_new_bytes(&task_id, b"READY\n").len(), 1);
256        assert_eq!(registry.scan_new_bytes(&task_id, b"READY\n").len(), 1);
257        assert_eq!(registry.active_count(&task_id), 1);
258    }
259
260    #[test]
261    fn cap_8_watches_per_task_rejects_9th() {
262        let mut registry = WatchRegistry::default();
263        for _ in 0..8 {
264            registry
265                .register("bash-1".into(), WatchPattern::Substring("x".into()), true)
266                .unwrap();
267        }
268        assert_eq!(
269            registry.register("bash-1".into(), WatchPattern::Substring("x".into()), true),
270            Err("too_many_watches")
271        );
272    }
273
274    #[test]
275    fn regex_pattern_matches_with_capture() {
276        let mut registry = WatchRegistry::default();
277        let task_id = "bash-1".to_string();
278        registry
279            .register(
280                task_id.clone(),
281                WatchPattern::regex("port (\\d+)").unwrap(),
282                true,
283            )
284            .unwrap();
285        let hits = registry.scan_new_bytes(&task_id, b"listening on port 3000\n");
286        assert_eq!(hits[0].match_text, "port 3000");
287    }
288}