use crate::error::{OrmError, OrmResult};
use serde_json::Value as JsonValue;
use serde_json::Value;
use std::collections::HashMap;
use std::fmt::Display;
use std::hash::{Hash, Hasher};
use std::sync::Arc;
use tokio::sync::{Mutex, RwLock};
#[derive(Debug, Clone)]
pub struct QueryKey {
pub table: String,
pub query_type: String,
pub conditions: HashMap<String, Vec<Value>>,
}
impl QueryKey {
pub fn relationship(table: &str, foreign_key: &str, parent_ids: &[Value]) -> Self {
let mut conditions = HashMap::new();
conditions.insert(foreign_key.to_string(), parent_ids.to_vec());
Self {
table: table.to_string(),
query_type: "relationship".to_string(),
conditions,
}
}
pub fn batch_select(table: &str, ids: &[Value]) -> Self {
let mut conditions = HashMap::new();
conditions.insert("id".to_string(), ids.to_vec());
Self {
table: table.to_string(),
query_type: "batch_select".to_string(),
conditions,
}
}
}
impl PartialEq for QueryKey {
fn eq(&self, other: &Self) -> bool {
self.table == other.table
&& self.query_type == other.query_type
&& self.conditions == other.conditions
}
}
impl Eq for QueryKey {}
impl Hash for QueryKey {
fn hash<H: Hasher>(&self, state: &mut H) {
self.table.hash(state);
self.query_type.hash(state);
let mut sorted_conditions: Vec<_> = self.conditions.iter().collect();
sorted_conditions.sort_by_key(|(k, _)| k.as_str());
for (key, values) in sorted_conditions {
key.hash(state);
for value in values {
serde_json::to_string(value).unwrap_or_default().hash(state);
}
}
}
}
#[derive(Debug)]
struct PendingQuery {
result: Arc<Mutex<Option<OrmResult<Vec<JsonValue>>>>>,
waiter_count: usize,
}
pub struct QueryDeduplicator {
pending_queries: Arc<RwLock<HashMap<QueryKey, PendingQuery>>>,
stats: Arc<RwLock<DeduplicationStats>>,
}
impl Default for QueryDeduplicator {
fn default() -> Self {
Self::new()
}
}
impl QueryDeduplicator {
pub fn new() -> Self {
Self {
pending_queries: Arc::new(RwLock::new(HashMap::new())),
stats: Arc::new(RwLock::new(DeduplicationStats::default())),
}
}
pub async fn execute_deduplicated<F, Fut>(
&self,
query_key: QueryKey,
execute_fn: F,
) -> OrmResult<Vec<JsonValue>>
where
F: FnOnce() -> Fut,
Fut: std::future::Future<Output = OrmResult<Vec<JsonValue>>>,
{
{
let mut pending = self.pending_queries.write().await;
if let Some(pending_query) = pending.get_mut(&query_key) {
pending_query.waiter_count += 1;
let result_mutex = pending_query.result.clone();
let mut stats = self.stats.write().await;
stats.queries_deduplicated += 1;
drop(stats);
drop(pending);
let mut result_guard = result_mutex.lock().await;
while result_guard.is_none() {
drop(result_guard);
tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
result_guard = result_mutex.lock().await;
}
result_guard
.as_ref()
.unwrap()
.clone()
.map_err(|e| OrmError::Query(e.to_string()))
} else {
let result_mutex = Arc::new(Mutex::new(None));
pending.insert(
query_key.clone(),
PendingQuery {
result: result_mutex.clone(),
waiter_count: 1,
},
);
let mut stats = self.stats.write().await;
stats.unique_queries_executed += 1;
drop(stats);
drop(pending);
let result = execute_fn().await;
let mut pending = self.pending_queries.write().await;
if let Some(pending_query) = pending.get(&query_key) {
let mut result_guard = pending_query.result.lock().await;
*result_guard = Some(result.clone());
}
pending.remove(&query_key);
result
}
}
}
pub async fn stats(&self) -> DeduplicationStats {
self.stats.read().await.clone()
}
pub async fn reset_stats(&self) {
let mut stats = self.stats.write().await;
*stats = DeduplicationStats::default();
}
pub async fn has_pending_queries(&self) -> bool {
!self.pending_queries.read().await.is_empty()
}
pub async fn pending_query_count(&self) -> usize {
self.pending_queries.read().await.len()
}
}
#[derive(Debug, Clone, Default)]
pub struct DeduplicationStats {
pub unique_queries_executed: usize,
pub queries_deduplicated: usize,
pub queries_saved: usize,
}
impl DeduplicationStats {
pub fn deduplication_ratio(&self) -> f64 {
let total = self.unique_queries_executed + self.queries_deduplicated;
if total == 0 {
0.0
} else {
self.queries_deduplicated as f64 / total as f64
}
}
pub fn total_queries(&self) -> usize {
self.unique_queries_executed + self.queries_deduplicated
}
}
impl Display for DeduplicationStats {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"QueryDeduplicator Stats: {} unique queries, {} deduplicated ({:.1}% dedup rate)",
self.unique_queries_executed,
self.queries_deduplicated,
self.deduplication_ratio() * 100.0
)
}
}