amq/persistence/
engine.rs

1use tokio::{sync::mpsc, time::interval};
2
3use crate::{
4    message::{MessageBox, MessageStatus},
5    persistence::{
6        config::PersistenceConfig, snapshot::SnapshotManager, wal::WalEntry, WalManager,
7    },
8};
9
10#[derive(Clone)]
11pub struct PersistenceEngine {
12    #[allow(dead_code)]
13    config: PersistenceConfig,
14    wal_manager: WalManager,
15    snapshot_manager: SnapshotManager,
16    #[allow(dead_code)]
17    snapshot_trigger: mpsc::Sender<()>,
18}
19
20impl PersistenceEngine {
21    pub async fn new(config: PersistenceConfig) -> Result<Self, Box<dyn std::error::Error>> {
22        let wal_manager = WalManager::new(config.clone()).await?;
23        let snapshot_manager = SnapshotManager::new(config.clone()).await?;
24
25        let (snapshot_trigger, mut snapshot_receiver) = mpsc::channel::<()>(1);
26
27        let engine = Self {
28            config: config.clone(),
29            wal_manager,
30            snapshot_manager,
31            snapshot_trigger: snapshot_trigger.clone(),
32        };
33
34        let _snapshot_manager = engine.snapshot_manager.clone();
35        let wal_manager = engine.wal_manager.clone();
36        let config = config.clone();
37
38        tokio::spawn(async move {
39            let mut snapshot_timer = interval(config.snapshot_interval);
40            snapshot_timer.tick().await;
41
42            loop {
43                tokio::select! {
44                    _ = snapshot_timer.tick() => {
45                        let count = wal_manager.get_entry_count().await;
46                        if count >= config.snapshot_wal_threshold {
47                            let _ = snapshot_trigger.send(()).await;
48                        }
49                    }
50                    Some(_) = snapshot_receiver.recv() => {
51                    }
52                }
53            }
54        });
55
56        Ok(engine)
57    }
58
59    pub async fn append_message(&self, message: &MessageBox) -> Result<(), Box<dyn std::error::Error>> {
60        self.wal_manager
61            .append(WalEntry::AddMessage(message.clone()))
62            .await
63    }
64
65    pub async fn update_message_status(
66        &self,
67        id: u64,
68        status: MessageStatus,
69    ) -> Result<(), Box<dyn std::error::Error>> {
70        self.wal_manager
71            .append(WalEntry::UpdateStatus { id, status })
72            .await
73    }
74
75    pub async fn remove_message(&self, id: u64) -> Result<(), Box<dyn std::error::Error>> {
76        self.wal_manager.append(WalEntry::RemoveMessage(id)).await
77    }
78
79    pub async fn create_snapshot(
80        &self,
81        messages: &[MessageBox],
82        next_message_id: u64,
83        next_connection_id: u64,
84    ) -> Result<u64, Box<dyn std::error::Error>> {
85        let timestamp = self
86            .snapshot_manager
87            .create_snapshot(messages, next_message_id, next_connection_id)
88            .await?;
89
90        self.wal_manager.cleanup_old_wals(timestamp).await?;
91
92        Ok(timestamp)
93    }
94
95    pub async fn recover(
96        &self,
97    ) -> Result<(Vec<MessageBox>, u64, u64), Box<dyn std::error::Error>> {
98        let mut messages = Vec::new();
99        let mut next_message_id = 0;
100        let mut next_connection_id = 0;
101
102        if let Some(snapshot) = self.snapshot_manager.load_latest_snapshot().await? {
103            messages = snapshot.messages;
104            next_message_id = snapshot.next_message_id;
105            next_connection_id = snapshot.next_connection_id;
106        }
107
108        let wal_entries = self.wal_manager.recover().await?;
109
110        for entry in wal_entries {
111            match entry {
112                WalEntry::AddMessage(msg) => {
113                    messages.push(msg);
114                }
115                WalEntry::UpdateStatus { id, status } => {
116                    if let Some(msg) = messages.iter_mut().find(|m| m.id == id) {
117                        msg.status = status;
118                    }
119                }
120                WalEntry::AddHistory { id, history } => {
121                    if let Some(msg) = messages.iter_mut().find(|m| m.id == id) {
122                        msg.history.push(history);
123                    }
124                }
125                WalEntry::RemoveMessage(id) => {
126                    messages.retain(|m| m.id != id);
127                }
128            }
129        }
130
131        Ok((messages, next_message_id, next_connection_id))
132    }
133
134    pub async fn shutdown(&self) -> Result<(), Box<dyn std::error::Error>> {
135        Ok(())
136    }
137}