1use std::collections::{HashMap, HashSet};
6use std::time::Duration;
7
8use sqlx::PgPool;
9use tokio::sync::mpsc;
10
11#[derive(Clone, Debug)]
13pub struct ClientStatsBatcherConfig {
14 pub channel_capacity: usize,
15 pub flush_interval: Duration,
16 pub flush_max_clients: usize,
18}
19
20impl Default for ClientStatsBatcherConfig {
21 fn default() -> Self {
22 Self {
23 channel_capacity: 50_000,
24 flush_interval: Duration::from_millis(250),
25 flush_max_clients: 2_000,
26 }
27 }
28}
29
30#[derive(Debug)]
31enum BatcherCommand {
32 RequestDelta {
33 client_name: String,
34 d_req: i64,
35 d_succ: i64,
36 d_fail: i64,
37 d_cached: i64,
38 },
39 OperationDelta {
40 client_name: String,
41 d_op: i64,
42 },
43 TableDelta {
44 client_name: String,
45 table_name: String,
46 operation: String,
47 d_total: i64,
48 d_err: i64,
49 },
50 LastSeen {
51 client_name: String,
52 },
53}
54
55#[derive(Default, Clone)]
56struct ClientAccum {
57 d_req: i64,
58 d_succ: i64,
59 d_fail: i64,
60 d_cached: i64,
61 d_op: i64,
62}
63
64#[derive(Default)]
65struct BatcherState {
66 clients: HashMap<String, ClientAccum>,
67 tables: HashMap<(String, String, String), (i64, i64)>,
68 last_seen: HashSet<String>,
69}
70
71impl BatcherState {
72 fn apply(&mut self, cmd: BatcherCommand) {
73 match cmd {
74 BatcherCommand::RequestDelta {
75 client_name,
76 d_req,
77 d_succ,
78 d_fail,
79 d_cached,
80 } => {
81 let entry: &mut ClientAccum = self.clients.entry(client_name).or_default();
82 entry.d_req += d_req;
83 entry.d_succ += d_succ;
84 entry.d_fail += d_fail;
85 entry.d_cached += d_cached;
86 }
87 BatcherCommand::OperationDelta { client_name, d_op } => {
88 let entry: &mut ClientAccum = self.clients.entry(client_name).or_default();
89 entry.d_op += d_op;
90 }
91 BatcherCommand::TableDelta {
92 client_name,
93 table_name,
94 operation,
95 d_total,
96 d_err,
97 } => {
98 let key: (String, String, String) = (client_name, table_name, operation);
99 let totals: &mut (i64, i64) = self.tables.entry(key).or_insert((0, 0));
100 totals.0 += d_total;
101 totals.1 += d_err;
102 }
103 BatcherCommand::LastSeen { client_name } => {
104 self.last_seen.insert(client_name);
105 }
106 }
107 }
108
109 fn len_clients(&self) -> usize {
110 self.clients.len() + self.tables.len()
111 }
112
113 fn is_empty(&self) -> bool {
114 self.clients.is_empty() && self.tables.is_empty() && self.last_seen.is_empty()
115 }
116
117 async fn flush_all(&mut self, pool: &PgPool) {
118 if self.is_empty() {
119 return;
120 }
121
122 for (name, acc) in std::mem::take(&mut self.clients) {
123 if acc.d_req == 0 && acc.d_op == 0 {
124 continue;
125 }
126
127 if let Err(err) = sqlx::query(
128 r#"
129 INSERT INTO client_statistics (
130 client_name,
131 total_requests,
132 successful_requests,
133 failed_requests,
134 total_cached_requests,
135 total_operations,
136 last_request_at,
137 last_operation_at
138 )
139 VALUES ($1, $2, $3, $4, $5, $6,
140 CASE WHEN $7::boolean THEN now() ELSE NULL END,
141 CASE WHEN $8::boolean THEN now() ELSE NULL END
142 )
143 ON CONFLICT (client_name) DO UPDATE
144 SET total_requests = client_statistics.total_requests + EXCLUDED.total_requests,
145 successful_requests = client_statistics.successful_requests
146 + EXCLUDED.successful_requests,
147 failed_requests = client_statistics.failed_requests + EXCLUDED.failed_requests,
148 total_cached_requests = client_statistics.total_cached_requests
149 + EXCLUDED.total_cached_requests,
150 total_operations = client_statistics.total_operations + EXCLUDED.total_operations,
151 last_request_at = CASE
152 WHEN EXCLUDED.total_requests > 0 THEN now()
153 ELSE client_statistics.last_request_at
154 END,
155 last_operation_at = CASE
156 WHEN EXCLUDED.total_operations > 0 THEN now()
157 ELSE client_statistics.last_operation_at
158 END,
159 updated_at = now()
160 "#,
161 )
162 .bind(&name)
163 .bind(acc.d_req)
164 .bind(acc.d_succ)
165 .bind(acc.d_fail)
166 .bind(acc.d_cached)
167 .bind(acc.d_op)
168 .bind(acc.d_req > 0)
169 .bind(acc.d_op > 0)
170 .execute(pool)
171 .await
172 {
173 tracing::error!(
174 error = %err,
175 client = %name,
176 "client_stats_batcher: flush client_statistics failed"
177 );
178 }
179 }
180
181 for ((client_name, table_name, operation), (d_total, d_err)) in
182 std::mem::take(&mut self.tables)
183 {
184 if d_total == 0 && d_err == 0 {
185 continue;
186 }
187
188 if let Err(err) = sqlx::query(
189 r#"
190 INSERT INTO client_table_statistics (
191 client_name,
192 table_name,
193 operation,
194 total_operations,
195 error_operations,
196 last_operation_at
197 )
198 VALUES ($1, $2, $3, $4, $5, now())
199 ON CONFLICT (client_name, table_name, operation) DO UPDATE
200 SET total_operations = client_table_statistics.total_operations
201 + EXCLUDED.total_operations,
202 error_operations = client_table_statistics.error_operations
203 + EXCLUDED.error_operations,
204 last_operation_at = now(),
205 updated_at = now()
206 "#,
207 )
208 .bind(&client_name)
209 .bind(&table_name)
210 .bind(&operation)
211 .bind(d_total)
212 .bind(d_err)
213 .execute(pool)
214 .await
215 {
216 tracing::error!(
217 error = %err,
218 client = %client_name,
219 table = %table_name,
220 operation = %operation,
221 "client_stats_batcher: flush client_table_statistics failed"
222 );
223 }
224 }
225
226 let names: Vec<String> = std::mem::take(&mut self.last_seen).into_iter().collect();
227 if names.is_empty() {
228 return;
229 }
230
231 if let Err(err) = sqlx::query(
232 r#"
233 UPDATE athena_clients
234 SET last_seen_at = now(),
235 updated_at = now()
236 WHERE deleted_at IS NULL
237 AND lower(client_name) IN (SELECT lower(x) FROM unnest($1::text[]) AS t(x))
238 "#,
239 )
240 .bind(names.as_slice())
241 .execute(pool)
242 .await
243 {
244 tracing::error!(
245 error = %err,
246 clients = ?names,
247 "client_stats_batcher: batch last_seen update failed"
248 );
249 }
250 }
251}
252
253async fn run_worker(
254 mut rx: mpsc::Receiver<BatcherCommand>,
255 pool: PgPool,
256 config: ClientStatsBatcherConfig,
257) {
258 let mut tick: tokio::time::Interval = tokio::time::interval(config.flush_interval);
259 tick.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay);
260 let mut state: BatcherState = BatcherState::default();
261
262 loop {
263 tokio::select! {
264 biased;
265 cmd = rx.recv() => {
266 match cmd {
267 Some(command) => {
268 state.apply(command);
269 if state.len_clients() >= config.flush_max_clients {
270 state.flush_all(&pool).await;
271 }
272 }
273 None => {
274 state.flush_all(&pool).await;
275 return;
276 }
277 }
278 }
279 _ = tick.tick() => {
280 state.flush_all(&pool).await;
281 }
282 }
283 }
284}
285
286#[derive(Clone)]
288pub struct ClientStatsBatcher {
289 tx: mpsc::Sender<BatcherCommand>,
290}
291
292impl ClientStatsBatcher {
293 pub fn spawn(pool: PgPool, config: ClientStatsBatcherConfig) -> Self {
295 let cap: usize = config.channel_capacity.max(1);
296 let (tx, rx) = mpsc::channel(cap);
297 tokio::spawn(run_worker(rx, pool, config));
298 Self { tx }
299 }
300
301 fn try_send(&self, cmd: BatcherCommand) {
302 match self.tx.try_send(cmd) {
303 Err(mpsc::error::TrySendError::Full(_)) => {
304 tracing::warn!(
305 target: "athena_rs::client_stats_batcher",
306 "client stats batcher channel full; dropping delta"
307 );
308 }
309 Err(mpsc::error::TrySendError::Closed(_)) => {}
310 Ok(()) => {}
311 }
312 }
313
314 pub fn try_enqueue_request_stats(&self, client_name: &str, status_code: i32, cached: bool) {
316 let d_succ: i64 = i64::from((200..400).contains(&status_code));
317 let d_fail: i64 = i64::from(status_code >= 400);
318 let d_cached: i64 = i64::from(cached);
319 self.try_send(BatcherCommand::RequestDelta {
320 client_name: client_name.to_string(),
321 d_req: 1,
322 d_succ,
323 d_fail,
324 d_cached,
325 });
326 }
327
328 pub fn try_enqueue_operation_stats(&self, client_name: &str) {
330 self.try_send(BatcherCommand::OperationDelta {
331 client_name: client_name.to_string(),
332 d_op: 1,
333 });
334 }
335
336 pub fn try_enqueue_table_stats(
338 &self,
339 client_name: &str,
340 table_name: &str,
341 operation: &str,
342 is_error: bool,
343 ) {
344 self.try_send(BatcherCommand::TableDelta {
345 client_name: client_name.to_string(),
346 table_name: table_name.to_string(),
347 operation: operation.to_string(),
348 d_total: 1,
349 d_err: i64::from(is_error),
350 });
351 }
352
353 pub fn try_enqueue_last_seen(&self, client_name: &str) {
355 self.try_send(BatcherCommand::LastSeen {
356 client_name: client_name.to_string(),
357 });
358 }
359}
360
361#[cfg(test)]
362mod tests {
363 use super::*;
364
365 #[test]
366 fn merge_deltas_accumulates() {
367 let mut state: BatcherState = BatcherState::default();
368 state.apply(BatcherCommand::RequestDelta {
369 client_name: "c1".into(),
370 d_req: 1,
371 d_succ: 1,
372 d_fail: 0,
373 d_cached: 0,
374 });
375 state.apply(BatcherCommand::RequestDelta {
376 client_name: "c1".into(),
377 d_req: 1,
378 d_succ: 0,
379 d_fail: 1,
380 d_cached: 1,
381 });
382 state.apply(BatcherCommand::OperationDelta {
383 client_name: "c1".into(),
384 d_op: 3,
385 });
386 let acc: &ClientAccum = state.clients.get("c1").expect("client accumulator");
387 assert_eq!(acc.d_req, 2);
388 assert_eq!(acc.d_succ, 1);
389 assert_eq!(acc.d_fail, 1);
390 assert_eq!(acc.d_cached, 1);
391 assert_eq!(acc.d_op, 3);
392 }
393
394 #[test]
395 fn last_seen_dedupes_per_flush_batch() {
396 let mut state: BatcherState = BatcherState::default();
397 state.apply(BatcherCommand::LastSeen {
398 client_name: "c".into(),
399 });
400 state.apply(BatcherCommand::LastSeen {
401 client_name: "c".into(),
402 });
403 assert_eq!(state.last_seen.len(), 1);
404 }
405
406 #[test]
407 fn table_merge_accumulates() {
408 let mut state: BatcherState = BatcherState::default();
409 state.apply(BatcherCommand::TableDelta {
410 client_name: "c".into(),
411 table_name: "t".into(),
412 operation: "insert".into(),
413 d_total: 1,
414 d_err: 0,
415 });
416 state.apply(BatcherCommand::TableDelta {
417 client_name: "c".into(),
418 table_name: "t".into(),
419 operation: "insert".into(),
420 d_total: 1,
421 d_err: 1,
422 });
423 let key = ("c".into(), "t".into(), "insert".into());
424 assert_eq!(state.tables.get(&key), Some(&(2, 1)));
425 }
426}