forge_runtime/realtime/
invalidation.rs

1use std::collections::{HashMap, HashSet};
2use std::sync::Arc;
3use std::time::Duration;
4
5use tokio::sync::{RwLock, mpsc};
6use tokio::time::Instant;
7
8use forge_core::realtime::{Change, SubscriptionId};
9
10use super::manager::SubscriptionManager;
11
12/// Configuration for the invalidation engine.
13#[derive(Debug, Clone)]
14pub struct InvalidationConfig {
15    /// Debounce window in milliseconds.
16    pub debounce_ms: u64,
17    /// Maximum debounce wait in milliseconds.
18    pub max_debounce_ms: u64,
19    /// Whether to coalesce changes by table.
20    pub coalesce_by_table: bool,
21    /// Maximum changes to buffer before forcing flush.
22    pub max_buffer_size: usize,
23}
24
25impl Default for InvalidationConfig {
26    fn default() -> Self {
27        Self {
28            debounce_ms: 50,
29            max_debounce_ms: 200,
30            coalesce_by_table: true,
31            max_buffer_size: 1000,
32        }
33    }
34}
35
36/// Pending invalidation for a subscription.
37#[derive(Debug)]
38struct PendingInvalidation {
39    /// Subscription ID.
40    #[allow(dead_code)]
41    subscription_id: SubscriptionId,
42    /// Tables that changed.
43    changed_tables: HashSet<String>,
44    /// When this invalidation was first queued.
45    first_change: Instant,
46    /// When the last change was received.
47    last_change: Instant,
48}
49
50/// Engine for determining which subscriptions need re-execution.
51pub struct InvalidationEngine {
52    subscription_manager: Arc<SubscriptionManager>,
53    #[allow(dead_code)]
54    config: InvalidationConfig,
55    /// Pending invalidations per subscription.
56    pending: Arc<RwLock<HashMap<SubscriptionId, PendingInvalidation>>>,
57    /// Channel for signaling invalidations.
58    #[allow(dead_code)]
59    invalidation_tx: mpsc::Sender<Vec<SubscriptionId>>,
60    #[allow(dead_code)]
61    invalidation_rx: Arc<RwLock<mpsc::Receiver<Vec<SubscriptionId>>>>,
62}
63
64impl InvalidationEngine {
65    /// Create a new invalidation engine.
66    pub fn new(subscription_manager: Arc<SubscriptionManager>, config: InvalidationConfig) -> Self {
67        let (invalidation_tx, invalidation_rx) = mpsc::channel(1024);
68
69        Self {
70            subscription_manager,
71            config,
72            pending: Arc::new(RwLock::new(HashMap::new())),
73            invalidation_tx,
74            invalidation_rx: Arc::new(RwLock::new(invalidation_rx)),
75        }
76    }
77
78    /// Process a database change.
79    pub async fn process_change(&self, change: Change) {
80        // Find affected subscriptions
81        let affected = self
82            .subscription_manager
83            .find_affected_subscriptions(&change)
84            .await;
85
86        if affected.is_empty() {
87            return;
88        }
89
90        tracing::debug!(
91            table = %change.table,
92            affected_count = affected.len(),
93            "Found affected subscriptions for change"
94        );
95
96        let now = Instant::now();
97        let mut pending = self.pending.write().await;
98
99        for sub_id in affected {
100            let entry = pending
101                .entry(sub_id)
102                .or_insert_with(|| PendingInvalidation {
103                    subscription_id: sub_id,
104                    changed_tables: HashSet::new(),
105                    first_change: now,
106                    last_change: now,
107                });
108
109            entry.changed_tables.insert(change.table.clone());
110            entry.last_change = now;
111        }
112
113        // Check if we should flush due to buffer size
114        if pending.len() >= self.config.max_buffer_size {
115            drop(pending);
116            self.flush_all().await;
117        }
118    }
119
120    /// Check for subscriptions that need to be invalidated.
121    pub async fn check_pending(&self) -> Vec<SubscriptionId> {
122        let now = Instant::now();
123        let debounce = Duration::from_millis(self.config.debounce_ms);
124        let max_debounce = Duration::from_millis(self.config.max_debounce_ms);
125
126        let mut pending = self.pending.write().await;
127        let mut ready = Vec::new();
128
129        pending.retain(|_, inv| {
130            let since_last = now.duration_since(inv.last_change);
131            let since_first = now.duration_since(inv.first_change);
132
133            // Ready if debounce window passed or max wait exceeded
134            if since_last >= debounce || since_first >= max_debounce {
135                ready.push(inv.subscription_id);
136                false // Remove from pending
137            } else {
138                true // Keep in pending
139            }
140        });
141
142        ready
143    }
144
145    /// Flush all pending invalidations immediately.
146    pub async fn flush_all(&self) -> Vec<SubscriptionId> {
147        let mut pending = self.pending.write().await;
148        let ready: Vec<SubscriptionId> = pending.keys().copied().collect();
149        pending.clear();
150        ready
151    }
152
153    /// Get the invalidation receiver for consuming invalidation events.
154    pub async fn take_receiver(&self) -> Option<mpsc::Receiver<Vec<SubscriptionId>>> {
155        let _rx_guard = self.invalidation_rx.write().await;
156        // We can only take once, so this is a simple swap
157        // In practice, you'd use a different pattern
158        None // Simplified - receiver is accessed via run loop
159    }
160
161    /// Run the invalidation check loop.
162    pub async fn run(&self) {
163        let check_interval = Duration::from_millis(self.config.debounce_ms / 2);
164
165        loop {
166            tokio::time::sleep(check_interval).await;
167
168            let ready = self.check_pending().await;
169            if !ready.is_empty() && self.invalidation_tx.send(ready).await.is_err() {
170                // Receiver dropped, stop the loop
171                break;
172            }
173        }
174    }
175
176    /// Get pending count for monitoring.
177    pub async fn pending_count(&self) -> usize {
178        self.pending.read().await.len()
179    }
180
181    /// Get statistics about the invalidation engine.
182    pub async fn stats(&self) -> InvalidationStats {
183        let pending = self.pending.read().await;
184
185        let mut tables_pending = HashSet::new();
186        for inv in pending.values() {
187            tables_pending.extend(inv.changed_tables.iter().cloned());
188        }
189
190        InvalidationStats {
191            pending_subscriptions: pending.len(),
192            pending_tables: tables_pending.len(),
193        }
194    }
195}
196
197/// Statistics about the invalidation engine.
198#[derive(Debug, Clone, Default)]
199pub struct InvalidationStats {
200    /// Number of subscriptions pending invalidation.
201    pub pending_subscriptions: usize,
202    /// Number of unique tables with pending changes.
203    pub pending_tables: usize,
204}
205
206/// Coalesces multiple changes for the same table.
207#[allow(dead_code)]
208pub struct ChangeCoalescer {
209    /// Changes grouped by table.
210    changes_by_table: HashMap<String, Vec<Change>>,
211}
212
213#[allow(dead_code)]
214impl ChangeCoalescer {
215    /// Create a new change coalescer.
216    pub fn new() -> Self {
217        Self {
218            changes_by_table: HashMap::new(),
219        }
220    }
221
222    /// Add a change.
223    pub fn add(&mut self, change: Change) {
224        self.changes_by_table
225            .entry(change.table.clone())
226            .or_default()
227            .push(change);
228    }
229
230    /// Get coalesced tables that had changes.
231    pub fn tables(&self) -> impl Iterator<Item = &str> {
232        self.changes_by_table.keys().map(|s| s.as_str())
233    }
234
235    /// Drain all changes.
236    pub fn drain(&mut self) -> HashMap<String, Vec<Change>> {
237        std::mem::take(&mut self.changes_by_table)
238    }
239
240    /// Check if empty.
241    pub fn is_empty(&self) -> bool {
242        self.changes_by_table.is_empty()
243    }
244
245    /// Count total changes.
246    pub fn len(&self) -> usize {
247        self.changes_by_table.values().map(|v| v.len()).sum()
248    }
249}
250
251impl Default for ChangeCoalescer {
252    fn default() -> Self {
253        Self::new()
254    }
255}
256
257#[cfg(test)]
258mod tests {
259    use super::*;
260    use forge_core::realtime::ChangeOperation;
261
262    #[test]
263    fn test_invalidation_config_default() {
264        let config = InvalidationConfig::default();
265        assert_eq!(config.debounce_ms, 50);
266        assert_eq!(config.max_debounce_ms, 200);
267        assert!(config.coalesce_by_table);
268    }
269
270    #[test]
271    fn test_change_coalescer() {
272        let mut coalescer = ChangeCoalescer::new();
273        assert!(coalescer.is_empty());
274
275        coalescer.add(Change::new("projects".to_string(), ChangeOperation::Insert));
276        coalescer.add(Change::new("projects".to_string(), ChangeOperation::Update));
277        coalescer.add(Change::new("users".to_string(), ChangeOperation::Insert));
278
279        assert_eq!(coalescer.len(), 3);
280
281        let tables: Vec<&str> = coalescer.tables().collect();
282        assert!(tables.contains(&"projects"));
283        assert!(tables.contains(&"users"));
284    }
285
286    #[test]
287    fn test_change_coalescer_drain() {
288        let mut coalescer = ChangeCoalescer::new();
289        coalescer.add(Change::new("projects".to_string(), ChangeOperation::Insert));
290        coalescer.add(Change::new("users".to_string(), ChangeOperation::Delete));
291
292        let drained = coalescer.drain();
293        assert!(coalescer.is_empty());
294        assert_eq!(drained.len(), 2);
295        assert!(drained.contains_key("projects"));
296        assert!(drained.contains_key("users"));
297    }
298
299    #[tokio::test]
300    async fn test_invalidation_engine_creation() {
301        let subscription_manager = Arc::new(SubscriptionManager::new(50));
302        let engine = InvalidationEngine::new(subscription_manager, InvalidationConfig::default());
303
304        assert_eq!(engine.pending_count().await, 0);
305
306        let stats = engine.stats().await;
307        assert_eq!(stats.pending_subscriptions, 0);
308        assert_eq!(stats.pending_tables, 0);
309    }
310
311    #[tokio::test]
312    async fn test_invalidation_flush_all() {
313        let subscription_manager = Arc::new(SubscriptionManager::new(50));
314        let engine = InvalidationEngine::new(subscription_manager, InvalidationConfig::default());
315
316        // Flush on empty should return empty
317        let flushed = engine.flush_all().await;
318        assert!(flushed.is_empty());
319    }
320}