use crate::application::ports::Storage;
use crate::application::registry::EventState;
use crate::domain::signature::EventSignature;
use redis::aio::ConnectionManager;
use redis::{AsyncCommands, Client, RedisError};
use serde::{Deserialize, Serialize};
use std::fmt;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::RwLock;
#[derive(Debug, Clone, Serialize, Deserialize)]
struct SerializableEventState {
policy: crate::domain::policy::Policy,
suppressed_count: usize,
first_suppressed_secs: u64,
first_suppressed_nanos: u32,
last_suppressed_secs: u64,
last_suppressed_nanos: u32,
}
impl SerializableEventState {
fn from_event_state(state: &EventState, base_instant: Instant) -> Self {
let counter = state.counter.snapshot();
let first_duration = counter.first_suppressed.duration_since(base_instant);
let last_duration = counter.last_suppressed.duration_since(base_instant);
Self {
policy: state.policy.clone(),
suppressed_count: counter.suppressed_count,
first_suppressed_secs: first_duration.as_secs(),
first_suppressed_nanos: first_duration.subsec_nanos(),
last_suppressed_secs: last_duration.as_secs(),
last_suppressed_nanos: last_duration.subsec_nanos(),
}
}
fn to_event_state(&self, base_instant: Instant) -> EventState {
let first_suppressed =
base_instant + Duration::new(self.first_suppressed_secs, self.first_suppressed_nanos);
let last_suppressed =
base_instant + Duration::new(self.last_suppressed_secs, self.last_suppressed_nanos);
EventState::from_snapshot(
self.policy.clone(),
self.suppressed_count,
first_suppressed,
last_suppressed,
)
}
}
#[derive(Debug, Clone)]
pub struct RedisStorageConfig {
pub ttl: Duration,
pub key_prefix: String,
}
impl Default for RedisStorageConfig {
fn default() -> Self {
Self {
ttl: Duration::from_secs(3600),
key_prefix: "tracing-throttle:".to_string(),
}
}
}
pub struct RedisStorage {
connection: Arc<RwLock<ConnectionManager>>,
config: RedisStorageConfig,
base_instant: Instant,
}
impl fmt::Debug for RedisStorage {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("RedisStorage")
.field("config", &self.config)
.finish_non_exhaustive()
}
}
impl RedisStorage {
pub async fn connect(url: &str) -> Result<Self, RedisError> {
Self::connect_with_config(url, RedisStorageConfig::default()).await
}
pub async fn connect_with_config(
url: &str,
config: RedisStorageConfig,
) -> Result<Self, RedisError> {
let client = Client::open(url)?;
let connection = ConnectionManager::new(client).await?;
Ok(Self {
connection: Arc::new(RwLock::new(connection)),
config,
base_instant: Instant::now(),
})
}
fn key(&self, signature: &EventSignature) -> String {
format!("{}{}", self.config.key_prefix, signature)
}
async fn get(&self, signature: &EventSignature) -> Result<Option<EventState>, RedisError> {
let key = self.key(signature);
let mut conn = self.connection.write().await;
let bytes: Option<Vec<u8>> = conn.get(&key).await?;
if let Some(bytes) = bytes {
if let Ok(serializable) = bincode::deserialize::<SerializableEventState>(&bytes) {
Ok(Some(serializable.to_event_state(self.base_instant)))
} else {
let _: () = conn.del(&key).await?;
Ok(None)
}
} else {
Ok(None)
}
}
async fn set(&self, signature: &EventSignature, state: &EventState) -> Result<(), RedisError> {
let key = self.key(signature);
let serializable = SerializableEventState::from_event_state(state, self.base_instant);
if let Ok(bytes) = bincode::serialize(&serializable) {
let mut conn = self.connection.write().await;
let ttl_secs = self.config.ttl.as_secs();
conn.set_ex::<_, _, ()>(&key, bytes, ttl_secs).await?;
}
Ok(())
}
}
impl Clone for RedisStorage {
fn clone(&self) -> Self {
Self {
connection: Arc::clone(&self.connection),
config: self.config.clone(),
base_instant: self.base_instant,
}
}
}
impl Storage<EventSignature, EventState> for RedisStorage {
fn with_entry_mut<F, R>(
&self,
key: EventSignature,
factory: impl FnOnce() -> EventState,
accessor: F,
) -> R
where
F: FnOnce(&mut EventState) -> R,
{
if let Ok(handle) = tokio::runtime::Handle::try_current() {
tokio::task::block_in_place(|| {
handle.block_on(async {
let mut state = match self.get(&key).await {
Ok(Some(state)) => state,
Ok(None) | Err(_) => factory(),
};
let result = accessor(&mut state);
if let Err(e) = self.set(&key, &state).await {
tracing::warn!(
error = %e,
signature = %key.as_hash(),
"Failed to persist event state to Redis"
);
}
result
})
})
} else {
let rt = tokio::runtime::Runtime::new().expect("Failed to create tokio runtime");
rt.block_on(async {
let mut state = match self.get(&key).await {
Ok(Some(state)) => state,
Ok(None) | Err(_) => factory(),
};
let result = accessor(&mut state);
if let Err(e) = self.set(&key, &state).await {
tracing::warn!(
error = %e,
signature = %key.as_hash(),
"Failed to persist event state to Redis"
);
}
result
})
}
}
fn len(&self) -> usize {
0
}
fn is_empty(&self) -> bool {
false
}
fn clear(&self) {
let pattern = format!("{}*", self.config.key_prefix);
let clear_fn = async {
let mut conn = self.connection.write().await;
let mut cursor = 0;
loop {
let (new_cursor, keys): (u64, Vec<String>) = redis::cmd("SCAN")
.arg(cursor)
.arg("MATCH")
.arg(&pattern)
.arg("COUNT")
.arg(100)
.query_async(&mut *conn)
.await
.unwrap_or((0, vec![]));
if !keys.is_empty() {
let _: Result<(), RedisError> = conn.del(&keys).await;
}
if new_cursor == 0 {
break;
}
cursor = new_cursor;
}
};
if let Ok(handle) = tokio::runtime::Handle::try_current() {
tokio::task::block_in_place(|| handle.block_on(clear_fn));
} else {
let rt = tokio::runtime::Runtime::new().expect("Failed to create tokio runtime");
rt.block_on(clear_fn);
}
}
fn for_each<F>(&self, mut f: F)
where
F: FnMut(&EventSignature, &EventState),
{
let pattern = format!("{}*", self.config.key_prefix);
let for_each_fn = async {
let mut conn = self.connection.write().await;
let mut cursor = 0;
loop {
let (new_cursor, keys): (u64, Vec<String>) = redis::cmd("SCAN")
.arg(cursor)
.arg("MATCH")
.arg(&pattern)
.arg("COUNT")
.arg(100)
.query_async(&mut *conn)
.await
.unwrap_or((0, vec![]));
for key in keys {
if let Some(sig_str) = key.strip_prefix(&self.config.key_prefix) {
if let Ok(sig_hash) = sig_str.parse::<u64>() {
let signature = EventSignature::from_hash(sig_hash);
if let Ok(Some(state)) = self.get(&signature).await {
f(&signature, &state);
}
}
}
}
if new_cursor == 0 {
break;
}
cursor = new_cursor;
}
};
if let Ok(handle) = tokio::runtime::Handle::try_current() {
tokio::task::block_in_place(|| handle.block_on(for_each_fn));
} else {
let rt = tokio::runtime::Runtime::new().expect("Failed to create tokio runtime");
rt.block_on(for_each_fn);
}
}
fn retain<F>(&self, mut f: F)
where
F: FnMut(&EventSignature, &mut EventState) -> bool,
{
let pattern = format!("{}*", self.config.key_prefix);
let retain_fn = async {
let mut conn = self.connection.write().await;
let mut cursor = 0;
loop {
let (new_cursor, keys): (u64, Vec<String>) = redis::cmd("SCAN")
.arg(cursor)
.arg("MATCH")
.arg(&pattern)
.arg("COUNT")
.arg(100)
.query_async(&mut *conn)
.await
.unwrap_or((0, vec![]));
for key in keys {
if let Some(sig_str) = key.strip_prefix(&self.config.key_prefix) {
if let Ok(sig_hash) = sig_str.parse::<u64>() {
let signature = EventSignature::from_hash(sig_hash);
if let Ok(Some(mut state)) = self.get(&signature).await {
if !f(&signature, &mut state) {
if let Err(e) = conn.del::<_, ()>(&key).await {
tracing::warn!(
error = %e,
key = %key,
"Failed to delete key from Redis during retain"
);
}
} else {
if let Err(e) = self.set(&signature, &state).await {
tracing::warn!(
error = %e,
signature = %signature.as_hash(),
"Failed to update key in Redis during retain"
);
}
}
}
}
}
}
if new_cursor == 0 {
break;
}
cursor = new_cursor;
}
};
if let Ok(handle) = tokio::runtime::Handle::try_current() {
tokio::task::block_in_place(|| handle.block_on(retain_fn));
} else {
let rt = tokio::runtime::Runtime::new().expect("Failed to create tokio runtime");
rt.block_on(retain_fn);
}
}
}