use std::collections::{HashMap, HashSet};
use std::time::Duration;
use sqlx::PgPool;
use tokio::sync::mpsc;
#[derive(Clone, Debug)]
pub struct ClientStatsBatcherConfig {
pub channel_capacity: usize,
pub flush_interval: Duration,
pub flush_max_clients: usize,
}
impl Default for ClientStatsBatcherConfig {
fn default() -> Self {
Self {
channel_capacity: 50_000,
flush_interval: Duration::from_millis(250),
flush_max_clients: 2_000,
}
}
}
#[derive(Debug)]
enum BatcherCommand {
RequestDelta {
client_name: String,
d_req: i64,
d_succ: i64,
d_fail: i64,
d_cached: i64,
},
OperationDelta {
client_name: String,
d_op: i64,
},
TableDelta {
client_name: String,
table_name: String,
operation: String,
d_total: i64,
d_err: i64,
},
LastSeen {
client_name: String,
},
}
#[derive(Default, Clone)]
struct ClientAccum {
d_req: i64,
d_succ: i64,
d_fail: i64,
d_cached: i64,
d_op: i64,
}
#[derive(Default)]
struct BatcherState {
clients: HashMap<String, ClientAccum>,
tables: HashMap<(String, String, String), (i64, i64)>,
last_seen: HashSet<String>,
}
impl BatcherState {
fn apply(&mut self, cmd: BatcherCommand) {
match cmd {
BatcherCommand::RequestDelta {
client_name,
d_req,
d_succ,
d_fail,
d_cached,
} => {
let e = self.clients.entry(client_name).or_default();
e.d_req += d_req;
e.d_succ += d_succ;
e.d_fail += d_fail;
e.d_cached += d_cached;
}
BatcherCommand::OperationDelta { client_name, d_op } => {
let e = self.clients.entry(client_name).or_default();
e.d_op += d_op;
}
BatcherCommand::TableDelta {
client_name,
table_name,
operation,
d_total,
d_err,
} => {
let k = (client_name, table_name, operation);
let t = self.tables.entry(k).or_insert((0, 0));
t.0 += d_total;
t.1 += d_err;
}
BatcherCommand::LastSeen { client_name } => {
self.last_seen.insert(client_name);
}
}
}
fn len_clients(&self) -> usize {
self.clients.len() + self.tables.len()
}
fn is_empty(&self) -> bool {
self.clients.is_empty() && self.tables.is_empty() && self.last_seen.is_empty()
}
async fn flush_all(&mut self, pool: &PgPool) {
if self.is_empty() {
return;
}
for (name, acc) in std::mem::take(&mut self.clients) {
if acc.d_req == 0 && acc.d_op == 0 {
continue;
}
if let Err(err) = sqlx::query(
r#"
INSERT INTO client_statistics (
client_name,
total_requests,
successful_requests,
failed_requests,
total_cached_requests,
total_operations,
last_request_at,
last_operation_at
)
VALUES ($1, $2, $3, $4, $5, $6,
CASE WHEN $7::boolean THEN now() ELSE NULL END,
CASE WHEN $8::boolean THEN now() ELSE NULL END
)
ON CONFLICT (client_name) DO UPDATE
SET total_requests = client_statistics.total_requests + EXCLUDED.total_requests,
successful_requests = client_statistics.successful_requests
+ EXCLUDED.successful_requests,
failed_requests = client_statistics.failed_requests + EXCLUDED.failed_requests,
total_cached_requests = client_statistics.total_cached_requests
+ EXCLUDED.total_cached_requests,
total_operations = client_statistics.total_operations + EXCLUDED.total_operations,
last_request_at = CASE
WHEN EXCLUDED.total_requests > 0 THEN now()
ELSE client_statistics.last_request_at
END,
last_operation_at = CASE
WHEN EXCLUDED.total_operations > 0 THEN now()
ELSE client_statistics.last_operation_at
END,
updated_at = now()
"#,
)
.bind(&name)
.bind(acc.d_req)
.bind(acc.d_succ)
.bind(acc.d_fail)
.bind(acc.d_cached)
.bind(acc.d_op)
.bind(acc.d_req > 0)
.bind(acc.d_op > 0)
.execute(pool)
.await
{
tracing::error!(
error = %err,
client = %name,
"client_stats_batcher: flush client_statistics failed"
);
}
}
for ((client_name, table_name, operation), (d_total, d_err)) in
std::mem::take(&mut self.tables)
{
if d_total == 0 && d_err == 0 {
continue;
}
if let Err(err) = sqlx::query(
r#"
INSERT INTO client_table_statistics (
client_name,
table_name,
operation,
total_operations,
error_operations,
last_operation_at
)
VALUES ($1, $2, $3, $4, $5, now())
ON CONFLICT (client_name, table_name, operation) DO UPDATE
SET total_operations = client_table_statistics.total_operations
+ EXCLUDED.total_operations,
error_operations = client_table_statistics.error_operations
+ EXCLUDED.error_operations,
last_operation_at = now(),
updated_at = now()
"#,
)
.bind(&client_name)
.bind(&table_name)
.bind(&operation)
.bind(d_total)
.bind(d_err)
.execute(pool)
.await
{
tracing::error!(
error = %err,
client = %client_name,
table = %table_name,
operation = %operation,
"client_stats_batcher: flush client_table_statistics failed"
);
}
}
let names: Vec<String> = std::mem::take(&mut self.last_seen).into_iter().collect();
if names.is_empty() {
return;
}
if let Err(err) = sqlx::query(
r#"
UPDATE athena_clients
SET last_seen_at = now(),
updated_at = now()
WHERE deleted_at IS NULL
AND lower(client_name) IN (SELECT lower(x) FROM unnest($1::text[]) AS t(x))
"#,
)
.bind(names.as_slice())
.execute(pool)
.await
{
tracing::error!(
error = %err,
clients = ?names,
"client_stats_batcher: batch last_seen update failed"
);
}
}
}
async fn run_worker(
mut rx: mpsc::Receiver<BatcherCommand>,
pool: PgPool,
config: ClientStatsBatcherConfig,
) {
let mut tick = tokio::time::interval(config.flush_interval);
tick.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay);
let mut state = BatcherState::default();
loop {
tokio::select! {
biased;
cmd = rx.recv() => {
match cmd {
Some(c) => {
state.apply(c);
if state.len_clients() >= config.flush_max_clients {
state.flush_all(&pool).await;
}
}
None => {
state.flush_all(&pool).await;
return;
}
}
}
_ = tick.tick() => {
state.flush_all(&pool).await;
}
}
}
}
#[derive(Clone)]
pub struct ClientStatsBatcher {
tx: mpsc::Sender<BatcherCommand>,
}
impl ClientStatsBatcher {
pub fn spawn(pool: PgPool, config: ClientStatsBatcherConfig) -> Self {
let cap: usize = config.channel_capacity.max(1);
let (tx, rx) = mpsc::channel(cap);
tokio::spawn(run_worker(rx, pool, config));
Self { tx }
}
fn try_send(&self, cmd: BatcherCommand) {
match self.tx.try_send(cmd) {
Err(mpsc::error::TrySendError::Full(_)) => {
tracing::warn!(
target: "athena_rs::client_stats_batcher",
"client stats batcher channel full; dropping delta"
);
}
Err(mpsc::error::TrySendError::Closed(_)) => {}
Ok(()) => {}
}
}
pub fn try_enqueue_request_stats(&self, client_name: &str, status_code: i32, cached: bool) {
let d_succ: i64 = i64::from((200..400).contains(&status_code));
let d_fail: i64 = i64::from(status_code >= 400);
let d_cached: i64 = i64::from(cached);
self.try_send(BatcherCommand::RequestDelta {
client_name: client_name.to_string(),
d_req: 1,
d_succ,
d_fail,
d_cached,
});
}
pub fn try_enqueue_operation_stats(&self, client_name: &str) {
self.try_send(BatcherCommand::OperationDelta {
client_name: client_name.to_string(),
d_op: 1,
});
}
pub fn try_enqueue_table_stats(
&self,
client_name: &str,
table_name: &str,
operation: &str,
is_error: bool,
) {
self.try_send(BatcherCommand::TableDelta {
client_name: client_name.to_string(),
table_name: table_name.to_string(),
operation: operation.to_string(),
d_total: 1,
d_err: i64::from(is_error),
});
}
pub fn try_enqueue_last_seen(&self, client_name: &str) {
self.try_send(BatcherCommand::LastSeen {
client_name: client_name.to_string(),
});
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn merge_deltas_accumulates() {
let mut s: BatcherState = BatcherState::default();
s.apply(BatcherCommand::RequestDelta {
client_name: "c1".into(),
d_req: 1,
d_succ: 1,
d_fail: 0,
d_cached: 0,
});
s.apply(BatcherCommand::RequestDelta {
client_name: "c1".into(),
d_req: 1,
d_succ: 0,
d_fail: 1,
d_cached: 1,
});
s.apply(BatcherCommand::OperationDelta {
client_name: "c1".into(),
d_op: 3,
});
let acc = s.clients.get("c1").unwrap();
assert_eq!(acc.d_req, 2);
assert_eq!(acc.d_succ, 1);
assert_eq!(acc.d_fail, 1);
assert_eq!(acc.d_cached, 1);
assert_eq!(acc.d_op, 3);
}
#[test]
fn last_seen_dedupes_per_flush_batch() {
let mut s: BatcherState = BatcherState::default();
s.apply(BatcherCommand::LastSeen {
client_name: "c".into(),
});
s.apply(BatcherCommand::LastSeen {
client_name: "c".into(),
});
assert_eq!(s.last_seen.len(), 1);
}
#[test]
fn table_merge_accumulates() {
let mut s: BatcherState = BatcherState::default();
s.apply(BatcherCommand::TableDelta {
client_name: "c".into(),
table_name: "t".into(),
operation: "insert".into(),
d_total: 1,
d_err: 0,
});
s.apply(BatcherCommand::TableDelta {
client_name: "c".into(),
table_name: "t".into(),
operation: "insert".into(),
d_total: 1,
d_err: 1,
});
let k = ("c".into(), "t".into(), "insert".into());
assert_eq!(s.tables.get(&k), Some(&(2, 1)));
}
}