1use crate::error::Result;
6use peat_schema::command::v1::HierarchicalCommand;
7use std::collections::{BTreeMap, HashMap};
8use std::sync::Arc;
9use std::time::{Duration, SystemTime};
10use tokio::sync::RwLock;
11
12#[derive(Debug, Clone)]
14pub struct AckTimeout {
15 pub command_id: String,
17 pub expected_acks: Vec<String>,
19 pub received_acks: Vec<String>,
21 pub expires_at: SystemTime,
23}
24
25pub struct TimeoutManager {
30 expiring_commands: Arc<RwLock<BTreeMap<SystemTime, Vec<String>>>>,
33
34 ack_timeouts: Arc<RwLock<HashMap<String, AckTimeout>>>,
37}
38
39impl TimeoutManager {
40 pub fn new() -> Self {
42 Self {
43 expiring_commands: Arc::new(RwLock::new(BTreeMap::new())),
44 ack_timeouts: Arc::new(RwLock::new(HashMap::new())),
45 }
46 }
47
48 pub async fn register_expiration(&self, command: &HierarchicalCommand) -> Result<()> {
53 if let Some(expires_at) = command.expires_at.as_ref() {
54 let expiry = SystemTime::UNIX_EPOCH + Duration::from_secs(expires_at.seconds);
55 self.expiring_commands
56 .write()
57 .await
58 .entry(expiry)
59 .or_default()
60 .push(command.command_id.clone());
61 }
62 Ok(())
63 }
64
65 pub async fn process_expired(&self) -> Vec<String> {
70 let now = SystemTime::now();
71 let mut expired = Vec::new();
72
73 let mut expiring = self.expiring_commands.write().await;
74
75 let expired_keys: Vec<SystemTime> = expiring.range(..=now).map(|(k, _)| *k).collect();
77
78 for key in expired_keys {
80 if let Some(commands) = expiring.remove(&key) {
81 expired.extend(commands);
82 }
83 }
84
85 expired
86 }
87
88 pub async fn unregister_expiration(&self, command_id: &str) -> Result<()> {
92 let mut expiring = self.expiring_commands.write().await;
93
94 for (_, cmd_list) in expiring.iter_mut() {
96 cmd_list.retain(|id| id != command_id);
97 }
98
99 expiring.retain(|_, cmd_list| !cmd_list.is_empty());
101
102 Ok(())
103 }
104
105 pub async fn register_ack_timeout(
109 &self,
110 command_id: String,
111 expected_acks: Vec<String>,
112 timeout: Duration,
113 ) -> Result<()> {
114 let ack_timeout = AckTimeout {
115 command_id: command_id.clone(),
116 expected_acks,
117 received_acks: Vec::new(),
118 expires_at: SystemTime::now() + timeout,
119 };
120
121 self.ack_timeouts
122 .write()
123 .await
124 .insert(command_id, ack_timeout);
125
126 Ok(())
127 }
128
129 pub async fn record_ack(&self, command_id: &str, node_id: &str) -> bool {
134 let mut timeouts = self.ack_timeouts.write().await;
135
136 if let Some(timeout) = timeouts.get_mut(command_id) {
137 if !timeout.received_acks.contains(&node_id.to_string()) {
138 timeout.received_acks.push(node_id.to_string());
139 }
140
141 timeout.received_acks.len() >= timeout.expected_acks.len()
143 } else {
144 false
145 }
146 }
147
148 pub async fn check_ack_timeouts(&self) -> Vec<String> {
155 let now = SystemTime::now();
156 let timeouts = self.ack_timeouts.read().await;
157
158 timeouts
159 .iter()
160 .filter(|(_, t)| t.expires_at <= now && t.received_acks.len() < t.expected_acks.len())
161 .map(|(id, _)| id.clone())
162 .collect()
163 }
164
165 pub async fn get_ack_status(&self, command_id: &str) -> Option<AckTimeout> {
169 self.ack_timeouts.read().await.get(command_id).cloned()
170 }
171
172 pub async fn unregister_ack_timeout(&self, command_id: &str) -> Result<()> {
176 self.ack_timeouts.write().await.remove(command_id);
177 Ok(())
178 }
179
180 pub async fn expiration_count(&self) -> usize {
182 self.expiring_commands
183 .read()
184 .await
185 .values()
186 .map(|v| v.len())
187 .sum()
188 }
189
190 pub async fn ack_timeout_count(&self) -> usize {
192 self.ack_timeouts.read().await.len()
193 }
194}
195
196impl Default for TimeoutManager {
197 fn default() -> Self {
198 Self::new()
199 }
200}
201
202#[cfg(test)]
203mod tests {
204 use super::*;
205 use peat_schema::command::v1::{command_target::Scope, CommandTarget};
206 use peat_schema::common::v1::Timestamp;
207 use tokio::time::sleep;
208
209 fn create_test_command_with_ttl(
210 command_id: &str,
211 expires_at_seconds: u64,
212 ) -> HierarchicalCommand {
213 HierarchicalCommand {
214 command_id: command_id.to_string(),
215 originator_id: "test-node".to_string(),
216 target: Some(CommandTarget {
217 scope: Scope::Individual as i32,
218 target_ids: vec!["target-1".to_string()],
219 }),
220 expires_at: Some(Timestamp {
221 seconds: expires_at_seconds,
222 nanos: 0,
223 }),
224 ..Default::default()
225 }
226 }
227
228 #[tokio::test]
229 async fn test_register_and_process_expired() {
230 let manager = TimeoutManager::new();
231
232 let now_secs = SystemTime::now()
234 .duration_since(SystemTime::UNIX_EPOCH)
235 .unwrap()
236 .as_secs();
237 let expired_cmd = create_test_command_with_ttl("cmd-1", now_secs - 10);
238
239 manager.register_expiration(&expired_cmd).await.unwrap();
240
241 let expired = manager.process_expired().await;
243
244 assert_eq!(expired.len(), 1);
245 assert_eq!(expired[0], "cmd-1");
246
247 assert_eq!(manager.expiration_count().await, 0);
249 }
250
251 #[tokio::test]
252 async fn test_command_not_expired_yet() {
253 let manager = TimeoutManager::new();
254
255 let now_secs = SystemTime::now()
257 .duration_since(SystemTime::UNIX_EPOCH)
258 .unwrap()
259 .as_secs();
260 let future_cmd = create_test_command_with_ttl("cmd-1", now_secs + 3600);
261
262 manager.register_expiration(&future_cmd).await.unwrap();
263
264 let expired = manager.process_expired().await;
266
267 assert_eq!(expired.len(), 0);
268 assert_eq!(manager.expiration_count().await, 1);
269 }
270
271 #[tokio::test]
272 async fn test_unregister_expiration() {
273 let manager = TimeoutManager::new();
274
275 let now_secs = SystemTime::now()
276 .duration_since(SystemTime::UNIX_EPOCH)
277 .unwrap()
278 .as_secs();
279 let cmd = create_test_command_with_ttl("cmd-1", now_secs + 3600);
280
281 manager.register_expiration(&cmd).await.unwrap();
282 assert_eq!(manager.expiration_count().await, 1);
283
284 manager.unregister_expiration("cmd-1").await.unwrap();
285 assert_eq!(manager.expiration_count().await, 0);
286 }
287
288 #[tokio::test]
289 async fn test_ack_timeout_registration() {
290 let manager = TimeoutManager::new();
291
292 manager
293 .register_ack_timeout(
294 "cmd-1".to_string(),
295 vec!["node-1".to_string(), "node-2".to_string()],
296 Duration::from_secs(30),
297 )
298 .await
299 .unwrap();
300
301 let status = manager.get_ack_status("cmd-1").await.unwrap();
302 assert_eq!(status.command_id, "cmd-1");
303 assert_eq!(status.expected_acks.len(), 2);
304 assert_eq!(status.received_acks.len(), 0);
305 }
306
307 #[tokio::test]
308 async fn test_record_ack() {
309 let manager = TimeoutManager::new();
310
311 manager
312 .register_ack_timeout(
313 "cmd-1".to_string(),
314 vec!["node-1".to_string(), "node-2".to_string()],
315 Duration::from_secs(30),
316 )
317 .await
318 .unwrap();
319
320 let all_received = manager.record_ack("cmd-1", "node-1").await;
322 assert!(!all_received);
323
324 let all_received = manager.record_ack("cmd-1", "node-2").await;
326 assert!(all_received);
327
328 let status = manager.get_ack_status("cmd-1").await.unwrap();
329 assert_eq!(status.received_acks.len(), 2);
330 }
331
332 #[tokio::test]
333 async fn test_ack_timeout_detection() {
334 let manager = TimeoutManager::new();
335
336 manager
338 .register_ack_timeout(
339 "cmd-1".to_string(),
340 vec!["node-1".to_string(), "node-2".to_string()],
341 Duration::from_millis(100),
342 )
343 .await
344 .unwrap();
345
346 manager.record_ack("cmd-1", "node-1").await;
348
349 sleep(Duration::from_millis(150)).await;
351
352 let timed_out = manager.check_ack_timeouts().await;
354 assert_eq!(timed_out.len(), 1);
355 assert_eq!(timed_out[0], "cmd-1");
356 }
357
358 #[tokio::test]
359 async fn test_ack_timeout_not_detected_if_all_received() {
360 let manager = TimeoutManager::new();
361
362 manager
363 .register_ack_timeout(
364 "cmd-1".to_string(),
365 vec!["node-1".to_string(), "node-2".to_string()],
366 Duration::from_millis(100),
367 )
368 .await
369 .unwrap();
370
371 manager.record_ack("cmd-1", "node-1").await;
373 manager.record_ack("cmd-1", "node-2").await;
374
375 sleep(Duration::from_millis(150)).await;
377
378 let timed_out = manager.check_ack_timeouts().await;
380 assert_eq!(timed_out.len(), 0);
381 }
382
383 #[tokio::test]
384 async fn test_unregister_ack_timeout() {
385 let manager = TimeoutManager::new();
386
387 manager
388 .register_ack_timeout(
389 "cmd-1".to_string(),
390 vec!["node-1".to_string()],
391 Duration::from_secs(30),
392 )
393 .await
394 .unwrap();
395
396 assert_eq!(manager.ack_timeout_count().await, 1);
397
398 manager.unregister_ack_timeout("cmd-1").await.unwrap();
399 assert_eq!(manager.ack_timeout_count().await, 0);
400 }
401
402 #[tokio::test]
403 async fn test_multiple_commands_same_expiration() {
404 let manager = TimeoutManager::new();
405
406 let now_secs = SystemTime::now()
407 .duration_since(SystemTime::UNIX_EPOCH)
408 .unwrap()
409 .as_secs();
410
411 let cmd1 = create_test_command_with_ttl("cmd-1", now_secs - 10);
413 let cmd2 = create_test_command_with_ttl("cmd-2", now_secs - 10);
414
415 manager.register_expiration(&cmd1).await.unwrap();
416 manager.register_expiration(&cmd2).await.unwrap();
417
418 let expired = manager.process_expired().await;
419
420 assert_eq!(expired.len(), 2);
421 assert!(expired.contains(&"cmd-1".to_string()));
422 assert!(expired.contains(&"cmd-2".to_string()));
423 }
424
425 #[tokio::test]
426 async fn test_duplicate_ack_not_counted_twice() {
427 let manager = TimeoutManager::new();
428
429 manager
430 .register_ack_timeout(
431 "cmd-1".to_string(),
432 vec!["node-1".to_string(), "node-2".to_string()],
433 Duration::from_secs(30),
434 )
435 .await
436 .unwrap();
437
438 manager.record_ack("cmd-1", "node-1").await;
440 manager.record_ack("cmd-1", "node-1").await;
441
442 let status = manager.get_ack_status("cmd-1").await.unwrap();
443 assert_eq!(status.received_acks.len(), 1); }
445}