elif_orm/loading/batch_loader/
mod.rs

1use crate::{
2    error::{OrmError, OrmResult},
3    model::Model,
4    query::QueryBuilder,
5};
6use serde_json::Value as JsonValue;
7use serde_json::Value;
8use std::collections::HashMap;
9use std::sync::Arc;
10use tokio::sync::RwLock;
11
12pub mod config;
13pub mod row_conversion;
14
15pub use config::{BatchConfig, CacheStats};
16
17/// Result of a batch load operation
18#[derive(Debug)]
19pub struct BatchLoadResult {
20    /// Loaded records grouped by model type and ID
21    pub records: HashMap<String, HashMap<Value, JsonValue>>,
22    /// Number of queries executed
23    pub query_count: usize,
24    /// Total records loaded
25    pub record_count: usize,
26}
27
28/// Batch loader for efficient relationship loading
29#[derive(Clone)]
30pub struct BatchLoader {
31    config: BatchConfig,
32    query_cache: Arc<RwLock<HashMap<String, Vec<JsonValue>>>>,
33}
34
35impl Default for BatchLoader {
36    fn default() -> Self {
37        Self::new()
38    }
39}
40
41impl BatchLoader {
42    /// Create a new batch loader with default configuration
43    pub fn new() -> Self {
44        Self::with_config(BatchConfig::default())
45    }
46
47    /// Create a new batch loader with custom configuration
48    pub fn with_config(config: BatchConfig) -> Self {
49        Self {
50            config,
51            query_cache: Arc::new(RwLock::new(HashMap::new())),
52        }
53    }
54
55    /// Load multiple records in batches
56    pub async fn load_batch<M: Model>(
57        &self,
58        ids: Vec<Value>,
59        table: &str,
60        connection: &sqlx::PgPool,
61    ) -> OrmResult<Vec<JsonValue>> {
62        if ids.is_empty() {
63            return Ok(Vec::new());
64        }
65
66        let mut all_results = Vec::new();
67        let chunks: Vec<_> = ids.chunks(self.config.max_batch_size).collect();
68
69        for chunk in chunks {
70            let results = self.execute_batch_query(chunk, table, connection).await?;
71            all_results.extend(results);
72        }
73
74        Ok(all_results)
75    }
76
77    /// Execute a single batch query
78    async fn execute_batch_query(
79        &self,
80        ids: &[Value],
81        table: &str,
82        connection: &sqlx::PgPool,
83    ) -> OrmResult<Vec<JsonValue>> {
84        // Build batch query using ANY() for efficient IN clause
85        let id_values: Vec<String> = ids
86            .iter()
87            .enumerate()
88            .map(|(i, _)| format!("${}", i + 1))
89            .collect();
90
91        let query = QueryBuilder::<()>::new()
92            .from(table)
93            .where_raw(&format!("id = ANY(ARRAY[{}])", id_values.join(", ")));
94
95        let (sql, _params) = query.to_sql_with_params();
96        let mut db_query = sqlx::query(&sql);
97
98        // Bind all ID values
99        for id in ids {
100            db_query = match id {
101                Value::Null => db_query.bind(None::<i32>),
102                Value::Bool(b) => db_query.bind(b),
103                Value::Number(n) => {
104                    if let Some(i) = n.as_i64() {
105                        db_query.bind(i)
106                    } else if let Some(f) = n.as_f64() {
107                        db_query.bind(f)
108                    } else {
109                        return Err(OrmError::Query("Invalid number type".into()));
110                    }
111                }
112                Value::String(s) => db_query.bind(s.as_str()),
113                _ => return Err(OrmError::Query("Unsupported ID type".into())),
114            };
115        }
116
117        let rows = db_query
118            .fetch_all(connection)
119            .await
120            .map_err(|e| OrmError::Database(format!("Batch query failed: {}", e)))?;
121
122        // Convert rows to JSON values
123        let mut results = Vec::new();
124        for row in rows {
125            let json_row = self
126                .row_to_json(&row)
127                .map_err(|e| OrmError::Database(format!("Failed to convert row to JSON: {}", e)))?;
128            results.push(json_row);
129        }
130
131        Ok(results)
132    }
133
134    /// Load relationships in batches with deduplication
135    pub async fn load_relationships(
136        &self,
137        parent_type: &str,
138        parent_ids: Vec<Value>,
139        relationship_name: &str,
140        foreign_key: &str,
141        related_table: &str,
142        connection: &sqlx::PgPool,
143    ) -> OrmResult<HashMap<Value, Vec<JsonValue>>> {
144        if parent_ids.is_empty() {
145            return Ok(HashMap::new());
146        }
147
148        // Check cache if deduplication is enabled
149        let cache_key = format!("{}:{}:{:?}", parent_type, relationship_name, parent_ids);
150
151        if self.config.deduplicate_queries {
152            let cache = self.query_cache.read().await;
153            if let Some(cached_results) = cache.get(&cache_key) {
154                return self.group_by_parent_id(cached_results.clone(), foreign_key, &parent_ids);
155            }
156        }
157
158        // Use ANY() for efficient IN clause
159        let parent_id_values: Vec<String> = parent_ids
160            .iter()
161            .enumerate()
162            .map(|(i, _)| format!("${}", i + 1))
163            .collect();
164
165        // Execute batch query for relationships
166        let query = QueryBuilder::<()>::new()
167            .from(related_table)
168            .where_raw(&format!(
169                "{} = ANY(ARRAY[{}])",
170                foreign_key,
171                parent_id_values.join(", ")
172            ));
173
174        let (sql, _params) = query.to_sql_with_params();
175        let mut db_query = sqlx::query(&sql);
176
177        // Bind parent IDs
178        for parent_id in &parent_ids {
179            db_query = match parent_id {
180                Value::Null => db_query.bind(None::<i32>),
181                Value::Bool(b) => db_query.bind(b),
182                Value::Number(n) => {
183                    if let Some(i) = n.as_i64() {
184                        db_query.bind(i)
185                    } else if let Some(f) = n.as_f64() {
186                        db_query.bind(f)
187                    } else {
188                        return Err(OrmError::Query("Invalid number type".into()));
189                    }
190                }
191                Value::String(s) => db_query.bind(s.as_str()),
192                _ => return Err(OrmError::Query("Unsupported ID type".into())),
193            };
194        }
195
196        let rows = db_query
197            .fetch_all(connection)
198            .await
199            .map_err(|e| OrmError::Database(format!("Relationship batch query failed: {}", e)))?;
200
201        // Convert to JSON values
202        let mut results = Vec::new();
203        for row in rows {
204            let json_row = self
205                .row_to_json(&row)
206                .map_err(|e| OrmError::Database(format!("Failed to convert row to JSON: {}", e)))?;
207            results.push(json_row);
208        }
209
210        // Cache results if deduplication is enabled
211        if self.config.deduplicate_queries {
212            let mut cache = self.query_cache.write().await;
213            cache.insert(cache_key, results.clone());
214        }
215
216        // Group results by parent ID
217        self.group_by_parent_id(results, foreign_key, &parent_ids)
218    }
219
220    /// Load nested relationships with deep optimization
221    pub async fn load_nested_relationships(
222        &self,
223        _root_table: &str,
224        root_ids: Vec<Value>,
225        relationship_path: &str,
226        connection: &sqlx::PgPool,
227    ) -> OrmResult<HashMap<Value, JsonValue>> {
228        if root_ids.is_empty() || relationship_path.is_empty() {
229            return Ok(HashMap::new());
230        }
231
232        // Parse relationship path (e.g., "posts.comments.user")
233        let relations: Vec<&str> = relationship_path.split('.').collect();
234        let mut current_ids = root_ids.clone();
235        let mut results: HashMap<Value, JsonValue> = HashMap::new();
236
237        // Process each level of nesting
238        for (depth, relation) in relations.iter().enumerate() {
239            if current_ids.is_empty() {
240                break;
241            }
242
243            // Determine table and foreign key based on relationship type
244            let (related_table, foreign_key) = self.get_relationship_mapping(relation)?;
245
246            // Load current level relationships in optimized batches
247            let level_results = self
248                .load_relationships_optimized(
249                    &format!("level_{}", depth),
250                    current_ids,
251                    relation,
252                    &foreign_key,
253                    &related_table,
254                    connection,
255                )
256                .await?;
257
258            // Update current IDs for next level
259            current_ids = level_results
260                .values()
261                .flatten()
262                .filter_map(|record| record.get("id").cloned())
263                .collect();
264
265            // Merge results with proper nesting
266            self.merge_nested_results(&mut results, level_results, depth == 0);
267        }
268
269        Ok(results)
270    }
271
272    /// Load relationships with advanced optimization strategies
273    async fn load_relationships_optimized(
274        &self,
275        parent_type: &str,
276        parent_ids: Vec<Value>,
277        relationship_name: &str,
278        foreign_key: &str,
279        related_table: &str,
280        connection: &sqlx::PgPool,
281    ) -> OrmResult<HashMap<Value, Vec<JsonValue>>> {
282        // Use smaller batch sizes for nested queries to avoid memory issues
283        let optimal_batch_size = std::cmp::min(self.config.max_batch_size, 50);
284        let mut all_results: HashMap<Value, Vec<JsonValue>> = HashMap::new();
285
286        // Process in optimized chunks
287        for chunk in parent_ids.chunks(optimal_batch_size) {
288            let chunk_results = self
289                .load_relationships(
290                    parent_type,
291                    chunk.to_vec(),
292                    relationship_name,
293                    foreign_key,
294                    related_table,
295                    connection,
296                )
297                .await?;
298
299            // Merge chunk results
300            for (parent_id, relations) in chunk_results {
301                all_results.entry(parent_id).or_default().extend(relations);
302            }
303        }
304
305        Ok(all_results)
306    }
307
308    /// Get relationship mapping for a relation name
309    fn get_relationship_mapping(&self, relation: &str) -> OrmResult<(String, String)> {
310        // This would normally use relationship metadata
311        // For now, we'll use convention-based mapping
312        match relation {
313            "posts" => Ok(("posts".to_string(), "user_id".to_string())),
314            "comments" => Ok(("comments".to_string(), "post_id".to_string())),
315            "user" => Ok(("users".to_string(), "user_id".to_string())),
316            "profile" => Ok(("profiles".to_string(), "user_id".to_string())),
317            _ => Ok((format!("{}s", relation), format!("{}_id", relation))),
318        }
319    }
320
321    /// Merge nested results with proper hierarchical structure
322    fn merge_nested_results(
323        &self,
324        target: &mut HashMap<Value, JsonValue>,
325        source: HashMap<Value, Vec<JsonValue>>,
326        is_root: bool,
327    ) {
328        for (parent_id, relations) in source {
329            if is_root {
330                // For root level, create the initial structure
331                let parent_id_copy = parent_id.clone();
332                target.insert(
333                    parent_id,
334                    serde_json::json!({
335                        "id": parent_id_copy,
336                        "relations": relations
337                    }),
338                );
339            } else {
340                // For nested levels, update existing structure
341                if let Some(existing) = target.get_mut(&parent_id) {
342                    if let Some(obj) = existing.as_object_mut() {
343                        obj.insert("nested_relations".to_string(), serde_json::json!(relations));
344                    }
345                }
346            }
347        }
348    }
349
350    /// Group results by parent ID
351    fn group_by_parent_id(
352        &self,
353        results: Vec<JsonValue>,
354        foreign_key: &str,
355        parent_ids: &[Value],
356    ) -> OrmResult<HashMap<Value, Vec<JsonValue>>> {
357        let mut grouped: HashMap<Value, Vec<JsonValue>> = HashMap::new();
358
359        // Initialize with empty vecs for all parent IDs
360        for parent_id in parent_ids {
361            grouped.insert(parent_id.clone(), Vec::new());
362        }
363
364        // Group results by foreign key value
365        for result in results {
366            if let Some(fk_value) = result.get(foreign_key) {
367                let parent_id = serde_json::from_value(fk_value.clone()).unwrap_or(Value::Null);
368
369                grouped.entry(parent_id).or_default().push(result);
370            }
371        }
372
373        Ok(grouped)
374    }
375
376    /// Clear the query cache
377    pub async fn clear_cache(&self) {
378        let mut cache = self.query_cache.write().await;
379        cache.clear();
380    }
381
382    /// Get cache statistics
383    pub async fn cache_stats(&self) -> CacheStats {
384        let cache = self.query_cache.read().await;
385        CacheStats {
386            cached_queries: cache.len(),
387            total_cached_records: cache.values().map(|v| v.len()).sum(),
388        }
389    }
390}
391
392#[cfg(test)]
393mod tests;