use std::{
collections::HashMap,
ops::Deref,
sync::{
Arc, Mutex,
atomic::{AtomicI64, Ordering},
},
time::Duration,
};
use dashmap::DashMap;
use redis::{Script, aio::ConnectionManager};
use tokio::time::Instant;
use crate::{
ActivityTracker, CounterComparator, DistkitError, DistkitRedisKey, EPOCH_CHANGE_INTERVAL,
RedisKeyGenerator, RedisKeyGeneratorTypeKey,
counter::{CounterError, CounterOptions, CounterTrait},
execute_pipeline_with_script_retry, mutex_lock,
};
const MAX_BATCH_SIZE: usize = 100;
const GET_LUA: &str = r#"
local container_key = KEYS[1]
local key = KEYS[2]
return {key, tonumber(redis.call('HGET', container_key, key)) or 0}
"#;
const COMMIT_STATE_LUA: &str = r#"
local container_key = KEYS[1]
local key = KEYS[2]
local count = tonumber(ARGV[1]) or 0
redis.call('HINCRBY', container_key, key, count)
"#;
const DEL_LUA: &str = r#"
local container_key = KEYS[1]
local key = KEYS[2]
local total = redis.call('HGET', container_key, key) or 0
redis.call('HDEL', container_key, key)
return total
"#;
const CLEAR_LUA: &str = r#"
local container_key = KEYS[1]
redis.call('DEL', container_key)
"#;
#[derive(Debug)]
struct Commit {
key: DistkitRedisKey,
delta: i64,
}
#[derive(Debug)]
struct SingleStore {
remote_total: AtomicI64,
delta: AtomicI64,
last_updated: Mutex<Instant>,
last_flushed: Mutex<Option<Instant>>,
}
#[derive(Debug)]
pub struct LaxCounter {
connection_manager: ConnectionManager,
key_generator: RedisKeyGenerator,
store: DashMap<DistkitRedisKey, SingleStore>,
locks: DashMap<DistkitRedisKey, Arc<tokio::sync::Mutex<()>>>,
get_script: Script,
allowed_lag: Duration,
commit_state_script: Script,
del_script: Script,
clear_script: Script,
batch: tokio::sync::Mutex<Vec<Commit>>,
activity: Arc<ActivityTracker>,
}
impl LaxCounter {
pub fn new(options: CounterOptions) -> Arc<Self> {
let CounterOptions {
prefix,
connection_manager,
allowed_lag,
} = options;
let key_generator = RedisKeyGenerator::new(prefix, RedisKeyGeneratorTypeKey::Lax);
let get_script = Script::new(GET_LUA);
let del_script = Script::new(DEL_LUA);
let clear_script = Script::new(CLEAR_LUA);
let commit_state_script = Script::new(COMMIT_STATE_LUA);
let counter = Self {
connection_manager,
key_generator,
store: DashMap::default(),
get_script,
del_script,
clear_script,
allowed_lag,
locks: DashMap::default(),
commit_state_script,
batch: tokio::sync::Mutex::new(Vec::new()),
activity: ActivityTracker::new(EPOCH_CHANGE_INTERVAL),
};
let counter = Arc::new(counter);
counter.run_flush_task();
counter
}
fn run_flush_task(self: &Arc<Self>) {
tokio::spawn({
let allowed_lag = self.allowed_lag;
let counter = Arc::downgrade(self);
let mut is_active_watch = self.activity.subscribe();
async move {
let mut interval = tokio::time::interval(allowed_lag);
interval.tick().await;
loop {
let is_active = {
let Some(counter) = counter.upgrade() else {
break;
};
counter.activity.get_is_active()
};
if !is_active && is_active_watch.changed().await.is_err() {
break;
}
interval.tick().await;
let counter = match counter.upgrade() {
Some(counter) => counter,
None => break,
};
let mut batch = counter.batch.lock().await;
for entry in counter.store.iter() {
let key = entry.key();
let store = entry.value();
if store.delta.load(Ordering::Acquire) == 0 {
continue;
}
let last_flushed = mutex_lock(&store.last_flushed, "last_flushed")
.map(|el| *el)
.unwrap_or(None);
if let Some(last_flushed) = last_flushed
&& last_flushed.elapsed() < allowed_lag
{
continue;
}
let delta = store.delta.swap(0, Ordering::AcqRel);
store.remote_total.fetch_add(delta, Ordering::AcqRel);
let Ok(mut last_flushed) = mutex_lock(&store.last_flushed, "last_flushed")
else {
continue;
};
*last_flushed = Some(Instant::now());
batch.push(Commit {
key: key.clone(),
delta,
});
}
if let Err(err) = counter.flush_to_redis(&mut batch, MAX_BATCH_SIZE).await {
tracing::error!("Failed to flush to redis: {err:?}");
continue;
}
}
}
});
}
async fn flush_to_redis(
&self,
batch: &mut Vec<Commit>,
max_batch_size: usize,
) -> Result<(), DistkitError> {
if batch.is_empty() {
return Ok(());
}
let mut processed = 0;
while processed < batch.len() {
let end = (processed + max_batch_size).min(batch.len());
let chunk = &batch[processed..end];
self.batch_commit_state(chunk)
.await
.map_err(|err| CounterError::CommitToRedisFailed(format!("{err:?}")))?;
processed = end;
}
batch.drain(..processed);
Ok(())
}
async fn batch_commit_state(&self, commits: &[Commit]) -> Result<(), DistkitError> {
let mut conn = self.connection_manager.clone();
let script = &self.commit_state_script;
execute_pipeline_with_script_retry::<(), _, _>(&mut conn, script, commits, |commit| {
let mut inv = script.key(self.key_generator.container_key());
inv.key(commit.key.as_str());
inv.arg(commit.delta);
inv
})
.await
}
async fn ensure_valid_state(&self, key: &DistkitRedisKey) -> Result<(), DistkitError> {
let lock = self.get_or_create_lock(key).await;
let _guard = lock.lock().await;
{
let store = self.store.get(key);
if let Some(ref store) = store
&& let SingleStore { last_updated, .. } = store.deref()
&& mutex_lock(last_updated, "last_updated")?.elapsed() < self.allowed_lag
{
return Ok(());
}
}
let mut conn = self.connection_manager.clone();
let (_, remote_total): (String, i64) = self
.get_script
.key(self.key_generator.container_key())
.key(key.as_str())
.invoke_async(&mut conn)
.await?;
let store = match self.store.get(key) {
Some(store) => store,
None => {
self.store
.entry(key.clone())
.or_insert_with(|| SingleStore {
remote_total: AtomicI64::new(remote_total),
delta: AtomicI64::new(0),
last_updated: Mutex::new(Instant::now()),
last_flushed: Mutex::new(None),
});
self.store.get(key).expect("store should be present here")
}
};
store.remote_total.store(remote_total, Ordering::Release);
*mutex_lock(&store.last_updated, "last_updated")? = Instant::now();
Ok(())
}
async fn get_or_create_lock(&self, key: &DistkitRedisKey) -> Arc<tokio::sync::Mutex<()>> {
if let Some(lock) = self.locks.get(key) {
return lock.clone();
}
self.locks
.entry(key.clone())
.or_insert_with(|| Arc::new(tokio::sync::Mutex::new(())))
.clone()
}
async fn batch_refresh_stale(&self, keys: &[&DistkitRedisKey]) -> Result<(), DistkitError> {
if keys.is_empty() {
return Ok(());
}
let mut stale_keys = Vec::with_capacity(keys.len());
for key in keys {
let Some(store) = self.store.get(*key) else {
stale_keys.push(*key);
continue;
};
if let Ok(last_flushed) = mutex_lock(&store.last_flushed, "last_flushed")
&& let Some(last_flushed) = last_flushed.deref()
&& last_flushed.elapsed() < self.allowed_lag
{
continue;
}
stale_keys.push(*key);
}
let mut batch = self.batch.lock().await;
self.flush_to_redis(&mut batch, MAX_BATCH_SIZE).await?;
let mut conn = self.connection_manager.clone();
let script = &self.get_script;
let raw: Vec<(String, i64)> =
execute_pipeline_with_script_retry(&mut conn, script, &stale_keys, |key| {
let mut inv = script.key(self.key_generator.container_key());
inv.key(key.as_str());
inv
})
.await?;
let map: HashMap<String, i64> = raw.into_iter().collect();
for key in stale_keys {
let remote_total = map.get(key.as_str()).copied().unwrap_or(0);
match self.store.get(key) {
Some(store) => {
store.remote_total.store(remote_total, Ordering::Release);
*mutex_lock(&store.last_updated, "last_updated")? = Instant::now();
}
None => {
let value = self
.store
.entry((*key).clone())
.or_insert_with(|| SingleStore {
remote_total: AtomicI64::new(remote_total),
delta: AtomicI64::new(0),
last_updated: Mutex::new(Instant::now()),
last_flushed: Mutex::new(None),
});
value.remote_total.store(remote_total, Ordering::Release);
*mutex_lock(&value.last_updated, "last_updated")? = Instant::now();
}
}
}
Ok(())
}
}
#[async_trait::async_trait]
impl CounterTrait for LaxCounter {
async fn inc(&self, key: &DistkitRedisKey, count: i64) -> Result<i64, DistkitError> {
Ok(self.inc_if(key, CounterComparator::Nil, count).await?.0)
}
async fn inc_if(
&self,
key: &DistkitRedisKey,
comparator: CounterComparator,
count: i64,
) -> Result<(i64, i64), DistkitError> {
self.activity.signal();
let store = match self.store.get(key) {
Some(store)
if mutex_lock(&store.last_updated, "last_updated")?.elapsed()
< self.allowed_lag =>
{
store
}
Some(store) => {
drop(store);
self.ensure_valid_state(key).await?;
self.store.get(key).expect("store should be present here")
}
None => {
self.ensure_valid_state(key).await?;
self.store.get(key).expect("store should be present here")
}
};
let remote_total = store.remote_total.load(Ordering::Acquire);
let current = remote_total + store.delta.load(Ordering::Acquire);
if !comparator.matches(current) {
return Ok((current, current));
}
let prev_delta = if count > 0 {
store.delta.fetch_add(count, Ordering::AcqRel)
} else {
store.delta.fetch_sub(count.abs(), Ordering::AcqRel)
};
Ok((remote_total + prev_delta + count, current))
}
async fn dec(&self, key: &DistkitRedisKey, count: i64) -> Result<i64, DistkitError> {
self.inc(key, -count).await
}
async fn get(&self, key: &DistkitRedisKey) -> Result<i64, DistkitError> {
self.activity.signal();
let store = match self.store.get(key) {
Some(store)
if mutex_lock(&store.last_updated, "last_updated")?.elapsed()
< self.allowed_lag =>
{
store
}
Some(store) => {
drop(store);
self.ensure_valid_state(key).await?;
self.store.get(key).expect("store should be present here")
}
None => {
self.ensure_valid_state(key).await?;
self.store.get(key).expect("store should be present here")
}
};
let delta = store.delta.load(Ordering::Acquire);
let total = store.remote_total.load(Ordering::Acquire) + delta;
Ok(total)
}
async fn set(&self, key: &DistkitRedisKey, count: i64) -> Result<i64, DistkitError> {
Ok(self.set_if(key, CounterComparator::Nil, count).await?.0)
}
async fn set_if(
&self,
key: &DistkitRedisKey,
comparator: CounterComparator,
count: i64,
) -> Result<(i64, i64), DistkitError> {
self.activity.signal();
let store = match self.store.get(key) {
Some(store)
if mutex_lock(&store.last_updated, "last_updated")?.elapsed()
< self.allowed_lag =>
{
store
}
Some(store) => {
drop(store);
self.ensure_valid_state(key).await?;
self.store.get(key).expect("store should be present here")
}
None => {
self.ensure_valid_state(key).await?;
self.store.get(key).expect("store should be present here")
}
};
let remote_total = store.remote_total.load(Ordering::Acquire);
let current = remote_total + store.delta.load(Ordering::Acquire);
if !comparator.matches(current) {
return Ok((current, current));
}
store.delta.store(count - remote_total, Ordering::Release);
Ok((count, current))
}
async fn del(&self, key: &DistkitRedisKey) -> Result<i64, DistkitError> {
self.activity.signal();
let lock = self.get_or_create_lock(key).await;
let _guard = lock.lock().await;
{
let mut batch = self.batch.lock().await;
batch.retain(|commit| commit.key != *key);
}
let Some((_key, store)) = self.store.remove(key) else {
return Ok(0);
};
let mut conn = self.connection_manager.clone();
let total: i64 = self
.del_script
.key(self.key_generator.container_key())
.key(key.as_str())
.invoke_async(&mut conn)
.await?;
let total = total + store.delta.swap(0, Ordering::AcqRel);
Ok(total)
}
async fn clear(&self) -> Result<(), DistkitError> {
self.activity.signal();
self.store.clear();
{
let mut batch = self.batch.lock().await;
batch.clear();
}
let mut conn = self.connection_manager.clone();
let _: () = self
.clear_script
.key(self.key_generator.container_key())
.invoke_async(&mut conn)
.await?;
Ok(())
}
async fn get_all<'k>(
&self,
keys: &[&'k DistkitRedisKey],
) -> Result<Vec<(&'k DistkitRedisKey, i64)>, DistkitError> {
if keys.is_empty() {
return Ok(vec![]);
}
self.activity.signal();
self.batch_refresh_stale(keys).await?;
keys.iter()
.map(|key| {
let store = self.store.get(*key).expect("store populated after refresh");
Ok((
*key,
store.remote_total.load(Ordering::Acquire)
+ store.delta.load(Ordering::Acquire),
))
})
.collect()
}
async fn inc_all<'k>(
&self,
updates: &[(&'k DistkitRedisKey, i64)],
) -> Result<Vec<(&'k DistkitRedisKey, i64)>, DistkitError> {
let conditional_updates: Vec<(&DistkitRedisKey, CounterComparator, i64)> = updates
.iter()
.map(|(key, count)| (*key, CounterComparator::Nil, *count))
.collect();
Ok(self
.inc_all_if(&conditional_updates)
.await?
.into_iter()
.map(|(key, new, _)| (key, new))
.collect())
}
async fn inc_all_if<'k>(
&self,
updates: &[(&'k DistkitRedisKey, CounterComparator, i64)],
) -> Result<Vec<(&'k DistkitRedisKey, i64, i64)>, DistkitError> {
if updates.is_empty() {
return Ok(vec![]);
}
self.activity.signal();
let keys: Vec<&DistkitRedisKey> = updates.iter().map(|(key, _, _)| *key).collect();
self.batch_refresh_stale(&keys).await?;
updates
.iter()
.map(|(key, comparator, count)| {
let store = self.store.get(*key).expect("store populated after refresh");
let remote_total = store.remote_total.load(Ordering::Acquire);
let current = remote_total + store.delta.load(Ordering::Acquire);
if comparator.matches(current) {
let prev_delta = if *count > 0 {
store.delta.fetch_add(*count, Ordering::AcqRel)
} else {
store.delta.fetch_sub(count.abs(), Ordering::AcqRel)
};
Ok((*key, remote_total + prev_delta + *count, current))
} else {
Ok((*key, current, current))
}
})
.collect()
}
async fn set_all<'k>(
&self,
updates: &[(&'k DistkitRedisKey, i64)],
) -> Result<Vec<(&'k DistkitRedisKey, i64)>, DistkitError> {
let conditional_updates: Vec<(&DistkitRedisKey, CounterComparator, i64)> = updates
.iter()
.map(|(key, count)| (*key, CounterComparator::Nil, *count))
.collect();
Ok(self
.set_all_if(&conditional_updates)
.await?
.into_iter()
.map(|(key, new, _)| (key, new))
.collect())
}
async fn set_all_if<'k>(
&self,
updates: &[(&'k DistkitRedisKey, CounterComparator, i64)],
) -> Result<Vec<(&'k DistkitRedisKey, i64, i64)>, DistkitError> {
if updates.is_empty() {
return Ok(vec![]);
}
self.activity.signal();
let keys: Vec<&DistkitRedisKey> = updates.iter().map(|(key, _, _)| *key).collect();
self.batch_refresh_stale(&keys).await?;
updates
.iter()
.map(|(key, comparator, count)| {
let store = self.store.get(*key).expect("store populated after refresh");
let remote_total = store.remote_total.load(Ordering::Acquire);
let current = remote_total + store.delta.load(Ordering::Acquire);
if comparator.matches(current) {
store.delta.store(count - remote_total, Ordering::Release);
Ok((*key, *count, current))
} else {
Ok((*key, current, current))
}
})
.collect()
}
}