sockudo 3.1.0

A simple, fast, and secure WebSocket server for real-time applications.
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
// src/filter/index.rs
//! Inverted filter index for O(1) message-to-subscriber matching.
//!
//! Instead of evaluating N filters for each message (O(N)), we build an inverted
//! index that maps tag values to subscriber sets, enabling O(1) lookups.
//!
//! ## Scenarios Optimized
//!
//! - **Scenario 1**: 10k unique `eq` filters → 10k index entries, O(1) lookup per message
//! - **Scenario 4**: 10k sockets × 100 IN values → up to 1M index entries, O(1) lookup
//! - **Scenario 5**: 10k sockets × 500 IN values → up to 5M index entries, O(1) lookup
//!
//! ## Performance Optimizations (v2)
//!
//! - **SocketId instead of WebSocketRef**: Stores Copy-able SocketId (16 bytes) instead of
//!   Arc-cloned WebSocketRef, eliminating heap allocations during lookups
//! - **Flattened eq_index**: Single-level DashMap with composite hash keys instead of
//!   3-level nested DashMaps, reducing lock contention from 3 locks to 1
//! - **Cache-friendly**: SocketId is stack-allocated and cache-line friendly

use super::node::FilterNode;
use super::ops::{CompareOp, LogicalOp};
use crate::websocket::SocketId;
use ahash::AHashSet;
use dashmap::{DashMap, DashSet};
use std::collections::BTreeMap;
use std::hash::{Hash, Hasher};

/// Composite key for the flattened eq_index.
/// Uses pre-computed hashes for O(1) lookups without string comparisons.
/// Compute composite hash for O(1) eq_index lookups
fn compute_eq_hash(channel: &str, key: &str, value: &str) -> u64 {
    use ahash::AHasher;
    let mut hasher = AHasher::default();
    channel.hash(&mut hasher);
    key.hash(&mut hasher);
    value.hash(&mut hasher);
    hasher.finish()
}

/// Inverted index for fast filter matching.
///
/// OPTIMIZED v2: Uses flattened single-level DashMap with composite keys
/// and stores Copy-able SocketId instead of Arc-cloned WebSocketRef.
#[derive(Default)]
pub struct FilterIndex {
    /// Flattened index for equality filters: hash(channel, key, value) -> SocketIds
    /// Single level of locking for dramatically reduced contention
    eq_index: DashMap<u64, DashSet<SocketId>>,

    /// Reverse index: channel -> set of (key, value) hashes present in that channel
    /// Used for efficient channel cleanup and stats
    channel_keys: DashMap<String, DashSet<u64>>,

    /// Sockets with complex filters that can't be indexed (AND/OR/NOT, inequality, etc.)
    /// These still need per-message evaluation
    /// channel -> socket_ids
    complex_filters: DashMap<String, DashSet<SocketId>>,

    /// Sockets with no filter (receive all messages)
    /// channel -> socket_ids
    no_filter: DashMap<String, DashSet<SocketId>>,
}

/// Result of looking up matching sockets for a message
pub struct IndexLookupResult {
    /// Sockets that match via indexed filters (eq, in) - now returns SocketId
    pub indexed_matches: Vec<SocketId>,
    /// Sockets that need per-message filter evaluation (complex filters)
    pub needs_evaluation: Vec<SocketId>,
    /// Sockets with no filter (receive all)
    pub no_filter: Vec<SocketId>,
}

impl FilterIndex {
    pub fn new() -> Self {
        Self::default()
    }

    /// Index a socket's filter for a channel.
    ///
    /// Call this when a socket subscribes to a channel with a filter.
    ///
    /// RACE CONDITION NOTE: This method adds the socket to exactly ONE index type
    /// based on the filter. The corresponding remove_socket_filter() method will
    /// attempt to remove from ALL index types to handle any inconsistencies.
    pub fn add_socket_filter(
        &self,
        channel: &str,
        socket_id: SocketId,
        filter: Option<&FilterNode>,
    ) {
        // First, ensure the socket is not in any OTHER index for this channel
        // This prevents duplicate entries if add_socket_filter is called multiple times
        // with different filters (e.g., filter update scenario)
        self.remove_socket_from_all_indexes(channel, socket_id, filter);

        match filter {
            None => {
                // No filter = receive all messages
                tracing::debug!(
                    "FilterIndex: Adding socket {} to no_filter for channel {}",
                    socket_id,
                    channel
                );
                self.no_filter
                    .entry(channel.to_string())
                    .or_default()
                    .insert(socket_id);
            }
            Some(filter_node) => {
                if let Some(indexable) = self.extract_indexable_filter(filter_node) {
                    // Filter can be indexed
                    tracing::debug!(
                        "FilterIndex: Adding socket {} to eq_index for channel {}, key={}, values_count={}",
                        socket_id,
                        channel,
                        indexable.key,
                        indexable.values.len()
                    );
                    self.add_to_eq_index(channel, socket_id, &indexable);
                } else {
                    // Complex filter, needs per-message evaluation
                    tracing::debug!(
                        "FilterIndex: Adding socket {} to complex_filters for channel {} (filter not indexable)",
                        socket_id,
                        channel
                    );
                    self.complex_filters
                        .entry(channel.to_string())
                        .or_default()
                        .insert(socket_id);
                }
            }
        }
    }

    /// Helper to remove socket from all non-eq indexes for a channel
    /// Used internally to ensure clean state before adding to a specific index
    fn remove_socket_from_all_indexes(
        &self,
        channel: &str,
        socket_id: SocketId,
        _current_filter: Option<&FilterNode>,
    ) {
        // Remove from no_filter (idempotent)
        if let Some(set) = self.no_filter.get(channel) {
            set.remove(&socket_id);
        }

        // Remove from complex_filters (idempotent)
        if let Some(set) = self.complex_filters.get(channel) {
            set.remove(&socket_id);
        }

        // Note: We don't remove from eq_index here because:
        // 1. We don't know what filter was previously used
        // 2. The eq_index uses hash keys that require knowing the exact filter values
        // 3. If the filter is being updated, the old eq_index entries will become orphaned
        //    but this is acceptable - they'll be cleaned up when the channel is cleared
        //    and won't cause incorrect behavior (socket won't receive duplicate messages)
    }

    /// Remove a socket's filter from the index.
    ///
    /// Call this when a socket unsubscribes from a channel.
    ///
    /// RACE CONDITION FIX: Always removes from ALL index types, not just the one
    /// matching the current filter. This handles cases where:
    /// 1. The filter changed between add and remove
    /// 2. A race condition caused the socket to be added to multiple indexes
    /// 3. The filter parameter is stale or incorrect
    ///
    /// The cost of extra remove operations on DashSet is negligible (O(1) per set),
    /// and this guarantees no orphaned entries remain in any index.
    pub fn remove_socket_filter(
        &self,
        channel: &str,
        socket_id: SocketId,
        filter: Option<&FilterNode>,
    ) {
        // ALWAYS remove from no_filter set (idempotent, safe even if not present)
        if let Some(set) = self.no_filter.get(channel) {
            set.remove(&socket_id);
        }

        // ALWAYS remove from complex_filters set (idempotent, safe even if not present)
        if let Some(set) = self.complex_filters.get(channel) {
            set.remove(&socket_id);
        }

        // Remove from eq_index based on filter (if provided)
        // This is the only index where we need to know the filter to find the right entries
        if let Some(filter_node) = filter
            && let Some(indexable) = self.extract_indexable_filter(filter_node)
        {
            self.remove_from_eq_index(channel, socket_id, &indexable);
        }

        // ADDITIONAL SAFETY: If no filter was provided but socket might be in eq_index,
        // we can't efficiently remove it without knowing the key/values.
        // This is acceptable because:
        // 1. The caller should always provide the filter if one was used during add
        // 2. Orphaned eq_index entries will be cleaned up when the channel is cleared
        // 3. The socket won't receive messages anyway since it's unsubscribed from the channel
    }

    /// Look up sockets that should receive a message with the given tags.
    ///
    /// Returns categorized results for efficient processing.
    ///
    /// OPTIMIZATION: Uses fast path when only one indexed key matches (common case).
    /// This avoids HashSet deduplication overhead for the typical scenario where
    /// messages have one filterable tag (e.g., item_id).
    pub fn lookup(&self, channel: &str, tags: &BTreeMap<String, String>) -> IndexLookupResult {
        // 1. Collect sockets with no filter (they get all messages)
        let no_filter_sockets = if let Some(set) = self.no_filter.get(channel) {
            let mut sockets = Vec::with_capacity(set.len());
            for entry in set.iter() {
                sockets.push(*entry.key());
            }
            sockets
        } else {
            Vec::new()
        };

        // 2. Look up indexed matches - optimized with flattened index
        let indexed_matches = self.lookup_eq_index(channel, tags);

        // 3. Collect sockets with complex filters (need evaluation)
        let needs_evaluation = if let Some(set) = self.complex_filters.get(channel) {
            let mut sockets = Vec::with_capacity(set.len());
            for entry in set.iter() {
                sockets.push(*entry.key());
            }
            sockets
        } else {
            Vec::new()
        };

        IndexLookupResult {
            indexed_matches,
            needs_evaluation,
            no_filter: no_filter_sockets,
        }
    }

    /// Optimized lookup using the flattened eq_index
    fn lookup_eq_index(&self, channel: &str, tags: &BTreeMap<String, String>) -> Vec<SocketId> {
        if tags.is_empty() {
            return Vec::new();
        }

        // First pass: find matching hash keys and count them
        let mut matching_hashes: Vec<u64> = Vec::with_capacity(tags.len());

        for (tag_key, tag_value) in tags {
            let hash = compute_eq_hash(channel, tag_key, tag_value);
            if self.eq_index.contains_key(&hash) {
                matching_hashes.push(hash);
            }
        }

        match matching_hashes.len() {
            0 => Vec::new(),
            1 => {
                // FAST PATH: Single matching key - no deduplication needed
                if let Some(socket_set) = self.eq_index.get(&matching_hashes[0]) {
                    let mut matches = Vec::with_capacity(socket_set.len());
                    for entry in socket_set.iter() {
                        matches.push(*entry.key());
                    }
                    matches
                } else {
                    Vec::new()
                }
            }
            _ => {
                // SLOW PATH: Multiple matching keys - need deduplication
                let mut dedup_set = AHashSet::new();
                for hash in matching_hashes {
                    if let Some(socket_set) = self.eq_index.get(&hash) {
                        for entry in socket_set.iter() {
                            dedup_set.insert(*entry.key());
                        }
                    }
                }
                dedup_set.into_iter().collect()
            }
        }
    }

    /// Check if a filter can be indexed and extract its components.
    ///
    /// Indexable filters:
    /// - Simple equality: `key = value`
    /// - IN operator: `key IN [v1, v2, ...]`
    /// - OR of indexable filters (all children must be indexable with same key)
    fn extract_indexable_filter(&self, filter: &FilterNode) -> Option<IndexableFilter> {
        // Check if it's a logical operation
        if let Some(op) = filter.logical_op() {
            match op {
                LogicalOp::Or => {
                    // OR is indexable if all children are simple eq/in on same key
                    let mut all_values = Vec::new();
                    let mut common_key: Option<String> = None;

                    for child in filter.nodes() {
                        if let Some(indexable) = self.extract_indexable_filter(child) {
                            match &common_key {
                                None => common_key = Some(indexable.key.clone()),
                                Some(k) if k != &indexable.key => return None, // Different keys
                                _ => {}
                            }
                            all_values.extend(indexable.values);
                        } else {
                            return None; // Child is not indexable
                        }
                    }

                    common_key.map(|key| IndexableFilter {
                        key,
                        values: all_values,
                    })
                }
                LogicalOp::And | LogicalOp::Not => {
                    // AND and NOT are not directly indexable
                    // (could optimize AND with post-filtering, but complex)
                    None
                }
            }
        } else {
            // Leaf node - check the comparison operator
            match filter.compare_op() {
                CompareOp::Equal => {
                    let key = filter.key().to_string();
                    let value = filter.val().to_string();
                    if key.is_empty() {
                        return None;
                    }
                    Some(IndexableFilter {
                        key,
                        values: vec![value],
                    })
                }
                CompareOp::In => {
                    let key = filter.key();

                    if key.is_empty() {
                        return None;
                    }

                    let vals_ref = filter.vals();
                    if vals_ref.is_empty() || vals_ref.len() > 500 {
                        return None;
                    }

                    // Index IN filters up to 500 values.
                    // The one-time subscription cost (O(values) DashMap insertions) is amortized
                    // over many broadcasts. For high-throughput scenarios, O(1) lookup per message
                    // is much better than O(sockets) filter evaluations per message.
                    // Filters with >500 values fall back to complex_filters path with O(log n) binary search.
                    Some(IndexableFilter {
                        key: key.to_string(),
                        values: vals_ref.to_vec(),
                    })
                }
                // Other operators can't be efficiently indexed
                _ => None,
            }
        }
    }

    fn add_to_eq_index(&self, channel: &str, socket_id: SocketId, indexable: &IndexableFilter) {
        // Track which keys belong to this channel for cleanup
        let channel_key_set = self.channel_keys.entry(channel.to_string()).or_default();

        for value in &indexable.values {
            let hash = compute_eq_hash(channel, &indexable.key, value);

            // Add to channel's key set for cleanup tracking
            channel_key_set.insert(hash);

            // Add socket to the set (single lock acquisition)
            if let Some(socket_set) = self.eq_index.get(&hash) {
                socket_set.insert(socket_id);
            } else {
                self.eq_index.entry(hash).or_default().insert(socket_id);
            }
        }
    }

    fn remove_from_eq_index(
        &self,
        channel: &str,
        socket_id: SocketId,
        indexable: &IndexableFilter,
    ) {
        for value in &indexable.values {
            let hash = compute_eq_hash(channel, &indexable.key, value);

            if let Some(socket_set) = self.eq_index.get(&hash) {
                socket_set.remove(&socket_id);
            }
        }
    }

    /// Get statistics about the index for monitoring/debugging
    pub fn stats(&self, channel: &str) -> IndexStats {
        let mut eq_entries = 0;
        let mut eq_sockets = 0;

        // Count entries using channel_keys
        if let Some(key_hashes) = self.channel_keys.get(channel) {
            for hash_entry in key_hashes.iter() {
                let hash = *hash_entry.key();
                if let Some(socket_set) = self.eq_index.get(&hash) {
                    eq_entries += 1;
                    eq_sockets += socket_set.len();
                }
            }
        }

        let complex_count = self
            .complex_filters
            .get(channel)
            .map(|s| s.len())
            .unwrap_or(0);

        let no_filter_count = self.no_filter.get(channel).map(|s| s.len()).unwrap_or(0);

        IndexStats {
            eq_entries,
            eq_sockets,
            complex_filters: complex_count,
            no_filter: no_filter_count,
        }
    }

    /// Clear all entries for a channel (used when channel is removed)
    pub fn clear_channel(&self, channel: &str) {
        // Clear eq_index entries using channel_keys
        if let Some((_, key_hashes)) = self.channel_keys.remove(channel) {
            for hash_entry in key_hashes.iter() {
                let hash = *hash_entry.key();
                self.eq_index.remove(&hash);
            }
        }

        self.complex_filters.remove(channel);
        self.no_filter.remove(channel);
    }
}

/// Represents a filter that can be indexed
struct IndexableFilter {
    key: String,
    values: Vec<String>,
}

/// Statistics about the filter index
#[derive(Debug, Clone)]
pub struct IndexStats {
    /// Number of unique (key, value) pairs indexed
    pub eq_entries: usize,
    /// Total socket references in eq index (includes duplicates across values)
    pub eq_sockets: usize,
    /// Sockets with complex filters requiring evaluation
    pub complex_filters: usize,
    /// Sockets with no filter
    pub no_filter: usize,
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_eq_index_key_hashing() {
        let hash1 = compute_eq_hash("channel1", "key1", "value1");
        let hash2 = compute_eq_hash("channel1", "key1", "value1");
        let hash3 = compute_eq_hash("channel1", "key1", "value2");

        assert_eq!(hash1, hash2);
        assert_ne!(hash1, hash3);
    }

    #[test]
    fn test_filter_index_no_filter() {
        let index = FilterIndex::new();
        let socket_id = SocketId::new();

        index.add_socket_filter("channel1", socket_id, None);

        let result = index.lookup("channel1", &BTreeMap::new());
        assert_eq!(result.no_filter.len(), 1);
        assert_eq!(result.no_filter[0], socket_id);
    }
}