Skip to main content

forge_runtime/realtime/
invalidation.rs

1use std::collections::{HashMap, HashSet};
2use std::sync::Arc;
3use std::time::Duration;
4
5use tokio::sync::{RwLock, mpsc};
6use tokio::time::Instant;
7
8use forge_core::realtime::{Change, QueryGroupId};
9
10use super::manager::SubscriptionManager;
11
12/// Configuration for the invalidation engine.
13///
14/// Uses debouncing to batch rapid changes into single re-executions per group.
15/// This prevents "thundering herd" scenarios where a batch insert triggers
16/// N subscription refreshes.
17#[derive(Debug, Clone)]
18pub struct InvalidationConfig {
19    /// Debounce window in milliseconds.
20    pub debounce_ms: u64,
21    /// Maximum debounce wait in milliseconds.
22    pub max_debounce_ms: u64,
23    /// Whether to coalesce changes by table.
24    pub coalesce_by_table: bool,
25    /// Maximum changes to buffer before forcing flush.
26    pub max_buffer_size: usize,
27}
28
29impl Default for InvalidationConfig {
30    fn default() -> Self {
31        Self {
32            debounce_ms: 50,
33            max_debounce_ms: 200,
34            coalesce_by_table: true,
35            max_buffer_size: 1000,
36        }
37    }
38}
39
40/// Pending invalidation for a query group.
41#[derive(Debug)]
42struct PendingInvalidation {
43    #[allow(dead_code)]
44    group_id: QueryGroupId,
45    changed_tables: HashSet<String>,
46    first_change: Instant,
47    last_change: Instant,
48}
49
50/// Engine for determining which query groups need re-execution.
51/// Operates on groups (not individual subscriptions) for O(groups) cost.
52pub struct InvalidationEngine {
53    subscription_manager: Arc<SubscriptionManager>,
54    #[allow(dead_code)]
55    config: InvalidationConfig,
56    /// Pending invalidations per query group.
57    pending: Arc<RwLock<HashMap<QueryGroupId, PendingInvalidation>>>,
58    #[allow(dead_code)]
59    invalidation_tx: mpsc::Sender<Vec<QueryGroupId>>,
60    #[allow(dead_code)]
61    invalidation_rx: Arc<RwLock<mpsc::Receiver<Vec<QueryGroupId>>>>,
62}
63
64impl InvalidationEngine {
65    /// Create a new invalidation engine.
66    pub fn new(subscription_manager: Arc<SubscriptionManager>, config: InvalidationConfig) -> Self {
67        let (invalidation_tx, invalidation_rx) = mpsc::channel(1024);
68
69        Self {
70            subscription_manager,
71            config,
72            pending: Arc::new(RwLock::new(HashMap::new())),
73            invalidation_tx,
74            invalidation_rx: Arc::new(RwLock::new(invalidation_rx)),
75        }
76    }
77
78    /// Process a database change. Finds affected groups (not subscriptions).
79    pub async fn process_change(&self, change: Change) {
80        let affected = self.subscription_manager.find_affected_groups(&change);
81
82        if affected.is_empty() {
83            return;
84        }
85
86        tracing::debug!(
87            table = %change.table,
88            affected_groups = affected.len(),
89            "Found affected groups for change"
90        );
91
92        let now = Instant::now();
93        let mut pending = self.pending.write().await;
94
95        for group_id in affected {
96            let entry = pending
97                .entry(group_id)
98                .or_insert_with(|| PendingInvalidation {
99                    group_id,
100                    changed_tables: HashSet::new(),
101                    first_change: now,
102                    last_change: now,
103                });
104
105            entry.changed_tables.insert(change.table.clone());
106            entry.last_change = now;
107        }
108
109        if pending.len() >= self.config.max_buffer_size {
110            drop(pending);
111            self.flush_all().await;
112        }
113    }
114
115    /// Check for groups that need to be invalidated (debounce expired).
116    pub async fn check_pending(&self) -> Vec<QueryGroupId> {
117        let now = Instant::now();
118        let debounce = Duration::from_millis(self.config.debounce_ms);
119        let max_debounce = Duration::from_millis(self.config.max_debounce_ms);
120
121        let mut pending = self.pending.write().await;
122        let mut ready = Vec::new();
123
124        pending.retain(|_, inv| {
125            let since_last = now.duration_since(inv.last_change);
126            let since_first = now.duration_since(inv.first_change);
127
128            if since_last >= debounce || since_first >= max_debounce {
129                ready.push(inv.group_id);
130                false
131            } else {
132                true
133            }
134        });
135
136        ready
137    }
138
139    /// Flush all pending invalidations immediately.
140    pub async fn flush_all(&self) -> Vec<QueryGroupId> {
141        let mut pending = self.pending.write().await;
142        let ready: Vec<QueryGroupId> = pending.keys().copied().collect();
143        pending.clear();
144        ready
145    }
146
147    /// Run the invalidation check loop.
148    pub async fn run(&self) {
149        let check_interval = Duration::from_millis(self.config.debounce_ms / 2);
150
151        loop {
152            tokio::time::sleep(check_interval).await;
153
154            let ready = self.check_pending().await;
155            if !ready.is_empty() && self.invalidation_tx.send(ready).await.is_err() {
156                break;
157            }
158        }
159    }
160
161    /// Get pending count for monitoring.
162    pub async fn pending_count(&self) -> usize {
163        self.pending.read().await.len()
164    }
165
166    /// Get statistics about the invalidation engine.
167    pub async fn stats(&self) -> InvalidationStats {
168        let pending = self.pending.read().await;
169
170        let mut tables_pending = HashSet::new();
171        for inv in pending.values() {
172            tables_pending.extend(inv.changed_tables.iter().cloned());
173        }
174
175        InvalidationStats {
176            pending_groups: pending.len(),
177            pending_tables: tables_pending.len(),
178        }
179    }
180}
181
182/// Statistics about the invalidation engine.
183#[derive(Debug, Clone, Default)]
184pub struct InvalidationStats {
185    /// Number of groups pending invalidation.
186    pub pending_groups: usize,
187    /// Number of unique tables with pending changes.
188    pub pending_tables: usize,
189}
190
191#[cfg(test)]
192mod tests {
193    use super::*;
194
195    #[test]
196    fn test_invalidation_config_default() {
197        let config = InvalidationConfig::default();
198        assert_eq!(config.debounce_ms, 50);
199        assert_eq!(config.max_debounce_ms, 200);
200        assert!(config.coalesce_by_table);
201    }
202
203    #[tokio::test]
204    async fn test_invalidation_engine_creation() {
205        let subscription_manager = Arc::new(SubscriptionManager::new(50));
206        let engine = InvalidationEngine::new(subscription_manager, InvalidationConfig::default());
207
208        assert_eq!(engine.pending_count().await, 0);
209
210        let stats = engine.stats().await;
211        assert_eq!(stats.pending_groups, 0);
212        assert_eq!(stats.pending_tables, 0);
213    }
214
215    #[tokio::test]
216    async fn test_invalidation_flush_all() {
217        let subscription_manager = Arc::new(SubscriptionManager::new(50));
218        let engine = InvalidationEngine::new(subscription_manager, InvalidationConfig::default());
219
220        let flushed = engine.flush_all().await;
221        assert!(flushed.is_empty());
222    }
223}