1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
//! Subscription lifecycle management: subscribe and unsubscribe operations.
use std::collections::HashSet;
use std::sync::atomic::Ordering;
use tokio::sync::mpsc;
use tracing::debug;
use super::SubscriptionManager;
use crate::subscription::{
extract_table_refs, Subscription, SubscriptionError, SubscriptionId, SubscriptionUpdate,
};
impl SubscriptionManager {
/// Create a new subscription for a query
///
/// Parses the query to extract table dependencies and registers the
/// subscription for notifications.
///
/// # Arguments
///
/// * `query` - SQL query to monitor
/// * `notify_tx` - Channel to send updates to the subscriber
///
/// # Returns
///
/// The subscription ID on success, or an error if parsing fails or limits exceeded
///
/// # Errors
///
/// - `ParseError` if the query cannot be parsed or references no tables
/// - `GlobalLimitExceeded` if the global subscription limit is reached
///
/// # Example
///
/// ```text
/// let manager = SubscriptionManager::new();
/// let (tx, mut rx) = mpsc::channel(16);
///
/// let id = manager.subscribe("SELECT * FROM users".to_string(), tx)?;
/// println!("Subscribed with ID: {}", id);
/// ```
pub fn subscribe(
&self,
query: String,
notify_tx: mpsc::Sender<SubscriptionUpdate>,
) -> Result<SubscriptionId, SubscriptionError> {
// Atomically reserve a slot to prevent TOCTOU race condition
// Use compare-and-swap loop to safely increment the counter
loop {
let current_count = self.subscription_count_atomic.load(Ordering::Acquire);
if current_count >= self.config.max_global {
self.limit_exceeded_count.fetch_add(1, Ordering::Relaxed);
return Err(SubscriptionError::GlobalLimitExceeded {
current: current_count,
max: self.config.max_global,
});
}
// Try to atomically increment the count
match self.subscription_count_atomic.compare_exchange(
current_count,
current_count + 1,
Ordering::AcqRel,
Ordering::Acquire,
) {
Ok(_) => break, // Successfully reserved a slot
Err(_) => continue, // Another thread changed the count, retry
}
}
// Parse query and extract table dependencies
let tables = match self.extract_tables(&query) {
Ok(tables) => tables,
Err(e) => {
// Release the reserved slot on parse error
self.subscription_count_atomic.fetch_sub(1, Ordering::Release);
return Err(e);
}
};
if tables.is_empty() {
// Release the reserved slot
self.subscription_count_atomic.fetch_sub(1, Ordering::Release);
return Err(SubscriptionError::ParseError(
"Query must reference at least one table".to_string(),
));
}
// Create subscription with default channel buffer size
let subscription = Subscription::new(query.clone(), tables.clone(), notify_tx);
let id = subscription.id;
debug!(
subscription_id = %id,
tables = ?tables,
"Creating new subscription"
);
// Register subscription
self.subscriptions.insert(id, subscription);
// Index by tables
for table in tables {
self.table_index.entry(table).or_default().insert(id);
}
Ok(id)
}
/// Create a new subscription for a specific connection (wire protocol)
///
/// This is the primary method for wire protocol subscriptions. It:
/// - Checks both global and per-connection limits
/// - Associates the subscription with a connection ID for cleanup
/// - Stores the wire protocol UUID for lookup
///
/// # Arguments
///
/// * `query` - SQL query to monitor
/// * `notify_tx` - Channel to send updates to the subscriber
/// * `connection_id` - The connection/session ID that owns this subscription
/// * `wire_subscription_id` - The wire protocol UUID for this subscription
/// * `table_dependencies` - Pre-extracted table dependencies (from AST parsing)
///
/// # Returns
///
/// The subscription ID on success, or an error if limits exceeded
///
/// # Errors
///
/// - `GlobalLimitExceeded` if the global subscription limit is reached
/// - `ConnectionLimitExceeded` if the per-connection limit is reached
pub fn subscribe_for_connection(
&self,
query: String,
notify_tx: mpsc::Sender<SubscriptionUpdate>,
connection_id: String,
wire_subscription_id: [u8; 16],
table_dependencies: HashSet<String>,
filter: Option<String>,
) -> Result<SubscriptionId, SubscriptionError> {
// Check per-connection limit first
let conn_count = self
.connection_subscription_counts
.entry(connection_id.clone())
.or_insert_with(|| std::sync::atomic::AtomicUsize::new(0));
// Use CAS loop for per-connection limit check
loop {
let current_conn_count = conn_count.load(Ordering::Acquire);
if current_conn_count >= self.config.max_per_connection {
return Err(SubscriptionError::ConnectionLimitExceeded {
current: current_conn_count,
max: self.config.max_per_connection,
});
}
// Try to atomically increment the per-connection count
match conn_count.compare_exchange(
current_conn_count,
current_conn_count + 1,
Ordering::AcqRel,
Ordering::Acquire,
) {
Ok(_) => break,
Err(_) => continue,
}
}
// Atomically reserve a global slot to prevent TOCTOU race condition
loop {
let current_count = self.subscription_count_atomic.load(Ordering::Acquire);
if current_count >= self.config.max_global {
// Release the per-connection slot we reserved
conn_count.fetch_sub(1, Ordering::Release);
self.limit_exceeded_count.fetch_add(1, Ordering::Relaxed);
return Err(SubscriptionError::GlobalLimitExceeded {
current: current_count,
max: self.config.max_global,
});
}
// Try to atomically increment the count
match self.subscription_count_atomic.compare_exchange(
current_count,
current_count + 1,
Ordering::AcqRel,
Ordering::Acquire,
) {
Ok(_) => break,
Err(_) => continue,
}
}
// Create subscription with connection tracking
let subscription = Subscription::for_connection(
query.clone(),
table_dependencies.clone(),
notify_tx,
connection_id.clone(),
wire_subscription_id,
filter,
&self.config,
);
let id = subscription.id;
debug!(
subscription_id = %id,
connection_id = %connection_id,
tables = ?table_dependencies,
"Creating new subscription for connection"
);
// Register subscription
self.subscriptions.insert(id, subscription);
// Index by tables (lowercase for case-insensitive matching)
for table in table_dependencies {
self.table_index.entry(table.to_lowercase()).or_default().insert(id);
}
// Index by connection
self.connection_index.entry(connection_id).or_default().insert(id);
// Index by wire ID
self.wire_id_index.insert(wire_subscription_id, id);
Ok(id)
}
/// Remove a subscription
///
/// Unregisters the subscription and removes it from all indexes.
///
/// # Arguments
///
/// * `id` - The subscription ID to remove
///
/// # Returns
///
/// `true` if the removed subscription was selective-eligible, `false` otherwise
pub fn unsubscribe(&self, id: SubscriptionId) -> bool {
if let Some((_, subscription)) = self.subscriptions.remove(&id) {
debug!(subscription_id = %id, "Removing subscription");
let was_selective_eligible = subscription.selective_eligible;
// Decrement the atomic counter
self.subscription_count_atomic.fetch_sub(1, Ordering::Release);
// Remove from table index
for table in &subscription.tables {
if let Some(mut ids) = self.table_index.get_mut(table) {
ids.remove(&id);
}
}
// Remove from connection index if present
if let Some(ref conn_id) = subscription.connection_id {
if let Some(mut ids) = self.connection_index.get_mut(conn_id) {
ids.remove(&id);
}
// Decrement per-connection count
if let Some(count) = self.connection_subscription_counts.get(conn_id) {
count.fetch_sub(1, Ordering::Release);
}
}
// Remove from wire ID index if present
if let Some(wire_id) = subscription.wire_subscription_id {
self.wire_id_index.remove(&wire_id);
}
return was_selective_eligible;
}
false
}
/// Remove a subscription by its wire protocol ID
///
/// This is used by wire protocol clients that use UUID-based subscription IDs.
///
/// # Arguments
///
/// * `wire_id` - The wire protocol subscription ID (UUID bytes)
///
/// # Returns
///
/// `true` if the removed subscription was selective-eligible, `false` otherwise.
/// Returns `false` if the subscription was not found.
pub fn unsubscribe_by_wire_id(&self, wire_id: &[u8; 16]) -> bool {
if let Some((_, id)) = self.wire_id_index.remove(wire_id) {
self.unsubscribe(id)
} else {
false
}
}
/// Remove all subscriptions for a connection
///
/// This should be called when a connection closes to clean up all its
/// subscriptions. This is important for wire protocol connections.
///
/// # Arguments
///
/// * `connection_id` - The connection ID to clean up
///
/// # Returns
///
/// A tuple of (total_removed, selective_eligible_removed)
pub fn unsubscribe_all_for_connection(&self, connection_id: &str) -> (usize, usize) {
let subscription_ids: Vec<SubscriptionId> = if let Some((_, ids)) =
self.connection_index.remove(connection_id)
{
ids.into_iter().collect()
} else {
return (0, 0);
};
let count = subscription_ids.len();
debug!(
connection_id = %connection_id,
subscription_count = count,
"Removing all subscriptions for connection"
);
let mut selective_eligible_count = 0;
for id in subscription_ids {
// Note: unsubscribe will try to remove from connection_index again,
// but it will be a no-op since we already removed it
if self.unsubscribe(id) {
selective_eligible_count += 1;
}
}
// Clean up the per-connection count entry
self.connection_subscription_counts.remove(connection_id);
(count, selective_eligible_count)
}
/// Extract table references from a query
pub(crate) fn extract_tables(&self, query: &str) -> Result<HashSet<String>, SubscriptionError> {
let stmt = vibesql_parser::Parser::parse_sql(query)
.map_err(|e| SubscriptionError::ParseError(e.to_string()))?;
Ok(extract_table_refs(&stmt))
}
}