Skip to main content

forge_runtime/realtime/
invalidation.rs

1use std::collections::{HashMap, HashSet};
2use std::sync::Arc;
3use std::time::Duration;
4
5use tokio::sync::{RwLock, mpsc};
6use tokio::time::Instant;
7
8use forge_core::realtime::{Change, SubscriptionId};
9
10use super::manager::SubscriptionManager;
11
12/// Configuration for the invalidation engine.
13///
14/// The invalidation engine uses a debounce algorithm to batch rapid changes
15/// into single re-executions. This prevents "thundering herd" scenarios where
16/// a batch insert of 1000 rows would trigger 1000 subscription refreshes.
17///
18/// The algorithm works as follows:
19/// 1. When a change arrives, record the subscription as pending
20/// 2. Wait for `debounce_ms` of silence (no new changes to that subscription)
21/// 3. If `max_debounce_ms` passes since the first change, flush anyway
22/// 4. If buffer exceeds `max_buffer_size`, flush immediately (memory protection)
23///
24/// This balances latency (users want updates fast) against efficiency (batching
25/// reduces database load). Default values target 50ms debounce with 200ms max
26/// wait, meaning updates arrive within 200ms worst-case.
27#[derive(Debug, Clone)]
28pub struct InvalidationConfig {
29    /// Debounce window in milliseconds.
30    /// After a change, wait this long for more changes before invalidating.
31    pub debounce_ms: u64,
32    /// Maximum debounce wait in milliseconds.
33    /// Even if changes keep arriving, invalidate after this duration.
34    pub max_debounce_ms: u64,
35    /// Whether to coalesce changes by table.
36    /// When true, multiple changes to the same table become a single invalidation.
37    pub coalesce_by_table: bool,
38    /// Maximum changes to buffer before forcing flush.
39    /// Prevents unbounded memory growth during high-throughput periods.
40    pub max_buffer_size: usize,
41}
42
43impl Default for InvalidationConfig {
44    fn default() -> Self {
45        Self {
46            debounce_ms: 50,
47            max_debounce_ms: 200,
48            coalesce_by_table: true,
49            max_buffer_size: 1000,
50        }
51    }
52}
53
54/// Pending invalidation for a subscription.
55#[derive(Debug)]
56struct PendingInvalidation {
57    /// Subscription ID.
58    #[allow(dead_code)]
59    subscription_id: SubscriptionId,
60    /// Tables that changed.
61    changed_tables: HashSet<String>,
62    /// When this invalidation was first queued.
63    first_change: Instant,
64    /// When the last change was received.
65    last_change: Instant,
66}
67
68/// Engine for determining which subscriptions need re-execution.
69pub struct InvalidationEngine {
70    subscription_manager: Arc<SubscriptionManager>,
71    #[allow(dead_code)]
72    config: InvalidationConfig,
73    /// Pending invalidations per subscription.
74    pending: Arc<RwLock<HashMap<SubscriptionId, PendingInvalidation>>>,
75    /// Channel for signaling invalidations.
76    #[allow(dead_code)]
77    invalidation_tx: mpsc::Sender<Vec<SubscriptionId>>,
78    #[allow(dead_code)]
79    invalidation_rx: Arc<RwLock<mpsc::Receiver<Vec<SubscriptionId>>>>,
80}
81
82impl InvalidationEngine {
83    /// Create a new invalidation engine.
84    pub fn new(subscription_manager: Arc<SubscriptionManager>, config: InvalidationConfig) -> Self {
85        let (invalidation_tx, invalidation_rx) = mpsc::channel(1024);
86
87        Self {
88            subscription_manager,
89            config,
90            pending: Arc::new(RwLock::new(HashMap::new())),
91            invalidation_tx,
92            invalidation_rx: Arc::new(RwLock::new(invalidation_rx)),
93        }
94    }
95
96    /// Process a database change.
97    pub async fn process_change(&self, change: Change) {
98        // Find affected subscriptions
99        let affected = self
100            .subscription_manager
101            .find_affected_subscriptions(&change)
102            .await;
103
104        if affected.is_empty() {
105            return;
106        }
107
108        tracing::debug!(
109            table = %change.table,
110            affected_count = affected.len(),
111            "Found affected subscriptions for change"
112        );
113
114        let now = Instant::now();
115        let mut pending = self.pending.write().await;
116
117        for sub_id in affected {
118            let entry = pending
119                .entry(sub_id)
120                .or_insert_with(|| PendingInvalidation {
121                    subscription_id: sub_id,
122                    changed_tables: HashSet::new(),
123                    first_change: now,
124                    last_change: now,
125                });
126
127            entry.changed_tables.insert(change.table.clone());
128            entry.last_change = now;
129        }
130
131        // Check if we should flush due to buffer size
132        if pending.len() >= self.config.max_buffer_size {
133            drop(pending);
134            self.flush_all().await;
135        }
136    }
137
138    /// Check for subscriptions that need to be invalidated.
139    pub async fn check_pending(&self) -> Vec<SubscriptionId> {
140        let now = Instant::now();
141        let debounce = Duration::from_millis(self.config.debounce_ms);
142        let max_debounce = Duration::from_millis(self.config.max_debounce_ms);
143
144        let mut pending = self.pending.write().await;
145        let mut ready = Vec::new();
146
147        pending.retain(|_, inv| {
148            let since_last = now.duration_since(inv.last_change);
149            let since_first = now.duration_since(inv.first_change);
150
151            // Ready if debounce window passed or max wait exceeded
152            if since_last >= debounce || since_first >= max_debounce {
153                ready.push(inv.subscription_id);
154                false // Remove from pending
155            } else {
156                true // Keep in pending
157            }
158        });
159
160        ready
161    }
162
163    /// Flush all pending invalidations immediately.
164    pub async fn flush_all(&self) -> Vec<SubscriptionId> {
165        let mut pending = self.pending.write().await;
166        let ready: Vec<SubscriptionId> = pending.keys().copied().collect();
167        pending.clear();
168        ready
169    }
170
171    /// Get the invalidation receiver for consuming invalidation events.
172    pub async fn take_receiver(&self) -> Option<mpsc::Receiver<Vec<SubscriptionId>>> {
173        let _rx_guard = self.invalidation_rx.write().await;
174        // We can only take once, so this is a simple swap
175        // In practice, you'd use a different pattern
176        None // Simplified - receiver is accessed via run loop
177    }
178
179    /// Run the invalidation check loop.
180    pub async fn run(&self) {
181        let check_interval = Duration::from_millis(self.config.debounce_ms / 2);
182
183        loop {
184            tokio::time::sleep(check_interval).await;
185
186            let ready = self.check_pending().await;
187            if !ready.is_empty() && self.invalidation_tx.send(ready).await.is_err() {
188                // Receiver dropped, stop the loop
189                break;
190            }
191        }
192    }
193
194    /// Get pending count for monitoring.
195    pub async fn pending_count(&self) -> usize {
196        self.pending.read().await.len()
197    }
198
199    /// Get statistics about the invalidation engine.
200    pub async fn stats(&self) -> InvalidationStats {
201        let pending = self.pending.read().await;
202
203        let mut tables_pending = HashSet::new();
204        for inv in pending.values() {
205            tables_pending.extend(inv.changed_tables.iter().cloned());
206        }
207
208        InvalidationStats {
209            pending_subscriptions: pending.len(),
210            pending_tables: tables_pending.len(),
211        }
212    }
213}
214
215/// Statistics about the invalidation engine.
216#[derive(Debug, Clone, Default)]
217pub struct InvalidationStats {
218    /// Number of subscriptions pending invalidation.
219    pub pending_subscriptions: usize,
220    /// Number of unique tables with pending changes.
221    pub pending_tables: usize,
222}
223
224#[cfg(test)]
225mod tests {
226    use super::*;
227
228    #[test]
229    fn test_invalidation_config_default() {
230        let config = InvalidationConfig::default();
231        assert_eq!(config.debounce_ms, 50);
232        assert_eq!(config.max_debounce_ms, 200);
233        assert!(config.coalesce_by_table);
234    }
235
236    #[tokio::test]
237    async fn test_invalidation_engine_creation() {
238        let subscription_manager = Arc::new(SubscriptionManager::new(50));
239        let engine = InvalidationEngine::new(subscription_manager, InvalidationConfig::default());
240
241        assert_eq!(engine.pending_count().await, 0);
242
243        let stats = engine.stats().await;
244        assert_eq!(stats.pending_subscriptions, 0);
245        assert_eq!(stats.pending_tables, 0);
246    }
247
248    #[tokio::test]
249    async fn test_invalidation_flush_all() {
250        let subscription_manager = Arc::new(SubscriptionManager::new(50));
251        let engine = InvalidationEngine::new(subscription_manager, InvalidationConfig::default());
252
253        // Flush on empty should return empty
254        let flushed = engine.flush_all().await;
255        assert!(flushed.is_empty());
256    }
257}