elif_orm/loading/batch_loader/
mod.rs1use crate::{
2 error::{OrmError, OrmResult},
3 model::Model,
4 query::QueryBuilder,
5};
6use serde_json::Value as JsonValue;
7use serde_json::Value;
8use std::collections::HashMap;
9use std::sync::Arc;
10use tokio::sync::RwLock;
11
12pub mod config;
13pub mod row_conversion;
14
15pub use config::{BatchConfig, CacheStats};
16
17#[derive(Debug)]
19pub struct BatchLoadResult {
20 pub records: HashMap<String, HashMap<Value, JsonValue>>,
22 pub query_count: usize,
24 pub record_count: usize,
26}
27
28#[derive(Clone)]
30pub struct BatchLoader {
31 config: BatchConfig,
32 query_cache: Arc<RwLock<HashMap<String, Vec<JsonValue>>>>,
33}
34
35impl BatchLoader {
36 pub fn new() -> Self {
38 Self::with_config(BatchConfig::default())
39 }
40
41 pub fn with_config(config: BatchConfig) -> Self {
43 Self {
44 config,
45 query_cache: Arc::new(RwLock::new(HashMap::new())),
46 }
47 }
48
49 pub async fn load_batch<M: Model>(
51 &self,
52 ids: Vec<Value>,
53 table: &str,
54 connection: &sqlx::PgPool,
55 ) -> OrmResult<Vec<JsonValue>> {
56 if ids.is_empty() {
57 return Ok(Vec::new());
58 }
59
60 let mut all_results = Vec::new();
61 let chunks: Vec<_> = ids.chunks(self.config.max_batch_size).collect();
62
63 for chunk in chunks {
64 let results = self.execute_batch_query(chunk, table, connection).await?;
65 all_results.extend(results);
66 }
67
68 Ok(all_results)
69 }
70
71 async fn execute_batch_query(
73 &self,
74 ids: &[Value],
75 table: &str,
76 connection: &sqlx::PgPool,
77 ) -> OrmResult<Vec<JsonValue>> {
78 let id_values: Vec<String> = ids
80 .iter()
81 .enumerate()
82 .map(|(i, _)| format!("${}", i + 1))
83 .collect();
84
85 let query = QueryBuilder::<()>::new()
86 .from(table)
87 .where_raw(&format!("id = ANY(ARRAY[{}])", id_values.join(", ")));
88
89 let (sql, _params) = query.to_sql_with_params();
90 let mut db_query = sqlx::query(&sql);
91
92 for id in ids {
94 db_query = match id {
95 Value::Null => db_query.bind(None::<i32>),
96 Value::Bool(b) => db_query.bind(b),
97 Value::Number(n) => {
98 if let Some(i) = n.as_i64() {
99 db_query.bind(i)
100 } else if let Some(f) = n.as_f64() {
101 db_query.bind(f)
102 } else {
103 return Err(OrmError::Query("Invalid number type".into()));
104 }
105 }
106 Value::String(s) => db_query.bind(s.as_str()),
107 _ => return Err(OrmError::Query("Unsupported ID type".into())),
108 };
109 }
110
111 let rows = db_query.fetch_all(connection).await.map_err(|e| {
112 OrmError::Database(format!("Batch query failed: {}", e))
113 })?;
114
115 let mut results = Vec::new();
117 for row in rows {
118 let json_row = self.row_to_json(&row).map_err(|e| {
119 OrmError::Database(format!("Failed to convert row to JSON: {}", e))
120 })?;
121 results.push(json_row);
122 }
123
124 Ok(results)
125 }
126
127 pub async fn load_relationships(
129 &self,
130 parent_type: &str,
131 parent_ids: Vec<Value>,
132 relationship_name: &str,
133 foreign_key: &str,
134 related_table: &str,
135 connection: &sqlx::PgPool,
136 ) -> OrmResult<HashMap<Value, Vec<JsonValue>>> {
137 if parent_ids.is_empty() {
138 return Ok(HashMap::new());
139 }
140
141 let cache_key = format!("{}:{}:{:?}", parent_type, relationship_name, parent_ids);
143
144 if self.config.deduplicate_queries {
145 let cache = self.query_cache.read().await;
146 if let Some(cached_results) = cache.get(&cache_key) {
147 return self.group_by_parent_id(cached_results.clone(), foreign_key, &parent_ids);
148 }
149 }
150
151 let parent_id_values: Vec<String> = parent_ids
153 .iter()
154 .enumerate()
155 .map(|(i, _)| format!("${}", i + 1))
156 .collect();
157
158 let query = QueryBuilder::<()>::new()
160 .from(related_table)
161 .where_raw(&format!(
162 "{} = ANY(ARRAY[{}])",
163 foreign_key,
164 parent_id_values.join(", ")
165 ));
166
167 let (sql, _params) = query.to_sql_with_params();
168 let mut db_query = sqlx::query(&sql);
169
170 for parent_id in &parent_ids {
172 db_query = match parent_id {
173 Value::Null => db_query.bind(None::<i32>),
174 Value::Bool(b) => db_query.bind(b),
175 Value::Number(n) => {
176 if let Some(i) = n.as_i64() {
177 db_query.bind(i)
178 } else if let Some(f) = n.as_f64() {
179 db_query.bind(f)
180 } else {
181 return Err(OrmError::Query("Invalid number type".into()));
182 }
183 }
184 Value::String(s) => db_query.bind(s.as_str()),
185 _ => return Err(OrmError::Query("Unsupported ID type".into())),
186 };
187 }
188
189 let rows = db_query.fetch_all(connection).await.map_err(|e| {
190 OrmError::Database(format!("Relationship batch query failed: {}", e))
191 })?;
192
193 let mut results = Vec::new();
195 for row in rows {
196 let json_row = self.row_to_json(&row).map_err(|e| {
197 OrmError::Database(format!("Failed to convert row to JSON: {}", e))
198 })?;
199 results.push(json_row);
200 }
201
202 if self.config.deduplicate_queries {
204 let mut cache = self.query_cache.write().await;
205 cache.insert(cache_key, results.clone());
206 }
207
208 self.group_by_parent_id(results, foreign_key, &parent_ids)
210 }
211
212 pub async fn load_nested_relationships(
214 &self,
215 root_table: &str,
216 root_ids: Vec<Value>,
217 relationship_path: &str,
218 connection: &sqlx::PgPool,
219 ) -> OrmResult<HashMap<Value, JsonValue>> {
220 if root_ids.is_empty() || relationship_path.is_empty() {
221 return Ok(HashMap::new());
222 }
223
224 let relations: Vec<&str> = relationship_path.split('.').collect();
226 let mut current_ids = root_ids.clone();
227 let mut results: HashMap<Value, JsonValue> = HashMap::new();
228
229 for (depth, relation) in relations.iter().enumerate() {
231 if current_ids.is_empty() {
232 break;
233 }
234
235 let (related_table, foreign_key) = self.get_relationship_mapping(relation)?;
237
238 let level_results = self.load_relationships_optimized(
240 &format!("level_{}", depth),
241 current_ids,
242 relation,
243 &foreign_key,
244 &related_table,
245 connection,
246 ).await?;
247
248 current_ids = level_results
250 .values()
251 .flatten()
252 .filter_map(|record| record.get("id").cloned())
253 .collect();
254
255 self.merge_nested_results(&mut results, level_results, depth == 0);
257 }
258
259 Ok(results)
260 }
261
262 async fn load_relationships_optimized(
264 &self,
265 parent_type: &str,
266 parent_ids: Vec<Value>,
267 relationship_name: &str,
268 foreign_key: &str,
269 related_table: &str,
270 connection: &sqlx::PgPool,
271 ) -> OrmResult<HashMap<Value, Vec<JsonValue>>> {
272 let optimal_batch_size = std::cmp::min(self.config.max_batch_size, 50);
274 let mut all_results: HashMap<Value, Vec<JsonValue>> = HashMap::new();
275
276 for chunk in parent_ids.chunks(optimal_batch_size) {
278 let chunk_results = self.load_relationships(
279 parent_type,
280 chunk.to_vec(),
281 relationship_name,
282 foreign_key,
283 related_table,
284 connection,
285 ).await?;
286
287 for (parent_id, relations) in chunk_results {
289 all_results.entry(parent_id).or_insert_with(Vec::new).extend(relations);
290 }
291 }
292
293 Ok(all_results)
294 }
295
296 fn get_relationship_mapping(&self, relation: &str) -> OrmResult<(String, String)> {
298 match relation {
301 "posts" => Ok(("posts".to_string(), "user_id".to_string())),
302 "comments" => Ok(("comments".to_string(), "post_id".to_string())),
303 "user" => Ok(("users".to_string(), "user_id".to_string())),
304 "profile" => Ok(("profiles".to_string(), "user_id".to_string())),
305 _ => Ok((format!("{}s", relation), format!("{}_id", relation))),
306 }
307 }
308
309 fn merge_nested_results(
311 &self,
312 target: &mut HashMap<Value, JsonValue>,
313 source: HashMap<Value, Vec<JsonValue>>,
314 is_root: bool,
315 ) {
316 for (parent_id, relations) in source {
317 if is_root {
318 let parent_id_copy = parent_id.clone();
320 target.insert(parent_id, serde_json::json!({
321 "id": parent_id_copy,
322 "relations": relations
323 }));
324 } else {
325 if let Some(existing) = target.get_mut(&parent_id) {
327 if let Some(obj) = existing.as_object_mut() {
328 obj.insert("nested_relations".to_string(), serde_json::json!(relations));
329 }
330 }
331 }
332 }
333 }
334
335 fn group_by_parent_id(
337 &self,
338 results: Vec<JsonValue>,
339 foreign_key: &str,
340 parent_ids: &[Value],
341 ) -> OrmResult<HashMap<Value, Vec<JsonValue>>> {
342 let mut grouped: HashMap<Value, Vec<JsonValue>> = HashMap::new();
343
344 for parent_id in parent_ids {
346 grouped.insert(parent_id.clone(), Vec::new());
347 }
348
349 for result in results {
351 if let Some(fk_value) = result.get(foreign_key) {
352 let parent_id = serde_json::from_value(fk_value.clone())
353 .unwrap_or(Value::Null);
354
355 grouped
356 .entry(parent_id)
357 .or_insert_with(Vec::new)
358 .push(result);
359 }
360 }
361
362 Ok(grouped)
363 }
364
365 pub async fn clear_cache(&self) {
367 let mut cache = self.query_cache.write().await;
368 cache.clear();
369 }
370
371 pub async fn cache_stats(&self) -> CacheStats {
373 let cache = self.query_cache.read().await;
374 CacheStats {
375 cached_queries: cache.len(),
376 total_cached_records: cache.values().map(|v| v.len()).sum(),
377 }
378 }
379}
380
381#[cfg(test)]
382mod tests;