forge_runtime/realtime/
invalidation.rs1use std::collections::{HashMap, HashSet};
2use std::sync::Arc;
3use std::time::Duration;
4
5use tokio::sync::{RwLock, mpsc};
6use tokio::time::Instant;
7
8use forge_core::realtime::{Change, QueryGroupId};
9
10use super::manager::SubscriptionManager;
11
12#[derive(Debug, Clone)]
18pub struct InvalidationConfig {
19 pub debounce_ms: u64,
21 pub max_debounce_ms: u64,
23 pub coalesce_by_table: bool,
25 pub max_buffer_size: usize,
27}
28
29impl Default for InvalidationConfig {
30 fn default() -> Self {
31 Self {
32 debounce_ms: 50,
33 max_debounce_ms: 200,
34 coalesce_by_table: true,
35 max_buffer_size: 1000,
36 }
37 }
38}
39
40#[derive(Debug)]
42struct PendingInvalidation {
43 #[allow(dead_code)]
44 group_id: QueryGroupId,
45 changed_tables: HashSet<String>,
46 first_change: Instant,
47 last_change: Instant,
48}
49
50pub struct InvalidationEngine {
53 subscription_manager: Arc<SubscriptionManager>,
54 #[allow(dead_code)]
55 config: InvalidationConfig,
56 pending: Arc<RwLock<HashMap<QueryGroupId, PendingInvalidation>>>,
58 #[allow(dead_code)]
59 invalidation_tx: mpsc::Sender<Vec<QueryGroupId>>,
60 #[allow(dead_code)]
61 invalidation_rx: Arc<RwLock<mpsc::Receiver<Vec<QueryGroupId>>>>,
62}
63
64impl InvalidationEngine {
65 pub fn new(subscription_manager: Arc<SubscriptionManager>, config: InvalidationConfig) -> Self {
67 let (invalidation_tx, invalidation_rx) = mpsc::channel(1024);
68
69 Self {
70 subscription_manager,
71 config,
72 pending: Arc::new(RwLock::new(HashMap::new())),
73 invalidation_tx,
74 invalidation_rx: Arc::new(RwLock::new(invalidation_rx)),
75 }
76 }
77
78 pub async fn process_change(&self, change: Change) {
80 let affected = self.subscription_manager.find_affected_groups(&change);
81
82 if affected.is_empty() {
83 return;
84 }
85
86 tracing::debug!(
87 table = %change.table,
88 affected_groups = affected.len(),
89 "Found affected groups for change"
90 );
91
92 let now = Instant::now();
93 let mut pending = self.pending.write().await;
94
95 for group_id in affected {
96 let entry = pending
97 .entry(group_id)
98 .or_insert_with(|| PendingInvalidation {
99 group_id,
100 changed_tables: HashSet::new(),
101 first_change: now,
102 last_change: now,
103 });
104
105 entry.changed_tables.insert(change.table.clone());
106 entry.last_change = now;
107 }
108
109 if pending.len() >= self.config.max_buffer_size {
110 drop(pending);
111 self.flush_all().await;
112 }
113 }
114
115 pub async fn check_pending(&self) -> Vec<QueryGroupId> {
117 let now = Instant::now();
118 let debounce = Duration::from_millis(self.config.debounce_ms);
119 let max_debounce = Duration::from_millis(self.config.max_debounce_ms);
120
121 let mut pending = self.pending.write().await;
122 let mut ready = Vec::new();
123
124 pending.retain(|_, inv| {
125 let since_last = now.duration_since(inv.last_change);
126 let since_first = now.duration_since(inv.first_change);
127
128 if since_last >= debounce || since_first >= max_debounce {
129 ready.push(inv.group_id);
130 false
131 } else {
132 true
133 }
134 });
135
136 ready
137 }
138
139 pub async fn flush_all(&self) -> Vec<QueryGroupId> {
141 let mut pending = self.pending.write().await;
142 let ready: Vec<QueryGroupId> = pending.keys().copied().collect();
143 pending.clear();
144 ready
145 }
146
147 pub async fn run(&self) {
149 let check_interval = Duration::from_millis(self.config.debounce_ms / 2);
150
151 loop {
152 tokio::time::sleep(check_interval).await;
153
154 let ready = self.check_pending().await;
155 if !ready.is_empty() && self.invalidation_tx.send(ready).await.is_err() {
156 break;
157 }
158 }
159 }
160
161 pub async fn pending_count(&self) -> usize {
163 self.pending.read().await.len()
164 }
165
166 pub async fn stats(&self) -> InvalidationStats {
168 let pending = self.pending.read().await;
169
170 let mut tables_pending = HashSet::new();
171 for inv in pending.values() {
172 tables_pending.extend(inv.changed_tables.iter().cloned());
173 }
174
175 InvalidationStats {
176 pending_groups: pending.len(),
177 pending_tables: tables_pending.len(),
178 }
179 }
180}
181
182#[derive(Debug, Clone, Default)]
184pub struct InvalidationStats {
185 pub pending_groups: usize,
187 pub pending_tables: usize,
189}
190
191#[cfg(test)]
192mod tests {
193 use super::*;
194
195 #[test]
196 fn test_invalidation_config_default() {
197 let config = InvalidationConfig::default();
198 assert_eq!(config.debounce_ms, 50);
199 assert_eq!(config.max_debounce_ms, 200);
200 assert!(config.coalesce_by_table);
201 }
202
203 #[tokio::test]
204 async fn test_invalidation_engine_creation() {
205 let subscription_manager = Arc::new(SubscriptionManager::new(50));
206 let engine = InvalidationEngine::new(subscription_manager, InvalidationConfig::default());
207
208 assert_eq!(engine.pending_count().await, 0);
209
210 let stats = engine.stats().await;
211 assert_eq!(stats.pending_groups, 0);
212 assert_eq!(stats.pending_tables, 0);
213 }
214
215 #[tokio::test]
216 async fn test_invalidation_flush_all() {
217 let subscription_manager = Arc::new(SubscriptionManager::new(50));
218 let engine = InvalidationEngine::new(subscription_manager, InvalidationConfig::default());
219
220 let flushed = engine.flush_all().await;
221 assert!(flushed.is_empty());
222 }
223}