Skip to main content

borer_core/stream/
stats.rs

1use std::{
2    io,
3    pin::Pin,
4    task::{Context, Poll},
5};
6
7use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
8
9use crate::store::store;
10
11/// Stream wrapper that records per-connection traffic in the shared store.
12#[derive(Debug)]
13pub struct StatsStream<S> {
14    inner: S,
15    conn_id: String,
16}
17
18impl<S> StatsStream<S> {
19    /// Create a tracked stream and register it in the shared traffic store.
20    pub fn new(
21        inner: S,
22        user: impl Into<String>,
23        hash: impl Into<String>,
24        peer_addr: impl Into<String>,
25        req_addr: impl Into<String>,
26        padding: bool,
27    ) -> Self {
28        let conn_id = gen_conn_id();
29        let user: String = user.into();
30        let hash: String = hash.into();
31        let peer_addr: String = peer_addr.into();
32        let req_addr: String = req_addr.into();
33
34        store().insert_conn(&user, &hash, &conn_id, &peer_addr, &req_addr, padding);
35        StatsStream { inner, conn_id }
36    }
37}
38
39impl<S> Drop for StatsStream<S> {
40    fn drop(&mut self) {
41        store().delete_conn(&self.conn_id);
42    }
43}
44
45impl<S> AsyncRead for StatsStream<S>
46where
47    S: AsyncRead + Unpin,
48{
49    fn poll_read(
50        self: Pin<&mut Self>,
51        cx: &mut Context<'_>,
52        buf: &mut ReadBuf,
53    ) -> Poll<io::Result<()>> {
54        let conn_id = self.conn_id.clone();
55        let me = self.get_mut();
56        let filled_before = buf.filled().len();
57        let poll = Pin::new(&mut me.inner).poll_read(cx, buf);
58        if let Poll::Ready(Ok(())) = poll {
59            let filled_after = buf.filled().len();
60            let bytes_read = filled_after.saturating_sub(filled_before);
61            store().add_up(&conn_id, bytes_read);
62            debug!("stats => read buf: {bytes_read}");
63        }
64        poll
65    }
66}
67
68impl<S> AsyncWrite for StatsStream<S>
69where
70    S: AsyncWrite + Unpin,
71{
72    fn poll_write(
73        mut self: Pin<&mut Self>,
74        cx: &mut Context<'_>,
75        buf: &[u8],
76    ) -> Poll<io::Result<usize>> {
77        let conn_id = self.conn_id.clone();
78        debug!("stats => write buf: {}", buf.len());
79        match Pin::new(&mut self.inner).poll_write(cx, buf) {
80            Poll::Ready(Ok(bytes_written)) => {
81                store().add_down(&conn_id, bytes_written);
82                Poll::Ready(Ok(bytes_written))
83            }
84            other => other,
85        }
86    }
87
88    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
89        Pin::new(&mut self.inner).poll_flush(cx)
90    }
91
92    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
93        Pin::new(&mut self.inner).poll_shutdown(cx)
94    }
95}
96
97pub(crate) fn gen_conn_id() -> String {
98    let uuid = uuid::Uuid::new_v4().as_u128();
99    format!("{:032x}", uuid)
100}
101
102#[cfg(test)]
103mod tests {
104    use std::time::{SystemTime, UNIX_EPOCH};
105
106    use std::{
107        io,
108        pin::Pin,
109        task::{Context, Poll},
110    };
111
112    use tokio::io::{AsyncReadExt, AsyncWrite, AsyncWriteExt};
113
114    use super::{StatsStream, gen_conn_id};
115    use crate::store::store;
116
117    #[test]
118    fn gen_conn_id_is_32_char_hex() {
119        let id = gen_conn_id();
120
121        assert_eq!(id.len(), 32);
122        assert!(id.bytes().all(|b| b.is_ascii_hexdigit()));
123        assert!(id.bytes().all(|b| !b.is_ascii_uppercase()));
124    }
125
126    #[tokio::test]
127    async fn stats_stream_tracks_read_write_and_drop() {
128        let store = store();
129        let suffix = SystemTime::now()
130            .duration_since(UNIX_EPOCH)
131            .unwrap()
132            .as_nanos();
133        let user = format!("alice-{suffix}");
134        let hash = format!("hash-{suffix}");
135        let (client, mut peer) = tokio::io::duplex(64);
136        let mut stream = StatsStream::new(
137            client,
138            user.clone(),
139            hash.clone(),
140            "127.0.0.1:1".to_string(),
141            "example.com:443".to_string(),
142            false,
143        );
144
145        let conn = only_conn(&store, &user);
146        assert_eq!(conn.user, user);
147        assert_eq!(conn.hash, hash);
148
149        stream.write_all(b"hello").await.unwrap();
150        let mut downstream = [0u8; 5];
151        peer.read_exact(&mut downstream).await.unwrap();
152        assert_eq!(&downstream, b"hello");
153
154        peer.write_all(b"world").await.unwrap();
155        let mut upstream = [0u8; 5];
156        stream.read_exact(&mut upstream).await.unwrap();
157        assert_eq!(&upstream, b"world");
158
159        let conn = only_conn(&store, &conn.user);
160        assert_eq!(conn.traffic.down, 5);
161        assert_eq!(conn.traffic.up, 5);
162        let traffic = store.get_traffic_by_user(&conn.user).unwrap();
163        assert_eq!(traffic.down, 5);
164        assert_eq!(traffic.up, 5);
165
166        drop(stream);
167        assert_eq!(store.get_conns_by_user(&conn.user).unwrap().len(), 0);
168    }
169
170    #[tokio::test]
171    async fn stats_stream_counts_only_bytes_written_by_inner_stream() {
172        let store = store();
173        let suffix = SystemTime::now()
174            .duration_since(UNIX_EPOCH)
175            .unwrap()
176            .as_nanos();
177        let user = format!("writer-{suffix}");
178        let hash = format!("hash-{suffix}");
179        let mut stream = StatsStream::new(
180            PartialWriter { max_write: 2 },
181            user.clone(),
182            hash,
183            "127.0.0.1:1".to_string(),
184            "example.com:443".to_string(),
185            false,
186        );
187
188        let bytes_written = stream.write(b"hello").await.unwrap();
189
190        assert_eq!(bytes_written, 2);
191        let conn = only_conn(&store, &user);
192        assert_eq!(conn.traffic.down, 2);
193    }
194
195    fn only_conn(
196        store: &std::sync::Arc<crate::store::Store>,
197        user: &str,
198    ) -> crate::store::Connection {
199        let mut conns = store.get_conns_by_user(user).unwrap();
200        assert_eq!(conns.len(), 1);
201        conns.pop().unwrap()
202    }
203
204    struct PartialWriter {
205        max_write: usize,
206    }
207
208    impl AsyncWrite for PartialWriter {
209        fn poll_write(
210            self: Pin<&mut Self>,
211            _cx: &mut Context<'_>,
212            buf: &[u8],
213        ) -> Poll<io::Result<usize>> {
214            Poll::Ready(Ok(buf.len().min(self.max_write)))
215        }
216
217        fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
218            Poll::Ready(Ok(()))
219        }
220
221        fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
222            Poll::Ready(Ok(()))
223        }
224    }
225}