forge_runtime/realtime/
invalidation.rs1use std::collections::{HashMap, HashSet};
2use std::sync::Arc;
3use std::time::Duration;
4
5use tokio::sync::{RwLock, mpsc};
6use tokio::time::Instant;
7
8use forge_core::realtime::{Change, SubscriptionId};
9
10use super::manager::SubscriptionManager;
11
12#[derive(Debug, Clone)]
28pub struct InvalidationConfig {
29 pub debounce_ms: u64,
32 pub max_debounce_ms: u64,
35 pub coalesce_by_table: bool,
38 pub max_buffer_size: usize,
41}
42
43impl Default for InvalidationConfig {
44 fn default() -> Self {
45 Self {
46 debounce_ms: 50,
47 max_debounce_ms: 200,
48 coalesce_by_table: true,
49 max_buffer_size: 1000,
50 }
51 }
52}
53
54#[derive(Debug)]
56struct PendingInvalidation {
57 #[allow(dead_code)]
59 subscription_id: SubscriptionId,
60 changed_tables: HashSet<String>,
62 first_change: Instant,
64 last_change: Instant,
66}
67
68pub struct InvalidationEngine {
70 subscription_manager: Arc<SubscriptionManager>,
71 #[allow(dead_code)]
72 config: InvalidationConfig,
73 pending: Arc<RwLock<HashMap<SubscriptionId, PendingInvalidation>>>,
75 #[allow(dead_code)]
77 invalidation_tx: mpsc::Sender<Vec<SubscriptionId>>,
78 #[allow(dead_code)]
79 invalidation_rx: Arc<RwLock<mpsc::Receiver<Vec<SubscriptionId>>>>,
80}
81
82impl InvalidationEngine {
83 pub fn new(subscription_manager: Arc<SubscriptionManager>, config: InvalidationConfig) -> Self {
85 let (invalidation_tx, invalidation_rx) = mpsc::channel(1024);
86
87 Self {
88 subscription_manager,
89 config,
90 pending: Arc::new(RwLock::new(HashMap::new())),
91 invalidation_tx,
92 invalidation_rx: Arc::new(RwLock::new(invalidation_rx)),
93 }
94 }
95
96 pub async fn process_change(&self, change: Change) {
98 let affected = self
100 .subscription_manager
101 .find_affected_subscriptions(&change)
102 .await;
103
104 if affected.is_empty() {
105 return;
106 }
107
108 tracing::debug!(
109 table = %change.table,
110 affected_count = affected.len(),
111 "Found affected subscriptions for change"
112 );
113
114 let now = Instant::now();
115 let mut pending = self.pending.write().await;
116
117 for sub_id in affected {
118 let entry = pending
119 .entry(sub_id)
120 .or_insert_with(|| PendingInvalidation {
121 subscription_id: sub_id,
122 changed_tables: HashSet::new(),
123 first_change: now,
124 last_change: now,
125 });
126
127 entry.changed_tables.insert(change.table.clone());
128 entry.last_change = now;
129 }
130
131 if pending.len() >= self.config.max_buffer_size {
133 drop(pending);
134 self.flush_all().await;
135 }
136 }
137
138 pub async fn check_pending(&self) -> Vec<SubscriptionId> {
140 let now = Instant::now();
141 let debounce = Duration::from_millis(self.config.debounce_ms);
142 let max_debounce = Duration::from_millis(self.config.max_debounce_ms);
143
144 let mut pending = self.pending.write().await;
145 let mut ready = Vec::new();
146
147 pending.retain(|_, inv| {
148 let since_last = now.duration_since(inv.last_change);
149 let since_first = now.duration_since(inv.first_change);
150
151 if since_last >= debounce || since_first >= max_debounce {
153 ready.push(inv.subscription_id);
154 false } else {
156 true }
158 });
159
160 ready
161 }
162
163 pub async fn flush_all(&self) -> Vec<SubscriptionId> {
165 let mut pending = self.pending.write().await;
166 let ready: Vec<SubscriptionId> = pending.keys().copied().collect();
167 pending.clear();
168 ready
169 }
170
171 pub async fn take_receiver(&self) -> Option<mpsc::Receiver<Vec<SubscriptionId>>> {
173 let _rx_guard = self.invalidation_rx.write().await;
174 None }
178
179 pub async fn run(&self) {
181 let check_interval = Duration::from_millis(self.config.debounce_ms / 2);
182
183 loop {
184 tokio::time::sleep(check_interval).await;
185
186 let ready = self.check_pending().await;
187 if !ready.is_empty() && self.invalidation_tx.send(ready).await.is_err() {
188 break;
190 }
191 }
192 }
193
194 pub async fn pending_count(&self) -> usize {
196 self.pending.read().await.len()
197 }
198
199 pub async fn stats(&self) -> InvalidationStats {
201 let pending = self.pending.read().await;
202
203 let mut tables_pending = HashSet::new();
204 for inv in pending.values() {
205 tables_pending.extend(inv.changed_tables.iter().cloned());
206 }
207
208 InvalidationStats {
209 pending_subscriptions: pending.len(),
210 pending_tables: tables_pending.len(),
211 }
212 }
213}
214
215#[derive(Debug, Clone, Default)]
217pub struct InvalidationStats {
218 pub pending_subscriptions: usize,
220 pub pending_tables: usize,
222}
223
224#[cfg(test)]
225mod tests {
226 use super::*;
227
228 #[test]
229 fn test_invalidation_config_default() {
230 let config = InvalidationConfig::default();
231 assert_eq!(config.debounce_ms, 50);
232 assert_eq!(config.max_debounce_ms, 200);
233 assert!(config.coalesce_by_table);
234 }
235
236 #[tokio::test]
237 async fn test_invalidation_engine_creation() {
238 let subscription_manager = Arc::new(SubscriptionManager::new(50));
239 let engine = InvalidationEngine::new(subscription_manager, InvalidationConfig::default());
240
241 assert_eq!(engine.pending_count().await, 0);
242
243 let stats = engine.stats().await;
244 assert_eq!(stats.pending_subscriptions, 0);
245 assert_eq!(stats.pending_tables, 0);
246 }
247
248 #[tokio::test]
249 async fn test_invalidation_flush_all() {
250 let subscription_manager = Arc::new(SubscriptionManager::new(50));
251 let engine = InvalidationEngine::new(subscription_manager, InvalidationConfig::default());
252
253 let flushed = engine.flush_all().await;
255 assert!(flushed.is_empty());
256 }
257}