Skip to main content

memlink_shm/
recovery.rs

1//! Crash recovery, stale slot cleanup, and daemon liveness monitoring.
2//! Uses slot state machine, timestamps, heartbeats, and PID files.
3
4use std::sync::atomic::{AtomicU64, AtomicU8, Ordering};
5use std::sync::Arc;
6use std::time::{Duration, SystemTime, UNIX_EPOCH};
7use std::thread;
8use std::fs;
9
10#[repr(u8)]
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub enum SlotState {
13    Empty = 0,
14    Writing = 1,
15    Ready = 2,
16    Reading = 3,
17    Done = 4,
18}
19
20impl SlotState {
21    fn from_u8(val: u8) -> Option<Self> {
22        match val {
23            0 => Some(SlotState::Empty),
24            1 => Some(SlotState::Writing),
25            2 => Some(SlotState::Ready),
26            3 => Some(SlotState::Reading),
27            4 => Some(SlotState::Done),
28            _ => None,
29        }
30    }
31
32    fn as_u8(self) -> u8 {
33        self as u8
34    }
35}
36
37pub struct AtomicSlotState {
38    inner: AtomicU8,
39}
40
41impl AtomicSlotState {
42    pub fn new(state: SlotState) -> Self {
43        Self {
44            inner: AtomicU8::new(state.as_u8()),
45        }
46    }
47
48    pub fn load(&self) -> SlotState {
49        SlotState::from_u8(self.inner.load(Ordering::Acquire))
50            .unwrap_or(SlotState::Empty)
51    }
52
53    pub fn store(&self, state: SlotState) {
54        self.inner.store(state.as_u8(), Ordering::Release);
55    }
56
57    pub fn compare_exchange(
58        &self,
59        current: SlotState,
60        new: SlotState,
61    ) -> Result<SlotState, SlotState> {
62        match self.inner.compare_exchange(
63            current.as_u8(),
64            new.as_u8(),
65            Ordering::AcqRel,
66            Ordering::Acquire,
67        ) {
68            Ok(val) => Ok(SlotState::from_u8(val).unwrap_or(SlotState::Empty)),
69            Err(val) => Err(SlotState::from_u8(val).unwrap_or(SlotState::Empty)),
70        }
71    }
72}
73
74#[repr(C)]
75pub struct SlotMetadata {
76    pub state: AtomicSlotState,
77    pub timestamp: AtomicU64,
78    pub sequence: AtomicU64,
79}
80
81impl SlotMetadata {
82    pub fn new() -> Self {
83        Self {
84            state: AtomicSlotState::new(SlotState::Empty),
85            timestamp: AtomicU64::new(current_timestamp()),
86            sequence: AtomicU64::new(0),
87        }
88    }
89
90    pub fn update_timestamp(&self) {
91        self.timestamp.store(current_timestamp(), Ordering::Release);
92    }
93
94    pub fn age_seconds(&self) -> u64 {
95        let current = current_timestamp();
96        let stored = self.timestamp.load(Ordering::Acquire);
97        current.saturating_sub(stored)
98    }
99
100    pub fn is_stale(&self, timeout_seconds: u64) -> bool {
101        let state = self.state.load();
102        match state {
103            SlotState::Writing | SlotState::Reading => {
104                self.age_seconds() > timeout_seconds
105            }
106            _ => false,
107        }
108    }
109
110    pub fn recover_stale(&self, timeout_seconds: u64) -> bool {
111        if self.is_stale(timeout_seconds) {
112            self.state.store(SlotState::Empty);
113            self.update_timestamp();
114            true
115        } else {
116            false
117        }
118    }
119}
120
121impl Default for SlotMetadata {
122    fn default() -> Self {
123        Self::new()
124    }
125}
126
127fn current_timestamp() -> u64 {
128    SystemTime::now()
129        .duration_since(UNIX_EPOCH)
130        .unwrap_or(Duration::ZERO)
131        .as_secs()
132}
133
134pub struct Heartbeat {
135    timestamp: AtomicU64,
136    interval: u64,
137    active: AtomicU8,
138}
139
140impl Heartbeat {
141    pub fn new(interval_seconds: u64) -> Self {
142        Self {
143            timestamp: AtomicU64::new(current_timestamp()),
144            interval: interval_seconds,
145            active: AtomicU8::new(1),
146        }
147    }
148
149    pub fn beat(&self) {
150        self.timestamp.store(current_timestamp(), Ordering::Release);
151    }
152
153    pub fn is_alive(&self, timeout_seconds: u64) -> bool {
154        if self.active.load(Ordering::Acquire) == 0 {
155            return false;
156        }
157        let current = current_timestamp();
158        let last = self.timestamp.load(Ordering::Acquire);
159        current.saturating_sub(last) <= timeout_seconds
160    }
161
162    pub fn stop(&self) {
163        self.active.store(0, Ordering::Release);
164    }
165
166    pub fn start_monitoring(
167        self: &Arc<Self>,
168        callback: impl Fn() + Send + Sync + 'static,
169    ) -> thread::JoinHandle<()> {
170        let heartbeat = Arc::clone(self);
171        thread::spawn(move || {
172            while heartbeat.active.load(Ordering::Acquire) == 1 {
173                thread::sleep(Duration::from_secs(heartbeat.interval));
174                if !heartbeat.is_alive(heartbeat.interval * 3) {
175                    callback();
176                    break;
177                }
178            }
179        })
180    }
181}
182
183pub struct RecoveryManager {
184    pid_path: String,
185    heartbeat: Arc<Heartbeat>,
186    active: AtomicU8,
187}
188
189impl RecoveryManager {
190    pub fn new(_shm_path: &str) -> Self {
191        let pid_path = format!("{}.pid", _shm_path);
192        Self {
193            pid_path,
194            heartbeat: Arc::new(Heartbeat::new(1)),
195            active: AtomicU8::new(0),
196        }
197    }
198
199    pub fn is_already_running(&self) -> bool {
200        if let Ok(pid_str) = fs::read_to_string(&self.pid_path) {
201            if let Ok(pid) = pid_str.trim().parse::<u32>() {
202                if process_exists(pid) {
203                    return true;
204                }
205            }
206        }
207        false
208    }
209
210    pub fn register_daemon(&self) -> Result<(), String> {
211        if self.is_already_running() {
212            return Err("Another daemon is already running".to_string());
213        }
214
215        let pid = std::process::id();
216        fs::write(&self.pid_path, pid.to_string())
217            .map_err(|e| format!("Failed to write PID file: {}", e))?;
218
219        self.active.store(1, Ordering::Release);
220        self.heartbeat.beat();
221        Ok(())
222    }
223
224    pub fn unregister_daemon(&self) {
225        self.active.store(0, Ordering::Release);
226        self.heartbeat.stop();
227        let _ = fs::remove_file(&self.pid_path);
228    }
229
230    pub fn heartbeat(&self) -> &Arc<Heartbeat> {
231        &self.heartbeat
232    }
233
234    pub fn recover_stale_slots(
235        &self,
236        slots: &[SlotMetadata],
237        timeout_seconds: u64,
238    ) -> usize {
239        let mut recovered = 0;
240        for slot in slots {
241            if slot.recover_stale(timeout_seconds) {
242                recovered += 1;
243            }
244        }
245        recovered
246    }
247
248    pub fn cleanup_orphaned_shm(path: &str) -> Result<bool, String> {
249        let pid_path = format!("{}.pid", path);
250
251        if let Ok(pid_str) = fs::read_to_string(&pid_path) {
252            if let Ok(pid) = pid_str.trim().parse::<u32>() {
253                if !process_exists(pid) {
254                    let _ = fs::remove_file(&pid_path);
255                    let _ = fs::remove_file(path);
256                    return Ok(true);
257                }
258            }
259        }
260        Ok(false)
261    }
262}
263
264impl Drop for RecoveryManager {
265    fn drop(&mut self) {
266        self.unregister_daemon();
267    }
268}
269
270fn process_exists(pid: u32) -> bool {
271    #[cfg(unix)]
272    {
273        unsafe {
274            libc::kill(pid as libc::pid_t, 0) == 0
275        }
276    }
277    #[cfg(windows)]
278    {
279        use windows::Win32::Foundation::CloseHandle;
280        use windows::Win32::System::Threading::{
281            OpenProcess, PROCESS_QUERY_INFORMATION,
282        };
283
284        match unsafe {
285            OpenProcess(PROCESS_QUERY_INFORMATION, false, pid)
286        } {
287            Ok(handle) => {
288                let _ = unsafe { CloseHandle(handle) };
289                true
290            }
291            Err(_) => false,
292        }
293    }
294}