ralph_workflow/cloud/
heartbeat.rs1use super::CloudReporter;
7use std::sync::mpsc;
8use std::sync::Arc;
9use std::thread::{self, JoinHandle};
10use std::time::Duration;
11
12pub struct HeartbeatGuard {
16 stop_tx: Option<mpsc::Sender<()>>,
17 done_rx: Option<mpsc::Receiver<()>>,
18 handle: Option<JoinHandle<()>>,
19}
20
21impl HeartbeatGuard {
22 pub fn start(reporter: Arc<dyn CloudReporter>, interval: Duration) -> Self {
27 let (stop_tx, stop_rx) = mpsc::channel::<()>();
28 let (done_tx, done_rx) = mpsc::channel::<()>();
29
30 let handle = thread::spawn(move || {
31 loop {
32 match stop_rx.recv_timeout(interval) {
33 Ok(()) | Err(mpsc::RecvTimeoutError::Disconnected) => break,
34 Err(mpsc::RecvTimeoutError::Timeout) => {
35 let _ = reporter.heartbeat();
37 }
38 }
39 }
40
41 let _ = done_tx.send(());
43 });
44
45 Self {
46 stop_tx: Some(stop_tx),
47 done_rx: Some(done_rx),
48 handle: Some(handle),
49 }
50 }
51}
52
53impl Drop for HeartbeatGuard {
54 fn drop(&mut self) {
55 const DROP_JOIN_TIMEOUT: Duration = Duration::from_millis(50);
56
57 if let Some(tx) = self.stop_tx.take() {
58 let _ = tx.send(());
59 }
60
61 if let Some(done_rx) = self.done_rx.take() {
64 if done_rx.recv_timeout(DROP_JOIN_TIMEOUT).is_ok() {
65 if let Some(handle) = self.handle.take() {
66 let _ = handle.join();
67 }
68 return;
69 }
70 }
71
72 let _ = self.handle.take();
74 }
75}
76
77#[cfg(test)]
78mod tests {
79 use super::*;
80 use crate::cloud::mock::MockCloudReporter;
81 use std::time::Instant;
82
83 #[test]
84 fn test_heartbeat_sends_periodic_signals() {
85 let reporter = Arc::new(MockCloudReporter::new());
86 let reporter_clone = Arc::clone(&reporter);
87
88 let _guard = HeartbeatGuard::start(reporter_clone, Duration::from_millis(25));
89
90 let deadline = Instant::now() + Duration::from_millis(750);
91 while reporter.heartbeat_count() < 3 && Instant::now() < deadline {
92 thread::sleep(Duration::from_millis(10));
93 }
94
95 let count = reporter.heartbeat_count();
96 assert!(count >= 3, "Expected at least 3 heartbeats, got {count}");
97 }
98
99 #[test]
100 fn test_heartbeat_stops_on_drop() {
101 let reporter = Arc::new(MockCloudReporter::new());
102 let reporter_clone = Arc::clone(&reporter);
103
104 {
105 let _guard = HeartbeatGuard::start(reporter_clone, Duration::from_millis(25));
106 thread::sleep(Duration::from_millis(100));
107 } let count_at_drop = reporter.heartbeat_count();
110 thread::sleep(Duration::from_millis(100));
111 let count_after_drop = reporter.heartbeat_count();
112
113 assert_eq!(
114 count_at_drop, count_after_drop,
115 "Heartbeats should stop after guard is dropped"
116 );
117 }
118
119 #[test]
120 fn test_drop_does_not_block_for_full_interval() {
121 let reporter = Arc::new(MockCloudReporter::new());
122 let reporter_clone = Arc::clone(&reporter);
123
124 let start = Instant::now();
125 {
126 let _guard = HeartbeatGuard::start(reporter_clone, Duration::from_secs(5));
127 thread::sleep(Duration::from_millis(50));
129 }
130 let elapsed = start.elapsed();
131
132 assert!(
133 elapsed < Duration::from_millis(500),
134 "drop should return promptly; elapsed={elapsed:?}"
135 );
136 }
137
138 #[test]
139 fn test_drop_does_not_block_when_heartbeat_call_is_stalled() {
140 use crate::cloud::types::{CloudError, PipelineResult, ProgressUpdate};
141 use std::sync::mpsc;
142
143 struct BlockingReporter {
144 entered_tx: mpsc::Sender<()>,
145 }
146
147 impl CloudReporter for BlockingReporter {
148 fn report_progress(&self, _update: &ProgressUpdate) -> Result<(), CloudError> {
149 Ok(())
150 }
151
152 fn heartbeat(&self) -> Result<(), CloudError> {
153 let _ = self.entered_tx.send(());
154 thread::sleep(Duration::from_millis(300));
156 Ok(())
157 }
158
159 fn report_completion(&self, _result: &PipelineResult) -> Result<(), CloudError> {
160 Ok(())
161 }
162 }
163
164 let (tx, rx) = mpsc::channel::<()>();
165 let reporter = Arc::new(BlockingReporter { entered_tx: tx });
166 let reporter_clone = Arc::clone(&reporter);
167
168 let start = Instant::now();
169 {
170 let _guard = HeartbeatGuard::start(reporter_clone, Duration::from_millis(1));
171 let _ = rx.recv_timeout(Duration::from_millis(250));
173 }
174 let elapsed = start.elapsed();
175
176 assert!(
177 elapsed < Duration::from_millis(150),
178 "drop should not block on stalled heartbeat; elapsed={elapsed:?}"
179 );
180 }
181}