1use crate::error::Result;
4use azoth::AzothDb;
5use azoth_core::traits::canonical::{CanonicalReadTxn, CanonicalStore};
6use serde::{Deserialize, Serialize};
7use std::sync::Arc;
8use std::time::Duration;
9
10#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
12pub enum RetentionPolicy {
13 KeepAll,
15
16 KeepDays(u64),
18
19 KeepCount(u64),
21}
22
23#[derive(Debug, Clone, Default)]
25pub struct CompactionStats {
26 pub deleted: u64,
28
29 pub min_consumer_position: Option<u64>,
31
32 pub policy_cutoff: u64,
34
35 pub actual_cutoff: u64,
37}
38
39impl CompactionStats {
40 pub fn empty() -> Self {
41 Self::default()
42 }
43}
44
45pub struct RetentionManager {
47 db: Arc<AzothDb>,
48}
49
50impl RetentionManager {
51 pub fn new(db: Arc<AzothDb>) -> Self {
53 Self { db }
54 }
55
56 pub fn set_retention(&self, stream: &str, policy: RetentionPolicy) -> Result<()> {
58 let key = format!("bus:stream:{}:retention", stream).into_bytes();
59 let policy_bytes = serde_json::to_vec(&policy)?;
60
61 azoth::Transaction::new(&self.db)
62 .keys(vec![key.clone()])
63 .execute(|ctx| {
64 ctx.set(&key, &azoth::TypedValue::Bytes(policy_bytes))?;
65 Ok(())
66 })?;
67
68 Ok(())
69 }
70
71 pub fn get_retention(&self, stream: &str) -> Result<RetentionPolicy> {
73 let key = format!("bus:stream:{}:retention", stream).into_bytes();
74 let txn = self.db.canonical().read_txn()?;
75
76 match txn.get_state(&key)? {
77 Some(bytes) => {
78 let value = azoth::TypedValue::from_bytes(&bytes)?;
79 let policy_bytes = match value {
80 azoth::TypedValue::Bytes(b) => b,
81 _ => {
82 return Err(crate::error::BusError::InvalidState(
83 "Retention policy must be bytes".into(),
84 ))
85 }
86 };
87 Ok(serde_json::from_slice(&policy_bytes)?)
88 }
89 None => Ok(RetentionPolicy::KeepAll), }
91 }
92
93 fn find_min_consumer_cursor(&self, stream: &str) -> Result<Option<u64>> {
98 let bus = crate::EventBus::new(self.db.clone());
101 let consumers = bus.list_consumers(stream)?;
102
103 if consumers.is_empty() {
104 return Ok(None);
105 }
106
107 let min_position = consumers.iter().map(|c| c.position).min().unwrap();
111
112 if min_position == 0 {
114 return Ok(None);
115 }
116
117 Ok(Some(min_position - 1))
119 }
120
121 pub fn compact(&self, stream: &str) -> Result<CompactionStats> {
131 let policy = self.get_retention(stream)?;
132
133 let policy_cutoff = match policy {
135 RetentionPolicy::KeepAll => {
136 return Ok(CompactionStats::empty());
137 }
138 RetentionPolicy::KeepDays(_days) => {
139 return Ok(CompactionStats::empty());
143 }
144 RetentionPolicy::KeepCount(max) => {
145 let meta = self.db.canonical().meta()?;
146 let head = meta.next_event_id;
147 head.saturating_sub(max)
148 }
149 };
150
151 let min_consumer_position = self.find_min_consumer_cursor(stream)?;
153
154 let actual_cutoff = match min_consumer_position {
156 Some(min_pos) => policy_cutoff.min(min_pos),
157 None => policy_cutoff, };
159
160 let deleted = if actual_cutoff > 0 {
162 0
167 } else {
168 0
169 };
170
171 Ok(CompactionStats {
172 deleted,
173 min_consumer_position,
174 policy_cutoff,
175 actual_cutoff,
176 })
177 }
178
179 pub fn compact_all(&self) -> Result<Vec<(String, CompactionStats)>> {
181 let prefix = b"bus:stream:";
182
183 let mut results = Vec::new();
184 let mut iter = self.db.canonical().scan_prefix(prefix)?;
185
186 while let Some((key, _value)) = iter.next()? {
187 let key_str = String::from_utf8_lossy(&key);
188
189 if key_str.ends_with(":retention") {
191 let parts: Vec<&str> = key_str.split(':').collect();
192 if parts.len() >= 3 {
193 let stream = parts[2];
194 let stats = self.compact(stream)?;
195 results.push((stream.to_string(), stats));
196 }
197 }
198 }
199
200 Ok(results)
201 }
202
203 pub async fn run_continuous(self: Arc<Self>, interval: Duration) {
207 let mut ticker = tokio::time::interval(interval);
208
209 loop {
210 ticker.tick().await;
211
212 match self.compact_all() {
213 Ok(results) => {
214 for (stream, stats) in results {
215 if stats.deleted > 0 {
216 tracing::info!(
217 stream = %stream,
218 deleted = stats.deleted,
219 actual_cutoff = stats.actual_cutoff,
220 "Compacted stream"
221 );
222 }
223 }
224 }
225 Err(e) => {
226 tracing::error!(error = ?e, "Compaction failed");
227 }
228 }
229 }
230 }
231}
232
233#[cfg(test)]
234mod tests {
235 use super::*;
236 use azoth::Transaction;
237 use tempfile::TempDir;
238
239 fn test_db() -> (Arc<AzothDb>, TempDir) {
240 let temp = TempDir::new().unwrap();
241 let db = AzothDb::open(temp.path()).unwrap();
242 (Arc::new(db), temp)
243 }
244
245 fn publish_events(db: &AzothDb, stream: &str, count: usize) -> Result<()> {
246 Transaction::new(db).execute(|ctx| {
247 for i in 0..count {
248 let event_type = format!("{}:event{}", stream, i);
249 ctx.log(&event_type, &format!("data{}", i))?;
250 }
251 Ok(())
252 })?;
253 Ok(())
254 }
255
256 #[test]
257 fn test_retention_policy_storage() {
258 let (db, _temp) = test_db();
259 let mgr = RetentionManager::new(db);
260
261 mgr.set_retention("test", RetentionPolicy::KeepDays(7))
263 .unwrap();
264
265 let policy = mgr.get_retention("test").unwrap();
267 assert_eq!(policy, RetentionPolicy::KeepDays(7));
268 }
269
270 #[test]
271 fn test_retention_default_keep_all() {
272 let (db, _temp) = test_db();
273 let mgr = RetentionManager::new(db);
274
275 let policy = mgr.get_retention("nonexistent").unwrap();
277 assert_eq!(policy, RetentionPolicy::KeepAll);
278 }
279
280 #[test]
281 fn test_find_min_consumer_cursor_no_consumers() {
282 let (db, _temp) = test_db();
283 let mgr = RetentionManager::new(db.clone());
284
285 let min = mgr.find_min_consumer_cursor("test").unwrap();
287 assert_eq!(min, None);
288 }
289
290 #[test]
291 fn test_find_min_consumer_cursor_with_consumers() {
292 let (db, _temp) = test_db();
293 let mgr = RetentionManager::new(db.clone());
294
295 publish_events(&db, "test", 10).unwrap();
297
298 {
300 let bus = crate::EventBus::new(db.clone());
301 let mut c1 = bus.subscribe("test", "c1").unwrap();
302 let mut c2 = bus.subscribe("test", "c2").unwrap();
303
304 for _ in 0..5 {
306 if let Some(event) = c1.next().unwrap() {
307 c1.ack(event.id).unwrap();
308 }
309 }
310
311 for _ in 0..3 {
312 if let Some(event) = c2.next().unwrap() {
313 c2.ack(event.id).unwrap();
314 }
315 }
316 } std::thread::sleep(std::time::Duration::from_millis(10));
320
321 let min = mgr.find_min_consumer_cursor("test").unwrap();
323 assert_eq!(min, Some(2));
324 }
325
326 #[test]
327 fn test_compact_keep_count() {
328 let (db, _temp) = test_db();
329 let mgr = RetentionManager::new(db.clone());
330
331 publish_events(&db, "test", 100).unwrap();
332
333 mgr.set_retention("test", RetentionPolicy::KeepCount(50))
335 .unwrap();
336
337 let stats = mgr.compact("test").unwrap();
338
339 assert_eq!(stats.policy_cutoff, 50);
341 assert_eq!(stats.actual_cutoff, 50); }
343
344 #[test]
345 fn test_compact_respects_slow_consumers() {
346 let (db, _temp) = test_db();
347 let mgr = RetentionManager::new(db.clone());
348
349 publish_events(&db, "test", 100).unwrap();
350
351 let bus = crate::EventBus::new(db.clone());
353 let mut consumer = bus.subscribe("test", "slow").unwrap();
354
355 for _ in 0..61 {
356 if let Some(event) = consumer.next().unwrap() {
357 consumer.ack(event.id).unwrap();
358 }
359 }
360
361 mgr.set_retention("test", RetentionPolicy::KeepCount(30))
363 .unwrap();
364
365 let stats = mgr.compact("test").unwrap();
366
367 assert_eq!(stats.policy_cutoff, 70);
369 assert_eq!(stats.min_consumer_position, Some(60));
370 assert_eq!(stats.actual_cutoff, 60); }
372
373 #[test]
374 fn test_compact_all() {
375 let (db, _temp) = test_db();
376 let mgr = RetentionManager::new(db.clone());
377
378 publish_events(&db, "stream1", 50).unwrap();
381 publish_events(&db, "stream2", 30).unwrap();
382
383 mgr.set_retention("stream1", RetentionPolicy::KeepCount(20))
385 .unwrap();
386 mgr.set_retention("stream2", RetentionPolicy::KeepCount(10))
387 .unwrap();
388
389 let results = mgr.compact_all().unwrap();
390
391 assert_eq!(results.len(), 2);
392
393 let stream1_stats = results
395 .iter()
396 .find(|(name, _)| name == "stream1")
397 .map(|(_, stats)| stats)
398 .unwrap();
399 let stream2_stats = results
400 .iter()
401 .find(|(name, _)| name == "stream2")
402 .map(|(_, stats)| stats)
403 .unwrap();
404
405 assert_eq!(stream1_stats.policy_cutoff, 60);
407 assert_eq!(stream2_stats.policy_cutoff, 70);
409 }
410}