borer_core/stream/
stats.rs1use std::{
2 io,
3 pin::Pin,
4 task::{Context, Poll},
5};
6
7use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
8
9use crate::store::store;
10
11#[derive(Debug)]
13pub struct StatsStream<S> {
14 inner: S,
15 conn_id: String,
16}
17
18impl<S> StatsStream<S> {
19 pub fn new(
21 inner: S,
22 user: impl Into<String>,
23 hash: impl Into<String>,
24 peer_addr: impl Into<String>,
25 req_addr: impl Into<String>,
26 padding: bool,
27 ) -> Self {
28 let conn_id = gen_conn_id();
29 let user: String = user.into();
30 let hash: String = hash.into();
31 let peer_addr: String = peer_addr.into();
32 let req_addr: String = req_addr.into();
33
34 store().insert_conn(&user, &hash, &conn_id, &peer_addr, &req_addr, padding);
35 StatsStream { inner, conn_id }
36 }
37}
38
39impl<S> Drop for StatsStream<S> {
40 fn drop(&mut self) {
41 store().delete_conn(&self.conn_id);
42 }
43}
44
45impl<S> AsyncRead for StatsStream<S>
46where
47 S: AsyncRead + Unpin,
48{
49 fn poll_read(
50 self: Pin<&mut Self>,
51 cx: &mut Context<'_>,
52 buf: &mut ReadBuf,
53 ) -> Poll<io::Result<()>> {
54 let conn_id = self.conn_id.clone();
55 let me = self.get_mut();
56 let filled_before = buf.filled().len();
57 let poll = Pin::new(&mut me.inner).poll_read(cx, buf);
58 if let Poll::Ready(Ok(())) = poll {
59 let filled_after = buf.filled().len();
60 let bytes_read = filled_after.saturating_sub(filled_before);
61 store().add_up(&conn_id, bytes_read);
62 debug!("stats => read buf: {bytes_read}");
63 }
64 poll
65 }
66}
67
68impl<S> AsyncWrite for StatsStream<S>
69where
70 S: AsyncWrite + Unpin,
71{
72 fn poll_write(
73 mut self: Pin<&mut Self>,
74 cx: &mut Context<'_>,
75 buf: &[u8],
76 ) -> Poll<io::Result<usize>> {
77 let conn_id = self.conn_id.clone();
78 debug!("stats => write buf: {}", buf.len());
79 match Pin::new(&mut self.inner).poll_write(cx, buf) {
80 Poll::Ready(Ok(bytes_written)) => {
81 store().add_down(&conn_id, bytes_written);
82 Poll::Ready(Ok(bytes_written))
83 }
84 other => other,
85 }
86 }
87
88 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
89 Pin::new(&mut self.inner).poll_flush(cx)
90 }
91
92 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
93 Pin::new(&mut self.inner).poll_shutdown(cx)
94 }
95}
96
97pub(crate) fn gen_conn_id() -> String {
98 let uuid = uuid::Uuid::new_v4().as_u128();
99 format!("{:032x}", uuid)
100}
101
102#[cfg(test)]
103mod tests {
104 use std::time::{SystemTime, UNIX_EPOCH};
105
106 use std::{
107 io,
108 pin::Pin,
109 task::{Context, Poll},
110 };
111
112 use tokio::io::{AsyncReadExt, AsyncWrite, AsyncWriteExt};
113
114 use super::{StatsStream, gen_conn_id};
115 use crate::store::store;
116
117 #[test]
118 fn gen_conn_id_is_32_char_hex() {
119 let id = gen_conn_id();
120
121 assert_eq!(id.len(), 32);
122 assert!(id.bytes().all(|b| b.is_ascii_hexdigit()));
123 assert!(id.bytes().all(|b| !b.is_ascii_uppercase()));
124 }
125
126 #[tokio::test]
127 async fn stats_stream_tracks_read_write_and_drop() {
128 let store = store();
129 let suffix = SystemTime::now()
130 .duration_since(UNIX_EPOCH)
131 .unwrap()
132 .as_nanos();
133 let user = format!("alice-{suffix}");
134 let hash = format!("hash-{suffix}");
135 let (client, mut peer) = tokio::io::duplex(64);
136 let mut stream = StatsStream::new(
137 client,
138 user.clone(),
139 hash.clone(),
140 "127.0.0.1:1".to_string(),
141 "example.com:443".to_string(),
142 false,
143 );
144
145 let conn = only_conn(&store, &user);
146 assert_eq!(conn.user, user);
147 assert_eq!(conn.hash, hash);
148
149 stream.write_all(b"hello").await.unwrap();
150 let mut downstream = [0u8; 5];
151 peer.read_exact(&mut downstream).await.unwrap();
152 assert_eq!(&downstream, b"hello");
153
154 peer.write_all(b"world").await.unwrap();
155 let mut upstream = [0u8; 5];
156 stream.read_exact(&mut upstream).await.unwrap();
157 assert_eq!(&upstream, b"world");
158
159 let conn = only_conn(&store, &conn.user);
160 assert_eq!(conn.traffic.down, 5);
161 assert_eq!(conn.traffic.up, 5);
162 let traffic = store.get_traffic_by_user(&conn.user).unwrap();
163 assert_eq!(traffic.down, 5);
164 assert_eq!(traffic.up, 5);
165
166 drop(stream);
167 assert_eq!(store.get_conns_by_user(&conn.user).unwrap().len(), 0);
168 }
169
170 #[tokio::test]
171 async fn stats_stream_counts_only_bytes_written_by_inner_stream() {
172 let store = store();
173 let suffix = SystemTime::now()
174 .duration_since(UNIX_EPOCH)
175 .unwrap()
176 .as_nanos();
177 let user = format!("writer-{suffix}");
178 let hash = format!("hash-{suffix}");
179 let mut stream = StatsStream::new(
180 PartialWriter { max_write: 2 },
181 user.clone(),
182 hash,
183 "127.0.0.1:1".to_string(),
184 "example.com:443".to_string(),
185 false,
186 );
187
188 let bytes_written = stream.write(b"hello").await.unwrap();
189
190 assert_eq!(bytes_written, 2);
191 let conn = only_conn(&store, &user);
192 assert_eq!(conn.traffic.down, 2);
193 }
194
195 fn only_conn(
196 store: &std::sync::Arc<crate::store::Store>,
197 user: &str,
198 ) -> crate::store::Connection {
199 let mut conns = store.get_conns_by_user(user).unwrap();
200 assert_eq!(conns.len(), 1);
201 conns.pop().unwrap()
202 }
203
204 struct PartialWriter {
205 max_write: usize,
206 }
207
208 impl AsyncWrite for PartialWriter {
209 fn poll_write(
210 self: Pin<&mut Self>,
211 _cx: &mut Context<'_>,
212 buf: &[u8],
213 ) -> Poll<io::Result<usize>> {
214 Poll::Ready(Ok(buf.len().min(self.max_write)))
215 }
216
217 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
218 Poll::Ready(Ok(()))
219 }
220
221 fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
222 Poll::Ready(Ok(()))
223 }
224 }
225}