1use std::collections::HashMap;
4
5use crate::filter::FilterValue;
6use crate::traits::QueryEngine;
7
8use super::include::IncludeSpec;
9use super::spec::RelationSpec;
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
13pub enum RelationLoadStrategy {
14 #[default]
16 Separate,
17 Join,
19 Lazy,
21}
22
23impl RelationLoadStrategy {
24 pub fn is_separate(&self) -> bool {
26 matches!(self, Self::Separate)
27 }
28
29 pub fn is_join(&self) -> bool {
31 matches!(self, Self::Join)
32 }
33
34 pub fn is_lazy(&self) -> bool {
36 matches!(self, Self::Lazy)
37 }
38}
39
40pub struct RelationLoader<E: QueryEngine> {
42 engine: E,
43 strategy: RelationLoadStrategy,
44 batch_size: usize,
45}
46
47impl<E: QueryEngine> RelationLoader<E> {
48 pub fn new(engine: E) -> Self {
50 Self {
51 engine,
52 strategy: RelationLoadStrategy::Separate,
53 batch_size: 100,
54 }
55 }
56
57 pub fn with_strategy(mut self, strategy: RelationLoadStrategy) -> Self {
59 self.strategy = strategy;
60 self
61 }
62
63 pub fn with_batch_size(mut self, size: usize) -> Self {
65 self.batch_size = size;
66 self
67 }
68
69 pub fn engine(&self) -> &E {
71 &self.engine
72 }
73
74 pub fn build_one_to_many_query(
76 &self,
77 spec: &RelationSpec,
78 include: &IncludeSpec,
79 parent_ids: &[FilterValue],
80 ) -> (String, Vec<FilterValue>) {
81 let mut sql = format!(
82 "SELECT * FROM {} WHERE {} IN (",
83 spec.related_table,
84 spec.references.first().unwrap_or(&"id".to_string())
85 );
86
87 let placeholders: Vec<_> = (1..=parent_ids.len())
88 .map(|i| format!("${}", i))
89 .collect();
90 sql.push_str(&placeholders.join(", "));
91 sql.push(')');
92
93 if let Some(ref filter) = include.filter {
95 let (filter_sql, filter_params) = filter.to_sql(parent_ids.len());
96 sql.push_str(" AND ");
97 sql.push_str(&filter_sql);
98
99 let mut params = parent_ids.to_vec();
100 params.extend(filter_params);
101 return (sql, params);
102 }
103
104 if let Some(ref order) = include.order_by {
106 sql.push_str(" ORDER BY ");
107 sql.push_str(&order.to_sql());
108 }
109
110 if let Some(ref pagination) = include.pagination {
112 let pagination_sql = pagination.to_sql();
113 if !pagination_sql.is_empty() {
114 sql.push(' ');
115 sql.push_str(&pagination_sql);
116 }
117 }
118
119 (sql, parent_ids.to_vec())
120 }
121
122 pub fn build_many_to_one_query(
124 &self,
125 spec: &RelationSpec,
126 child_foreign_keys: &[FilterValue],
127 ) -> (String, Vec<FilterValue>) {
128 let default_pk = "id".to_string();
129 let pk = spec.references.first().unwrap_or(&default_pk);
130
131 let mut sql = format!(
132 "SELECT * FROM {} WHERE {} IN (",
133 spec.related_table, pk
134 );
135
136 let placeholders: Vec<_> = (1..=child_foreign_keys.len())
137 .map(|i| format!("${}", i))
138 .collect();
139 sql.push_str(&placeholders.join(", "));
140 sql.push(')');
141
142 (sql, child_foreign_keys.to_vec())
143 }
144
145 pub fn build_many_to_many_query(
147 &self,
148 spec: &RelationSpec,
149 include: &IncludeSpec,
150 parent_ids: &[FilterValue],
151 ) -> (String, Vec<FilterValue>) {
152 let jt = spec.join_table.as_ref().expect("many-to-many requires join table");
153
154 let mut sql = format!(
155 "SELECT t.*, jt.{} as _parent_id FROM {} t \
156 INNER JOIN {} jt ON t.{} = jt.{} \
157 WHERE jt.{} IN (",
158 jt.source_column,
159 spec.related_table,
160 jt.table_name,
161 spec.references.first().unwrap_or(&"id".to_string()),
162 jt.target_column,
163 jt.source_column
164 );
165
166 let placeholders: Vec<_> = (1..=parent_ids.len())
167 .map(|i| format!("${}", i))
168 .collect();
169 sql.push_str(&placeholders.join(", "));
170 sql.push(')');
171
172 if let Some(ref order) = include.order_by {
174 sql.push_str(" ORDER BY ");
175 sql.push_str(&order.to_sql());
176 }
177
178 (sql, parent_ids.to_vec())
179 }
180}
181
182impl<E: QueryEngine + Clone> Clone for RelationLoader<E> {
183 fn clone(&self) -> Self {
184 Self {
185 engine: self.engine.clone(),
186 strategy: self.strategy,
187 batch_size: self.batch_size,
188 }
189 }
190}
191
192pub type RelationLoadResult<T> = HashMap<String, Vec<T>>;
194
195#[derive(Debug)]
197pub struct BatchLoadContext {
198 pub parent_ids: Vec<FilterValue>,
200 pub group_by_field: String,
202}
203
204impl BatchLoadContext {
205 pub fn new(parent_ids: Vec<FilterValue>, group_by_field: impl Into<String>) -> Self {
207 Self {
208 parent_ids,
209 group_by_field: group_by_field.into(),
210 }
211 }
212}
213
214#[cfg(test)]
215mod tests {
216 use super::*;
217 use crate::error::{QueryError, QueryResult};
218 use crate::traits::{BoxFuture, Model};
219
220 struct TestModel;
221
222 impl Model for TestModel {
223 const MODEL_NAME: &'static str = "TestModel";
224 const TABLE_NAME: &'static str = "test_models";
225 const PRIMARY_KEY: &'static [&'static str] = &["id"];
226 const COLUMNS: &'static [&'static str] = &["id", "name"];
227 }
228
229 #[derive(Clone)]
230 struct MockEngine;
231
232 impl QueryEngine for MockEngine {
233 fn query_many<T: Model + Send + 'static>(
234 &self,
235 _sql: &str,
236 _params: Vec<FilterValue>,
237 ) -> BoxFuture<'_, QueryResult<Vec<T>>> {
238 Box::pin(async { Ok(Vec::new()) })
239 }
240
241 fn query_one<T: Model + Send + 'static>(
242 &self,
243 _sql: &str,
244 _params: Vec<FilterValue>,
245 ) -> BoxFuture<'_, QueryResult<T>> {
246 Box::pin(async { Err(QueryError::not_found("test")) })
247 }
248
249 fn query_optional<T: Model + Send + 'static>(
250 &self,
251 _sql: &str,
252 _params: Vec<FilterValue>,
253 ) -> BoxFuture<'_, QueryResult<Option<T>>> {
254 Box::pin(async { Ok(None) })
255 }
256
257 fn execute_insert<T: Model + Send + 'static>(
258 &self,
259 _sql: &str,
260 _params: Vec<FilterValue>,
261 ) -> BoxFuture<'_, QueryResult<T>> {
262 Box::pin(async { Err(QueryError::not_found("test")) })
263 }
264
265 fn execute_update<T: Model + Send + 'static>(
266 &self,
267 _sql: &str,
268 _params: Vec<FilterValue>,
269 ) -> BoxFuture<'_, QueryResult<Vec<T>>> {
270 Box::pin(async { Ok(Vec::new()) })
271 }
272
273 fn execute_delete(
274 &self,
275 _sql: &str,
276 _params: Vec<FilterValue>,
277 ) -> BoxFuture<'_, QueryResult<u64>> {
278 Box::pin(async { Ok(0) })
279 }
280
281 fn execute_raw(
282 &self,
283 _sql: &str,
284 _params: Vec<FilterValue>,
285 ) -> BoxFuture<'_, QueryResult<u64>> {
286 Box::pin(async { Ok(0) })
287 }
288
289 fn count(
290 &self,
291 _sql: &str,
292 _params: Vec<FilterValue>,
293 ) -> BoxFuture<'_, QueryResult<u64>> {
294 Box::pin(async { Ok(0) })
295 }
296 }
297
298 #[test]
299 fn test_relation_load_strategy() {
300 assert!(RelationLoadStrategy::Separate.is_separate());
301 assert!(RelationLoadStrategy::Join.is_join());
302 assert!(RelationLoadStrategy::Lazy.is_lazy());
303 }
304
305 #[test]
306 fn test_one_to_many_query() {
307 let loader = RelationLoader::new(MockEngine);
308 let spec = RelationSpec::one_to_many("posts", "Post", "posts")
309 .references(["author_id"]);
310 let include = IncludeSpec::new("posts");
311 let parent_ids = vec![FilterValue::Int(1), FilterValue::Int(2)];
312
313 let (sql, params) = loader.build_one_to_many_query(&spec, &include, &parent_ids);
314
315 assert!(sql.contains("SELECT * FROM posts"));
316 assert!(sql.contains("WHERE author_id IN"));
317 assert_eq!(params.len(), 2);
318 }
319
320 #[test]
321 fn test_many_to_one_query() {
322 let loader = RelationLoader::new(MockEngine);
323 let spec = RelationSpec::many_to_one("author", "User", "users")
324 .references(["id"]);
325 let foreign_keys = vec![FilterValue::Int(1), FilterValue::Int(2)];
326
327 let (sql, params) = loader.build_many_to_one_query(&spec, &foreign_keys);
328
329 assert!(sql.contains("SELECT * FROM users"));
330 assert!(sql.contains("WHERE id IN"));
331 assert_eq!(params.len(), 2);
332 }
333}
334