elif_orm/loading/batch_loader/
mod.rs

1use crate::{
2    error::{OrmError, OrmResult},
3    model::Model,
4    query::QueryBuilder,
5};
6use serde_json::Value as JsonValue;
7use serde_json::Value;
8use std::collections::HashMap;
9use std::sync::Arc;
10use tokio::sync::RwLock;
11
12pub mod config;
13pub mod row_conversion;
14
15pub use config::{BatchConfig, CacheStats};
16
17/// Result of a batch load operation
18#[derive(Debug)]
19pub struct BatchLoadResult {
20    /// Loaded records grouped by model type and ID
21    pub records: HashMap<String, HashMap<Value, JsonValue>>,
22    /// Number of queries executed
23    pub query_count: usize,
24    /// Total records loaded
25    pub record_count: usize,
26}
27
28/// Batch loader for efficient relationship loading
29#[derive(Clone)]
30pub struct BatchLoader {
31    config: BatchConfig,
32    query_cache: Arc<RwLock<HashMap<String, Vec<JsonValue>>>>,
33}
34
35impl BatchLoader {
36    /// Create a new batch loader with default configuration
37    pub fn new() -> Self {
38        Self::with_config(BatchConfig::default())
39    }
40
41    /// Create a new batch loader with custom configuration
42    pub fn with_config(config: BatchConfig) -> Self {
43        Self {
44            config,
45            query_cache: Arc::new(RwLock::new(HashMap::new())),
46        }
47    }
48
49    /// Load multiple records in batches
50    pub async fn load_batch<M: Model>(
51        &self,
52        ids: Vec<Value>,
53        table: &str,
54        connection: &sqlx::PgPool,
55    ) -> OrmResult<Vec<JsonValue>> {
56        if ids.is_empty() {
57            return Ok(Vec::new());
58        }
59
60        let mut all_results = Vec::new();
61        let chunks: Vec<_> = ids.chunks(self.config.max_batch_size).collect();
62
63        for chunk in chunks {
64            let results = self.execute_batch_query(chunk, table, connection).await?;
65            all_results.extend(results);
66        }
67
68        Ok(all_results)
69    }
70
71    /// Execute a single batch query
72    async fn execute_batch_query(
73        &self,
74        ids: &[Value],
75        table: &str,
76        connection: &sqlx::PgPool,
77    ) -> OrmResult<Vec<JsonValue>> {
78        // Build batch query using ANY() for efficient IN clause
79        let id_values: Vec<String> = ids
80            .iter()
81            .enumerate()
82            .map(|(i, _)| format!("${}", i + 1))
83            .collect();
84
85        let query = QueryBuilder::<()>::new()
86            .from(table)
87            .where_raw(&format!("id = ANY(ARRAY[{}])", id_values.join(", ")));
88
89        let (sql, _params) = query.to_sql_with_params();
90        let mut db_query = sqlx::query(&sql);
91
92        // Bind all ID values
93        for id in ids {
94            db_query = match id {
95                Value::Null => db_query.bind(None::<i32>),
96                Value::Bool(b) => db_query.bind(b),
97                Value::Number(n) => {
98                    if let Some(i) = n.as_i64() {
99                        db_query.bind(i)
100                    } else if let Some(f) = n.as_f64() {
101                        db_query.bind(f)
102                    } else {
103                        return Err(OrmError::Query("Invalid number type".into()));
104                    }
105                }
106                Value::String(s) => db_query.bind(s.as_str()),
107                _ => return Err(OrmError::Query("Unsupported ID type".into())),
108            };
109        }
110
111        let rows = db_query.fetch_all(connection).await.map_err(|e| {
112            OrmError::Database(format!("Batch query failed: {}", e))
113        })?;
114
115        // Convert rows to JSON values
116        let mut results = Vec::new();
117        for row in rows {
118            let json_row = self.row_to_json(&row).map_err(|e| {
119                OrmError::Database(format!("Failed to convert row to JSON: {}", e))
120            })?;
121            results.push(json_row);
122        }
123
124        Ok(results)
125    }
126
127    /// Load relationships in batches with deduplication
128    pub async fn load_relationships(
129        &self,
130        parent_type: &str,
131        parent_ids: Vec<Value>,
132        relationship_name: &str,
133        foreign_key: &str,
134        related_table: &str,
135        connection: &sqlx::PgPool,
136    ) -> OrmResult<HashMap<Value, Vec<JsonValue>>> {
137        if parent_ids.is_empty() {
138            return Ok(HashMap::new());
139        }
140
141        // Check cache if deduplication is enabled
142        let cache_key = format!("{}:{}:{:?}", parent_type, relationship_name, parent_ids);
143        
144        if self.config.deduplicate_queries {
145            let cache = self.query_cache.read().await;
146            if let Some(cached_results) = cache.get(&cache_key) {
147                return self.group_by_parent_id(cached_results.clone(), foreign_key, &parent_ids);
148            }
149        }
150
151        // Use ANY() for efficient IN clause
152        let parent_id_values: Vec<String> = parent_ids
153            .iter()
154            .enumerate()
155            .map(|(i, _)| format!("${}", i + 1))
156            .collect();
157
158        // Execute batch query for relationships
159        let query = QueryBuilder::<()>::new()
160            .from(related_table)
161            .where_raw(&format!(
162                "{} = ANY(ARRAY[{}])",
163                foreign_key,
164                parent_id_values.join(", ")
165            ));
166
167        let (sql, _params) = query.to_sql_with_params();
168        let mut db_query = sqlx::query(&sql);
169
170        // Bind parent IDs
171        for parent_id in &parent_ids {
172            db_query = match parent_id {
173                Value::Null => db_query.bind(None::<i32>),
174                Value::Bool(b) => db_query.bind(b),
175                Value::Number(n) => {
176                    if let Some(i) = n.as_i64() {
177                        db_query.bind(i)
178                    } else if let Some(f) = n.as_f64() {
179                        db_query.bind(f)
180                    } else {
181                        return Err(OrmError::Query("Invalid number type".into()));
182                    }
183                }
184                Value::String(s) => db_query.bind(s.as_str()),
185                _ => return Err(OrmError::Query("Unsupported ID type".into())),
186            };
187        }
188
189        let rows = db_query.fetch_all(connection).await.map_err(|e| {
190            OrmError::Database(format!("Relationship batch query failed: {}", e))
191        })?;
192
193        // Convert to JSON values
194        let mut results = Vec::new();
195        for row in rows {
196            let json_row = self.row_to_json(&row).map_err(|e| {
197                OrmError::Database(format!("Failed to convert row to JSON: {}", e))
198            })?;
199            results.push(json_row);
200        }
201
202        // Cache results if deduplication is enabled
203        if self.config.deduplicate_queries {
204            let mut cache = self.query_cache.write().await;
205            cache.insert(cache_key, results.clone());
206        }
207
208        // Group results by parent ID
209        self.group_by_parent_id(results, foreign_key, &parent_ids)
210    }
211
212    /// Load nested relationships with deep optimization
213    pub async fn load_nested_relationships(
214        &self,
215        root_table: &str,
216        root_ids: Vec<Value>,
217        relationship_path: &str,
218        connection: &sqlx::PgPool,
219    ) -> OrmResult<HashMap<Value, JsonValue>> {
220        if root_ids.is_empty() || relationship_path.is_empty() {
221            return Ok(HashMap::new());
222        }
223
224        // Parse relationship path (e.g., "posts.comments.user")
225        let relations: Vec<&str> = relationship_path.split('.').collect();
226        let mut current_ids = root_ids.clone();
227        let mut results: HashMap<Value, JsonValue> = HashMap::new();
228
229        // Process each level of nesting
230        for (depth, relation) in relations.iter().enumerate() {
231            if current_ids.is_empty() {
232                break;
233            }
234
235            // Determine table and foreign key based on relationship type
236            let (related_table, foreign_key) = self.get_relationship_mapping(relation)?;
237            
238            // Load current level relationships in optimized batches
239            let level_results = self.load_relationships_optimized(
240                &format!("level_{}", depth),
241                current_ids,
242                relation,
243                &foreign_key,
244                &related_table,
245                connection,
246            ).await?;
247
248            // Update current IDs for next level
249            current_ids = level_results
250                .values()
251                .flatten()
252                .filter_map(|record| record.get("id").cloned())
253                .collect();
254
255            // Merge results with proper nesting
256            self.merge_nested_results(&mut results, level_results, depth == 0);
257        }
258
259        Ok(results)
260    }
261
262    /// Load relationships with advanced optimization strategies
263    async fn load_relationships_optimized(
264        &self,
265        parent_type: &str,
266        parent_ids: Vec<Value>,
267        relationship_name: &str,
268        foreign_key: &str,
269        related_table: &str,
270        connection: &sqlx::PgPool,
271    ) -> OrmResult<HashMap<Value, Vec<JsonValue>>> {
272        // Use smaller batch sizes for nested queries to avoid memory issues
273        let optimal_batch_size = std::cmp::min(self.config.max_batch_size, 50);
274        let mut all_results: HashMap<Value, Vec<JsonValue>> = HashMap::new();
275
276        // Process in optimized chunks
277        for chunk in parent_ids.chunks(optimal_batch_size) {
278            let chunk_results = self.load_relationships(
279                parent_type,
280                chunk.to_vec(),
281                relationship_name,
282                foreign_key,
283                related_table,
284                connection,
285            ).await?;
286
287            // Merge chunk results
288            for (parent_id, relations) in chunk_results {
289                all_results.entry(parent_id).or_insert_with(Vec::new).extend(relations);
290            }
291        }
292
293        Ok(all_results)
294    }
295
296    /// Get relationship mapping for a relation name
297    fn get_relationship_mapping(&self, relation: &str) -> OrmResult<(String, String)> {
298        // This would normally use relationship metadata
299        // For now, we'll use convention-based mapping
300        match relation {
301            "posts" => Ok(("posts".to_string(), "user_id".to_string())),
302            "comments" => Ok(("comments".to_string(), "post_id".to_string())),
303            "user" => Ok(("users".to_string(), "user_id".to_string())),
304            "profile" => Ok(("profiles".to_string(), "user_id".to_string())),
305            _ => Ok((format!("{}s", relation), format!("{}_id", relation))),
306        }
307    }
308
309    /// Merge nested results with proper hierarchical structure
310    fn merge_nested_results(
311        &self,
312        target: &mut HashMap<Value, JsonValue>,
313        source: HashMap<Value, Vec<JsonValue>>,
314        is_root: bool,
315    ) {
316        for (parent_id, relations) in source {
317            if is_root {
318                // For root level, create the initial structure
319                let parent_id_copy = parent_id.clone();
320                target.insert(parent_id, serde_json::json!({
321                    "id": parent_id_copy,
322                    "relations": relations
323                }));
324            } else {
325                // For nested levels, update existing structure
326                if let Some(existing) = target.get_mut(&parent_id) {
327                    if let Some(obj) = existing.as_object_mut() {
328                        obj.insert("nested_relations".to_string(), serde_json::json!(relations));
329                    }
330                }
331            }
332        }
333    }
334
335    /// Group results by parent ID
336    fn group_by_parent_id(
337        &self,
338        results: Vec<JsonValue>,
339        foreign_key: &str,
340        parent_ids: &[Value],
341    ) -> OrmResult<HashMap<Value, Vec<JsonValue>>> {
342        let mut grouped: HashMap<Value, Vec<JsonValue>> = HashMap::new();
343
344        // Initialize with empty vecs for all parent IDs
345        for parent_id in parent_ids {
346            grouped.insert(parent_id.clone(), Vec::new());
347        }
348
349        // Group results by foreign key value
350        for result in results {
351            if let Some(fk_value) = result.get(foreign_key) {
352                let parent_id = serde_json::from_value(fk_value.clone())
353                    .unwrap_or(Value::Null);
354                
355                grouped
356                    .entry(parent_id)
357                    .or_insert_with(Vec::new)
358                    .push(result);
359            }
360        }
361
362        Ok(grouped)
363    }
364
365    /// Clear the query cache
366    pub async fn clear_cache(&self) {
367        let mut cache = self.query_cache.write().await;
368        cache.clear();
369    }
370
371    /// Get cache statistics
372    pub async fn cache_stats(&self) -> CacheStats {
373        let cache = self.query_cache.read().await;
374        CacheStats {
375            cached_queries: cache.len(),
376            total_cached_records: cache.values().map(|v| v.len()).sum(),
377        }
378    }
379}
380
381#[cfg(test)]
382mod tests;