Skip to main content

ralph_workflow/cloud/
heartbeat.rs

1//! Heartbeat background task for cloud mode.
2//!
3//! Sends periodic heartbeat signals to the cloud API to indicate
4//! the container is alive during long-running operations.
5
6use super::CloudReporter;
7use std::sync::mpsc;
8use std::sync::Arc;
9use std::thread::{self, JoinHandle};
10use std::time::Duration;
11
12/// Guard for heartbeat background task.
13///
14/// Automatically stops the heartbeat when dropped.
15pub struct HeartbeatGuard {
16    stop_tx: Option<mpsc::Sender<()>>,
17    done_rx: Option<mpsc::Receiver<()>>,
18    handle: Option<JoinHandle<()>>,
19}
20
21impl HeartbeatGuard {
22    /// Start a heartbeat background task.
23    ///
24    /// The task will send heartbeats at the specified interval until
25    /// the guard is dropped or the token is cancelled.
26    pub fn start(reporter: Arc<dyn CloudReporter>, interval: Duration) -> Self {
27        let (stop_tx, stop_rx) = mpsc::channel::<()>();
28        let (done_tx, done_rx) = mpsc::channel::<()>();
29
30        let handle = thread::spawn(move || {
31            loop {
32                match stop_rx.recv_timeout(interval) {
33                    Ok(()) | Err(mpsc::RecvTimeoutError::Disconnected) => break,
34                    Err(mpsc::RecvTimeoutError::Timeout) => {
35                        // Ignore heartbeat errors - graceful degradation
36                        let _ = reporter.heartbeat();
37                    }
38                }
39            }
40
41            // Signal completion so Drop can join without blocking.
42            let _ = done_tx.send(());
43        });
44
45        Self {
46            stop_tx: Some(stop_tx),
47            done_rx: Some(done_rx),
48            handle: Some(handle),
49        }
50    }
51}
52
53impl Drop for HeartbeatGuard {
54    fn drop(&mut self) {
55        const DROP_JOIN_TIMEOUT: Duration = Duration::from_millis(50);
56
57        if let Some(tx) = self.stop_tx.take() {
58            let _ = tx.send(());
59        }
60
61        // Best-effort: join promptly when possible, but never block pipeline shutdown.
62        // If the heartbeat call is stalled (e.g., network timeout), detach the thread.
63        if let Some(done_rx) = self.done_rx.take() {
64            if done_rx.recv_timeout(DROP_JOIN_TIMEOUT).is_ok() {
65                if let Some(handle) = self.handle.take() {
66                    let _ = handle.join();
67                }
68                return;
69            }
70        }
71
72        // Detach worker if it didn't exit quickly.
73        let _ = self.handle.take();
74    }
75}
76
77#[cfg(test)]
78mod tests {
79    use super::*;
80    use crate::cloud::mock::MockCloudReporter;
81    use std::time::Instant;
82
83    #[test]
84    fn test_heartbeat_sends_periodic_signals() {
85        let reporter = Arc::new(MockCloudReporter::new());
86        let reporter_clone = Arc::clone(&reporter);
87
88        let _guard = HeartbeatGuard::start(reporter_clone, Duration::from_millis(25));
89
90        let deadline = Instant::now() + Duration::from_millis(750);
91        while reporter.heartbeat_count() < 3 && Instant::now() < deadline {
92            thread::sleep(Duration::from_millis(10));
93        }
94
95        let count = reporter.heartbeat_count();
96        assert!(count >= 3, "Expected at least 3 heartbeats, got {count}");
97    }
98
99    #[test]
100    fn test_heartbeat_stops_on_drop() {
101        let reporter = Arc::new(MockCloudReporter::new());
102        let reporter_clone = Arc::clone(&reporter);
103
104        {
105            let _guard = HeartbeatGuard::start(reporter_clone, Duration::from_millis(25));
106            thread::sleep(Duration::from_millis(100));
107        } // guard dropped here
108
109        let count_at_drop = reporter.heartbeat_count();
110        thread::sleep(Duration::from_millis(100));
111        let count_after_drop = reporter.heartbeat_count();
112
113        assert_eq!(
114            count_at_drop, count_after_drop,
115            "Heartbeats should stop after guard is dropped"
116        );
117    }
118
119    #[test]
120    fn test_drop_does_not_block_for_full_interval() {
121        let reporter = Arc::new(MockCloudReporter::new());
122        let reporter_clone = Arc::clone(&reporter);
123
124        let start = Instant::now();
125        {
126            let _guard = HeartbeatGuard::start(reporter_clone, Duration::from_secs(5));
127            // Give the worker a chance to enter its sleep.
128            thread::sleep(Duration::from_millis(50));
129        }
130        let elapsed = start.elapsed();
131
132        assert!(
133            elapsed < Duration::from_millis(500),
134            "drop should return promptly; elapsed={elapsed:?}"
135        );
136    }
137
138    #[test]
139    fn test_drop_does_not_block_when_heartbeat_call_is_stalled() {
140        use crate::cloud::types::{CloudError, PipelineResult, ProgressUpdate};
141        use std::sync::mpsc;
142
143        struct BlockingReporter {
144            entered_tx: mpsc::Sender<()>,
145        }
146
147        impl CloudReporter for BlockingReporter {
148            fn report_progress(&self, _update: &ProgressUpdate) -> Result<(), CloudError> {
149                Ok(())
150            }
151
152            fn heartbeat(&self) -> Result<(), CloudError> {
153                let _ = self.entered_tx.send(());
154                // Simulate network stall (longer than drop timeout).
155                thread::sleep(Duration::from_millis(300));
156                Ok(())
157            }
158
159            fn report_completion(&self, _result: &PipelineResult) -> Result<(), CloudError> {
160                Ok(())
161            }
162        }
163
164        let (tx, rx) = mpsc::channel::<()>();
165        let reporter = Arc::new(BlockingReporter { entered_tx: tx });
166        let reporter_clone = Arc::clone(&reporter);
167
168        let start = Instant::now();
169        {
170            let _guard = HeartbeatGuard::start(reporter_clone, Duration::from_millis(1));
171            // Wait for the worker to enter the heartbeat call.
172            let _ = rx.recv_timeout(Duration::from_millis(250));
173        }
174        let elapsed = start.elapsed();
175
176        assert!(
177            elapsed < Duration::from_millis(150),
178            "drop should not block on stalled heartbeat; elapsed={elapsed:?}"
179        );
180    }
181}