elif_orm/loading/
query_deduplicator.rs1use crate::error::{OrmError, OrmResult};
2use serde_json::Value as JsonValue;
3use serde_json::Value;
4use std::collections::HashMap;
5use std::fmt::Display;
6use std::hash::{Hash, Hasher};
7use std::sync::Arc;
8use tokio::sync::{Mutex, RwLock};
9
10#[derive(Debug, Clone)]
12pub struct QueryKey {
13 pub table: String,
15 pub query_type: String,
17 pub conditions: HashMap<String, Vec<Value>>,
19}
20
21impl QueryKey {
22 pub fn relationship(table: &str, foreign_key: &str, parent_ids: &[Value]) -> Self {
24 let mut conditions = HashMap::new();
25 conditions.insert(foreign_key.to_string(), parent_ids.to_vec());
26
27 Self {
28 table: table.to_string(),
29 query_type: "relationship".to_string(),
30 conditions,
31 }
32 }
33
34 pub fn batch_select(table: &str, ids: &[Value]) -> Self {
36 let mut conditions = HashMap::new();
37 conditions.insert("id".to_string(), ids.to_vec());
38
39 Self {
40 table: table.to_string(),
41 query_type: "batch_select".to_string(),
42 conditions,
43 }
44 }
45}
46
47impl PartialEq for QueryKey {
48 fn eq(&self, other: &Self) -> bool {
49 self.table == other.table
50 && self.query_type == other.query_type
51 && self.conditions == other.conditions
52 }
53}
54
55impl Eq for QueryKey {}
56
57impl Hash for QueryKey {
58 fn hash<H: Hasher>(&self, state: &mut H) {
59 self.table.hash(state);
60 self.query_type.hash(state);
61
62 let mut sorted_conditions: Vec<_> = self.conditions.iter().collect();
64 sorted_conditions.sort_by_key(|(k, _)| k.as_str());
65
66 for (key, values) in sorted_conditions {
67 key.hash(state);
68 for value in values {
69 serde_json::to_string(value).unwrap_or_default().hash(state);
71 }
72 }
73 }
74}
75
76#[derive(Debug)]
78struct PendingQuery {
79 result: Arc<Mutex<Option<OrmResult<Vec<JsonValue>>>>>,
81 waiter_count: usize,
83}
84
85pub struct QueryDeduplicator {
87 pending_queries: Arc<RwLock<HashMap<QueryKey, PendingQuery>>>,
89 stats: Arc<RwLock<DeduplicationStats>>,
91}
92
93impl Default for QueryDeduplicator {
94 fn default() -> Self {
95 Self::new()
96 }
97}
98
99impl QueryDeduplicator {
100 pub fn new() -> Self {
102 Self {
103 pending_queries: Arc::new(RwLock::new(HashMap::new())),
104 stats: Arc::new(RwLock::new(DeduplicationStats::default())),
105 }
106 }
107
108 pub async fn execute_deduplicated<F, Fut>(
111 &self,
112 query_key: QueryKey,
113 execute_fn: F,
114 ) -> OrmResult<Vec<JsonValue>>
115 where
116 F: FnOnce() -> Fut,
117 Fut: std::future::Future<Output = OrmResult<Vec<JsonValue>>>,
118 {
119 {
121 let mut pending = self.pending_queries.write().await;
122 if let Some(pending_query) = pending.get_mut(&query_key) {
123 pending_query.waiter_count += 1;
125 let result_mutex = pending_query.result.clone();
126
127 let mut stats = self.stats.write().await;
129 stats.queries_deduplicated += 1;
130 drop(stats);
131 drop(pending);
132
133 let mut result_guard = result_mutex.lock().await;
135 while result_guard.is_none() {
136 drop(result_guard);
138 tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
139 result_guard = result_mutex.lock().await;
140 }
141
142 result_guard
144 .as_ref()
145 .unwrap()
146 .clone()
147 .map_err(|e| OrmError::Query(e.to_string()))
148 } else {
149 let result_mutex = Arc::new(Mutex::new(None));
151 pending.insert(
152 query_key.clone(),
153 PendingQuery {
154 result: result_mutex.clone(),
155 waiter_count: 1,
156 },
157 );
158
159 let mut stats = self.stats.write().await;
161 stats.unique_queries_executed += 1;
162 drop(stats);
163 drop(pending);
164
165 let result = execute_fn().await;
167
168 let mut pending = self.pending_queries.write().await;
170 if let Some(pending_query) = pending.get(&query_key) {
171 let mut result_guard = pending_query.result.lock().await;
172 *result_guard = Some(result.clone());
173 }
174 pending.remove(&query_key);
175
176 result
177 }
178 }
179 }
180
181 pub async fn stats(&self) -> DeduplicationStats {
183 self.stats.read().await.clone()
184 }
185
186 pub async fn reset_stats(&self) {
188 let mut stats = self.stats.write().await;
189 *stats = DeduplicationStats::default();
190 }
191
192 pub async fn has_pending_queries(&self) -> bool {
194 !self.pending_queries.read().await.is_empty()
195 }
196
197 pub async fn pending_query_count(&self) -> usize {
199 self.pending_queries.read().await.len()
200 }
201}
202
203#[derive(Debug, Clone, Default)]
205pub struct DeduplicationStats {
206 pub unique_queries_executed: usize,
208 pub queries_deduplicated: usize,
210 pub queries_saved: usize,
212}
213
214impl DeduplicationStats {
215 pub fn deduplication_ratio(&self) -> f64 {
217 let total = self.unique_queries_executed + self.queries_deduplicated;
218 if total == 0 {
219 0.0
220 } else {
221 self.queries_deduplicated as f64 / total as f64
222 }
223 }
224
225 pub fn total_queries(&self) -> usize {
227 self.unique_queries_executed + self.queries_deduplicated
228 }
229}
230
231impl Display for DeduplicationStats {
232 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
233 write!(
234 f,
235 "QueryDeduplicator Stats: {} unique queries, {} deduplicated ({:.1}% dedup rate)",
236 self.unique_queries_executed,
237 self.queries_deduplicated,
238 self.deduplication_ratio() * 100.0
239 )
240 }
241}