Skip to main content

tonic_debug/
connection.rs

1//! Connection lifecycle observability.
2//!
3//! Provides a Tower layer that tracks TCP/HTTP2 connection events — new
4//! connections, disconnections, and errors — giving operators visibility into
5//! client connectivity issues.
6
7use std::{
8    fmt,
9    future::Future,
10    net::SocketAddr,
11    pin::Pin,
12    sync::{
13        atomic::{AtomicU64, Ordering},
14        Arc,
15    },
16    task::{Context, Poll},
17};
18
19use tower_layer::Layer;
20use tower_service::Service;
21
22/// Shared connection metrics.
23#[derive(Debug, Clone)]
24pub struct ConnectionMetrics {
25    inner: Arc<ConnectionMetricsInner>,
26}
27
28#[derive(Debug)]
29struct ConnectionMetricsInner {
30    /// Total number of connections ever accepted.
31    total_connections: AtomicU64,
32    /// Currently active connections.
33    active_connections: AtomicU64,
34    /// Total number of connection errors.
35    connection_errors: AtomicU64,
36}
37
38impl ConnectionMetrics {
39    /// Create a new `ConnectionMetrics` instance.
40    pub fn new() -> Self {
41        Self {
42            inner: Arc::new(ConnectionMetricsInner {
43                total_connections: AtomicU64::new(0),
44                active_connections: AtomicU64::new(0),
45                connection_errors: AtomicU64::new(0),
46            }),
47        }
48    }
49
50    /// Get the total number of connections ever accepted.
51    pub fn total_connections(&self) -> u64 {
52        self.inner.total_connections.load(Ordering::Relaxed)
53    }
54
55    /// Get the number of currently active connections.
56    pub fn active_connections(&self) -> u64 {
57        self.inner.active_connections.load(Ordering::Relaxed)
58    }
59
60    /// Get the total number of connection errors.
61    pub fn connection_errors(&self) -> u64 {
62        self.inner.connection_errors.load(Ordering::Relaxed)
63    }
64
65    fn on_connect(&self) {
66        self.inner.total_connections.fetch_add(1, Ordering::Relaxed);
67        self.inner
68            .active_connections
69            .fetch_add(1, Ordering::Relaxed);
70    }
71
72    fn on_disconnect(&self) {
73        self.inner
74            .active_connections
75            .fetch_sub(1, Ordering::Relaxed);
76    }
77
78    fn on_error(&self) {
79        self.inner.connection_errors.fetch_add(1, Ordering::Relaxed);
80    }
81}
82
83impl Default for ConnectionMetrics {
84    fn default() -> Self {
85        Self::new()
86    }
87}
88
89impl fmt::Display for ConnectionMetrics {
90    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
91        write!(
92            f,
93            "connections(total={}, active={}, errors={})",
94            self.total_connections(),
95            self.active_connections(),
96            self.connection_errors()
97        )
98    }
99}
100
101/// A Tower layer that wraps a `MakeService` (or any per-connection service)
102/// to track and log connection lifecycle events.
103///
104/// This should be applied at the server level, wrapping the service factory
105/// that produces per-connection services.
106#[derive(Debug, Clone)]
107pub struct ConnectionTrackerLayer {
108    metrics: ConnectionMetrics,
109}
110
111impl ConnectionTrackerLayer {
112    /// Create a new `ConnectionTrackerLayer`.
113    pub fn new() -> Self {
114        Self {
115            metrics: ConnectionMetrics::new(),
116        }
117    }
118
119    /// Create a new `ConnectionTrackerLayer` with shared metrics.
120    pub fn with_metrics(metrics: ConnectionMetrics) -> Self {
121        Self { metrics }
122    }
123
124    /// Get a reference to the connection metrics.
125    pub fn metrics(&self) -> &ConnectionMetrics {
126        &self.metrics
127    }
128}
129
130impl Default for ConnectionTrackerLayer {
131    fn default() -> Self {
132        Self::new()
133    }
134}
135
136impl<S> Layer<S> for ConnectionTrackerLayer {
137    type Service = ConnectionTrackerService<S>;
138
139    fn layer(&self, inner: S) -> Self::Service {
140        ConnectionTrackerService {
141            inner,
142            metrics: self.metrics.clone(),
143        }
144    }
145}
146
147/// A Tower service that tracks connection lifecycle events.
148///
149/// When used as the outer service in a tonic server, each call to this
150/// service represents a new connection being established.
151#[derive(Debug, Clone)]
152pub struct ConnectionTrackerService<S> {
153    inner: S,
154    metrics: ConnectionMetrics,
155}
156
157impl<S> ConnectionTrackerService<S> {
158    /// Get a reference to the connection metrics.
159    pub fn metrics(&self) -> &ConnectionMetrics {
160        &self.metrics
161    }
162}
163
164impl<S, Target> Service<Target> for ConnectionTrackerService<S>
165where
166    S: Service<Target> + Clone + Send + 'static,
167    S::Response: Send + 'static,
168    S::Error: fmt::Display + Send + 'static,
169    S::Future: Send + 'static,
170    Target: fmt::Debug + Send + 'static,
171{
172    type Response = S::Response;
173    type Error = S::Error;
174    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
175
176    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
177        self.inner.poll_ready(cx)
178    }
179
180    fn call(&mut self, target: Target) -> Self::Future {
181        let metrics = self.metrics.clone();
182        let mut inner = self.inner.clone();
183        std::mem::swap(&mut self.inner, &mut inner);
184
185        metrics.on_connect();
186
187        tracing::info!(
188            peer = ?target,
189            active_connections = metrics.active_connections(),
190            total_connections = metrics.total_connections(),
191            "⚡ New connection established"
192        );
193
194        Box::pin(async move {
195            let result = inner.call(target).await;
196            match &result {
197                Ok(_) => {
198                    metrics.on_disconnect();
199                    tracing::info!(
200                        active_connections = metrics.active_connections(),
201                        "🔌 Connection closed"
202                    );
203                }
204                Err(e) => {
205                    metrics.on_error();
206                    metrics.on_disconnect();
207                    tracing::error!(
208                        error = %e,
209                        active_connections = metrics.active_connections(),
210                        connection_errors = metrics.connection_errors(),
211                        "❌ Connection error"
212                    );
213                }
214            }
215            result
216        })
217    }
218}
219
220/// A guard that decrements active connections when dropped.
221///
222/// Useful for tracking connection lifetimes in scenarios where the service
223/// response outlives the initial call (e.g., long-lived streaming connections).
224#[derive(Debug)]
225pub struct ConnectionGuard {
226    metrics: ConnectionMetrics,
227    peer: Option<SocketAddr>,
228}
229
230impl ConnectionGuard {
231    /// Create a new connection guard that will track a connection's lifetime.
232    pub fn new(metrics: ConnectionMetrics, peer: Option<SocketAddr>) -> Self {
233        metrics.on_connect();
234        tracing::info!(
235            peer = ?peer,
236            active_connections = metrics.active_connections(),
237            total_connections = metrics.total_connections(),
238            "⚡ New connection established"
239        );
240        Self { metrics, peer }
241    }
242}
243
244impl Drop for ConnectionGuard {
245    fn drop(&mut self) {
246        self.metrics.on_disconnect();
247        tracing::info!(
248            peer = ?self.peer,
249            active_connections = self.metrics.active_connections(),
250            "🔌 Connection closed"
251        );
252    }
253}
254
255#[cfg(test)]
256mod tests {
257    use super::*;
258
259    #[test]
260    fn test_connection_metrics() {
261        let metrics = ConnectionMetrics::new();
262        assert_eq!(metrics.total_connections(), 0);
263        assert_eq!(metrics.active_connections(), 0);
264        assert_eq!(metrics.connection_errors(), 0);
265
266        metrics.on_connect();
267        assert_eq!(metrics.total_connections(), 1);
268        assert_eq!(metrics.active_connections(), 1);
269
270        metrics.on_connect();
271        assert_eq!(metrics.total_connections(), 2);
272        assert_eq!(metrics.active_connections(), 2);
273
274        metrics.on_disconnect();
275        assert_eq!(metrics.active_connections(), 1);
276
277        metrics.on_error();
278        assert_eq!(metrics.connection_errors(), 1);
279    }
280
281    #[test]
282    fn test_connection_metrics_display() {
283        let metrics = ConnectionMetrics::new();
284        metrics.on_connect();
285        let display = format!("{}", metrics);
286        assert!(display.contains("total=1"));
287        assert!(display.contains("active=1"));
288        assert!(display.contains("errors=0"));
289    }
290
291    #[test]
292    fn test_metrics_shared_across_clones() {
293        let metrics = ConnectionMetrics::new();
294        let metrics2 = metrics.clone();
295
296        metrics.on_connect();
297        assert_eq!(metrics2.active_connections(), 1);
298
299        metrics2.on_connect();
300        assert_eq!(metrics.active_connections(), 2);
301    }
302
303    #[test]
304    fn test_connection_guard_drop() {
305        let metrics = ConnectionMetrics::new();
306        {
307            let _guard = ConnectionGuard::new(metrics.clone(), None);
308            assert_eq!(metrics.active_connections(), 1);
309        }
310        // Guard dropped — active connections should be decremented.
311        assert_eq!(metrics.active_connections(), 0);
312        assert_eq!(metrics.total_connections(), 1);
313    }
314}