1use std::{
2 collections::{HashMap, HashSet},
3 ops::{Deref, DerefMut},
4 sync::Arc,
5 time::{SystemTime, UNIX_EPOCH},
6};
7
8use dashmap::DashMap;
9use once_cell::sync::Lazy;
10use serde::Serialize;
11
12static STORE: Lazy<Arc<Store>> = Lazy::new(Store::arc);
13
14pub fn store() -> Arc<Store> {
16 STORE.clone()
17}
18
19#[derive(Debug)]
21pub struct Store {
22 user_map: DashMap<String , HashSet<String >>,
23 hash_traffic_map: DashMap<String , Traffic>,
24 conns_map: DashMap<String , HashSet<String >>,
25 conn_map: DashMap<String , Connection>,
26}
27
28#[derive(Clone, Debug, Default, Serialize)]
30pub struct Traffic {
31 pub up: usize,
32 pub down: usize,
33}
34
35#[derive(Clone, Debug, Default, Serialize)]
37pub struct UserStats {
38 pub up: usize,
39 pub down: usize,
40 pub conn_count: usize,
41}
42
43#[derive(Clone, Debug, Serialize)]
45pub struct Connection {
46 pub conn_id: String,
47 pub user: String,
48 pub hash: String,
49 pub peer_addr: String,
50 pub req_addr: String,
51 pub traffic: Traffic,
52 pub padding: bool,
53 pub create_at: u64,
54}
55
56impl Store {
57 pub fn arc() -> Arc<Self> {
58 Arc::new(Store {
59 user_map: DashMap::new(),
60 hash_traffic_map: DashMap::new(),
61 conns_map: DashMap::new(),
62 conn_map: DashMap::new(),
63 })
64 }
65
66 pub fn get_traffic_all(&self) -> HashMap<String, UserStats> {
67 self.user_map
68 .iter()
69 .filter_map(|m| {
70 let user = m.key().clone();
71 let conns = self.get_conns_by_user(&user);
72 let conn_count = conns.map(|v| v.len()).unwrap_or(0);
73 self.get_traffic_by_user(&user).map(|t| {
74 (
75 user,
76 UserStats {
77 up: t.up,
78 down: t.down,
79 conn_count,
80 },
81 )
82 })
83 })
84 .collect()
85 }
86
87 pub fn get_conns_all(&self) -> HashMap<String, Vec<Connection>> {
88 self.user_map
89 .iter()
90 .filter_map(|m| {
91 let user = m.key().clone();
92 self.get_conns_by_user(&user).map(|conns| (user, conns))
93 })
94 .collect()
95 }
96
97 pub fn get_traffic_by_user(&self, user: &str) -> Option<Traffic> {
98 if let Some(hashs) = self.user_map.get(user) {
99 let (up, down) = hashs
100 .iter()
101 .map(|h| self.get_traffic_by_hash(h))
102 .filter(|t| t.is_some())
103 .map(|t| t.map(|t| (t.get_up(), t.get_down())).unwrap_or((0, 0)))
104 .reduce(|(up1, down1), (up2, down2)| (up1 + up2, down1 + down2))
105 .unwrap_or((0, 0));
106 Some(Traffic { up, down })
107 } else {
108 None
109 }
110 }
111
112 pub fn get_traffic_by_hash(&self, hash: &str) -> Option<Traffic> {
113 self.hash_traffic_map.get(hash).map(|t| t.deref().clone())
114 }
115
116 pub fn get_conns_by_user(&self, user: &str) -> Option<Vec<Connection>> {
117 if let Some(hashs) = self.user_map.get(user) {
118 let conns = hashs
119 .iter()
120 .filter_map(|h| self.get_conns_by_hash(h))
121 .flatten()
122 .collect();
123 Some(conns)
124 } else {
125 None
126 }
127 }
128
129 pub fn get_conns_by_hash(&self, hash: &str) -> Option<Vec<Connection>> {
130 if let Some(conn_ids) = self.conns_map.get(hash) {
131 let conns = conn_ids
132 .iter()
133 .filter_map(|conn_id| self.conn_map.get(conn_id))
134 .map(|c| c.clone())
135 .collect();
136 Some(conns)
137 } else {
138 None
139 }
140 }
141
142 pub fn insert_conn<S>(
143 &self,
144 user: S,
145 hash: S,
146 conn_id: S,
147 peer_addr: S,
148 req_addr: S,
149 padding: bool,
150 ) where
151 S: Into<String>,
152 {
153 let user = user.into();
154 let hash = hash.into();
155 let conn_id = conn_id.into();
156 let timestamp = SystemTime::now()
157 .duration_since(UNIX_EPOCH)
158 .expect("system time failed!")
159 .as_secs();
160 let conn = Connection {
161 user: user.to_string(),
162 hash: hash.to_string(),
163 conn_id: conn_id.to_string(),
164 peer_addr: peer_addr.into(),
165 req_addr: req_addr.into(),
166 traffic: Default::default(),
167 padding,
168 create_at: timestamp,
169 };
170 self.insert_hash(&user, &hash);
171 self.insert_conns(&hash, &conn_id);
172 self.conn_map.insert(conn_id, conn);
173 }
174
175 fn insert_hash(&self, user: &str, hash: &str) {
176 if let Some(mut v) = self.user_map.get_mut(user) {
177 v.insert(hash.to_string());
178 } else {
179 self.user_map
180 .insert(user.to_string(), HashSet::from_iter([hash.to_string()]));
181 }
182 }
183
184 fn insert_conns(&self, hash: &str, conn_id: &str) {
185 if let Some(mut conns) = self.conns_map.get_mut(hash) {
186 conns.insert(conn_id.to_string());
187 } else {
188 self.conns_map
189 .insert(hash.to_string(), HashSet::from_iter([conn_id.to_string()]));
190 }
191 }
192
193 pub fn delete_conn(&self, conn_id: &str) {
194 if let Some((_, conn)) = self.conn_map.remove(conn_id) {
195 if let Some(mut conns) = self.conns_map.get_mut(&conn.hash) {
196 conns.retain(|v| v != conn_id);
197 }
198 }
199 }
200
201 pub fn add_up(&self, conn_id: &str, v: usize) {
202 if let Some(mut conn) = self.conn_map.get_mut(conn_id) {
203 let hash = &conn.hash;
204 self.add_up_by_hash(hash, v);
205 conn.deref_mut().traffic.add_up(v);
206 }
207 }
208
209 pub fn add_down(&self, conn_id: &str, v: usize) {
210 if let Some(mut conn) = self.conn_map.get_mut(conn_id) {
211 let hash = &conn.hash;
212 self.add_donw_by_hash(hash, v);
213 conn.deref_mut().traffic.add_down(v);
214 }
215 }
216
217 fn add_up_by_hash(&self, hash: &str, v: usize) {
218 if let Some(mut t) = self.hash_traffic_map.get_mut(hash) {
219 t.add_up(v);
220 } else {
221 self.hash_traffic_map
222 .insert(hash.to_string(), Traffic { up: v, down: 0 });
223 }
224 }
225
226 fn add_donw_by_hash(&self, hash: &str, v: usize) {
227 if let Some(mut t) = self.hash_traffic_map.get_mut(hash) {
228 t.add_down(v);
229 } else {
230 self.hash_traffic_map
231 .insert(hash.to_string(), Traffic { up: 0, down: v });
232 }
233 }
234}
235
236impl Traffic {
237 pub fn get_up(&self) -> usize {
238 self.up
239 }
240 pub fn get_down(&self) -> usize {
241 self.down
242 }
243 pub fn add_up(&mut self, v: usize) {
244 self.up += v
245 }
246 pub fn add_down(&mut self, v: usize) {
247 self.down += v
248 }
249}
250
251#[cfg(test)]
252mod tests {
253 use super::Store;
254
255 #[test]
256 fn insert_and_aggregate_traffic_by_user() {
257 let store = Store::arc();
258 store.insert_conn(
259 "alice",
260 "hash-a",
261 "conn-1",
262 "127.0.0.1:1",
263 "example.com:443",
264 false,
265 );
266 store.insert_conn(
267 "alice",
268 "hash-b",
269 "conn-2",
270 "127.0.0.1:2",
271 "example.com:80",
272 true,
273 );
274 store.add_up("conn-1", 5);
275 store.add_down("conn-1", 7);
276 store.add_up("conn-2", 11);
277
278 let traffic = store.get_traffic_by_user("alice").unwrap();
279 let all = store.get_traffic_all();
280
281 assert_eq!(traffic.up, 16);
282 assert_eq!(traffic.down, 7);
283 assert_eq!(all.get("alice").unwrap().conn_count, 2);
284 }
285
286 #[test]
287 fn get_conns_by_hash_and_delete_conn_updates_views() {
288 let store = Store::arc();
289 store.insert_conn(
290 "alice",
291 "hash-a",
292 "conn-1",
293 "127.0.0.1:1",
294 "example.com:443",
295 false,
296 );
297 store.insert_conn(
298 "alice",
299 "hash-a",
300 "conn-2",
301 "127.0.0.1:2",
302 "example.com:80",
303 true,
304 );
305
306 let before = store.get_conns_by_hash("hash-a").unwrap();
307 store.delete_conn("conn-1");
308 let after = store.get_conns_by_hash("hash-a").unwrap();
309
310 assert_eq!(before.len(), 2);
311 assert_eq!(after.len(), 1);
312 assert_eq!(after[0].conn_id, "conn-2");
313 }
314
315 #[test]
316 fn unknown_user_returns_none() {
317 let store = Store::arc();
318
319 assert!(store.get_traffic_by_user("missing").is_none());
320 assert!(store.get_conns_by_user("missing").is_none());
321 }
322}