use std::sync::{
Arc,
atomic::{AtomicI64, AtomicU64, Ordering},
};
use dashmap::DashMap;
use redis::{Script, aio::ConnectionManager};
use crate::{
ActivityTracker, EPOCH_CHANGE_INTERVAL, RedisKey, RedisKeyGenerator, RedisKeyGeneratorTypeKey,
error::DistkitError,
icounter::{InstanceAwareCounterTrait, generate_instance_id},
};
#[derive(Debug)]
struct SingleStore {
epoch: AtomicU64,
cumulative: AtomicI64,
local_count: AtomicI64,
}
impl SingleStore {
fn new(epoch: u64, cumulative: i64, local_count: i64) -> Self {
Self {
epoch: AtomicU64::new(epoch),
cumulative: AtomicI64::new(cumulative),
local_count: AtomicI64::new(local_count),
}
}
}
const HELPERS_LUA: &str = r#"
local function now_ms()
local time_array = redis.call("TIME")
return tonumber(time_array[1]) * 1000 + math.floor(tonumber(time_array[2]) / 1000)
end
local function delete_dead_instances(prefix, instances_key, cumulative_key, keys_key, dead_threshold_ms, timestamp_ms)
local cutoff = timestamp_ms - dead_threshold_ms
local to_remove = redis.call("ZRANGE", instances_key, "-inf", cutoff, "BYSCORE")
for _, inst_id in ipairs(to_remove) do
local inst_count_key = prefix .. ':count:' .. inst_id
local all_keys = redis.call('SMEMBERS', keys_key)
if #all_keys > 0 then
local values = redis.call('HMGET', inst_count_key, unpack(all_keys))
for i = 1, #values do
local c = tonumber(values[i] or 0) or 0
if c ~= 0 then
redis.call('HINCRBY', cumulative_key, all_keys[i], -c)
end
end
end
redis.call('DEL', inst_count_key)
redis.call('ZREM', instances_key, inst_id)
end
end
-- Returns 1 if the instance was not previously in the ZSET (newly created or
-- was cleaned up as dead), 0 if it was already live.
local function check_and_zadd(instances_key, instance_id, ts)
local prev = redis.call('ZSCORE', instances_key, instance_id)
local created = (prev == false or prev == nil) and 1 or 0
redis.call('ZADD', instances_key, ts, instance_id)
return created
end
"#;
const INC_LUA: &str = r#"
local epoch_key = KEYS[1]
local instances_key = KEYS[2]
local cumulative_key = KEYS[3]
local keys_key = KEYS[4]
local inst_count_key = KEYS[5]
local counter_key = ARGV[1]
local delta = tonumber(ARGV[2])
local local_epoch = tonumber(ARGV[3])
local dead_threshold = tonumber(ARGV[4])
local prefix = ARGV[5]
local instance_id = ARGV[6]
local ts = now_ms()
local instance_created = check_and_zadd(instances_key, instance_id, ts)
delete_dead_instances(prefix, instances_key, cumulative_key, keys_key, dead_threshold, ts)
local redis_epoch = tonumber(redis.call('HGET', epoch_key, counter_key) or 0) or 0
local is_stale = (local_epoch ~= redis_epoch)
local new_inst_count
if is_stale then
redis.call('HSET', inst_count_key, counter_key, delta)
new_inst_count = delta
else
new_inst_count = tonumber(redis.call('HINCRBY', inst_count_key, counter_key, delta))
end
local new_cumulative = tonumber(redis.call('HINCRBY', cumulative_key, counter_key, delta))
redis.call('SADD', keys_key, counter_key)
return {counter_key, new_cumulative, new_inst_count, redis_epoch, instance_created}
"#;
const SET_LUA: &str = r#"
local epoch_key = KEYS[1]
local instances_key = KEYS[2]
local cumulative_key = KEYS[3]
local keys_key = KEYS[4]
local inst_count_key = KEYS[5]
local counter_key = ARGV[1]
local count = tonumber(ARGV[2])
local local_epoch = tonumber(ARGV[3])
local dead_threshold = tonumber(ARGV[4])
local prefix = ARGV[5]
local instance_id = ARGV[6]
local max_epoch = tonumber(ARGV[7])
local ts = now_ms()
local instance_created = check_and_zadd(instances_key, instance_id, ts)
delete_dead_instances(prefix, instances_key, cumulative_key, keys_key, dead_threshold, ts)
local old_epoch = tonumber(redis.call('HGET', epoch_key, counter_key) or 0) or 0
local new_epoch = old_epoch + 1
if new_epoch > max_epoch then
new_epoch = 0
end
redis.call('HSET', epoch_key, counter_key, new_epoch)
redis.call('HSET', cumulative_key, counter_key, count)
redis.call('HSET', inst_count_key, counter_key, count)
redis.call('SADD', keys_key, counter_key)
return {count, count, new_epoch, instance_created}
"#;
const SET_ON_INSTANCE_LUA: &str = r#"
local epoch_key = KEYS[1]
local instances_key = KEYS[2]
local cumulative_key = KEYS[3]
local keys_key = KEYS[4]
local inst_count_key = KEYS[5]
local counter_key = ARGV[1]
local count = tonumber(ARGV[2])
local local_epoch = tonumber(ARGV[3])
local dead_threshold = tonumber(ARGV[4])
local prefix = ARGV[5]
local instance_id = ARGV[6]
local ts = now_ms()
local instance_created = check_and_zadd(instances_key, instance_id, ts)
delete_dead_instances(prefix, instances_key, cumulative_key, keys_key, dead_threshold, ts)
local redis_epoch = tonumber(redis.call('HGET', epoch_key, counter_key) or 0) or 0
local inst_count = tonumber(redis.call('HGET', inst_count_key, counter_key) or 0) or 0
local is_stale = (local_epoch ~= redis_epoch)
local effective_old = is_stale and 0 or inst_count
local delta = count - effective_old
redis.call('HSET', inst_count_key, counter_key, count)
local new_cumulative = tonumber(redis.call('HINCRBY', cumulative_key, counter_key, delta))
redis.call('SADD', keys_key, counter_key)
return {new_cumulative, count, redis_epoch, instance_created}
"#;
const GET_LUA: &str = r#"
local epoch_key = KEYS[1]
local instances_key = KEYS[2]
local cumulative_key = KEYS[3]
local keys_key = KEYS[4]
local inst_count_key = KEYS[5]
local counter_key = ARGV[1]
local local_epoch = tonumber(ARGV[2])
local dead_threshold = tonumber(ARGV[3])
local prefix = ARGV[4]
local instance_id = ARGV[5]
local ts = now_ms()
local instance_created = check_and_zadd(instances_key, instance_id, ts)
delete_dead_instances(prefix, instances_key, cumulative_key, keys_key, dead_threshold, ts)
local redis_epoch = tonumber(redis.call('HGET', epoch_key, counter_key) or 0) or 0
local cumulative = tonumber(redis.call('HGET', cumulative_key, counter_key) or 0) or 0
local inst_count = tonumber(redis.call('HGET', inst_count_key, counter_key) or 0) or 0
return {cumulative, inst_count, redis_epoch, instance_created}
"#;
const DEL_LUA: &str = r#"
local epoch_key = KEYS[1]
local instances_key = KEYS[2]
local cumulative_key = KEYS[3]
local keys_key = KEYS[4]
local inst_count_key = KEYS[5]
local counter_key = ARGV[1]
local local_epoch = tonumber(ARGV[2])
local dead_threshold = tonumber(ARGV[3])
local prefix = ARGV[4]
local instance_id = ARGV[5]
local max_epoch = tonumber(ARGV[6])
local ts = now_ms()
local instance_created = check_and_zadd(instances_key, instance_id, ts)
delete_dead_instances(prefix, instances_key, cumulative_key, keys_key, dead_threshold, ts)
local old_cumulative = tonumber(redis.call('HGET', cumulative_key, counter_key) or 0) or 0
local old_epoch = tonumber(redis.call('HGET', epoch_key, counter_key) or 0) or 0
local new_epoch = old_epoch + 1
if new_epoch > max_epoch then
new_epoch = 0
end
redis.call('HSET', epoch_key, counter_key, new_epoch)
redis.call('HDEL', cumulative_key, counter_key)
redis.call('SREM', keys_key, counter_key)
redis.call('HDEL', inst_count_key, counter_key)
return {old_cumulative, new_epoch, instance_created}
"#;
const DEL_ON_INSTANCE_LUA: &str = r#"
local epoch_key = KEYS[1]
local instances_key = KEYS[2]
local cumulative_key = KEYS[3]
local keys_key = KEYS[4]
local inst_count_key = KEYS[5]
local counter_key = ARGV[1]
local local_epoch = tonumber(ARGV[2])
local dead_threshold = tonumber(ARGV[3])
local prefix = ARGV[4]
local instance_id = ARGV[5]
local ts = now_ms()
local instance_created = check_and_zadd(instances_key, instance_id, ts)
delete_dead_instances(prefix, instances_key, cumulative_key, keys_key, dead_threshold, ts)
local redis_epoch = tonumber(redis.call('HGET', epoch_key, counter_key) or 0) or 0
local inst_count = tonumber(redis.call('HGET', inst_count_key, counter_key) or 0) or 0
local is_stale = (local_epoch ~= redis_epoch)
redis.call('HDEL', inst_count_key, counter_key)
local new_cumulative
if is_stale then
new_cumulative = tonumber(redis.call('HGET', cumulative_key, counter_key) or 0) or 0
else
new_cumulative = tonumber(redis.call('HINCRBY', cumulative_key, counter_key, -inst_count))
end
return {new_cumulative, inst_count, redis_epoch, instance_created}
"#;
const CLEAR_LUA: &str = r#"
local epoch_key = KEYS[1]
local instances_key = KEYS[2]
local cumulative_key = KEYS[3]
local keys_key = KEYS[4]
local prefix = ARGV[1]
local all_instances = redis.call('ZRANGE', instances_key, 0, -1)
for _, inst_id in ipairs(all_instances) do
redis.call('DEL', prefix .. ':count:' .. inst_id)
end
redis.call('DEL', epoch_key, instances_key, cumulative_key, keys_key)
"#;
const CLEAR_ON_INSTANCE_LUA: &str = r#"
local epoch_key = KEYS[1]
local instances_key = KEYS[2]
local cumulative_key = KEYS[3]
local keys_key = KEYS[4]
local inst_count_key = KEYS[5]
local dead_threshold = tonumber(ARGV[1])
local prefix = ARGV[2]
local instance_id = ARGV[3]
local ts = now_ms()
check_and_zadd(instances_key, instance_id, ts)
delete_dead_instances(prefix, instances_key, cumulative_key, keys_key, dead_threshold, ts)
local all_keys = redis.call('HKEYS', inst_count_key)
if #all_keys > 0 then
local values = redis.call('HMGET', inst_count_key, unpack(all_keys))
for i = 1, #values do
local c = tonumber(values[i] or 0) or 0
if c ~= 0 then
redis.call('HINCRBY', cumulative_key, all_keys[i], -c)
end
end
end
redis.call('DEL', inst_count_key)
"#;
const MARK_ALIVE_LUA: &str = r#"
local instances_key = KEYS[1]
local epoch_key = KEYS[2]
local cumulative_key = KEYS[3]
local keys_key = KEYS[4]
local dead_threshold = tonumber(ARGV[1])
local prefix = ARGV[2]
local instance_id = ARGV[3]
local ts = now_ms()
local instance_created = check_and_zadd(instances_key, instance_id, ts)
delete_dead_instances(prefix, instances_key, cumulative_key, keys_key, dead_threshold, ts)
return tostring(instance_created)
"#;
const INC_IF_EPOCH_MATCHES_LUA: &str = r#"
local epoch_key = KEYS[1]
local instances_key = KEYS[2]
local cumulative_key = KEYS[3]
local keys_key = KEYS[4]
local inst_count_key = KEYS[5]
local counter_key = ARGV[1]
local recovery_count = tonumber(ARGV[2])
local local_epoch = tonumber(ARGV[3])
local dead_threshold = tonumber(ARGV[4])
local prefix = ARGV[5]
local instance_id = ARGV[6]
local ts = now_ms()
check_and_zadd(instances_key, instance_id, ts)
delete_dead_instances(prefix, instances_key, cumulative_key, keys_key, dead_threshold, ts)
local redis_epoch = tonumber(redis.call('HGET', epoch_key, counter_key) or 0) or 0
if redis_epoch ~= local_epoch then
-- Epoch moved while offline; contribution is stale — do not recover.
local cumulative = tonumber(redis.call('HGET', cumulative_key, counter_key) or 0) or 0
local inst_count = tonumber(redis.call('HGET', inst_count_key, counter_key) or 0) or 0
return {counter_key, cumulative, inst_count, redis_epoch}
end
-- Epoch still matches — safe to restore the contribution.
local new_inst_count = tonumber(redis.call('HINCRBY', inst_count_key, counter_key, recovery_count))
local new_cumulative = tonumber(redis.call('HINCRBY', cumulative_key, counter_key, recovery_count))
redis.call('SADD', keys_key, counter_key)
return {counter_key, new_cumulative, new_inst_count, redis_epoch}
"#;
#[derive(Debug, Clone)]
pub struct StrictInstanceAwareCounterOptions {
pub prefix: RedisKey,
pub connection_manager: ConnectionManager,
pub dead_instance_threshold_ms: u64,
}
impl StrictInstanceAwareCounterOptions {
pub fn new(prefix: RedisKey, connection_manager: ConnectionManager) -> Self {
Self {
prefix,
connection_manager,
dead_instance_threshold_ms: 30_000,
}
}
}
#[derive(Debug)]
pub struct StrictInstanceAwareCounter {
connection_manager: ConnectionManager,
key_generator: RedisKeyGenerator,
instance_id: String,
dead_instance_threshold_ms: u64,
local_store: DashMap<RedisKey, SingleStore>,
max_epoch: u64,
inc_script: Script,
set_script: Script,
set_on_instance_script: Script,
get_script: Script,
del_script: Script,
del_on_instance_script: Script,
clear_script: Script,
clear_on_instance_script: Script,
mark_alive_script: Script,
inc_if_epoch_matches_script: Script,
activity: Arc<ActivityTracker>,
}
impl StrictInstanceAwareCounter {
fn build(
key_generator: RedisKeyGenerator,
connection_manager: ConnectionManager,
dead_instance_threshold_ms: u64,
) -> Arc<Self> {
let instance_id = generate_instance_id();
let counter = Arc::new(Self {
connection_manager,
key_generator,
instance_id,
dead_instance_threshold_ms,
local_store: DashMap::default(),
max_epoch: u64::MAX / 2,
inc_script: Script::new(&format!("{HELPERS_LUA}\n{INC_LUA}")),
set_script: Script::new(&format!("{HELPERS_LUA}\n{SET_LUA}")),
set_on_instance_script: Script::new(&format!("{HELPERS_LUA}\n{SET_ON_INSTANCE_LUA}")),
get_script: Script::new(&format!("{HELPERS_LUA}\n{GET_LUA}")),
del_script: Script::new(&format!("{HELPERS_LUA}\n{DEL_LUA}")),
del_on_instance_script: Script::new(&format!("{HELPERS_LUA}\n{DEL_ON_INSTANCE_LUA}")),
clear_script: Script::new(CLEAR_LUA),
clear_on_instance_script: Script::new(&format!(
"{HELPERS_LUA}\n{CLEAR_ON_INSTANCE_LUA}"
)),
mark_alive_script: Script::new(&format!("{HELPERS_LUA}\n{MARK_ALIVE_LUA}")),
inc_if_epoch_matches_script: Script::new(&format!(
"{HELPERS_LUA}\n{INC_IF_EPOCH_MATCHES_LUA}"
)),
activity: ActivityTracker::new(EPOCH_CHANGE_INTERVAL),
});
counter.run_heartbeat_task();
counter
}
pub fn new(options: StrictInstanceAwareCounterOptions) -> Arc<Self> {
let StrictInstanceAwareCounterOptions {
prefix,
connection_manager,
dead_instance_threshold_ms,
} = options;
let key_generator = RedisKeyGenerator::new(prefix, RedisKeyGeneratorTypeKey::InstanceAware);
Self::build(
key_generator,
connection_manager,
dead_instance_threshold_ms,
)
}
pub(crate) fn new_as_lax_backend(options: StrictInstanceAwareCounterOptions) -> Arc<Self> {
let StrictInstanceAwareCounterOptions {
prefix,
connection_manager,
dead_instance_threshold_ms,
} = options;
let key_generator =
RedisKeyGenerator::new(prefix, RedisKeyGeneratorTypeKey::LaxInstanceAware);
Self::build(
key_generator,
connection_manager,
dead_instance_threshold_ms,
)
}
fn epoch_key(&self) -> String {
format!("{}:epoch", self.key_generator.container_key())
}
fn instances_key(&self) -> String {
format!("{}:instances", self.key_generator.container_key())
}
fn cumulative_key(&self) -> String {
format!("{}:cumulative", self.key_generator.container_key())
}
fn keys_key(&self) -> String {
format!("{}:keys", self.key_generator.container_key())
}
fn inst_count_key(&self) -> String {
format!(
"{}:count:{}",
self.key_generator.container_key(),
self.instance_id
)
}
fn prefix_str(&self) -> String {
self.key_generator.container_key()
}
fn get_local_epoch(&self, key: &RedisKey) -> u64 {
self.local_store
.get(key)
.map(|s| s.epoch.load(Ordering::Acquire))
.unwrap_or(0)
}
fn get_local_count(&self, key: &RedisKey) -> i64 {
self.local_store
.get(key)
.map(|s| s.local_count.load(Ordering::Acquire))
.unwrap_or(0)
}
fn update_local_store(&self, key: &RedisKey, epoch: u64, cumulative: i64, local_count: i64) {
match self.local_store.get(key) {
Some(s) => {
s.epoch.store(epoch, Ordering::Release);
s.cumulative.store(cumulative, Ordering::Release);
s.local_count.store(local_count, Ordering::Release);
}
None => {
self.local_store
.entry(key.clone())
.and_modify(|s| {
s.epoch.store(epoch, Ordering::Relaxed);
s.cumulative.store(cumulative, Ordering::Relaxed);
s.local_count.store(local_count, Ordering::Relaxed);
})
.or_insert_with(|| SingleStore::new(epoch, cumulative, local_count));
}
}
}
fn run_heartbeat_task(self: &Arc<Self>) {
let weak = Arc::downgrade(self);
let mut activity_watch = self.activity.subscribe();
tokio::spawn(async move {
let mut tick = tokio::time::interval(EPOCH_CHANGE_INTERVAL);
tick.tick().await;
loop {
tokio::select! {
changed = activity_watch.changed() => {
if changed.is_err() { break; }
let Some(c) = weak.upgrade() else { break; };
if !c.activity.get_is_active() {
let _ = c.mark_alive().await;
}
}
_ = tick.tick() => {
let Some(c) = weak.upgrade() else { break; };
if !c.activity.get_is_active() {
let _ = c.mark_alive().await;
}
}
}
}
});
}
fn build_recovery_pipeline(
&self,
chunk: &[(RedisKey, i64, u64)],
load_script: bool,
) -> redis::Pipeline {
let mut pipe = redis::Pipeline::new();
if load_script {
pipe.load_script(&self.inc_if_epoch_matches_script).ignore();
}
for (key, count, local_epoch) in chunk {
pipe.invoke_script(
self.inc_if_epoch_matches_script
.key(self.epoch_key())
.key(self.instances_key())
.key(self.cumulative_key())
.key(self.keys_key())
.key(self.inst_count_key())
.arg(key.as_str())
.arg(*count)
.arg(*local_epoch)
.arg(self.dead_instance_threshold_ms)
.arg(self.prefix_str())
.arg(&self.instance_id),
);
}
pipe
}
async fn recover_contributions_batched(
&self,
recoveries: Vec<(RedisKey, i64, u64)>,
chunk_size: usize,
) -> Result<(), DistkitError> {
if recoveries.is_empty() {
return Ok(());
}
let mut conn = self.connection_manager.clone();
let mut processed = 0;
while processed < recoveries.len() {
let end = (processed + chunk_size).min(recoveries.len());
let chunk = &recoveries[processed..end];
let results: Vec<(String, i64, i64, i64)> = {
let pipe = self.build_recovery_pipeline(chunk, false);
match pipe.query_async(&mut conn).await {
Ok(r) => r,
Err(err) => {
if err.kind() != redis::ErrorKind::Server(redis::ServerErrorKind::NoScript)
{
return Err(DistkitError::RedisError(err));
}
let pipe = self.build_recovery_pipeline(chunk, true);
pipe.query_async(&mut conn).await?
}
}
};
for (key_str, cumulative, inst_count, redis_epoch) in results {
if let Ok(key) = RedisKey::try_from(key_str) {
self.update_local_store(&key, redis_epoch as u64, cumulative, inst_count);
}
}
processed = end;
}
Ok(())
}
fn build_inc_batch_pipeline(
&self,
chunk: &[(RedisKey, i64)],
load_script: bool,
) -> redis::Pipeline {
let mut pipe = redis::Pipeline::new();
if load_script {
pipe.load_script(&self.inc_script).ignore();
}
for (key, delta) in chunk {
let local_epoch = self.get_local_epoch(key);
pipe.invoke_script(
self.inc_script
.key(self.epoch_key())
.key(self.instances_key())
.key(self.cumulative_key())
.key(self.keys_key())
.key(self.inst_count_key())
.arg(key.as_str())
.arg(*delta)
.arg(local_epoch)
.arg(self.dead_instance_threshold_ms)
.arg(self.prefix_str())
.arg(&self.instance_id),
);
}
pipe
}
pub async fn inc_batch(
&self,
increments: &mut Vec<(RedisKey, i64)>,
max_batch_size: usize,
) -> Result<Vec<(String, i64, i64)>, DistkitError> {
if increments.is_empty() {
return Ok(vec![]);
}
self.activity.signal();
let mut conn = self.connection_manager.clone();
let mut processed = 0;
let mut output: Vec<(String, i64, i64)> = Vec::with_capacity(increments.len());
while processed < increments.len() {
let end = (processed + max_batch_size).min(increments.len());
let chunk = &increments[processed..end];
let first_attempt = {
let pipe = self.build_inc_batch_pipeline(chunk, false);
pipe.query_async::<Vec<(String, i64, i64, u64, i64)>>(&mut conn)
.await
};
let chunk_results: Vec<(String, i64, i64, u64, i64)> = match first_attempt {
Ok(r) => r,
Err(err) => {
if err.kind() != redis::ErrorKind::Server(redis::ServerErrorKind::NoScript) {
return Err(DistkitError::RedisError(err));
}
let pipe = self.build_inc_batch_pipeline(chunk, true);
match pipe.query_async(&mut conn).await {
Ok(r) => r,
Err(e) => return Err(DistkitError::RedisError(e)),
}
}
};
for (key_str, cumulative, inst_count, redis_epoch, _) in chunk_results {
if let Ok(key) = RedisKey::try_from(key_str.clone()) {
self.update_local_store(&key, redis_epoch, cumulative, inst_count);
}
output.push((key_str, cumulative, inst_count));
}
processed = end;
}
increments.drain(..processed);
Ok(output)
}
#[cfg(test)]
pub(crate) async fn trigger_mark_alive(&self) -> Result<(), DistkitError> {
self.mark_alive().await
}
async fn mark_alive(&self) -> Result<(), DistkitError> {
let mut conn = self.connection_manager.clone();
let instance_created: i8 = self
.mark_alive_script
.key(self.instances_key())
.key(self.epoch_key())
.key(self.cumulative_key())
.key(self.keys_key())
.arg(self.dead_instance_threshold_ms)
.arg(self.prefix_str())
.arg(&self.instance_id)
.invoke_async(&mut conn)
.await?;
if instance_created != 0i8 {
let recoveries: Vec<(RedisKey, i64, u64)> = self
.local_store
.iter()
.filter_map(|e| {
let count = e.local_count.load(Ordering::Acquire);
let epoch = e.epoch.load(Ordering::Acquire);
if count > 0 {
Some((e.key().clone(), count, epoch))
} else {
None
}
})
.collect();
let _ = self.recover_contributions_batched(recoveries, 50).await;
}
Ok(())
}
pub fn instance_id(&self) -> &str {
&self.instance_id
}
pub async fn inc(&self, key: &RedisKey, count: i64) -> Result<(i64, i64), DistkitError> {
self.activity.signal();
let mut conn = self.connection_manager.clone();
let local_epoch = self.get_local_epoch(key);
let (_, cumulative, inst_count, redis_epoch, instance_created_raw): (
String,
i64,
i64,
u64,
i64,
) = self
.inc_script
.key(self.epoch_key())
.key(self.instances_key())
.key(self.cumulative_key())
.key(self.keys_key())
.key(self.inst_count_key())
.arg(key.as_str())
.arg(count)
.arg(local_epoch)
.arg(self.dead_instance_threshold_ms)
.arg(self.prefix_str())
.arg(&self.instance_id)
.invoke_async(&mut conn)
.await?;
let instance_created = instance_created_raw != 0;
let should_recover = instance_created && local_epoch == redis_epoch;
let old_local_count = self.get_local_count(key);
self.update_local_store(key, redis_epoch, cumulative, inst_count);
if should_recover && old_local_count > 0 {
return Box::pin(self.inc(key, old_local_count)).await;
}
Ok((cumulative, inst_count))
}
pub async fn set(&self, key: &RedisKey, count: i64) -> Result<(i64, i64), DistkitError> {
self.activity.signal();
let mut conn = self.connection_manager.clone();
let local_epoch = self.get_local_epoch(key);
let (cumulative, inst_count, new_epoch_raw, _): (i64, i64, u64, i64) = self
.set_script
.key(self.epoch_key())
.key(self.instances_key())
.key(self.cumulative_key())
.key(self.keys_key())
.key(self.inst_count_key())
.arg(key.as_str())
.arg(count)
.arg(local_epoch)
.arg(self.dead_instance_threshold_ms)
.arg(self.prefix_str())
.arg(&self.instance_id)
.arg(self.max_epoch)
.invoke_async(&mut conn)
.await?;
self.update_local_store(key, new_epoch_raw, cumulative, inst_count);
Ok((cumulative, inst_count))
}
pub async fn set_on_instance(
&self,
key: &RedisKey,
count: i64,
) -> Result<(i64, i64), DistkitError> {
self.activity.signal();
let mut conn = self.connection_manager.clone();
let local_epoch = self.get_local_epoch(key);
let (cumulative, inst_count, redis_epoch_raw, _): (i64, i64, u64, i64) = self
.set_on_instance_script
.key(self.epoch_key())
.key(self.instances_key())
.key(self.cumulative_key())
.key(self.keys_key())
.key(self.inst_count_key())
.arg(key.as_str())
.arg(count)
.arg(local_epoch)
.arg(self.dead_instance_threshold_ms)
.arg(self.prefix_str())
.arg(&self.instance_id)
.invoke_async(&mut conn)
.await?;
self.update_local_store(key, redis_epoch_raw, cumulative, inst_count);
Ok((cumulative, inst_count))
}
pub async fn get(&self, key: &RedisKey) -> Result<(i64, i64), DistkitError> {
self.activity.signal();
let mut conn = self.connection_manager.clone();
let local_epoch = self.get_local_epoch(key);
let (cumulative, inst_count, redis_epoch, instance_created_raw): (i64, i64, u64, i64) =
self.get_script
.key(self.epoch_key())
.key(self.instances_key())
.key(self.cumulative_key())
.key(self.keys_key())
.key(self.inst_count_key())
.arg(key.as_str())
.arg(local_epoch)
.arg(self.dead_instance_threshold_ms)
.arg(self.prefix_str())
.arg(&self.instance_id)
.invoke_async(&mut conn)
.await?;
let instance_created = instance_created_raw != 0;
let should_recover = instance_created && local_epoch == redis_epoch;
let old_local_count = self.get_local_count(key);
self.update_local_store(key, redis_epoch, cumulative, inst_count);
if should_recover && old_local_count > 0 {
return self.inc(key, old_local_count).await;
}
Ok((cumulative, inst_count))
}
pub async fn del(&self, key: &RedisKey) -> Result<(i64, i64), DistkitError> {
self.activity.signal();
let mut conn = self.connection_manager.clone();
let local_epoch = self.get_local_epoch(key);
let (old_cumulative, _, _): (i64, i64, i64) = self
.del_script
.key(self.epoch_key())
.key(self.instances_key())
.key(self.cumulative_key())
.key(self.keys_key())
.key(self.inst_count_key())
.arg(key.as_str())
.arg(local_epoch)
.arg(self.dead_instance_threshold_ms)
.arg(self.prefix_str())
.arg(&self.instance_id)
.arg(self.max_epoch)
.invoke_async(&mut conn)
.await?;
let old_inst_count = self.get_local_count(key);
self.local_store.remove(key);
Ok((old_cumulative, old_inst_count))
}
pub async fn del_on_instance(&self, key: &RedisKey) -> Result<(i64, i64), DistkitError> {
self.activity.signal();
let mut conn = self.connection_manager.clone();
let local_epoch = self.get_local_epoch(key);
let (new_cumulative, removed_count, redis_epoch, _): (i64, i64, u64, i64) = self
.del_on_instance_script
.key(self.epoch_key())
.key(self.instances_key())
.key(self.cumulative_key())
.key(self.keys_key())
.key(self.inst_count_key())
.arg(key.as_str())
.arg(local_epoch)
.arg(self.dead_instance_threshold_ms)
.arg(self.prefix_str())
.arg(&self.instance_id)
.invoke_async(&mut conn)
.await?;
self.update_local_store(key, redis_epoch, new_cumulative, 0);
Ok((new_cumulative, removed_count))
}
pub async fn clear(&self) -> Result<(), DistkitError> {
self.activity.signal();
let mut conn = self.connection_manager.clone();
let _: () = self
.clear_script
.key(self.epoch_key())
.key(self.instances_key())
.key(self.cumulative_key())
.key(self.keys_key())
.arg(self.prefix_str())
.invoke_async(&mut conn)
.await?;
self.local_store.clear();
Ok(())
}
pub async fn clear_on_instance(&self) -> Result<(), DistkitError> {
self.activity.signal();
let mut conn = self.connection_manager.clone();
let _: () = self
.clear_on_instance_script
.key(self.epoch_key())
.key(self.instances_key())
.key(self.cumulative_key())
.key(self.keys_key())
.key(self.inst_count_key())
.arg(self.dead_instance_threshold_ms)
.arg(self.prefix_str())
.arg(&self.instance_id)
.invoke_async(&mut conn)
.await?;
for entry in self.local_store.iter() {
entry.local_count.store(0, Ordering::Release);
}
Ok(())
}
}
#[async_trait::async_trait]
impl InstanceAwareCounterTrait for StrictInstanceAwareCounter {
fn instance_id(&self) -> &str {
self.instance_id()
}
async fn inc(&self, key: &RedisKey, count: i64) -> Result<(i64, i64), DistkitError> {
self.inc(key, count).await
}
async fn dec(&self, key: &RedisKey, count: i64) -> Result<(i64, i64), DistkitError> {
self.inc(key, -count).await
}
async fn set(&self, key: &RedisKey, count: i64) -> Result<(i64, i64), DistkitError> {
self.set(key, count).await
}
async fn set_on_instance(
&self,
key: &RedisKey,
count: i64,
) -> Result<(i64, i64), DistkitError> {
self.set_on_instance(key, count).await
}
async fn get(&self, key: &RedisKey) -> Result<(i64, i64), DistkitError> {
self.get(key).await
}
async fn del(&self, key: &RedisKey) -> Result<(i64, i64), DistkitError> {
self.del(key).await
}
async fn del_on_instance(&self, key: &RedisKey) -> Result<(i64, i64), DistkitError> {
self.del_on_instance(key).await
}
async fn clear(&self) -> Result<(), DistkitError> {
self.clear().await
}
async fn clear_on_instance(&self) -> Result<(), DistkitError> {
self.clear_on_instance().await
}
}