elif_orm/loading/
query_deduplicator.rs1use crate::{
2 error::{OrmError, OrmResult},
3};
4use serde_json::Value as JsonValue;
5use serde_json::Value;
6use std::collections::{HashMap, HashSet};
7use std::fmt::Display;
8use std::hash::{Hash, Hasher};
9use std::sync::Arc;
10use tokio::sync::{Mutex, RwLock};
11
12#[derive(Debug, Clone)]
14pub struct QueryKey {
15 pub table: String,
17 pub query_type: String,
19 pub conditions: HashMap<String, Vec<Value>>,
21}
22
23impl QueryKey {
24 pub fn relationship(
26 table: &str,
27 foreign_key: &str,
28 parent_ids: &[Value],
29 ) -> Self {
30 let mut conditions = HashMap::new();
31 conditions.insert(foreign_key.to_string(), parent_ids.to_vec());
32
33 Self {
34 table: table.to_string(),
35 query_type: "relationship".to_string(),
36 conditions,
37 }
38 }
39
40 pub fn batch_select(table: &str, ids: &[Value]) -> Self {
42 let mut conditions = HashMap::new();
43 conditions.insert("id".to_string(), ids.to_vec());
44
45 Self {
46 table: table.to_string(),
47 query_type: "batch_select".to_string(),
48 conditions,
49 }
50 }
51}
52
53impl PartialEq for QueryKey {
54 fn eq(&self, other: &Self) -> bool {
55 self.table == other.table
56 && self.query_type == other.query_type
57 && self.conditions == other.conditions
58 }
59}
60
61impl Eq for QueryKey {}
62
63impl Hash for QueryKey {
64 fn hash<H: Hasher>(&self, state: &mut H) {
65 self.table.hash(state);
66 self.query_type.hash(state);
67
68 let mut sorted_conditions: Vec<_> = self.conditions.iter().collect();
70 sorted_conditions.sort_by_key(|(k, _)| k.as_str());
71
72 for (key, values) in sorted_conditions {
73 key.hash(state);
74 for value in values {
75 serde_json::to_string(value).unwrap_or_default().hash(state);
77 }
78 }
79 }
80}
81
82#[derive(Debug)]
84struct PendingQuery {
85 result: Arc<Mutex<Option<OrmResult<Vec<JsonValue>>>>>,
87 waiter_count: usize,
89}
90
91pub struct QueryDeduplicator {
93 pending_queries: Arc<RwLock<HashMap<QueryKey, PendingQuery>>>,
95 stats: Arc<RwLock<DeduplicationStats>>,
97}
98
99impl QueryDeduplicator {
100 pub fn new() -> Self {
102 Self {
103 pending_queries: Arc::new(RwLock::new(HashMap::new())),
104 stats: Arc::new(RwLock::new(DeduplicationStats::default())),
105 }
106 }
107
108 pub async fn execute_deduplicated<F, Fut>(
111 &self,
112 query_key: QueryKey,
113 execute_fn: F,
114 ) -> OrmResult<Vec<JsonValue>>
115 where
116 F: FnOnce() -> Fut,
117 Fut: std::future::Future<Output = OrmResult<Vec<JsonValue>>>,
118 {
119 {
121 let mut pending = self.pending_queries.write().await;
122 if let Some(pending_query) = pending.get_mut(&query_key) {
123 pending_query.waiter_count += 1;
125 let result_mutex = pending_query.result.clone();
126
127 let mut stats = self.stats.write().await;
129 stats.queries_deduplicated += 1;
130 drop(stats);
131 drop(pending);
132
133 let mut result_guard = result_mutex.lock().await;
135 while result_guard.is_none() {
136 drop(result_guard);
138 tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
139 result_guard = result_mutex.lock().await;
140 }
141
142 return result_guard
144 .as_ref()
145 .unwrap()
146 .as_ref()
147 .map(|v| v.clone())
148 .map_err(|e| OrmError::Query(e.to_string()));
149 } else {
150 let result_mutex = Arc::new(Mutex::new(None));
152 pending.insert(
153 query_key.clone(),
154 PendingQuery {
155 result: result_mutex.clone(),
156 waiter_count: 1,
157 },
158 );
159
160 let mut stats = self.stats.write().await;
162 stats.unique_queries_executed += 1;
163 drop(stats);
164 drop(pending);
165
166 let result = execute_fn().await;
168
169 let mut pending = self.pending_queries.write().await;
171 if let Some(pending_query) = pending.get(&query_key) {
172 let mut result_guard = pending_query.result.lock().await;
173 *result_guard = Some(result.clone());
174 }
175 pending.remove(&query_key);
176
177 return result;
178 }
179 }
180 }
181
182 pub async fn stats(&self) -> DeduplicationStats {
184 self.stats.read().await.clone()
185 }
186
187 pub async fn reset_stats(&self) {
189 let mut stats = self.stats.write().await;
190 *stats = DeduplicationStats::default();
191 }
192
193 pub async fn has_pending_queries(&self) -> bool {
195 !self.pending_queries.read().await.is_empty()
196 }
197
198 pub async fn pending_query_count(&self) -> usize {
200 self.pending_queries.read().await.len()
201 }
202}
203
204#[derive(Debug, Clone, Default)]
206pub struct DeduplicationStats {
207 pub unique_queries_executed: usize,
209 pub queries_deduplicated: usize,
211 pub queries_saved: usize,
213}
214
215impl DeduplicationStats {
216 pub fn deduplication_ratio(&self) -> f64 {
218 let total = self.unique_queries_executed + self.queries_deduplicated;
219 if total == 0 {
220 0.0
221 } else {
222 self.queries_deduplicated as f64 / total as f64
223 }
224 }
225
226 pub fn total_queries(&self) -> usize {
228 self.unique_queries_executed + self.queries_deduplicated
229 }
230}
231
232impl Display for DeduplicationStats {
233 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
234 write!(
235 f,
236 "QueryDeduplicator Stats: {} unique queries, {} deduplicated ({:.1}% dedup rate)",
237 self.unique_queries_executed,
238 self.queries_deduplicated,
239 self.deduplication_ratio() * 100.0
240 )
241 }
242}
243