forge_runtime/realtime/
invalidation.rs1use std::collections::{HashMap, HashSet};
2use std::sync::Arc;
3use std::time::Duration;
4
5use tokio::sync::RwLock;
6use tokio::time::Instant;
7
8use forge_core::realtime::{Change, QueryGroupId};
9
10use super::manager::SubscriptionManager;
11
12#[derive(Debug, Clone)]
18pub struct InvalidationConfig {
19 pub debounce_ms: u64,
21 pub max_debounce_ms: u64,
23 pub coalesce_by_table: bool,
28 pub max_buffer_size: usize,
30}
31
32impl Default for InvalidationConfig {
33 fn default() -> Self {
34 Self {
35 debounce_ms: 50,
36 max_debounce_ms: 200,
37 coalesce_by_table: true,
38 max_buffer_size: 1000,
39 }
40 }
41}
42
43#[derive(Debug)]
45struct PendingInvalidation {
46 group_id: QueryGroupId,
47 changed_tables: HashSet<String>,
48 first_change: Instant,
49 last_change: Instant,
50}
51
52pub struct InvalidationEngine {
55 subscription_manager: Arc<SubscriptionManager>,
56 config: InvalidationConfig,
57 pending: Arc<RwLock<HashMap<QueryGroupId, PendingInvalidation>>>,
59}
60
61impl InvalidationEngine {
62 pub fn new(subscription_manager: Arc<SubscriptionManager>, config: InvalidationConfig) -> Self {
64 Self {
65 subscription_manager,
66 config,
67 pending: Arc::new(RwLock::new(HashMap::new())),
68 }
69 }
70
71 pub async fn process_change(&self, change: Change) {
73 let affected = self.subscription_manager.find_affected_groups(&change);
74
75 if affected.is_empty() {
76 return;
77 }
78
79 tracing::debug!(
80 table = %change.table,
81 affected_groups = affected.len(),
82 "Found affected groups for change"
83 );
84
85 let now = Instant::now();
86 let mut pending = self.pending.write().await;
87
88 for group_id in affected {
89 if self.config.coalesce_by_table {
90 let entry = pending
94 .entry(group_id)
95 .or_insert_with(|| PendingInvalidation {
96 group_id,
97 changed_tables: HashSet::new(),
98 first_change: now,
99 last_change: now,
100 });
101
102 entry.changed_tables.insert(change.table.clone());
103 entry.last_change = now;
104 } else {
105 pending
108 .entry(group_id)
109 .or_insert_with(|| PendingInvalidation {
110 group_id,
111 changed_tables: HashSet::from([change.table.clone()]),
112 first_change: now,
113 last_change: now,
114 });
115 }
116 }
117
118 if pending.len() >= self.config.max_buffer_size {
119 drop(pending);
120 self.flush_all().await;
121 }
122 }
123
124 pub async fn check_pending(&self) -> Vec<QueryGroupId> {
126 let now = Instant::now();
127 let debounce = Duration::from_millis(self.config.debounce_ms);
128 let max_debounce = Duration::from_millis(self.config.max_debounce_ms);
129
130 let mut pending = self.pending.write().await;
131 let mut ready = Vec::new();
132
133 pending.retain(|_, inv| {
134 let since_last = now.duration_since(inv.last_change);
135 let since_first = now.duration_since(inv.first_change);
136
137 if since_last >= debounce || since_first >= max_debounce {
138 ready.push(inv.group_id);
139 false
140 } else {
141 true
142 }
143 });
144
145 ready
146 }
147
148 pub async fn flush_all(&self) -> Vec<QueryGroupId> {
150 let mut pending = self.pending.write().await;
151 let ready: Vec<QueryGroupId> = pending.keys().copied().collect();
152 pending.clear();
153 ready
154 }
155
156 pub async fn pending_count(&self) -> usize {
158 self.pending.read().await.len()
159 }
160
161 pub async fn stats(&self) -> InvalidationStats {
163 let pending = self.pending.read().await;
164
165 let mut tables_pending = HashSet::new();
166 for inv in pending.values() {
167 tables_pending.extend(inv.changed_tables.iter().cloned());
168 }
169
170 InvalidationStats {
171 pending_groups: pending.len(),
172 pending_tables: tables_pending.len(),
173 }
174 }
175}
176
177#[derive(Debug, Clone, Default)]
179pub struct InvalidationStats {
180 pub pending_groups: usize,
182 pub pending_tables: usize,
184}
185
186#[cfg(test)]
187mod tests {
188 use super::*;
189
190 #[test]
191 fn test_invalidation_config_default() {
192 let config = InvalidationConfig::default();
193 assert_eq!(config.debounce_ms, 50);
194 assert_eq!(config.max_debounce_ms, 200);
195 assert!(config.coalesce_by_table);
196 }
197
198 #[tokio::test]
199 async fn test_invalidation_engine_creation() {
200 let subscription_manager = Arc::new(SubscriptionManager::new(50));
201 let engine = InvalidationEngine::new(subscription_manager, InvalidationConfig::default());
202
203 assert_eq!(engine.pending_count().await, 0);
204
205 let stats = engine.stats().await;
206 assert_eq!(stats.pending_groups, 0);
207 assert_eq!(stats.pending_tables, 0);
208 }
209
210 #[tokio::test]
211 async fn test_invalidation_flush_all() {
212 let subscription_manager = Arc::new(SubscriptionManager::new(50));
213 let engine = InvalidationEngine::new(subscription_manager, InvalidationConfig::default());
214
215 let flushed = engine.flush_all().await;
216 assert!(flushed.is_empty());
217 }
218
219 #[tokio::test]
220 async fn test_coalesce_by_table_enabled() {
221 let subscription_manager = Arc::new(SubscriptionManager::new(50));
222 let config = InvalidationConfig {
223 coalesce_by_table: true,
224 debounce_ms: 0,
225 ..Default::default()
226 };
227 let engine = InvalidationEngine::new(subscription_manager, config);
228
229 let change = Change::new("users", forge_core::realtime::ChangeOperation::Insert);
231 engine.process_change(change).await;
232 assert_eq!(engine.pending_count().await, 0);
233 }
234
235 #[tokio::test]
236 async fn test_coalesce_by_table_disabled() {
237 let subscription_manager = Arc::new(SubscriptionManager::new(50));
238 let config = InvalidationConfig {
239 coalesce_by_table: false,
240 debounce_ms: 0,
241 ..Default::default()
242 };
243 let engine = InvalidationEngine::new(subscription_manager, config);
244
245 let change = Change::new("users", forge_core::realtime::ChangeOperation::Insert);
247 engine.process_change(change).await;
248 assert_eq!(engine.pending_count().await, 0);
249 }
250}