elif_orm/loading/
query_deduplicator.rs

1use crate::{
2    error::{OrmError, OrmResult},
3};
4use serde_json::Value as JsonValue;
5use serde_json::Value;
6use std::collections::{HashMap, HashSet};
7use std::fmt::Display;
8use std::hash::{Hash, Hasher};
9use std::sync::Arc;
10use tokio::sync::{Mutex, RwLock};
11
12/// Represents a unique query that can be deduplicated
13#[derive(Debug, Clone)]
14pub struct QueryKey {
15    /// Table being queried
16    pub table: String,
17    /// Type of query (e.g., "select", "relationship")
18    pub query_type: String,
19    /// Conditions or parameters that make this query unique
20    pub conditions: HashMap<String, Vec<Value>>,
21}
22
23impl QueryKey {
24    /// Create a new query key for a relationship query
25    pub fn relationship(
26        table: &str,
27        foreign_key: &str,
28        parent_ids: &[Value],
29    ) -> Self {
30        let mut conditions = HashMap::new();
31        conditions.insert(foreign_key.to_string(), parent_ids.to_vec());
32        
33        Self {
34            table: table.to_string(),
35            query_type: "relationship".to_string(),
36            conditions,
37        }
38    }
39
40    /// Create a new query key for a batch select query
41    pub fn batch_select(table: &str, ids: &[Value]) -> Self {
42        let mut conditions = HashMap::new();
43        conditions.insert("id".to_string(), ids.to_vec());
44        
45        Self {
46            table: table.to_string(),
47            query_type: "batch_select".to_string(),
48            conditions,
49        }
50    }
51}
52
53impl PartialEq for QueryKey {
54    fn eq(&self, other: &Self) -> bool {
55        self.table == other.table
56            && self.query_type == other.query_type
57            && self.conditions == other.conditions
58    }
59}
60
61impl Eq for QueryKey {}
62
63impl Hash for QueryKey {
64    fn hash<H: Hasher>(&self, state: &mut H) {
65        self.table.hash(state);
66        self.query_type.hash(state);
67        
68        // Sort conditions for consistent hashing
69        let mut sorted_conditions: Vec<_> = self.conditions.iter().collect();
70        sorted_conditions.sort_by_key(|(k, _)| k.as_str());
71        
72        for (key, values) in sorted_conditions {
73            key.hash(state);
74            for value in values {
75                // Hash the JSON representation for consistency
76                serde_json::to_string(value).unwrap_or_default().hash(state);
77            }
78        }
79    }
80}
81
82/// Tracks pending queries to enable deduplication
83#[derive(Debug)]
84struct PendingQuery {
85    /// The result future that will be shared among all waiters
86    result: Arc<Mutex<Option<OrmResult<Vec<JsonValue>>>>>,
87    /// Number of requests waiting for this query
88    waiter_count: usize,
89}
90
91/// Query deduplicator that prevents executing identical queries multiple times
92pub struct QueryDeduplicator {
93    /// Currently executing queries
94    pending_queries: Arc<RwLock<HashMap<QueryKey, PendingQuery>>>,
95    /// Statistics about deduplication
96    stats: Arc<RwLock<DeduplicationStats>>,
97}
98
99impl QueryDeduplicator {
100    /// Create a new query deduplicator
101    pub fn new() -> Self {
102        Self {
103            pending_queries: Arc::new(RwLock::new(HashMap::new())),
104            stats: Arc::new(RwLock::new(DeduplicationStats::default())),
105        }
106    }
107
108    /// Execute a query with deduplication
109    /// If an identical query is already running, wait for its result instead
110    pub async fn execute_deduplicated<F, Fut>(
111        &self,
112        query_key: QueryKey,
113        execute_fn: F,
114    ) -> OrmResult<Vec<JsonValue>>
115    where
116        F: FnOnce() -> Fut,
117        Fut: std::future::Future<Output = OrmResult<Vec<JsonValue>>>,
118    {
119        // Check if query is already pending
120        {
121            let mut pending = self.pending_queries.write().await;
122            if let Some(pending_query) = pending.get_mut(&query_key) {
123                // Query is already running, increment waiter count
124                pending_query.waiter_count += 1;
125                let result_mutex = pending_query.result.clone();
126                
127                // Update stats
128                let mut stats = self.stats.write().await;
129                stats.queries_deduplicated += 1;
130                drop(stats);
131                drop(pending);
132                
133                // Wait for the result
134                let mut result_guard = result_mutex.lock().await;
135                while result_guard.is_none() {
136                    // Release lock and wait
137                    drop(result_guard);
138                    tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
139                    result_guard = result_mutex.lock().await;
140                }
141                
142                // Clone the result and return
143                return result_guard
144                    .as_ref()
145                    .unwrap()
146                    .as_ref()
147                    .map(|v| v.clone())
148                    .map_err(|e| OrmError::Query(e.to_string()));
149            } else {
150                // New query, add to pending
151                let result_mutex = Arc::new(Mutex::new(None));
152                pending.insert(
153                    query_key.clone(),
154                    PendingQuery {
155                        result: result_mutex.clone(),
156                        waiter_count: 1,
157                    },
158                );
159                
160                // Update stats
161                let mut stats = self.stats.write().await;
162                stats.unique_queries_executed += 1;
163                drop(stats);
164                drop(pending);
165                
166                // Execute the query
167                let result = execute_fn().await;
168                
169                // Store result and clean up
170                let mut pending = self.pending_queries.write().await;
171                if let Some(pending_query) = pending.get(&query_key) {
172                    let mut result_guard = pending_query.result.lock().await;
173                    *result_guard = Some(result.clone());
174                }
175                pending.remove(&query_key);
176                
177                return result;
178            }
179        }
180    }
181
182    /// Get deduplication statistics
183    pub async fn stats(&self) -> DeduplicationStats {
184        self.stats.read().await.clone()
185    }
186
187    /// Reset statistics
188    pub async fn reset_stats(&self) {
189        let mut stats = self.stats.write().await;
190        *stats = DeduplicationStats::default();
191    }
192
193    /// Check if any queries are currently pending
194    pub async fn has_pending_queries(&self) -> bool {
195        !self.pending_queries.read().await.is_empty()
196    }
197
198    /// Get the number of pending queries
199    pub async fn pending_query_count(&self) -> usize {
200        self.pending_queries.read().await.len()
201    }
202}
203
204/// Statistics about query deduplication
205#[derive(Debug, Clone, Default)]
206pub struct DeduplicationStats {
207    /// Number of unique queries executed
208    pub unique_queries_executed: usize,
209    /// Number of duplicate queries that were deduplicated
210    pub queries_deduplicated: usize,
211    /// Total queries saved by deduplication
212    pub queries_saved: usize,
213}
214
215impl DeduplicationStats {
216    /// Calculate the deduplication ratio
217    pub fn deduplication_ratio(&self) -> f64 {
218        let total = self.unique_queries_executed + self.queries_deduplicated;
219        if total == 0 {
220            0.0
221        } else {
222            self.queries_deduplicated as f64 / total as f64
223        }
224    }
225
226    /// Get total queries processed
227    pub fn total_queries(&self) -> usize {
228        self.unique_queries_executed + self.queries_deduplicated
229    }
230}
231
232impl Display for DeduplicationStats {
233    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
234        write!(
235            f,
236            "QueryDeduplicator Stats: {} unique queries, {} deduplicated ({:.1}% dedup rate)",
237            self.unique_queries_executed,
238            self.queries_deduplicated,
239            self.deduplication_ratio() * 100.0
240        )
241    }
242}
243