forge_runtime/realtime/
adaptive.rs1use std::collections::{HashMap, HashSet};
2use std::time::Duration;
3
4use tokio::sync::RwLock;
5
6use forge_core::realtime::TrackingMode;
7
8#[derive(Debug, Clone)]
10pub struct AdaptiveTrackingConfig {
11 pub row_threshold: usize,
13 pub table_threshold: usize,
15 pub max_tracked_rows: usize,
17 pub evaluation_interval: Duration,
19}
20
21impl Default for AdaptiveTrackingConfig {
22 fn default() -> Self {
23 Self {
24 row_threshold: 100,
25 table_threshold: 50,
26 max_tracked_rows: 10_000,
27 evaluation_interval: Duration::from_secs(60),
28 }
29 }
30}
31
32pub struct AdaptiveTracker {
37 config: AdaptiveTrackingConfig,
38 table_modes: RwLock<HashMap<String, TrackingMode>>,
40 tracked_rows: RwLock<HashMap<String, HashSet<String>>>,
42 subscription_counts: RwLock<HashMap<String, usize>>,
44 row_subscription_counts: RwLock<HashMap<String, usize>>,
46}
47
48impl AdaptiveTracker {
49 pub fn new(config: AdaptiveTrackingConfig) -> Self {
51 Self {
52 config,
53 table_modes: RwLock::new(HashMap::new()),
54 tracked_rows: RwLock::new(HashMap::new()),
55 subscription_counts: RwLock::new(HashMap::new()),
56 row_subscription_counts: RwLock::new(HashMap::new()),
57 }
58 }
59
60 pub async fn record_subscription(&self, table: &str, row_ids: Option<Vec<String>>) {
62 {
64 let mut counts = self.subscription_counts.write().await;
65 *counts.entry(table.to_string()).or_insert(0) += 1;
66 }
67
68 if let Some(ids) = row_ids {
70 let mut tracked = self.tracked_rows.write().await;
71 let rows = tracked.entry(table.to_string()).or_default();
72 let mut row_counts = self.row_subscription_counts.write().await;
73
74 for id in ids {
75 if rows.len() < self.config.max_tracked_rows {
76 rows.insert(id);
77 *row_counts.entry(table.to_string()).or_insert(0) += 1;
78 }
79 }
80 }
81
82 self.evaluate_table(table).await;
84 }
85
86 pub async fn remove_subscription(&self, table: &str, row_ids: Option<Vec<String>>) {
88 {
90 let mut counts = self.subscription_counts.write().await;
91 if let Some(count) = counts.get_mut(table) {
92 *count = count.saturating_sub(1);
93 }
94 }
95
96 if let Some(ids) = row_ids {
98 let mut tracked = self.tracked_rows.write().await;
99 if let Some(rows) = tracked.get_mut(table) {
100 let mut row_counts = self.row_subscription_counts.write().await;
101 for id in ids {
102 if rows.remove(&id)
103 && let Some(count) = row_counts.get_mut(table)
104 {
105 *count = count.saturating_sub(1);
106 }
107 }
108 }
109 }
110
111 self.evaluate_table(table).await;
113 }
114
115 pub async fn evaluate_table(&self, table: &str) {
117 let subscription_count = {
118 let counts = self.subscription_counts.read().await;
119 *counts.get(table).unwrap_or(&0)
120 };
121
122 let row_count = {
123 let row_counts = self.row_subscription_counts.read().await;
124 *row_counts.get(table).unwrap_or(&0)
125 };
126
127 let new_mode = if subscription_count == 0 {
128 TrackingMode::None
129 } else if row_count > self.config.row_threshold {
130 TrackingMode::Table
131 } else if row_count < self.config.table_threshold {
132 TrackingMode::Row
133 } else {
134 let modes = self.table_modes.read().await;
136 modes.get(table).copied().unwrap_or(TrackingMode::Row)
137 };
138
139 let mut modes = self.table_modes.write().await;
140 let old_mode = modes.get(table).copied();
141
142 if old_mode != Some(new_mode) {
143 modes.insert(table.to_string(), new_mode);
144 tracing::debug!(
145 table = %table,
146 old_mode = ?old_mode,
147 new_mode = ?new_mode,
148 subscription_count = subscription_count,
149 row_count = row_count,
150 "Tracking mode changed"
151 );
152 }
153 }
154
155 pub async fn evaluate(&self) {
157 let tables: Vec<String> = {
158 let counts = self.subscription_counts.read().await;
159 counts.keys().cloned().collect()
160 };
161
162 for table in tables {
163 self.evaluate_table(&table).await;
164 }
165 }
166
167 pub async fn should_invalidate(&self, table: &str, row_id: &str) -> bool {
169 let mode = {
170 let modes = self.table_modes.read().await;
171 modes.get(table).copied().unwrap_or(TrackingMode::None)
172 };
173
174 match mode {
175 TrackingMode::None => false,
176 TrackingMode::Table | TrackingMode::Adaptive => true,
177 TrackingMode::Row => {
178 let tracked = self.tracked_rows.read().await;
179 tracked
180 .get(table)
181 .map(|rows| rows.contains(row_id))
182 .unwrap_or(false)
183 }
184 }
185 }
186
187 pub async fn get_mode(&self, table: &str) -> TrackingMode {
189 let modes = self.table_modes.read().await;
190 modes.get(table).copied().unwrap_or(TrackingMode::None)
191 }
192
193 pub async fn stats(&self) -> AdaptiveTrackingStats {
195 let modes = self.table_modes.read().await;
196 let tracked = self.tracked_rows.read().await;
197 let counts = self.subscription_counts.read().await;
198
199 let tables_by_mode =
200 |mode: TrackingMode| -> usize { modes.values().filter(|&&m| m == mode).count() };
201
202 let total_tracked_rows: usize = tracked.values().map(|rows| rows.len()).sum();
203
204 AdaptiveTrackingStats {
205 tables_none: tables_by_mode(TrackingMode::None),
206 tables_row: tables_by_mode(TrackingMode::Row),
207 tables_table: tables_by_mode(TrackingMode::Table),
208 total_tracked_rows,
209 total_subscriptions: counts.values().sum(),
210 }
211 }
212
213 pub async fn clear(&self) {
215 self.table_modes.write().await.clear();
216 self.tracked_rows.write().await.clear();
217 self.subscription_counts.write().await.clear();
218 self.row_subscription_counts.write().await.clear();
219 }
220}
221
222#[derive(Debug, Clone, Default)]
224pub struct AdaptiveTrackingStats {
225 pub tables_none: usize,
227 pub tables_row: usize,
229 pub tables_table: usize,
231 pub total_tracked_rows: usize,
233 pub total_subscriptions: usize,
235}
236
237#[cfg(test)]
238mod tests {
239 use super::*;
240
241 #[tokio::test]
242 async fn test_adaptive_tracker_creation() {
243 let config = AdaptiveTrackingConfig::default();
244 let tracker = AdaptiveTracker::new(config);
245
246 let stats = tracker.stats().await;
247 assert_eq!(stats.tables_none, 0);
248 assert_eq!(stats.total_tracked_rows, 0);
249 }
250
251 #[tokio::test]
252 async fn test_subscription_tracking() {
253 let config = AdaptiveTrackingConfig {
254 row_threshold: 5,
255 table_threshold: 2,
256 ..Default::default()
257 };
258 let tracker = AdaptiveTracker::new(config);
259
260 tracker
262 .record_subscription("users", Some(vec!["user-1".to_string()]))
263 .await;
264
265 let mode = tracker.get_mode("users").await;
266 assert_eq!(mode, TrackingMode::Row);
267
268 assert!(tracker.should_invalidate("users", "user-1").await);
270 assert!(!tracker.should_invalidate("users", "user-2").await);
272 }
273
274 #[tokio::test]
275 async fn test_mode_switch_to_table() {
276 let config = AdaptiveTrackingConfig {
277 row_threshold: 3,
278 table_threshold: 1,
279 ..Default::default()
280 };
281 let tracker = AdaptiveTracker::new(config);
282
283 for i in 0..5 {
285 tracker
286 .record_subscription("users", Some(vec![format!("user-{}", i)]))
287 .await;
288 }
289
290 let mode = tracker.get_mode("users").await;
291 assert_eq!(mode, TrackingMode::Table);
292
293 assert!(tracker.should_invalidate("users", "user-999").await);
295 }
296}