use dioxus::{core::ReactiveContext, prelude::*};
use std::{
collections::{HashMap, HashSet},
sync::{Arc, Mutex, atomic::AtomicBool},
time::Duration,
};
#[cfg(not(target_family = "wasm"))]
use tokio::time;
#[cfg(target_family = "wasm")]
use wasmtimer::tokio as time;
type ReactiveContextSet = Arc<Mutex<HashSet<ReactiveContext>>>;
type ReactiveContextRegistry = Arc<Mutex<HashMap<String, ReactiveContextSet>>>;
#[derive(Debug, Clone, PartialEq)]
pub enum TaskType {
IntervalRefresh,
StaleCheck,
CacheCleanup,
CacheExpiration,
}
type PeriodicTaskRegistry = Arc<Mutex<HashMap<String, (TaskType, Duration, Arc<AtomicBool>)>>>;
#[derive(Clone, Default)]
pub struct RefreshRegistry {
refresh_counters: Arc<Mutex<HashMap<String, u64>>>,
reactive_contexts: ReactiveContextRegistry,
periodic_tasks: PeriodicTaskRegistry,
ongoing_revalidations: Arc<Mutex<HashSet<String>>>,
}
impl RefreshRegistry {
pub fn new() -> Self {
Self::default()
}
pub fn get_refresh_count(&self, key: &str) -> u64 {
if let Ok(counters) = self.refresh_counters.lock() {
*counters.get(key).unwrap_or(&0)
} else {
0
}
}
pub fn subscribe_to_refresh(&self, key: &str, reactive_context: ReactiveContext) {
if let Ok(mut contexts) = self.reactive_contexts.lock() {
let key_contexts = contexts
.entry(key.to_string())
.or_insert_with(|| Arc::new(Mutex::new(HashSet::new())));
if let Ok(mut context_set) = key_contexts.lock() {
context_set.insert(reactive_context);
}
}
}
pub fn trigger_refresh(&self, key: &str) {
if let Ok(mut counters) = self.refresh_counters.lock() {
let counter = counters.entry(key.to_string()).or_insert(0);
*counter += 1;
}
if let Ok(contexts) = self.reactive_contexts.lock() {
if let Some(key_contexts) = contexts.get(key) {
if let Ok(context_set) = key_contexts.lock() {
for reactive_context in context_set.iter() {
reactive_context.mark_dirty();
}
}
}
}
}
pub fn clear_all(&self) {
if let Ok(counters) = self.refresh_counters.lock() {
let keys: Vec<String> = counters.keys().cloned().collect();
drop(counters);
for key in keys {
self.trigger_refresh(&key);
}
}
}
pub fn start_periodic_task<F>(
&self,
key: &str,
task_type: TaskType,
interval: Duration,
task_fn: F,
) where
F: Fn() + 'static,
{
if let Ok(mut tasks) = self.periodic_tasks.lock() {
let task_key = format!("{key}:{task_type:?}");
if (task_type == TaskType::StaleCheck || task_type == TaskType::CacheExpiration)
&& tasks
.iter()
.any(|(k, (t, _, _))| k.starts_with(&format!("{key}:")) && *t == task_type)
{
return;
}
let should_create_new_task = match tasks.get(&task_key) {
None => true,
Some((_, current_interval, cancel_flag)) => {
if task_type == TaskType::IntervalRefresh && interval < *current_interval {
cancel_flag.store(true, std::sync::atomic::Ordering::SeqCst);
tasks.remove(&task_key);
true
} else {
false }
}
};
if should_create_new_task {
let actual_interval = match task_type {
TaskType::StaleCheck => Duration::max(
Duration::min(interval / 4, Duration::from_secs(30)),
Duration::from_secs(1),
),
TaskType::CacheExpiration => Duration::max(
Duration::min(interval / 4, Duration::from_secs(30)),
Duration::from_secs(1),
),
_ => interval,
};
let cancel_flag = Arc::new(AtomicBool::new(false));
let cancel_flag_clone = cancel_flag.clone();
let task_fn = Arc::new(task_fn);
spawn(async move {
loop {
if cancel_flag_clone.load(std::sync::atomic::Ordering::SeqCst) {
break;
}
time::sleep(actual_interval).await;
if cancel_flag_clone.load(std::sync::atomic::Ordering::SeqCst) {
break;
}
task_fn();
}
});
tasks.insert(task_key, (task_type, interval, cancel_flag));
}
}
}
pub fn start_interval_task<F>(&self, key: &str, interval: Duration, refresh_fn: F)
where
F: Fn() + 'static,
{
self.start_periodic_task(key, TaskType::IntervalRefresh, interval, refresh_fn);
}
pub fn start_stale_check_task<F>(&self, key: &str, stale_time: Duration, stale_check_fn: F)
where
F: Fn() + 'static,
{
self.start_periodic_task(key, TaskType::StaleCheck, stale_time, stale_check_fn);
}
pub fn stop_periodic_task(&self, key: &str, task_type: TaskType) {
if let Ok(mut tasks) = self.periodic_tasks.lock() {
let task_key = format!("{key}:{task_type:?}");
if let Some((_, _, cancel_flag)) = tasks.remove(&task_key) {
cancel_flag.store(true, std::sync::atomic::Ordering::SeqCst);
}
}
}
pub fn stop_interval_task(&self, key: &str) {
self.stop_periodic_task(key, TaskType::IntervalRefresh);
}
pub fn stop_stale_check_task(&self, key: &str) {
self.stop_periodic_task(key, TaskType::StaleCheck);
}
pub fn is_revalidation_in_progress(&self, key: &str) -> bool {
if let Ok(revalidations) = self.ongoing_revalidations.lock() {
revalidations.contains(key)
} else {
false
}
}
pub fn start_revalidation(&self, key: &str) -> bool {
if let Ok(mut revalidations) = self.ongoing_revalidations.lock() {
if revalidations.contains(key) {
false
} else {
revalidations.insert(key.to_string());
true
}
} else {
false
}
}
pub fn complete_revalidation(&self, key: &str) {
if let Ok(mut revalidations) = self.ongoing_revalidations.lock() {
revalidations.remove(key);
}
}
pub fn stats(&self) -> RefreshRegistryStats {
let refresh_count = if let Ok(counters) = self.refresh_counters.lock() {
counters.len()
} else {
0
};
let context_count = if let Ok(contexts) = self.reactive_contexts.lock() {
contexts.len()
} else {
0
};
let task_count = if let Ok(tasks) = self.periodic_tasks.lock() {
tasks.len()
} else {
0
};
let revalidation_count = if let Ok(revalidations) = self.ongoing_revalidations.lock() {
revalidations.len()
} else {
0
};
RefreshRegistryStats {
refresh_count,
context_count,
task_count,
revalidation_count,
}
}
pub fn cleanup(&self) -> RefreshCleanupStats {
let mut stats = RefreshCleanupStats::default();
if let Ok(mut contexts) = self.reactive_contexts.lock() {
let initial_context_count = contexts.len();
contexts.retain(|_, context_set| {
if let Ok(set) = context_set.lock() {
!set.is_empty()
} else {
false
}
});
stats.contexts_removed = initial_context_count - contexts.len();
}
if let Ok(mut revalidations) = self.ongoing_revalidations.lock() {
stats.revalidations_cleared = revalidations.len();
revalidations.clear();
}
stats
}
}
#[derive(Debug, Clone, Default)]
pub struct RefreshRegistryStats {
pub refresh_count: usize,
pub context_count: usize,
pub task_count: usize,
pub revalidation_count: usize,
}
#[derive(Debug, Clone, Default)]
pub struct RefreshCleanupStats {
pub contexts_removed: usize,
pub revalidations_cleared: usize,
}