Skip to main content

trojan_analytics/
collector.rs

1//! Event collector for non-blocking event recording.
2
3use std::net::{IpAddr, SocketAddr};
4use std::sync::Arc;
5use std::time::Instant;
6
7use tokio::sync::mpsc;
8use tracing::debug;
9use trojan_config::{AnalyticsConfig, AnalyticsPrivacyConfig};
10
11use crate::event::{AuthResult, CloseReason, ConnectionEvent, Protocol, TargetType, Transport};
12
13/// Event collector for recording connection events.
14///
15/// This struct is cheap to clone and can be shared across threads.
16/// Events are sent through a bounded channel to a background writer.
17#[derive(Debug, Clone)]
18pub struct EventCollector {
19    sender: mpsc::Sender<ConnectionEvent>,
20    config: Arc<AnalyticsConfig>,
21}
22
23impl EventCollector {
24    /// Create a new event collector.
25    pub(crate) fn new(sender: mpsc::Sender<ConnectionEvent>, config: Arc<AnalyticsConfig>) -> Self {
26        Self { sender, config }
27    }
28
29    /// Record a connection event (non-blocking).
30    ///
31    /// Returns `true` if the event was queued, `false` if the buffer is full.
32    #[inline]
33    pub fn record(&self, event: ConnectionEvent) -> bool {
34        self.sender.try_send(event).is_ok()
35    }
36
37    /// Create a connection event builder for the given connection.
38    ///
39    /// The builder will automatically send the event when dropped.
40    pub fn connection(&self, conn_id: u64, peer: SocketAddr) -> ConnectionEventBuilder {
41        ConnectionEventBuilder::new(self.clone(), conn_id, peer, &self.config)
42    }
43
44    /// Check if an event should be recorded based on sampling configuration.
45    ///
46    /// Returns `true` if the event should be recorded.
47    pub fn should_sample(&self, user_id: Option<&str>) -> bool {
48        let sampling = &self.config.sampling;
49
50        // Always record specified users
51        if let Some(uid) = user_id
52            && sampling.always_record_users.iter().any(|u| u == uid)
53        {
54            return true;
55        }
56
57        // Sample based on rate
58        if sampling.rate >= 1.0 {
59            return true;
60        }
61        if sampling.rate <= 0.0 {
62            return false;
63        }
64
65        rand::random::<f64>() < sampling.rate
66    }
67
68    /// Get the privacy configuration.
69    pub fn privacy(&self) -> &AnalyticsPrivacyConfig {
70        &self.config.privacy
71    }
72
73    /// Get the server ID.
74    pub fn server_id(&self) -> Option<&str> {
75        self.config.server_id.as_deref()
76    }
77}
78
79/// Builder for constructing connection events.
80///
81/// Events are automatically sent when the builder is dropped,
82/// or can be explicitly sent with `finish()`.
83#[derive(Debug)]
84pub struct ConnectionEventBuilder {
85    collector: EventCollector,
86    event: ConnectionEvent,
87    start_time: Instant,
88    sent: bool,
89}
90
91impl ConnectionEventBuilder {
92    /// Create a new connection event builder.
93    fn new(
94        collector: EventCollector,
95        conn_id: u64,
96        peer: SocketAddr,
97        config: &AnalyticsConfig,
98    ) -> Self {
99        let peer_ip = if config.privacy.record_peer_ip {
100            peer.ip()
101        } else {
102            // Use unspecified address if not recording
103            match peer {
104                SocketAddr::V4(_) => IpAddr::V4(std::net::Ipv4Addr::UNSPECIFIED),
105                SocketAddr::V6(_) => IpAddr::V6(std::net::Ipv6Addr::UNSPECIFIED),
106            }
107        };
108
109        let mut event = ConnectionEvent::new(conn_id, peer_ip, peer.port());
110        event.server_id = config.server_id.clone().unwrap_or_default();
111
112        Self {
113            collector,
114            event,
115            start_time: Instant::now(),
116            sent: false,
117        }
118    }
119
120    /// Set the user ID.
121    pub fn user(mut self, user_id: impl Into<String>) -> Self {
122        let uid = user_id.into();
123        let privacy = self.collector.privacy();
124
125        self.event.user_id = if privacy.full_user_id {
126            uid
127        } else {
128            // Truncate to prefix length
129            let len = privacy.user_id_prefix_len.min(uid.len());
130            uid[..len].to_string()
131        };
132        self.event.auth_result = AuthResult::Success;
133        self
134    }
135
136    /// Set authentication as failed.
137    pub fn auth_failed(mut self) -> Self {
138        self.event.auth_result = AuthResult::Failed;
139        self
140    }
141
142    /// Set the target information.
143    pub fn target(mut self, host: impl Into<String>, port: u16, target_type: TargetType) -> Self {
144        self.event.target_host = host.into();
145        self.event.target_port = port;
146        self.event.target_type = target_type;
147        self
148    }
149
150    /// Set the SNI (Server Name Indication).
151    pub fn sni(mut self, sni: impl Into<String>) -> Self {
152        if self.collector.privacy().record_sni {
153            self.event.sni = sni.into();
154        }
155        self
156    }
157
158    /// Set the protocol type.
159    pub fn protocol(mut self, protocol: Protocol) -> Self {
160        self.event.protocol = protocol;
161        self
162    }
163
164    /// Set the transport type.
165    pub fn transport(mut self, transport: Transport) -> Self {
166        self.event.transport = transport;
167        self
168    }
169
170    /// Mark as fallback connection.
171    pub fn fallback(mut self) -> Self {
172        self.event.is_fallback = true;
173        self.event.auth_result = AuthResult::Skipped;
174        self
175    }
176
177    /// Set GeoIP fields based on lookup result and privacy precision.
178    ///
179    /// Precision levels:
180    /// - `"city"`: fill all geo fields (country, region, city, ASN, org, lat/lon)
181    /// - `"country"`: fill only country code
182    /// - `"none"` or other: no-op
183    pub fn geo(mut self, result: trojan_config::GeoResult, precision: &str) -> Self {
184        match precision {
185            "city" => {
186                self.event.peer_country = result.country;
187                self.event.peer_region = result.region;
188                self.event.peer_city = result.city;
189                self.event.peer_asn = result.asn;
190                self.event.peer_org = result.org;
191                self.event.peer_longitude = result.longitude;
192                self.event.peer_latitude = result.latitude;
193            }
194            "country" => {
195                self.event.peer_country = result.country;
196            }
197            _ => {} // "none" or unknown: no-op
198        }
199        self
200    }
201
202    /// Add bytes to the traffic counters.
203    #[inline]
204    pub fn add_bytes(&mut self, sent: u64, recv: u64) {
205        self.event.bytes_sent += sent;
206        self.event.bytes_recv += recv;
207    }
208
209    /// Add packets to the packet counters (for UDP).
210    #[inline]
211    pub fn add_packets(&mut self, sent: u64, recv: u64) {
212        self.event.packets_sent += sent;
213        self.event.packets_recv += recv;
214    }
215
216    /// Get a mutable reference to the event for direct modification.
217    pub fn event_mut(&mut self) -> &mut ConnectionEvent {
218        &mut self.event
219    }
220
221    /// Finish and send the event with the given close reason.
222    #[allow(clippy::cast_possible_truncation)]
223    pub fn finish(mut self, close_reason: CloseReason) {
224        self.event.duration_ms = self.start_time.elapsed().as_millis() as u64;
225        self.event.close_reason = close_reason;
226        self.send();
227    }
228
229    /// Send the event.
230    fn send(&mut self) {
231        if self.sent {
232            return;
233        }
234        self.sent = true;
235
236        if !self.collector.record(self.event.clone()) {
237            debug!(
238                conn_id = self.event.conn_id,
239                "analytics buffer full, event dropped"
240            );
241        }
242    }
243}
244
245impl Drop for ConnectionEventBuilder {
246    #[allow(clippy::cast_possible_truncation)]
247    fn drop(&mut self) {
248        if !self.sent {
249            self.event.duration_ms = self.start_time.elapsed().as_millis() as u64;
250            self.send();
251        }
252    }
253}
254
255#[cfg(test)]
256mod tests {
257    use super::*;
258    use std::net::{Ipv4Addr, SocketAddrV4};
259    use std::sync::Arc;
260    use trojan_config::{AnalyticsConfig, GeoResult};
261
262    fn test_collector() -> (EventCollector, mpsc::Receiver<ConnectionEvent>) {
263        let (tx, rx) = mpsc::channel(64);
264        let config = Arc::new(AnalyticsConfig {
265            enabled: true,
266            ..Default::default()
267        });
268        (EventCollector::new(tx, config), rx)
269    }
270
271    fn test_peer() -> SocketAddr {
272        SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(1, 2, 3, 4), 12345))
273    }
274
275    #[test]
276    fn geo_builder_city_precision() {
277        let (collector, _rx) = test_collector();
278        let builder = collector.connection(1, test_peer());
279
280        let geo = GeoResult {
281            country: "US".into(),
282            region: "California".into(),
283            city: "Los Angeles".into(),
284            asn: 15169,
285            org: "Google LLC".into(),
286            longitude: -118.24,
287            latitude: 34.05,
288        };
289
290        let builder = builder.geo(geo, "city");
291        assert_eq!(builder.event.peer_country, "US");
292        assert_eq!(builder.event.peer_region, "California");
293        assert_eq!(builder.event.peer_city, "Los Angeles");
294        assert_eq!(builder.event.peer_asn, 15169);
295        assert_eq!(builder.event.peer_org, "Google LLC");
296        assert!((builder.event.peer_longitude - (-118.24)).abs() < 0.001);
297        assert!((builder.event.peer_latitude - 34.05).abs() < 0.001);
298    }
299
300    #[test]
301    fn geo_builder_country_precision() {
302        let (collector, _rx) = test_collector();
303        let builder = collector.connection(2, test_peer());
304
305        let geo = GeoResult {
306            country: "CN".into(),
307            region: "Shanghai".into(),
308            city: "Shanghai".into(),
309            asn: 4134,
310            org: "China Telecom".into(),
311            longitude: 121.47,
312            latitude: 31.23,
313        };
314
315        let builder = builder.geo(geo, "country");
316        assert_eq!(builder.event.peer_country, "CN");
317        assert!(builder.event.peer_region.is_empty());
318        assert!(builder.event.peer_city.is_empty());
319        assert_eq!(builder.event.peer_asn, 0);
320    }
321
322    #[test]
323    fn geo_builder_none_precision() {
324        let (collector, _rx) = test_collector();
325        let builder = collector.connection(3, test_peer());
326
327        let geo = GeoResult {
328            country: "JP".into(),
329            region: "Tokyo".into(),
330            city: "Tokyo".into(),
331            asn: 2497,
332            org: "IIJ".into(),
333            longitude: 139.69,
334            latitude: 35.69,
335        };
336
337        let builder = builder.geo(geo, "none");
338        assert!(builder.event.peer_country.is_empty());
339        assert!(builder.event.peer_region.is_empty());
340        assert_eq!(builder.event.peer_asn, 0);
341    }
342
343    #[tokio::test]
344    async fn event_builder_sends_on_finish() {
345        let (collector, mut rx) = test_collector();
346        let builder = collector.connection(10, test_peer());
347        builder
348            .target("example.com".to_string(), 443, TargetType::Domain)
349            .protocol(Protocol::Tcp)
350            .finish(CloseReason::Normal);
351
352        let event = rx.try_recv().unwrap();
353        assert_eq!(event.conn_id, 10);
354        assert_eq!(event.target_host, "example.com");
355        assert_eq!(event.target_port, 443);
356        assert_eq!(event.protocol, Protocol::Tcp);
357        assert_eq!(event.close_reason, CloseReason::Normal);
358    }
359
360    #[tokio::test]
361    async fn event_builder_sends_on_drop() {
362        let (collector, mut rx) = test_collector();
363        {
364            let _builder = collector.connection(20, test_peer());
365        }
366        let event = rx.try_recv().unwrap();
367        assert_eq!(event.conn_id, 20);
368    }
369
370    #[test]
371    fn should_sample_always_record_user() {
372        let (tx, _rx) = mpsc::channel(1);
373        let config = Arc::new(AnalyticsConfig {
374            enabled: true,
375            sampling: trojan_config::AnalyticsSamplingConfig {
376                rate: 0.0,
377                always_record_users: vec!["vip-user".into()],
378            },
379            ..Default::default()
380        });
381        let collector = EventCollector::new(tx, config);
382        assert!(collector.should_sample(Some("vip-user")));
383        assert!(!collector.should_sample(Some("normal-user")));
384    }
385
386    #[test]
387    fn should_sample_rate_boundaries() {
388        let (tx, _rx) = mpsc::channel(1);
389        let config = Arc::new(AnalyticsConfig {
390            enabled: true,
391            sampling: trojan_config::AnalyticsSamplingConfig {
392                rate: 1.0,
393                always_record_users: vec![],
394            },
395            ..Default::default()
396        });
397        let collector = EventCollector::new(tx, config);
398        assert!(collector.should_sample(None));
399
400        let (tx2, _rx2) = mpsc::channel(1);
401        let config2 = Arc::new(AnalyticsConfig {
402            enabled: true,
403            sampling: trojan_config::AnalyticsSamplingConfig {
404                rate: 0.0,
405                always_record_users: vec![],
406            },
407            ..Default::default()
408        });
409        let collector2 = EventCollector::new(tx2, config2);
410        assert!(!collector2.should_sample(None));
411    }
412
413    #[test]
414    fn user_id_truncation() {
415        let (collector, _rx) = test_collector();
416        let builder = collector.connection(30, test_peer());
417        let builder = builder.user("abcdef1234567890");
418        assert_eq!(builder.event.user_id, "abcdef12");
419    }
420
421    #[test]
422    fn add_bytes_and_packets() {
423        let (collector, _rx) = test_collector();
424        let mut builder = collector.connection(40, test_peer());
425        builder.add_bytes(100, 200);
426        builder.add_bytes(50, 25);
427        builder.add_packets(3, 5);
428        assert_eq!(builder.event.bytes_sent, 150);
429        assert_eq!(builder.event.bytes_recv, 225);
430        assert_eq!(builder.event.packets_sent, 3);
431        assert_eq!(builder.event.packets_recv, 5);
432    }
433}