Skip to main content

ralph_workflow/cloud/
runtime.rs

1use std::sync::mpsc;
2use std::sync::Arc;
3use std::thread::{self, JoinHandle};
4use std::time::Duration;
5
6use crate::cloud::types::{heartbeat_drop_join_timeout, heartbeat_should_join_thread};
7use crate::cloud::CloudReporter;
8
9pub mod io_redaction {
10    use std::sync::LazyLock;
11
12    pub static BEARER_TOKEN_RE: LazyLock<regex::Regex> =
13        LazyLock::new(|| regex::Regex::new(r"(?i)(bearer\s+)\S+").expect("valid regex"));
14
15    pub static COMMON_QUERY_RE: LazyLock<regex::Regex> = LazyLock::new(|| {
16        const KEYS: [&str; 5] = [
17            "access_token=",
18            "token=",
19            "password=",
20            "passwd=",
21            "oauth_token=",
22        ];
23        let pattern = format!("(?i)({})([^&\\s]*)", KEYS.join("|"));
24        regex::Regex::new(&pattern).expect("valid regex")
25    });
26
27    pub static TOKEN_LIKE_RE: LazyLock<regex::Regex> = LazyLock::new(|| {
28        const PREFIXES: [&str; 6] = ["ghp_", "github_pat_", "glpat-", "xoxb-", "xapp-", "ya29."];
29        let pattern = format!(
30            "({})[A-Za-z0-9_\\-\\.]+",
31            PREFIXES
32                .iter()
33                .map(|&s| regex::escape(s))
34                .collect::<Vec<_>>()
35                .join("|")
36        );
37        regex::Regex::new(&pattern).expect("valid regex")
38    });
39
40    pub fn redact_bearer_tokens(input: &str) -> String {
41        BEARER_TOKEN_RE
42            .replace_all(input, "$1<redacted>")
43            .to_string()
44    }
45
46    pub fn redact_common_query_params(input: &str) -> String {
47        COMMON_QUERY_RE
48            .replace_all(input, |caps: &regex::Captures| {
49                let key = caps.get(1).map_or("", |m| m.as_str());
50                format!("{}<redacted>", key)
51            })
52            .to_string()
53    }
54
55    pub fn redact_token_like_substrings(input: &str) -> String {
56        TOKEN_LIKE_RE.replace_all(input, "<redacted>").to_string()
57    }
58}
59
60pub struct HeartbeatGuard {
61    stop_tx: Option<mpsc::Sender<()>>,
62    done_rx: Option<mpsc::Receiver<()>>,
63    handle: Option<JoinHandle<()>>,
64}
65
66impl HeartbeatGuard {
67    pub fn start(reporter: Arc<dyn CloudReporter>, interval: Duration) -> Self {
68        let (stop_tx, stop_rx) = mpsc::channel::<()>();
69        let (done_tx, done_rx) = mpsc::channel::<()>();
70
71        let handle = thread::spawn(move || {
72            std::iter::successors(Some(interval), |_| Some(interval))
73                .filter_map(|timeout| match stop_rx.recv_timeout(timeout) {
74                    Err(mpsc::RecvTimeoutError::Timeout) => Some(()),
75                    _ => None,
76                })
77                .for_each(|_| {
78                    let _ = reporter.heartbeat();
79                });
80
81            let _ = done_tx.send(());
82        });
83
84        Self {
85            stop_tx: Some(stop_tx),
86            done_rx: Some(done_rx),
87            handle: Some(handle),
88        }
89    }
90}
91
92impl Drop for HeartbeatGuard {
93    fn drop(&mut self) {
94        let timeout = heartbeat_drop_join_timeout();
95
96        if let Some(tx) = self.stop_tx.take() {
97            let _ = tx.send(());
98        }
99
100        if let (Some(rx), Some(h)) = (self.done_rx.take(), self.handle.take()) {
101            let done_received = rx.recv_timeout(timeout).is_ok();
102            if heartbeat_should_join_thread(done_received) {
103                let _ = h.join();
104            }
105        }
106    }
107}
108
109#[cfg(test)]
110mod tests {
111    use super::*;
112    use crate::cloud::mock::MockCloudReporter;
113    use std::time::Instant;
114
115    #[test]
116    fn test_heartbeat_sends_periodic_signals() {
117        let reporter = Arc::new(MockCloudReporter::new());
118        let reporter_clone = Arc::clone(&reporter);
119
120        let _guard = HeartbeatGuard::start(reporter_clone, Duration::from_millis(25));
121
122        let deadline = Instant::now() + Duration::from_millis(750);
123        while reporter.heartbeat_count() < 3 && Instant::now() < deadline {
124            thread::sleep(Duration::from_millis(10));
125        }
126
127        let count = reporter.heartbeat_count();
128        assert!(count >= 3, "Expected at least 3 heartbeats, got {count}");
129    }
130
131    #[test]
132    fn test_heartbeat_stops_on_drop() {
133        let reporter = Arc::new(MockCloudReporter::new());
134        let reporter_clone = Arc::clone(&reporter);
135
136        {
137            let _guard = HeartbeatGuard::start(reporter_clone, Duration::from_millis(25));
138            thread::sleep(Duration::from_millis(100));
139        }
140
141        let count_at_drop = reporter.heartbeat_count();
142        thread::sleep(Duration::from_millis(100));
143        let count_after_drop = reporter.heartbeat_count();
144
145        assert_eq!(
146            count_at_drop, count_after_drop,
147            "Heartbeats should stop after guard is dropped"
148        );
149    }
150
151    #[test]
152    fn test_drop_does_not_block_for_full_interval() {
153        let reporter = Arc::new(MockCloudReporter::new());
154        let reporter_clone = Arc::clone(&reporter);
155
156        let start = Instant::now();
157        {
158            let _guard = HeartbeatGuard::start(reporter_clone, Duration::from_secs(5));
159            thread::sleep(Duration::from_millis(50));
160        }
161        let elapsed = start.elapsed();
162
163        assert!(
164            elapsed < Duration::from_millis(500),
165            "drop should return promptly; elapsed={elapsed:?}"
166        );
167    }
168
169    #[test]
170    fn test_drop_does_not_block_when_heartbeat_call_is_stalled() {
171        use crate::cloud::types::{CloudError, PipelineResult, ProgressUpdate};
172        use std::sync::mpsc;
173
174        struct BlockingReporter {
175            entered_tx: mpsc::Sender<()>,
176        }
177
178        impl CloudReporter for BlockingReporter {
179            fn report_progress(&self, _update: &ProgressUpdate) -> Result<(), CloudError> {
180                Ok(())
181            }
182
183            fn heartbeat(&self) -> Result<(), CloudError> {
184                let _ = self.entered_tx.send(());
185                thread::sleep(Duration::from_millis(300));
186                Ok(())
187            }
188
189            fn report_completion(&self, _result: &PipelineResult) -> Result<(), CloudError> {
190                Ok(())
191            }
192        }
193
194        let (tx, rx) = mpsc::channel::<()>();
195        let reporter = Arc::new(BlockingReporter { entered_tx: tx });
196        let reporter_clone = Arc::clone(&reporter);
197
198        let start = Instant::now();
199        {
200            let _guard = HeartbeatGuard::start(reporter_clone, Duration::from_millis(1));
201            let _ = rx.recv_timeout(Duration::from_millis(250));
202        }
203        let elapsed = start.elapsed();
204
205        assert!(
206            elapsed < Duration::from_millis(150),
207            "drop should not block on stalled heartbeat; elapsed={elapsed:?}"
208        );
209    }
210}