elif_orm/loading/batch_loader/
mod.rs1use crate::{
2 error::{OrmError, OrmResult},
3 model::Model,
4 query::QueryBuilder,
5};
6use serde_json::Value as JsonValue;
7use serde_json::Value;
8use std::collections::HashMap;
9use std::sync::Arc;
10use tokio::sync::RwLock;
11
12pub mod config;
13pub mod row_conversion;
14
15pub use config::{BatchConfig, CacheStats};
16
17#[derive(Debug)]
19pub struct BatchLoadResult {
20 pub records: HashMap<String, HashMap<Value, JsonValue>>,
22 pub query_count: usize,
24 pub record_count: usize,
26}
27
28#[derive(Clone)]
30pub struct BatchLoader {
31 config: BatchConfig,
32 query_cache: Arc<RwLock<HashMap<String, Vec<JsonValue>>>>,
33}
34
35impl Default for BatchLoader {
36 fn default() -> Self {
37 Self::new()
38 }
39}
40
41impl BatchLoader {
42 pub fn new() -> Self {
44 Self::with_config(BatchConfig::default())
45 }
46
47 pub fn with_config(config: BatchConfig) -> Self {
49 Self {
50 config,
51 query_cache: Arc::new(RwLock::new(HashMap::new())),
52 }
53 }
54
55 pub async fn load_batch<M: Model>(
57 &self,
58 ids: Vec<Value>,
59 table: &str,
60 connection: &sqlx::PgPool,
61 ) -> OrmResult<Vec<JsonValue>> {
62 if ids.is_empty() {
63 return Ok(Vec::new());
64 }
65
66 let mut all_results = Vec::new();
67 let chunks: Vec<_> = ids.chunks(self.config.max_batch_size).collect();
68
69 for chunk in chunks {
70 let results = self.execute_batch_query(chunk, table, connection).await?;
71 all_results.extend(results);
72 }
73
74 Ok(all_results)
75 }
76
77 async fn execute_batch_query(
79 &self,
80 ids: &[Value],
81 table: &str,
82 connection: &sqlx::PgPool,
83 ) -> OrmResult<Vec<JsonValue>> {
84 let id_values: Vec<String> = ids
86 .iter()
87 .enumerate()
88 .map(|(i, _)| format!("${}", i + 1))
89 .collect();
90
91 let query = QueryBuilder::<()>::new()
92 .from(table)
93 .where_raw(&format!("id = ANY(ARRAY[{}])", id_values.join(", ")));
94
95 let (sql, _params) = query.to_sql_with_params();
96 let mut db_query = sqlx::query(&sql);
97
98 for id in ids {
100 db_query = match id {
101 Value::Null => db_query.bind(None::<i32>),
102 Value::Bool(b) => db_query.bind(b),
103 Value::Number(n) => {
104 if let Some(i) = n.as_i64() {
105 db_query.bind(i)
106 } else if let Some(f) = n.as_f64() {
107 db_query.bind(f)
108 } else {
109 return Err(OrmError::Query("Invalid number type".into()));
110 }
111 }
112 Value::String(s) => db_query.bind(s.as_str()),
113 _ => return Err(OrmError::Query("Unsupported ID type".into())),
114 };
115 }
116
117 let rows = db_query
118 .fetch_all(connection)
119 .await
120 .map_err(|e| OrmError::Database(format!("Batch query failed: {}", e)))?;
121
122 let mut results = Vec::new();
124 for row in rows {
125 let json_row = self
126 .row_to_json(&row)
127 .map_err(|e| OrmError::Database(format!("Failed to convert row to JSON: {}", e)))?;
128 results.push(json_row);
129 }
130
131 Ok(results)
132 }
133
134 pub async fn load_relationships(
136 &self,
137 parent_type: &str,
138 parent_ids: Vec<Value>,
139 relationship_name: &str,
140 foreign_key: &str,
141 related_table: &str,
142 connection: &sqlx::PgPool,
143 ) -> OrmResult<HashMap<Value, Vec<JsonValue>>> {
144 if parent_ids.is_empty() {
145 return Ok(HashMap::new());
146 }
147
148 let cache_key = format!("{}:{}:{:?}", parent_type, relationship_name, parent_ids);
150
151 if self.config.deduplicate_queries {
152 let cache = self.query_cache.read().await;
153 if let Some(cached_results) = cache.get(&cache_key) {
154 return self.group_by_parent_id(cached_results.clone(), foreign_key, &parent_ids);
155 }
156 }
157
158 let parent_id_values: Vec<String> = parent_ids
160 .iter()
161 .enumerate()
162 .map(|(i, _)| format!("${}", i + 1))
163 .collect();
164
165 let query = QueryBuilder::<()>::new()
167 .from(related_table)
168 .where_raw(&format!(
169 "{} = ANY(ARRAY[{}])",
170 foreign_key,
171 parent_id_values.join(", ")
172 ));
173
174 let (sql, _params) = query.to_sql_with_params();
175 let mut db_query = sqlx::query(&sql);
176
177 for parent_id in &parent_ids {
179 db_query = match parent_id {
180 Value::Null => db_query.bind(None::<i32>),
181 Value::Bool(b) => db_query.bind(b),
182 Value::Number(n) => {
183 if let Some(i) = n.as_i64() {
184 db_query.bind(i)
185 } else if let Some(f) = n.as_f64() {
186 db_query.bind(f)
187 } else {
188 return Err(OrmError::Query("Invalid number type".into()));
189 }
190 }
191 Value::String(s) => db_query.bind(s.as_str()),
192 _ => return Err(OrmError::Query("Unsupported ID type".into())),
193 };
194 }
195
196 let rows = db_query
197 .fetch_all(connection)
198 .await
199 .map_err(|e| OrmError::Database(format!("Relationship batch query failed: {}", e)))?;
200
201 let mut results = Vec::new();
203 for row in rows {
204 let json_row = self
205 .row_to_json(&row)
206 .map_err(|e| OrmError::Database(format!("Failed to convert row to JSON: {}", e)))?;
207 results.push(json_row);
208 }
209
210 if self.config.deduplicate_queries {
212 let mut cache = self.query_cache.write().await;
213 cache.insert(cache_key, results.clone());
214 }
215
216 self.group_by_parent_id(results, foreign_key, &parent_ids)
218 }
219
220 pub async fn load_nested_relationships(
222 &self,
223 _root_table: &str,
224 root_ids: Vec<Value>,
225 relationship_path: &str,
226 connection: &sqlx::PgPool,
227 ) -> OrmResult<HashMap<Value, JsonValue>> {
228 if root_ids.is_empty() || relationship_path.is_empty() {
229 return Ok(HashMap::new());
230 }
231
232 let relations: Vec<&str> = relationship_path.split('.').collect();
234 let mut current_ids = root_ids.clone();
235 let mut results: HashMap<Value, JsonValue> = HashMap::new();
236
237 for (depth, relation) in relations.iter().enumerate() {
239 if current_ids.is_empty() {
240 break;
241 }
242
243 let (related_table, foreign_key) = self.get_relationship_mapping(relation)?;
245
246 let level_results = self
248 .load_relationships_optimized(
249 &format!("level_{}", depth),
250 current_ids,
251 relation,
252 &foreign_key,
253 &related_table,
254 connection,
255 )
256 .await?;
257
258 current_ids = level_results
260 .values()
261 .flatten()
262 .filter_map(|record| record.get("id").cloned())
263 .collect();
264
265 self.merge_nested_results(&mut results, level_results, depth == 0);
267 }
268
269 Ok(results)
270 }
271
272 async fn load_relationships_optimized(
274 &self,
275 parent_type: &str,
276 parent_ids: Vec<Value>,
277 relationship_name: &str,
278 foreign_key: &str,
279 related_table: &str,
280 connection: &sqlx::PgPool,
281 ) -> OrmResult<HashMap<Value, Vec<JsonValue>>> {
282 let optimal_batch_size = std::cmp::min(self.config.max_batch_size, 50);
284 let mut all_results: HashMap<Value, Vec<JsonValue>> = HashMap::new();
285
286 for chunk in parent_ids.chunks(optimal_batch_size) {
288 let chunk_results = self
289 .load_relationships(
290 parent_type,
291 chunk.to_vec(),
292 relationship_name,
293 foreign_key,
294 related_table,
295 connection,
296 )
297 .await?;
298
299 for (parent_id, relations) in chunk_results {
301 all_results.entry(parent_id).or_default().extend(relations);
302 }
303 }
304
305 Ok(all_results)
306 }
307
308 fn get_relationship_mapping(&self, relation: &str) -> OrmResult<(String, String)> {
310 match relation {
313 "posts" => Ok(("posts".to_string(), "user_id".to_string())),
314 "comments" => Ok(("comments".to_string(), "post_id".to_string())),
315 "user" => Ok(("users".to_string(), "user_id".to_string())),
316 "profile" => Ok(("profiles".to_string(), "user_id".to_string())),
317 _ => Ok((format!("{}s", relation), format!("{}_id", relation))),
318 }
319 }
320
321 fn merge_nested_results(
323 &self,
324 target: &mut HashMap<Value, JsonValue>,
325 source: HashMap<Value, Vec<JsonValue>>,
326 is_root: bool,
327 ) {
328 for (parent_id, relations) in source {
329 if is_root {
330 let parent_id_copy = parent_id.clone();
332 target.insert(
333 parent_id,
334 serde_json::json!({
335 "id": parent_id_copy,
336 "relations": relations
337 }),
338 );
339 } else {
340 if let Some(existing) = target.get_mut(&parent_id) {
342 if let Some(obj) = existing.as_object_mut() {
343 obj.insert("nested_relations".to_string(), serde_json::json!(relations));
344 }
345 }
346 }
347 }
348 }
349
350 fn group_by_parent_id(
352 &self,
353 results: Vec<JsonValue>,
354 foreign_key: &str,
355 parent_ids: &[Value],
356 ) -> OrmResult<HashMap<Value, Vec<JsonValue>>> {
357 let mut grouped: HashMap<Value, Vec<JsonValue>> = HashMap::new();
358
359 for parent_id in parent_ids {
361 grouped.insert(parent_id.clone(), Vec::new());
362 }
363
364 for result in results {
366 if let Some(fk_value) = result.get(foreign_key) {
367 let parent_id = serde_json::from_value(fk_value.clone()).unwrap_or(Value::Null);
368
369 grouped.entry(parent_id).or_default().push(result);
370 }
371 }
372
373 Ok(grouped)
374 }
375
376 pub async fn clear_cache(&self) {
378 let mut cache = self.query_cache.write().await;
379 cache.clear();
380 }
381
382 pub async fn cache_stats(&self) -> CacheStats {
384 let cache = self.query_cache.read().await;
385 CacheStats {
386 cached_queries: cache.len(),
387 total_cached_records: cache.values().map(|v| v.len()).sum(),
388 }
389 }
390}
391
392#[cfg(test)]
393mod tests;