Skip to main content

peat_protocol/command/
timeout_manager.rs

1//! Timeout management for hierarchical commands
2//!
3//! Handles command expiration (TTL) and acknowledgment timeout tracking.
4
5use crate::error::Result;
6use peat_schema::command::v1::HierarchicalCommand;
7use std::collections::{BTreeMap, HashMap};
8use std::sync::Arc;
9use std::time::{Duration, SystemTime};
10use tokio::sync::RwLock;
11
12/// Acknowledgment timeout tracking
13#[derive(Debug, Clone)]
14pub struct AckTimeout {
15    /// Command ID being tracked
16    pub command_id: String,
17    /// Node IDs expected to acknowledge
18    pub expected_acks: Vec<String>,
19    /// Node IDs that have acknowledged
20    pub received_acks: Vec<String>,
21    /// Time when timeout expires
22    pub expires_at: SystemTime,
23}
24
25/// Timeout manager for commands and acknowledgments
26///
27/// Tracks command expiration (TTL) and acknowledgment timeouts.
28/// Provides efficient lookup of expired commands via BTreeMap.
29pub struct TimeoutManager {
30    /// Commands indexed by expiration time
31    /// Key: expiration time, Value: list of command IDs expiring at that time
32    expiring_commands: Arc<RwLock<BTreeMap<SystemTime, Vec<String>>>>,
33
34    /// Acknowledgment timeout tracking
35    /// Key: command_id, Value: acknowledgment timeout info
36    ack_timeouts: Arc<RwLock<HashMap<String, AckTimeout>>>,
37}
38
39impl TimeoutManager {
40    /// Create a new timeout manager
41    pub fn new() -> Self {
42        Self {
43            expiring_commands: Arc::new(RwLock::new(BTreeMap::new())),
44            ack_timeouts: Arc::new(RwLock::new(HashMap::new())),
45        }
46    }
47
48    /// Register a command for expiration tracking
49    ///
50    /// If the command has an `expires_at` field, it will be tracked
51    /// for automatic expiration.
52    pub async fn register_expiration(&self, command: &HierarchicalCommand) -> Result<()> {
53        if let Some(expires_at) = command.expires_at.as_ref() {
54            let expiry = SystemTime::UNIX_EPOCH + Duration::from_secs(expires_at.seconds);
55            self.expiring_commands
56                .write()
57                .await
58                .entry(expiry)
59                .or_default()
60                .push(command.command_id.clone());
61        }
62        Ok(())
63    }
64
65    /// Check and process expired commands
66    ///
67    /// Returns a list of command IDs that have expired.
68    /// This should be called periodically by a background task.
69    pub async fn process_expired(&self) -> Vec<String> {
70        let now = SystemTime::now();
71        let mut expired = Vec::new();
72
73        let mut expiring = self.expiring_commands.write().await;
74
75        // Collect all expired keys (expiration times <= now)
76        let expired_keys: Vec<SystemTime> = expiring.range(..=now).map(|(k, _)| *k).collect();
77
78        // Remove and collect all expired commands
79        for key in expired_keys {
80            if let Some(commands) = expiring.remove(&key) {
81                expired.extend(commands);
82            }
83        }
84
85        expired
86    }
87
88    /// Unregister a command from expiration tracking
89    ///
90    /// Called when a command completes before expiring.
91    pub async fn unregister_expiration(&self, command_id: &str) -> Result<()> {
92        let mut expiring = self.expiring_commands.write().await;
93
94        // Remove command from all expiration time buckets
95        for (_, cmd_list) in expiring.iter_mut() {
96            cmd_list.retain(|id| id != command_id);
97        }
98
99        // Clean up empty time buckets
100        expiring.retain(|_, cmd_list| !cmd_list.is_empty());
101
102        Ok(())
103    }
104
105    /// Register an acknowledgment timeout
106    ///
107    /// Tracks expected acknowledgments for a command with a timeout.
108    pub async fn register_ack_timeout(
109        &self,
110        command_id: String,
111        expected_acks: Vec<String>,
112        timeout: Duration,
113    ) -> Result<()> {
114        let ack_timeout = AckTimeout {
115            command_id: command_id.clone(),
116            expected_acks,
117            received_acks: Vec::new(),
118            expires_at: SystemTime::now() + timeout,
119        };
120
121        self.ack_timeouts
122            .write()
123            .await
124            .insert(command_id, ack_timeout);
125
126        Ok(())
127    }
128
129    /// Record a received acknowledgment
130    ///
131    /// Updates the tracking for a command's acknowledgments.
132    /// Returns true if all expected acks have been received.
133    pub async fn record_ack(&self, command_id: &str, node_id: &str) -> bool {
134        let mut timeouts = self.ack_timeouts.write().await;
135
136        if let Some(timeout) = timeouts.get_mut(command_id) {
137            if !timeout.received_acks.contains(&node_id.to_string()) {
138                timeout.received_acks.push(node_id.to_string());
139            }
140
141            // Check if all expected acks received
142            timeout.received_acks.len() >= timeout.expected_acks.len()
143        } else {
144            false
145        }
146    }
147
148    /// Check for acknowledgment timeouts
149    ///
150    /// Returns list of command IDs that have timed out waiting for acks.
151    /// A command has timed out if:
152    /// 1. The timeout period has elapsed
153    /// 2. Not all expected acknowledgments have been received
154    pub async fn check_ack_timeouts(&self) -> Vec<String> {
155        let now = SystemTime::now();
156        let timeouts = self.ack_timeouts.read().await;
157
158        timeouts
159            .iter()
160            .filter(|(_, t)| t.expires_at <= now && t.received_acks.len() < t.expected_acks.len())
161            .map(|(id, _)| id.clone())
162            .collect()
163    }
164
165    /// Get acknowledgment status for a command
166    ///
167    /// Returns the acknowledgment tracking info if it exists.
168    pub async fn get_ack_status(&self, command_id: &str) -> Option<AckTimeout> {
169        self.ack_timeouts.read().await.get(command_id).cloned()
170    }
171
172    /// Remove acknowledgment timeout tracking
173    ///
174    /// Called when a command completes or is cancelled.
175    pub async fn unregister_ack_timeout(&self, command_id: &str) -> Result<()> {
176        self.ack_timeouts.write().await.remove(command_id);
177        Ok(())
178    }
179
180    /// Get count of commands being tracked for expiration
181    pub async fn expiration_count(&self) -> usize {
182        self.expiring_commands
183            .read()
184            .await
185            .values()
186            .map(|v| v.len())
187            .sum()
188    }
189
190    /// Get count of commands being tracked for ack timeout
191    pub async fn ack_timeout_count(&self) -> usize {
192        self.ack_timeouts.read().await.len()
193    }
194}
195
196impl Default for TimeoutManager {
197    fn default() -> Self {
198        Self::new()
199    }
200}
201
202#[cfg(test)]
203mod tests {
204    use super::*;
205    use peat_schema::command::v1::{command_target::Scope, CommandTarget};
206    use peat_schema::common::v1::Timestamp;
207    use tokio::time::sleep;
208
209    fn create_test_command_with_ttl(
210        command_id: &str,
211        expires_at_seconds: u64,
212    ) -> HierarchicalCommand {
213        HierarchicalCommand {
214            command_id: command_id.to_string(),
215            originator_id: "test-node".to_string(),
216            target: Some(CommandTarget {
217                scope: Scope::Individual as i32,
218                target_ids: vec!["target-1".to_string()],
219            }),
220            expires_at: Some(Timestamp {
221                seconds: expires_at_seconds,
222                nanos: 0,
223            }),
224            ..Default::default()
225        }
226    }
227
228    #[tokio::test]
229    async fn test_register_and_process_expired() {
230        let manager = TimeoutManager::new();
231
232        // Create command that expires in the past
233        let now_secs = SystemTime::now()
234            .duration_since(SystemTime::UNIX_EPOCH)
235            .unwrap()
236            .as_secs();
237        let expired_cmd = create_test_command_with_ttl("cmd-1", now_secs - 10);
238
239        manager.register_expiration(&expired_cmd).await.unwrap();
240
241        // Process expired commands
242        let expired = manager.process_expired().await;
243
244        assert_eq!(expired.len(), 1);
245        assert_eq!(expired[0], "cmd-1");
246
247        // Verify count updated
248        assert_eq!(manager.expiration_count().await, 0);
249    }
250
251    #[tokio::test]
252    async fn test_command_not_expired_yet() {
253        let manager = TimeoutManager::new();
254
255        // Create command that expires in the future
256        let now_secs = SystemTime::now()
257            .duration_since(SystemTime::UNIX_EPOCH)
258            .unwrap()
259            .as_secs();
260        let future_cmd = create_test_command_with_ttl("cmd-1", now_secs + 3600);
261
262        manager.register_expiration(&future_cmd).await.unwrap();
263
264        // Process expired - should be empty
265        let expired = manager.process_expired().await;
266
267        assert_eq!(expired.len(), 0);
268        assert_eq!(manager.expiration_count().await, 1);
269    }
270
271    #[tokio::test]
272    async fn test_unregister_expiration() {
273        let manager = TimeoutManager::new();
274
275        let now_secs = SystemTime::now()
276            .duration_since(SystemTime::UNIX_EPOCH)
277            .unwrap()
278            .as_secs();
279        let cmd = create_test_command_with_ttl("cmd-1", now_secs + 3600);
280
281        manager.register_expiration(&cmd).await.unwrap();
282        assert_eq!(manager.expiration_count().await, 1);
283
284        manager.unregister_expiration("cmd-1").await.unwrap();
285        assert_eq!(manager.expiration_count().await, 0);
286    }
287
288    #[tokio::test]
289    async fn test_ack_timeout_registration() {
290        let manager = TimeoutManager::new();
291
292        manager
293            .register_ack_timeout(
294                "cmd-1".to_string(),
295                vec!["node-1".to_string(), "node-2".to_string()],
296                Duration::from_secs(30),
297            )
298            .await
299            .unwrap();
300
301        let status = manager.get_ack_status("cmd-1").await.unwrap();
302        assert_eq!(status.command_id, "cmd-1");
303        assert_eq!(status.expected_acks.len(), 2);
304        assert_eq!(status.received_acks.len(), 0);
305    }
306
307    #[tokio::test]
308    async fn test_record_ack() {
309        let manager = TimeoutManager::new();
310
311        manager
312            .register_ack_timeout(
313                "cmd-1".to_string(),
314                vec!["node-1".to_string(), "node-2".to_string()],
315                Duration::from_secs(30),
316            )
317            .await
318            .unwrap();
319
320        // Record first ack
321        let all_received = manager.record_ack("cmd-1", "node-1").await;
322        assert!(!all_received);
323
324        // Record second ack
325        let all_received = manager.record_ack("cmd-1", "node-2").await;
326        assert!(all_received);
327
328        let status = manager.get_ack_status("cmd-1").await.unwrap();
329        assert_eq!(status.received_acks.len(), 2);
330    }
331
332    #[tokio::test]
333    async fn test_ack_timeout_detection() {
334        let manager = TimeoutManager::new();
335
336        // Register with very short timeout
337        manager
338            .register_ack_timeout(
339                "cmd-1".to_string(),
340                vec!["node-1".to_string(), "node-2".to_string()],
341                Duration::from_millis(100),
342            )
343            .await
344            .unwrap();
345
346        // Only record one ack
347        manager.record_ack("cmd-1", "node-1").await;
348
349        // Wait for timeout
350        sleep(Duration::from_millis(150)).await;
351
352        // Check for timeouts
353        let timed_out = manager.check_ack_timeouts().await;
354        assert_eq!(timed_out.len(), 1);
355        assert_eq!(timed_out[0], "cmd-1");
356    }
357
358    #[tokio::test]
359    async fn test_ack_timeout_not_detected_if_all_received() {
360        let manager = TimeoutManager::new();
361
362        manager
363            .register_ack_timeout(
364                "cmd-1".to_string(),
365                vec!["node-1".to_string(), "node-2".to_string()],
366                Duration::from_millis(100),
367            )
368            .await
369            .unwrap();
370
371        // Record all acks
372        manager.record_ack("cmd-1", "node-1").await;
373        manager.record_ack("cmd-1", "node-2").await;
374
375        // Wait for timeout
376        sleep(Duration::from_millis(150)).await;
377
378        // Check for timeouts - should be empty since all acks received
379        let timed_out = manager.check_ack_timeouts().await;
380        assert_eq!(timed_out.len(), 0);
381    }
382
383    #[tokio::test]
384    async fn test_unregister_ack_timeout() {
385        let manager = TimeoutManager::new();
386
387        manager
388            .register_ack_timeout(
389                "cmd-1".to_string(),
390                vec!["node-1".to_string()],
391                Duration::from_secs(30),
392            )
393            .await
394            .unwrap();
395
396        assert_eq!(manager.ack_timeout_count().await, 1);
397
398        manager.unregister_ack_timeout("cmd-1").await.unwrap();
399        assert_eq!(manager.ack_timeout_count().await, 0);
400    }
401
402    #[tokio::test]
403    async fn test_multiple_commands_same_expiration() {
404        let manager = TimeoutManager::new();
405
406        let now_secs = SystemTime::now()
407            .duration_since(SystemTime::UNIX_EPOCH)
408            .unwrap()
409            .as_secs();
410
411        // Two commands with same expiration time
412        let cmd1 = create_test_command_with_ttl("cmd-1", now_secs - 10);
413        let cmd2 = create_test_command_with_ttl("cmd-2", now_secs - 10);
414
415        manager.register_expiration(&cmd1).await.unwrap();
416        manager.register_expiration(&cmd2).await.unwrap();
417
418        let expired = manager.process_expired().await;
419
420        assert_eq!(expired.len(), 2);
421        assert!(expired.contains(&"cmd-1".to_string()));
422        assert!(expired.contains(&"cmd-2".to_string()));
423    }
424
425    #[tokio::test]
426    async fn test_duplicate_ack_not_counted_twice() {
427        let manager = TimeoutManager::new();
428
429        manager
430            .register_ack_timeout(
431                "cmd-1".to_string(),
432                vec!["node-1".to_string(), "node-2".to_string()],
433                Duration::from_secs(30),
434            )
435            .await
436            .unwrap();
437
438        // Record same ack twice
439        manager.record_ack("cmd-1", "node-1").await;
440        manager.record_ack("cmd-1", "node-1").await;
441
442        let status = manager.get_ack_status("cmd-1").await.unwrap();
443        assert_eq!(status.received_acks.len(), 1); // Should only count once
444    }
445}