Skip to main content

forge_runtime/realtime/
adaptive.rs

1use std::collections::{HashMap, HashSet};
2use std::time::Duration;
3
4use tokio::sync::RwLock;
5
6use forge_core::realtime::TrackingMode;
7
8/// Configuration for adaptive tracking.
9#[derive(Debug, Clone)]
10pub struct AdaptiveTrackingConfig {
11    /// Threshold to switch from table to row tracking.
12    pub row_threshold: usize,
13    /// Threshold to switch from row to table tracking.
14    pub table_threshold: usize,
15    /// Maximum number of rows to track per table.
16    pub max_tracked_rows: usize,
17    /// How often to re-evaluate tracking mode.
18    pub evaluation_interval: Duration,
19}
20
21impl Default for AdaptiveTrackingConfig {
22    fn default() -> Self {
23        Self {
24            row_threshold: 100,
25            table_threshold: 50,
26            max_tracked_rows: 10_000,
27            evaluation_interval: Duration::from_secs(60),
28        }
29    }
30}
31
32/// Adaptive tracker that switches between table and row-level tracking.
33///
34/// When few subscriptions exist for a table, track at row level.
35/// When many subscriptions exist, switch to table level.
36pub struct AdaptiveTracker {
37    config: AdaptiveTrackingConfig,
38    /// Current tracking mode per table.
39    table_modes: RwLock<HashMap<String, TrackingMode>>,
40    /// Rows being tracked per table.
41    tracked_rows: RwLock<HashMap<String, HashSet<String>>>,
42    /// Subscription count per table.
43    subscription_counts: RwLock<HashMap<String, usize>>,
44    /// Row subscription count per table.
45    row_subscription_counts: RwLock<HashMap<String, usize>>,
46}
47
48impl AdaptiveTracker {
49    /// Create a new adaptive tracker.
50    pub fn new(config: AdaptiveTrackingConfig) -> Self {
51        Self {
52            config,
53            table_modes: RwLock::new(HashMap::new()),
54            tracked_rows: RwLock::new(HashMap::new()),
55            subscription_counts: RwLock::new(HashMap::new()),
56            row_subscription_counts: RwLock::new(HashMap::new()),
57        }
58    }
59
60    /// Record a subscription for a table.
61    pub async fn record_subscription(&self, table: &str, row_ids: Option<Vec<String>>) {
62        // Update subscription counts
63        {
64            let mut counts = self.subscription_counts.write().await;
65            *counts.entry(table.to_string()).or_insert(0) += 1;
66        }
67
68        // Track specific rows if provided
69        if let Some(ids) = row_ids {
70            let mut tracked = self.tracked_rows.write().await;
71            let rows = tracked.entry(table.to_string()).or_default();
72            let mut row_counts = self.row_subscription_counts.write().await;
73
74            for id in ids {
75                if rows.len() < self.config.max_tracked_rows {
76                    rows.insert(id);
77                    *row_counts.entry(table.to_string()).or_insert(0) += 1;
78                }
79            }
80        }
81
82        // Evaluate if mode should change
83        self.evaluate_table(table).await;
84    }
85
86    /// Remove a subscription.
87    pub async fn remove_subscription(&self, table: &str, row_ids: Option<Vec<String>>) {
88        // Update subscription counts
89        {
90            let mut counts = self.subscription_counts.write().await;
91            if let Some(count) = counts.get_mut(table) {
92                *count = count.saturating_sub(1);
93            }
94        }
95
96        // Remove tracked rows if provided
97        if let Some(ids) = row_ids {
98            let mut tracked = self.tracked_rows.write().await;
99            if let Some(rows) = tracked.get_mut(table) {
100                let mut row_counts = self.row_subscription_counts.write().await;
101                for id in ids {
102                    if rows.remove(&id)
103                        && let Some(count) = row_counts.get_mut(table)
104                    {
105                        *count = count.saturating_sub(1);
106                    }
107                }
108            }
109        }
110
111        // Evaluate if mode should change
112        self.evaluate_table(table).await;
113    }
114
115    /// Evaluate and potentially switch tracking mode for a table.
116    pub async fn evaluate_table(&self, table: &str) {
117        let subscription_count = {
118            let counts = self.subscription_counts.read().await;
119            *counts.get(table).unwrap_or(&0)
120        };
121
122        let row_count = {
123            let row_counts = self.row_subscription_counts.read().await;
124            *row_counts.get(table).unwrap_or(&0)
125        };
126
127        let new_mode = if subscription_count == 0 {
128            TrackingMode::None
129        } else if row_count > self.config.row_threshold {
130            TrackingMode::Table
131        } else if row_count < self.config.table_threshold {
132            TrackingMode::Row
133        } else {
134            // Stay in current mode (hysteresis)
135            let modes = self.table_modes.read().await;
136            modes.get(table).copied().unwrap_or(TrackingMode::Row)
137        };
138
139        let mut modes = self.table_modes.write().await;
140        let old_mode = modes.get(table).copied();
141
142        if old_mode != Some(new_mode) {
143            modes.insert(table.to_string(), new_mode);
144            tracing::debug!(
145                table = %table,
146                old_mode = ?old_mode,
147                new_mode = ?new_mode,
148                subscription_count = subscription_count,
149                row_count = row_count,
150                "Tracking mode changed"
151            );
152        }
153    }
154
155    /// Evaluate all tables.
156    pub async fn evaluate(&self) {
157        let tables: Vec<String> = {
158            let counts = self.subscription_counts.read().await;
159            counts.keys().cloned().collect()
160        };
161
162        for table in tables {
163            self.evaluate_table(&table).await;
164        }
165    }
166
167    /// Check if a change should be invalidated.
168    pub async fn should_invalidate(&self, table: &str, row_id: &str) -> bool {
169        let mode = {
170            let modes = self.table_modes.read().await;
171            modes.get(table).copied().unwrap_or(TrackingMode::None)
172        };
173
174        match mode {
175            TrackingMode::None => false,
176            TrackingMode::Table | TrackingMode::Adaptive => true,
177            TrackingMode::Row => {
178                let tracked = self.tracked_rows.read().await;
179                tracked
180                    .get(table)
181                    .map(|rows| rows.contains(row_id))
182                    .unwrap_or(false)
183            }
184        }
185    }
186
187    /// Get the current tracking mode for a table.
188    pub async fn get_mode(&self, table: &str) -> TrackingMode {
189        let modes = self.table_modes.read().await;
190        modes.get(table).copied().unwrap_or(TrackingMode::None)
191    }
192
193    /// Get tracking statistics.
194    pub async fn stats(&self) -> AdaptiveTrackingStats {
195        let modes = self.table_modes.read().await;
196        let tracked = self.tracked_rows.read().await;
197        let counts = self.subscription_counts.read().await;
198
199        let tables_by_mode =
200            |mode: TrackingMode| -> usize { modes.values().filter(|&&m| m == mode).count() };
201
202        let total_tracked_rows: usize = tracked.values().map(|rows| rows.len()).sum();
203
204        AdaptiveTrackingStats {
205            tables_none: tables_by_mode(TrackingMode::None),
206            tables_row: tables_by_mode(TrackingMode::Row),
207            tables_table: tables_by_mode(TrackingMode::Table),
208            total_tracked_rows,
209            total_subscriptions: counts.values().sum(),
210        }
211    }
212
213    /// Clear all tracking state.
214    pub async fn clear(&self) {
215        self.table_modes.write().await.clear();
216        self.tracked_rows.write().await.clear();
217        self.subscription_counts.write().await.clear();
218        self.row_subscription_counts.write().await.clear();
219    }
220}
221
222/// Statistics about adaptive tracking.
223#[derive(Debug, Clone, Default)]
224pub struct AdaptiveTrackingStats {
225    /// Tables with no tracking.
226    pub tables_none: usize,
227    /// Tables with row-level tracking.
228    pub tables_row: usize,
229    /// Tables with table-level tracking.
230    pub tables_table: usize,
231    /// Total rows being tracked.
232    pub total_tracked_rows: usize,
233    /// Total active subscriptions.
234    pub total_subscriptions: usize,
235}
236
237#[cfg(test)]
238mod tests {
239    use super::*;
240
241    #[tokio::test]
242    async fn test_adaptive_tracker_creation() {
243        let config = AdaptiveTrackingConfig::default();
244        let tracker = AdaptiveTracker::new(config);
245
246        let stats = tracker.stats().await;
247        assert_eq!(stats.tables_none, 0);
248        assert_eq!(stats.total_tracked_rows, 0);
249    }
250
251    #[tokio::test]
252    async fn test_subscription_tracking() {
253        let config = AdaptiveTrackingConfig {
254            row_threshold: 5,
255            table_threshold: 2,
256            ..Default::default()
257        };
258        let tracker = AdaptiveTracker::new(config);
259
260        // Add a row subscription
261        tracker
262            .record_subscription("users", Some(vec!["user-1".to_string()]))
263            .await;
264
265        let mode = tracker.get_mode("users").await;
266        assert_eq!(mode, TrackingMode::Row);
267
268        // Should invalidate for tracked row
269        assert!(tracker.should_invalidate("users", "user-1").await);
270        // Should not invalidate for untracked row
271        assert!(!tracker.should_invalidate("users", "user-2").await);
272    }
273
274    #[tokio::test]
275    async fn test_mode_switch_to_table() {
276        let config = AdaptiveTrackingConfig {
277            row_threshold: 3,
278            table_threshold: 1,
279            ..Default::default()
280        };
281        let tracker = AdaptiveTracker::new(config);
282
283        // Add many row subscriptions
284        for i in 0..5 {
285            tracker
286                .record_subscription("users", Some(vec![format!("user-{}", i)]))
287                .await;
288        }
289
290        let mode = tracker.get_mode("users").await;
291        assert_eq!(mode, TrackingMode::Table);
292
293        // Should invalidate for any row in table mode
294        assert!(tracker.should_invalidate("users", "user-999").await);
295    }
296}