Skip to main content

athena_observability/
client_stats.rs

1//! Batched writes for gateway-facing client statistics tables.
2//!
3//! This module is runtime-portable: it only needs a Postgres pool and Tokio.
4
5use std::collections::{HashMap, HashSet};
6use std::time::Duration;
7
8use sqlx::PgPool;
9use tokio::sync::mpsc;
10
11/// Configuration for [`ClientStatsBatcher`].
12#[derive(Clone, Debug)]
13pub struct ClientStatsBatcherConfig {
14    pub channel_capacity: usize,
15    pub flush_interval: Duration,
16    /// Flush immediately when distinct client accumulators exceed this.
17    pub flush_max_clients: usize,
18}
19
20impl Default for ClientStatsBatcherConfig {
21    fn default() -> Self {
22        Self {
23            channel_capacity: 50_000,
24            flush_interval: Duration::from_millis(250),
25            flush_max_clients: 2_000,
26        }
27    }
28}
29
30#[derive(Debug)]
31enum BatcherCommand {
32    RequestDelta {
33        client_name: String,
34        d_req: i64,
35        d_succ: i64,
36        d_fail: i64,
37        d_cached: i64,
38    },
39    OperationDelta {
40        client_name: String,
41        d_op: i64,
42    },
43    TableDelta {
44        client_name: String,
45        table_name: String,
46        operation: String,
47        d_total: i64,
48        d_err: i64,
49    },
50    LastSeen {
51        client_name: String,
52    },
53}
54
55#[derive(Default, Clone)]
56struct ClientAccum {
57    d_req: i64,
58    d_succ: i64,
59    d_fail: i64,
60    d_cached: i64,
61    d_op: i64,
62}
63
64#[derive(Default)]
65struct BatcherState {
66    clients: HashMap<String, ClientAccum>,
67    tables: HashMap<(String, String, String), (i64, i64)>,
68    last_seen: HashSet<String>,
69}
70
71impl BatcherState {
72    fn apply(&mut self, cmd: BatcherCommand) {
73        match cmd {
74            BatcherCommand::RequestDelta {
75                client_name,
76                d_req,
77                d_succ,
78                d_fail,
79                d_cached,
80            } => {
81                let entry: &mut ClientAccum = self.clients.entry(client_name).or_default();
82                entry.d_req += d_req;
83                entry.d_succ += d_succ;
84                entry.d_fail += d_fail;
85                entry.d_cached += d_cached;
86            }
87            BatcherCommand::OperationDelta { client_name, d_op } => {
88                let entry: &mut ClientAccum = self.clients.entry(client_name).or_default();
89                entry.d_op += d_op;
90            }
91            BatcherCommand::TableDelta {
92                client_name,
93                table_name,
94                operation,
95                d_total,
96                d_err,
97            } => {
98                let key: (String, String, String) = (client_name, table_name, operation);
99                let totals: &mut (i64, i64) = self.tables.entry(key).or_insert((0, 0));
100                totals.0 += d_total;
101                totals.1 += d_err;
102            }
103            BatcherCommand::LastSeen { client_name } => {
104                self.last_seen.insert(client_name);
105            }
106        }
107    }
108
109    fn len_clients(&self) -> usize {
110        self.clients.len() + self.tables.len()
111    }
112
113    fn is_empty(&self) -> bool {
114        self.clients.is_empty() && self.tables.is_empty() && self.last_seen.is_empty()
115    }
116
117    async fn flush_all(&mut self, pool: &PgPool) {
118        if self.is_empty() {
119            return;
120        }
121
122        for (name, acc) in std::mem::take(&mut self.clients) {
123            if acc.d_req == 0 && acc.d_op == 0 {
124                continue;
125            }
126
127            if let Err(err) = sqlx::query(
128                r#"
129                INSERT INTO client_statistics (
130                    client_name,
131                    total_requests,
132                    successful_requests,
133                    failed_requests,
134                    total_cached_requests,
135                    total_operations,
136                    last_request_at,
137                    last_operation_at
138                )
139                VALUES ($1, $2, $3, $4, $5, $6,
140                    CASE WHEN $7::boolean THEN now() ELSE NULL END,
141                    CASE WHEN $8::boolean THEN now() ELSE NULL END
142                )
143                ON CONFLICT (client_name) DO UPDATE
144                SET total_requests = client_statistics.total_requests + EXCLUDED.total_requests,
145                    successful_requests = client_statistics.successful_requests
146                        + EXCLUDED.successful_requests,
147                    failed_requests = client_statistics.failed_requests + EXCLUDED.failed_requests,
148                    total_cached_requests = client_statistics.total_cached_requests
149                        + EXCLUDED.total_cached_requests,
150                    total_operations = client_statistics.total_operations + EXCLUDED.total_operations,
151                    last_request_at = CASE
152                        WHEN EXCLUDED.total_requests > 0 THEN now()
153                        ELSE client_statistics.last_request_at
154                    END,
155                    last_operation_at = CASE
156                        WHEN EXCLUDED.total_operations > 0 THEN now()
157                        ELSE client_statistics.last_operation_at
158                    END,
159                    updated_at = now()
160                "#,
161            )
162            .bind(&name)
163            .bind(acc.d_req)
164            .bind(acc.d_succ)
165            .bind(acc.d_fail)
166            .bind(acc.d_cached)
167            .bind(acc.d_op)
168            .bind(acc.d_req > 0)
169            .bind(acc.d_op > 0)
170            .execute(pool)
171            .await
172            {
173                tracing::error!(
174                    error = %err,
175                    client = %name,
176                    "client_stats_batcher: flush client_statistics failed"
177                );
178            }
179        }
180
181        for ((client_name, table_name, operation), (d_total, d_err)) in
182            std::mem::take(&mut self.tables)
183        {
184            if d_total == 0 && d_err == 0 {
185                continue;
186            }
187
188            if let Err(err) = sqlx::query(
189                r#"
190                INSERT INTO client_table_statistics (
191                    client_name,
192                    table_name,
193                    operation,
194                    total_operations,
195                    error_operations,
196                    last_operation_at
197                )
198                VALUES ($1, $2, $3, $4, $5, now())
199                ON CONFLICT (client_name, table_name, operation) DO UPDATE
200                SET total_operations = client_table_statistics.total_operations
201                    + EXCLUDED.total_operations,
202                    error_operations = client_table_statistics.error_operations
203                        + EXCLUDED.error_operations,
204                    last_operation_at = now(),
205                    updated_at = now()
206                "#,
207            )
208            .bind(&client_name)
209            .bind(&table_name)
210            .bind(&operation)
211            .bind(d_total)
212            .bind(d_err)
213            .execute(pool)
214            .await
215            {
216                tracing::error!(
217                    error = %err,
218                    client = %client_name,
219                    table = %table_name,
220                    operation = %operation,
221                    "client_stats_batcher: flush client_table_statistics failed"
222                );
223            }
224        }
225
226        let names: Vec<String> = std::mem::take(&mut self.last_seen).into_iter().collect();
227        if names.is_empty() {
228            return;
229        }
230
231        if let Err(err) = sqlx::query(
232            r#"
233            UPDATE athena_clients
234            SET last_seen_at = now(),
235                updated_at = now()
236            WHERE deleted_at IS NULL
237              AND lower(client_name) IN (SELECT lower(x) FROM unnest($1::text[]) AS t(x))
238            "#,
239        )
240        .bind(names.as_slice())
241        .execute(pool)
242        .await
243        {
244            tracing::error!(
245                error = %err,
246                clients = ?names,
247                "client_stats_batcher: batch last_seen update failed"
248            );
249        }
250    }
251}
252
253async fn run_worker(
254    mut rx: mpsc::Receiver<BatcherCommand>,
255    pool: PgPool,
256    config: ClientStatsBatcherConfig,
257) {
258    let mut tick: tokio::time::Interval = tokio::time::interval(config.flush_interval);
259    tick.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay);
260    let mut state: BatcherState = BatcherState::default();
261
262    loop {
263        tokio::select! {
264            biased;
265            cmd = rx.recv() => {
266                match cmd {
267                    Some(command) => {
268                        state.apply(command);
269                        if state.len_clients() >= config.flush_max_clients {
270                            state.flush_all(&pool).await;
271                        }
272                    }
273                    None => {
274                        state.flush_all(&pool).await;
275                        return;
276                    }
277                }
278            }
279            _ = tick.tick() => {
280                state.flush_all(&pool).await;
281            }
282        }
283    }
284}
285
286/// Handle for enqueueing client statistics and catalog touches without blocking request paths.
287#[derive(Clone)]
288pub struct ClientStatsBatcher {
289    tx: mpsc::Sender<BatcherCommand>,
290}
291
292impl ClientStatsBatcher {
293    /// Spawns the background flusher task.
294    pub fn spawn(pool: PgPool, config: ClientStatsBatcherConfig) -> Self {
295        let cap: usize = config.channel_capacity.max(1);
296        let (tx, rx) = mpsc::channel(cap);
297        tokio::spawn(run_worker(rx, pool, config));
298        Self { tx }
299    }
300
301    fn try_send(&self, cmd: BatcherCommand) {
302        match self.tx.try_send(cmd) {
303            Err(mpsc::error::TrySendError::Full(_)) => {
304                tracing::warn!(
305                    target: "athena_rs::client_stats_batcher",
306                    "client stats batcher channel full; dropping delta"
307                );
308            }
309            Err(mpsc::error::TrySendError::Closed(_)) => {}
310            Ok(()) => {}
311        }
312    }
313
314    /// Enqueues one request-level delta.
315    pub fn try_enqueue_request_stats(&self, client_name: &str, status_code: i32, cached: bool) {
316        let d_succ: i64 = i64::from((200..400).contains(&status_code));
317        let d_fail: i64 = i64::from(status_code >= 400);
318        let d_cached: i64 = i64::from(cached);
319        self.try_send(BatcherCommand::RequestDelta {
320            client_name: client_name.to_string(),
321            d_req: 1,
322            d_succ,
323            d_fail,
324            d_cached,
325        });
326    }
327
328    /// Enqueues one operation-level delta.
329    pub fn try_enqueue_operation_stats(&self, client_name: &str) {
330        self.try_send(BatcherCommand::OperationDelta {
331            client_name: client_name.to_string(),
332            d_op: 1,
333        });
334    }
335
336    /// Enqueues one table/operation delta.
337    pub fn try_enqueue_table_stats(
338        &self,
339        client_name: &str,
340        table_name: &str,
341        operation: &str,
342        is_error: bool,
343    ) {
344        self.try_send(BatcherCommand::TableDelta {
345            client_name: client_name.to_string(),
346            table_name: table_name.to_string(),
347            operation: operation.to_string(),
348            d_total: 1,
349            d_err: i64::from(is_error),
350        });
351    }
352
353    /// Enqueues a debounced `last_seen_at` touch.
354    pub fn try_enqueue_last_seen(&self, client_name: &str) {
355        self.try_send(BatcherCommand::LastSeen {
356            client_name: client_name.to_string(),
357        });
358    }
359}
360
361#[cfg(test)]
362mod tests {
363    use super::*;
364
365    #[test]
366    fn merge_deltas_accumulates() {
367        let mut state: BatcherState = BatcherState::default();
368        state.apply(BatcherCommand::RequestDelta {
369            client_name: "c1".into(),
370            d_req: 1,
371            d_succ: 1,
372            d_fail: 0,
373            d_cached: 0,
374        });
375        state.apply(BatcherCommand::RequestDelta {
376            client_name: "c1".into(),
377            d_req: 1,
378            d_succ: 0,
379            d_fail: 1,
380            d_cached: 1,
381        });
382        state.apply(BatcherCommand::OperationDelta {
383            client_name: "c1".into(),
384            d_op: 3,
385        });
386        let acc: &ClientAccum = state.clients.get("c1").expect("client accumulator");
387        assert_eq!(acc.d_req, 2);
388        assert_eq!(acc.d_succ, 1);
389        assert_eq!(acc.d_fail, 1);
390        assert_eq!(acc.d_cached, 1);
391        assert_eq!(acc.d_op, 3);
392    }
393
394    #[test]
395    fn last_seen_dedupes_per_flush_batch() {
396        let mut state: BatcherState = BatcherState::default();
397        state.apply(BatcherCommand::LastSeen {
398            client_name: "c".into(),
399        });
400        state.apply(BatcherCommand::LastSeen {
401            client_name: "c".into(),
402        });
403        assert_eq!(state.last_seen.len(), 1);
404    }
405
406    #[test]
407    fn table_merge_accumulates() {
408        let mut state: BatcherState = BatcherState::default();
409        state.apply(BatcherCommand::TableDelta {
410            client_name: "c".into(),
411            table_name: "t".into(),
412            operation: "insert".into(),
413            d_total: 1,
414            d_err: 0,
415        });
416        state.apply(BatcherCommand::TableDelta {
417            client_name: "c".into(),
418            table_name: "t".into(),
419            operation: "insert".into(),
420            d_total: 1,
421            d_err: 1,
422        });
423        let key = ("c".into(), "t".into(), "insert".into());
424        assert_eq!(state.tables.get(&key), Some(&(2, 1)));
425    }
426}