1use std::{
7 collections::HashMap,
8 ops::Bound,
9 sync::{Arc, RwLock},
10 time::{Duration, Instant},
11};
12
13use reifydb_core::key::{Key, KeyKind};
14use reifydb_type::Result;
15
16use super::{
17 persistence::{
18 decode_object_stats_key, decode_stats, decode_type_stats_key, encode_object_stats_key, encode_stats,
19 encode_type_stats_key, object_stats_key_prefix, type_stats_key_prefix,
20 },
21 types::{ObjectId, StorageStats, Tier},
22};
23use crate::{
24 backend::{PrimitiveStorage, primitive::TableId},
25 stats::parser::extract_object_id,
26};
27
28#[derive(Debug, Clone)]
30pub struct StorageTrackerConfig {
31 pub checkpoint_interval: Duration,
33}
34
35impl Default for StorageTrackerConfig {
36 fn default() -> Self {
37 Self {
38 checkpoint_interval: Duration::from_secs(10),
39 }
40 }
41}
42
43#[derive(Debug, Clone)]
45pub struct PreVersionInfo {
46 pub key_bytes: u64,
48 pub value_bytes: u64,
50}
51
52#[derive(Debug, Clone)]
57pub struct StorageTracker {
58 pub(super) inner: Arc<RwLock<StorageTrackerInner>>,
59}
60
61#[derive(Debug)]
62pub(super) struct StorageTrackerInner {
63 pub(super) by_type: HashMap<(Tier, KeyKind), StorageStats>,
65 pub(super) by_object: HashMap<(Tier, ObjectId), StorageStats>,
67 pub(super) by_tier: HashMap<Tier, StorageStats>,
69 pub(super) config: StorageTrackerConfig,
71 pub(super) last_checkpoint: Instant,
73}
74
75impl StorageTracker {
76 pub fn new(config: StorageTrackerConfig) -> Self {
78 Self {
79 inner: Arc::new(RwLock::new(StorageTrackerInner {
80 by_type: HashMap::new(),
81 by_object: HashMap::new(),
82 by_tier: HashMap::new(),
83 config,
84 last_checkpoint: Instant::now(),
85 })),
86 }
87 }
88
89 pub fn with_defaults() -> Self {
91 Self::new(StorageTrackerConfig::default())
92 }
93
94 pub fn record_write(
102 &self,
103 tier: Tier,
104 key: &[u8],
105 key_bytes: u64,
106 value_bytes: u64,
107 pre_version: Option<PreVersionInfo>,
108 ) {
109 let kind = Key::kind(key);
110 let object_id = kind.map(|k| extract_object_id(key, k));
111
112 let mut inner = self.inner.write().unwrap();
113
114 {
116 let stats = inner.by_tier.entry(tier).or_insert_with(StorageStats::new);
117 if let Some(pre) = &pre_version {
118 stats.record_update(key_bytes, value_bytes, pre.key_bytes, pre.value_bytes);
119 } else {
120 stats.record_insert(key_bytes, value_bytes);
121 }
122 }
123
124 if let Some(kind) = kind {
126 let stats = inner.by_type.entry((tier, kind)).or_insert_with(StorageStats::new);
127
128 if let Some(pre) = &pre_version {
129 stats.record_update(key_bytes, value_bytes, pre.key_bytes, pre.value_bytes);
130 } else {
131 stats.record_insert(key_bytes, value_bytes);
132 }
133 }
134
135 if let Some(object_id) = object_id {
137 let stats = inner.by_object.entry((tier, object_id)).or_insert_with(StorageStats::new);
138
139 if let Some(pre) = &pre_version {
140 stats.record_update(key_bytes, value_bytes, pre.key_bytes, pre.value_bytes);
141 } else {
142 stats.record_insert(key_bytes, value_bytes);
143 }
144 }
145 }
146
147 pub fn record_delete(&self, tier: Tier, key: &[u8], key_bytes: u64, pre_version: Option<PreVersionInfo>) {
154 let Some(pre) = pre_version else {
156 return;
157 };
158
159 let kind = Key::kind(key);
160 let object_id = kind.map(|k| extract_object_id(key, k));
161
162 let mut inner = self.inner.write().unwrap();
163
164 {
166 let stats = inner.by_tier.entry(tier).or_insert_with(StorageStats::new);
167 stats.record_delete(key_bytes, pre.key_bytes, pre.value_bytes);
168 }
169
170 if let Some(kind) = kind {
172 let stats = inner.by_type.entry((tier, kind)).or_insert_with(StorageStats::new);
173 stats.record_delete(key_bytes, pre.key_bytes, pre.value_bytes);
174 }
175
176 if let Some(object_id) = object_id {
178 if let Some(stats) = inner.by_object.get_mut(&(tier, object_id)) {
179 stats.record_delete(key_bytes, pre.key_bytes, pre.value_bytes);
180 }
181 }
182 }
183
184 pub fn record_drop(&self, tier: Tier, key: &[u8], versioned_key_bytes: u64, value_bytes: u64) {
194 let kind = Key::kind(key);
195 let object_id = kind.map(|k| extract_object_id(key, k));
196
197 let mut inner = self.inner.write().unwrap();
198
199 if let Some(stats) = inner.by_tier.get_mut(&tier) {
201 stats.record_drop(versioned_key_bytes, value_bytes);
202 }
203
204 if let Some(kind) = kind {
206 if let Some(stats) = inner.by_type.get_mut(&(tier, kind)) {
207 stats.record_drop(versioned_key_bytes, value_bytes);
208 }
209 }
210
211 if let Some(object_id) = object_id {
213 if let Some(stats) = inner.by_object.get_mut(&(tier, object_id)) {
214 stats.record_drop(versioned_key_bytes, value_bytes);
215 }
216 }
217 }
218
219 pub fn record_cdc_for_change(&self, tier: Tier, key: &[u8], value_bytes: u64, count: u64) {
227 let key_bytes = key.len() as u64;
228
229 let kind = Key::kind(key);
230 let object_id = kind.map(|k| extract_object_id(key, k));
231
232 let mut inner = self.inner.write().unwrap();
233
234 {
236 let stats = inner.by_tier.entry(tier).or_insert_with(StorageStats::new);
237 stats.record_cdc(key_bytes, value_bytes, count);
238 }
239
240 if let Some(kind) = kind {
242 let stats = inner.by_type.entry((tier, kind)).or_insert_with(StorageStats::new);
243 stats.record_cdc(key_bytes, value_bytes, count);
244 }
245
246 if let Some(object_id) = object_id {
248 let stats = inner.by_object.entry((tier, object_id)).or_insert_with(StorageStats::new);
249 stats.record_cdc(key_bytes, value_bytes, count);
250 }
251 }
252
253 pub fn record_tier_migration(
258 &self,
259 from_tier: Tier,
260 to_tier: Tier,
261 key: &[u8],
262 value_bytes: u64,
263 is_current: bool,
264 ) {
265 let key_bytes = key.len() as u64;
266
267 let kind = Key::kind(key);
268 let object_id = kind.map(|k| extract_object_id(key, k));
269
270 let mut inner = self.inner.write().unwrap();
271
272 {
274 if let Some(stats) = inner.by_tier.get_mut(&from_tier) {
276 if is_current {
277 stats.current_key_bytes = stats.current_key_bytes.saturating_sub(key_bytes);
278 stats.current_value_bytes =
279 stats.current_value_bytes.saturating_sub(value_bytes);
280 stats.current_count = stats.current_count.saturating_sub(1);
281 } else {
282 stats.historical_key_bytes =
283 stats.historical_key_bytes.saturating_sub(key_bytes);
284 stats.historical_value_bytes =
285 stats.historical_value_bytes.saturating_sub(value_bytes);
286 stats.historical_count = stats.historical_count.saturating_sub(1);
287 }
288 }
289
290 let stats = inner.by_tier.entry(to_tier).or_insert_with(StorageStats::new);
292 if is_current {
293 stats.current_key_bytes += key_bytes;
294 stats.current_value_bytes += value_bytes;
295 stats.current_count += 1;
296 } else {
297 stats.historical_key_bytes += key_bytes;
298 stats.historical_value_bytes += value_bytes;
299 stats.historical_count += 1;
300 }
301 }
302
303 if let Some(kind) = kind {
305 if let Some(stats) = inner.by_type.get_mut(&(from_tier, kind)) {
307 if is_current {
308 stats.current_key_bytes = stats.current_key_bytes.saturating_sub(key_bytes);
309 stats.current_value_bytes =
310 stats.current_value_bytes.saturating_sub(value_bytes);
311 stats.current_count = stats.current_count.saturating_sub(1);
312 } else {
313 stats.historical_key_bytes =
314 stats.historical_key_bytes.saturating_sub(key_bytes);
315 stats.historical_value_bytes =
316 stats.historical_value_bytes.saturating_sub(value_bytes);
317 stats.historical_count = stats.historical_count.saturating_sub(1);
318 }
319 }
320
321 let stats = inner.by_type.entry((to_tier, kind)).or_insert_with(StorageStats::new);
323 if is_current {
324 stats.current_key_bytes += key_bytes;
325 stats.current_value_bytes += value_bytes;
326 stats.current_count += 1;
327 } else {
328 stats.historical_key_bytes += key_bytes;
329 stats.historical_value_bytes += value_bytes;
330 stats.historical_count += 1;
331 }
332 }
333
334 if let Some(object_id) = object_id {
336 if let Some(stats) = inner.by_object.get_mut(&(from_tier, object_id)) {
338 if is_current {
339 stats.current_key_bytes = stats.current_key_bytes.saturating_sub(key_bytes);
340 stats.current_value_bytes =
341 stats.current_value_bytes.saturating_sub(value_bytes);
342 stats.current_count = stats.current_count.saturating_sub(1);
343 } else {
344 stats.historical_key_bytes =
345 stats.historical_key_bytes.saturating_sub(key_bytes);
346 stats.historical_value_bytes =
347 stats.historical_value_bytes.saturating_sub(value_bytes);
348 stats.historical_count = stats.historical_count.saturating_sub(1);
349 }
350 }
351
352 let stats = inner.by_object.entry((to_tier, object_id)).or_insert_with(StorageStats::new);
354 if is_current {
355 stats.current_key_bytes += key_bytes;
356 stats.current_value_bytes += value_bytes;
357 stats.current_count += 1;
358 } else {
359 stats.historical_key_bytes += key_bytes;
360 stats.historical_value_bytes += value_bytes;
361 stats.historical_count += 1;
362 }
363 }
364 }
365
366 pub fn should_checkpoint(&self) -> bool {
372 let inner = self.inner.read().unwrap();
373 inner.last_checkpoint.elapsed() >= inner.config.checkpoint_interval
374 }
375
376 pub async fn checkpoint_async<S: PrimitiveStorage>(&self, storage: &S) -> Result<()> {
380 storage.ensure_table(TableId::Single).await?;
382
383 let entries: Vec<(Vec<u8>, Option<Vec<u8>>)> = {
384 let inner = self.inner.read().unwrap();
385
386 let mut entries = Vec::new();
387
388 for ((tier, kind), stats) in &inner.by_type {
390 let key = encode_type_stats_key(*tier, *kind);
391 let value = encode_stats(stats);
392 entries.push((key, Some(value)));
393 }
394
395 for ((tier, object_id), stats) in &inner.by_object {
397 let key = encode_object_stats_key(*tier, *object_id);
398 let value = encode_stats(stats);
399 entries.push((key, Some(value)));
400 }
401
402 entries
403 };
404
405 storage.put(TableId::Single, entries).await?;
407
408 {
410 let mut inner = self.inner.write().unwrap();
411 inner.last_checkpoint = Instant::now();
412 }
413
414 Ok(())
415 }
416
417 pub async fn restore_async<S: PrimitiveStorage>(storage: &S, config: StorageTrackerConfig) -> Result<Self> {
421 let mut by_type: HashMap<(Tier, KeyKind), StorageStats> = HashMap::new();
422 let mut by_object: HashMap<(Tier, ObjectId), StorageStats> = HashMap::new();
423
424 let type_prefix = type_stats_key_prefix();
426 let mut end_prefix = type_prefix.clone();
427 if let Some(last) = end_prefix.last_mut() {
428 *last = last.saturating_add(1);
429 }
430
431 let batch = storage
432 .range_batch(TableId::Single, Bound::Included(type_prefix), Bound::Excluded(end_prefix), 1000)
433 .await?;
434
435 for entry in batch.entries {
436 if let Some((tier, kind)) = decode_type_stats_key(&entry.key) {
437 if let Some(value) = entry.value {
438 if let Some(stats) = decode_stats(&value) {
439 by_type.insert((tier, kind), stats);
440 }
441 }
442 }
443 }
444
445 let object_prefix = object_stats_key_prefix();
447 let mut end_prefix = object_prefix.clone();
448 if let Some(last) = end_prefix.last_mut() {
449 *last = last.saturating_add(1);
450 }
451
452 let batch = storage
453 .range_batch(TableId::Single, Bound::Included(object_prefix), Bound::Excluded(end_prefix), 1000)
454 .await?;
455
456 for entry in batch.entries {
457 if let Some((tier, object_id)) = decode_object_stats_key(&entry.key) {
458 if let Some(value) = entry.value {
459 if let Some(stats) = decode_stats(&value) {
460 by_object.insert((tier, object_id), stats);
461 }
462 }
463 }
464 }
465
466 let mut by_tier: HashMap<Tier, StorageStats> = HashMap::new();
468 for ((tier, _kind), stats) in &by_type {
469 let tier_stats = by_tier.entry(*tier).or_insert_with(StorageStats::new);
470 *tier_stats += stats.clone();
471 }
472
473 Ok(Self {
474 inner: Arc::new(RwLock::new(StorageTrackerInner {
475 by_type,
476 by_object,
477 by_tier,
478 config,
479 last_checkpoint: Instant::now(),
480 })),
481 })
482 }
483}
484
485#[cfg(test)]
486mod tests {
487 use reifydb_core::interface::SourceId;
488 use tokio::time::sleep;
489
490 use super::*;
491
492 fn make_row_key(source_id: u64, row: u64) -> Vec<u8> {
493 use reifydb_core::{interface::EncodableKey, key::RowKey};
494 use reifydb_type::RowNumber;
495
496 let key = RowKey {
497 source: SourceId::table(source_id),
498 row: RowNumber(row),
499 };
500 key.encode().to_vec()
501 }
502
503 #[test]
504 fn test_tracker_insert() {
505 let tracker = StorageTracker::with_defaults();
506 let key = make_row_key(1, 100);
507 let key_bytes = key.len() as u64;
508
509 tracker.record_write(Tier::Hot, &key, key_bytes, 50, None);
510
511 let stats = tracker.total_stats();
512 assert_eq!(stats.hot.current_key_bytes, key_bytes);
513 assert_eq!(stats.hot.current_value_bytes, 50);
514 assert_eq!(stats.hot.current_count, 1);
515 assert_eq!(stats.hot.historical_count, 0);
516 }
517
518 #[test]
519 fn test_tracker_update() {
520 let tracker = StorageTracker::with_defaults();
521 let key = make_row_key(1, 100);
522 let key_bytes = key.len() as u64;
523
524 tracker.record_write(Tier::Hot, &key, key_bytes, 50, None);
526
527 let pre_info = PreVersionInfo {
529 key_bytes,
530 value_bytes: 50,
531 };
532 tracker.record_write(Tier::Hot, &key, key_bytes, 75, Some(pre_info));
533
534 let stats = tracker.total_stats();
535 assert_eq!(stats.hot.current_key_bytes, key_bytes);
537 assert_eq!(stats.hot.current_value_bytes, 75);
538 assert_eq!(stats.hot.current_count, 1);
539
540 assert_eq!(stats.hot.historical_key_bytes, key_bytes);
542 assert_eq!(stats.hot.historical_value_bytes, 50);
543 assert_eq!(stats.hot.historical_count, 1);
544 }
545
546 #[test]
547 fn test_tracker_delete() {
548 let tracker = StorageTracker::with_defaults();
549 let key = make_row_key(1, 100);
550 let key_bytes = key.len() as u64;
551
552 tracker.record_write(Tier::Hot, &key, key_bytes, 50, None);
554
555 let pre_info = PreVersionInfo {
557 key_bytes,
558 value_bytes: 50,
559 };
560 tracker.record_delete(Tier::Hot, &key, key_bytes, Some(pre_info));
561
562 let stats = tracker.total_stats();
563 assert_eq!(stats.hot.current_count, 0);
565
566 assert_eq!(stats.hot.historical_count, 2);
568 }
569
570 #[test]
571 fn test_tracker_by_type() {
572 let tracker = StorageTracker::with_defaults();
573 let key1 = make_row_key(1, 100);
574 let key2 = make_row_key(2, 200);
575 let key1_bytes = key1.len() as u64;
576 let key2_bytes = key2.len() as u64;
577
578 tracker.record_write(Tier::Hot, &key1, key1_bytes, 50, None);
579 tracker.record_write(Tier::Hot, &key2, key2_bytes, 60, None);
580
581 let by_type = tracker.stats_by_type(Tier::Hot);
582 let row_stats = by_type.get(&KeyKind::Row).unwrap();
583
584 assert_eq!(row_stats.current_count, 2);
585 assert_eq!(row_stats.current_value_bytes, 110);
586 }
587
588 #[test]
589 fn test_tracker_per_object() {
590 let tracker = StorageTracker::with_defaults();
591 let key1 = make_row_key(1, 100);
592 let key2 = make_row_key(1, 200);
593 let key3 = make_row_key(2, 100);
594 let key1_bytes = key1.len() as u64;
595 let key2_bytes = key2.len() as u64;
596 let key3_bytes = key3.len() as u64;
597
598 tracker.record_write(Tier::Hot, &key1, key1_bytes, 50, None);
599 tracker.record_write(Tier::Hot, &key2, key2_bytes, 60, None);
600 tracker.record_write(Tier::Hot, &key3, key3_bytes, 70, None);
601
602 let source1 = ObjectId::Source(SourceId::table(1));
604 let stats1 = tracker.stats_for_object(source1).unwrap();
605 assert_eq!(stats1.hot.current_count, 2);
606 assert_eq!(stats1.hot.current_value_bytes, 110);
607
608 let source2 = ObjectId::Source(SourceId::table(2));
610 let stats2 = tracker.stats_for_object(source2).unwrap();
611 assert_eq!(stats2.hot.current_count, 1);
612 assert_eq!(stats2.hot.current_value_bytes, 70);
613 }
614
615 #[test]
616 fn test_tracker_tier_migration() {
617 let tracker = StorageTracker::with_defaults();
618 let key = make_row_key(1, 100);
619 let key_bytes = key.len() as u64;
620
621 tracker.record_write(Tier::Hot, &key, key_bytes, 50, None);
623
624 tracker.record_tier_migration(Tier::Hot, Tier::Warm, &key, 50, true);
626
627 let stats = tracker.total_stats();
628 assert_eq!(stats.hot.current_count, 0);
630 assert_eq!(stats.hot.current_bytes(), 0);
631
632 assert_eq!(stats.warm.current_count, 1);
634 assert_eq!(stats.warm.current_key_bytes, key_bytes);
635 assert_eq!(stats.warm.current_value_bytes, 50);
636 }
637
638 #[test]
639 fn test_top_objects() {
640 let tracker = StorageTracker::with_defaults();
641
642 let key1 = make_row_key(1, 100);
644 let key2 = make_row_key(2, 100);
645 let key3 = make_row_key(3, 100);
646 let key1_bytes = key1.len() as u64;
647 let key2_bytes = key2.len() as u64;
648 let key3_bytes = key3.len() as u64;
649
650 tracker.record_write(Tier::Hot, &key1, key1_bytes, 100, None);
651 tracker.record_write(Tier::Hot, &key2, key2_bytes, 200, None);
652 tracker.record_write(Tier::Hot, &key3, key3_bytes, 50, None);
653
654 let top = tracker.top_objects_by_size(2);
655 assert_eq!(top.len(), 2);
656
657 assert_eq!(top[0].0, ObjectId::Source(SourceId::table(2)));
659 assert_eq!(top[1].0, ObjectId::Source(SourceId::table(1)));
661 }
662
663 #[tokio::test]
668 async fn test_should_checkpoint_time_based() {
669 let config = StorageTrackerConfig {
670 checkpoint_interval: Duration::from_millis(50),
671 };
672 let tracker = StorageTracker::new(config);
673
674 assert!(!tracker.should_checkpoint());
676
677 sleep(Duration::from_millis(60)).await;
679
680 assert!(tracker.should_checkpoint());
682 }
683
684 #[tokio::test]
685 async fn test_checkpoint_and_restore_roundtrip() {
686 use crate::backend::BackendStorage;
687
688 let storage = BackendStorage::memory().await;
690
691 let config = StorageTrackerConfig {
693 checkpoint_interval: Duration::from_secs(10),
694 };
695 let tracker = StorageTracker::new(config.clone());
696
697 let key1 = make_row_key(1, 100);
699 let key2 = make_row_key(2, 200);
700 let key1_bytes = key1.len() as u64;
701 let key2_bytes = key2.len() as u64;
702 tracker.record_write(Tier::Hot, &key1, key1_bytes, 50, None);
703 tracker.record_write(Tier::Hot, &key2, key2_bytes, 100, None);
704 tracker.record_write(Tier::Warm, &key1, key1_bytes, 75, None);
705
706 tracker.checkpoint_async(&storage).await.unwrap();
708
709 let restored = StorageTracker::restore_async(&storage, config).await.unwrap();
711
712 let original_stats = tracker.total_stats();
714 let restored_stats = restored.total_stats();
715
716 assert_eq!(original_stats.hot.current_key_bytes, restored_stats.hot.current_key_bytes);
717 assert_eq!(original_stats.hot.current_value_bytes, restored_stats.hot.current_value_bytes);
718 assert_eq!(original_stats.hot.current_count, restored_stats.hot.current_count);
719 assert_eq!(original_stats.warm.current_key_bytes, restored_stats.warm.current_key_bytes);
720 assert_eq!(original_stats.warm.current_value_bytes, restored_stats.warm.current_value_bytes);
721
722 let original_by_type = tracker.stats_by_type(Tier::Hot);
724 let restored_by_type = restored.stats_by_type(Tier::Hot);
725 assert_eq!(
726 original_by_type.get(&KeyKind::Row).unwrap().current_count,
727 restored_by_type.get(&KeyKind::Row).unwrap().current_count
728 );
729
730 let source1 = ObjectId::Source(SourceId::table(1));
732 let original_obj = tracker.stats_for_object(source1).unwrap();
733 let restored_obj = restored.stats_for_object(source1).unwrap();
734 assert_eq!(original_obj.hot.current_value_bytes, restored_obj.hot.current_value_bytes);
735 }
736
737 #[tokio::test]
738 async fn test_checkpoint_resets_timer() {
739 use crate::backend::BackendStorage;
740
741 let storage = BackendStorage::memory().await;
742 let config = StorageTrackerConfig {
743 checkpoint_interval: Duration::from_millis(50),
744 };
745 let tracker = StorageTracker::new(config);
746
747 tokio::time::sleep(Duration::from_millis(60)).await;
749 assert!(tracker.should_checkpoint());
750
751 tracker.checkpoint_async(&storage).await.unwrap();
753
754 assert!(!tracker.should_checkpoint());
756
757 tokio::time::sleep(Duration::from_millis(60)).await;
759
760 assert!(tracker.should_checkpoint());
762 }
763
764 #[tokio::test]
765 async fn test_restore_empty_storage() {
766 use crate::backend::BackendStorage;
767
768 let storage = BackendStorage::memory().await;
770
771 let config = StorageTrackerConfig {
772 checkpoint_interval: Duration::from_secs(10),
773 };
774
775 let tracker = StorageTracker::restore_async(&storage, config).await.unwrap();
777 let stats = tracker.total_stats();
778
779 assert_eq!(stats.hot.current_count, 0);
780 assert_eq!(stats.warm.current_count, 0);
781 assert_eq!(stats.cold.current_count, 0);
782 }
783}