1use std::collections::HashMap;
4
5use crate::filter::FilterValue;
6use crate::traits::QueryEngine;
7
8use super::include::IncludeSpec;
9use super::spec::RelationSpec;
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
13pub enum RelationLoadStrategy {
14 #[default]
16 Separate,
17 Join,
19 Lazy,
21}
22
23impl RelationLoadStrategy {
24 pub fn is_separate(&self) -> bool {
26 matches!(self, Self::Separate)
27 }
28
29 pub fn is_join(&self) -> bool {
31 matches!(self, Self::Join)
32 }
33
34 pub fn is_lazy(&self) -> bool {
36 matches!(self, Self::Lazy)
37 }
38}
39
40pub struct RelationLoader<E: QueryEngine> {
42 engine: E,
43 strategy: RelationLoadStrategy,
44 batch_size: usize,
45}
46
47impl<E: QueryEngine> RelationLoader<E> {
48 pub fn new(engine: E) -> Self {
50 Self {
51 engine,
52 strategy: RelationLoadStrategy::Separate,
53 batch_size: 100,
54 }
55 }
56
57 pub fn with_strategy(mut self, strategy: RelationLoadStrategy) -> Self {
59 self.strategy = strategy;
60 self
61 }
62
63 pub fn with_batch_size(mut self, size: usize) -> Self {
65 self.batch_size = size;
66 self
67 }
68
69 pub fn engine(&self) -> &E {
71 &self.engine
72 }
73
74 pub fn build_one_to_many_query(
82 &self,
83 spec: &RelationSpec,
84 include: &IncludeSpec,
85 parent_ids: &[FilterValue],
86 ) -> (String, Vec<FilterValue>) {
87 let mut sql = format!(
88 "SELECT * FROM {} WHERE {} IN (",
89 spec.related_table,
90 spec.references.first().unwrap_or(&"id".to_string())
91 );
92
93 let placeholders: Vec<_> = (1..=parent_ids.len()).map(|i| format!("${}", i)).collect();
94 sql.push_str(&placeholders.join(", "));
95 sql.push(')');
96
97 if let Some(ref filter) = include.filter {
99 let (filter_sql, filter_params) =
100 filter.to_sql(parent_ids.len(), &crate::dialect::Postgres);
101 sql.push_str(" AND ");
102 sql.push_str(&filter_sql);
103
104 let mut params = parent_ids.to_vec();
105 params.extend(filter_params);
106 return (sql, params);
107 }
108
109 if let Some(ref order) = include.order_by {
111 sql.push_str(" ORDER BY ");
112 sql.push_str(&order.to_sql());
113 }
114
115 if let Some(ref pagination) = include.pagination {
117 let pagination_sql = pagination.to_sql();
118 if !pagination_sql.is_empty() {
119 sql.push(' ');
120 sql.push_str(&pagination_sql);
121 }
122 }
123
124 (sql, parent_ids.to_vec())
125 }
126
127 pub fn build_many_to_one_query(
129 &self,
130 spec: &RelationSpec,
131 child_foreign_keys: &[FilterValue],
132 ) -> (String, Vec<FilterValue>) {
133 let default_pk = "id".to_string();
134 let pk = spec.references.first().unwrap_or(&default_pk);
135
136 let mut sql = format!("SELECT * FROM {} WHERE {} IN (", spec.related_table, pk);
137
138 let placeholders: Vec<_> = (1..=child_foreign_keys.len())
139 .map(|i| format!("${}", i))
140 .collect();
141 sql.push_str(&placeholders.join(", "));
142 sql.push(')');
143
144 (sql, child_foreign_keys.to_vec())
145 }
146
147 pub fn build_many_to_many_query(
149 &self,
150 spec: &RelationSpec,
151 include: &IncludeSpec,
152 parent_ids: &[FilterValue],
153 ) -> (String, Vec<FilterValue>) {
154 let jt = spec
155 .join_table
156 .as_ref()
157 .expect("many-to-many requires join table");
158
159 let mut sql = format!(
160 "SELECT t.*, jt.{} as _parent_id FROM {} t \
161 INNER JOIN {} jt ON t.{} = jt.{} \
162 WHERE jt.{} IN (",
163 jt.source_column,
164 spec.related_table,
165 jt.table_name,
166 spec.references.first().unwrap_or(&"id".to_string()),
167 jt.target_column,
168 jt.source_column
169 );
170
171 let placeholders: Vec<_> = (1..=parent_ids.len()).map(|i| format!("${}", i)).collect();
172 sql.push_str(&placeholders.join(", "));
173 sql.push(')');
174
175 if let Some(ref order) = include.order_by {
177 sql.push_str(" ORDER BY ");
178 sql.push_str(&order.to_sql());
179 }
180
181 (sql, parent_ids.to_vec())
182 }
183}
184
185impl<E: QueryEngine + Clone> Clone for RelationLoader<E> {
186 fn clone(&self) -> Self {
187 Self {
188 engine: self.engine.clone(),
189 strategy: self.strategy,
190 batch_size: self.batch_size,
191 }
192 }
193}
194
195pub type RelationLoadResult<T> = HashMap<String, Vec<T>>;
197
198#[derive(Debug)]
200pub struct BatchLoadContext {
201 pub parent_ids: Vec<FilterValue>,
203 pub group_by_field: String,
205}
206
207impl BatchLoadContext {
208 pub fn new(parent_ids: Vec<FilterValue>, group_by_field: impl Into<String>) -> Self {
210 Self {
211 parent_ids,
212 group_by_field: group_by_field.into(),
213 }
214 }
215}
216
217#[cfg(test)]
218mod tests {
219 use super::*;
220 use crate::error::{QueryError, QueryResult};
221 use crate::traits::{BoxFuture, Model};
222
223 struct TestModel;
224
225 impl Model for TestModel {
226 const MODEL_NAME: &'static str = "TestModel";
227 const TABLE_NAME: &'static str = "test_models";
228 const PRIMARY_KEY: &'static [&'static str] = &["id"];
229 const COLUMNS: &'static [&'static str] = &["id", "name"];
230 }
231
232 #[derive(Clone)]
233 struct MockEngine;
234
235 impl QueryEngine for MockEngine {
236 fn dialect(&self) -> &dyn crate::dialect::SqlDialect {
237 &crate::dialect::Postgres
238 }
239
240 fn query_many<T: Model + Send + 'static>(
241 &self,
242 _sql: &str,
243 _params: Vec<FilterValue>,
244 ) -> BoxFuture<'_, QueryResult<Vec<T>>> {
245 Box::pin(async { Ok(Vec::new()) })
246 }
247
248 fn query_one<T: Model + Send + 'static>(
249 &self,
250 _sql: &str,
251 _params: Vec<FilterValue>,
252 ) -> BoxFuture<'_, QueryResult<T>> {
253 Box::pin(async { Err(QueryError::not_found("test")) })
254 }
255
256 fn query_optional<T: Model + Send + 'static>(
257 &self,
258 _sql: &str,
259 _params: Vec<FilterValue>,
260 ) -> BoxFuture<'_, QueryResult<Option<T>>> {
261 Box::pin(async { Ok(None) })
262 }
263
264 fn execute_insert<T: Model + Send + 'static>(
265 &self,
266 _sql: &str,
267 _params: Vec<FilterValue>,
268 ) -> BoxFuture<'_, QueryResult<T>> {
269 Box::pin(async { Err(QueryError::not_found("test")) })
270 }
271
272 fn execute_update<T: Model + Send + 'static>(
273 &self,
274 _sql: &str,
275 _params: Vec<FilterValue>,
276 ) -> BoxFuture<'_, QueryResult<Vec<T>>> {
277 Box::pin(async { Ok(Vec::new()) })
278 }
279
280 fn execute_delete(
281 &self,
282 _sql: &str,
283 _params: Vec<FilterValue>,
284 ) -> BoxFuture<'_, QueryResult<u64>> {
285 Box::pin(async { Ok(0) })
286 }
287
288 fn execute_raw(
289 &self,
290 _sql: &str,
291 _params: Vec<FilterValue>,
292 ) -> BoxFuture<'_, QueryResult<u64>> {
293 Box::pin(async { Ok(0) })
294 }
295
296 fn count(&self, _sql: &str, _params: Vec<FilterValue>) -> BoxFuture<'_, QueryResult<u64>> {
297 Box::pin(async { Ok(0) })
298 }
299 }
300
301 #[test]
302 fn test_relation_load_strategy() {
303 assert!(RelationLoadStrategy::Separate.is_separate());
304 assert!(RelationLoadStrategy::Join.is_join());
305 assert!(RelationLoadStrategy::Lazy.is_lazy());
306 }
307
308 #[test]
309 fn test_one_to_many_query() {
310 let loader = RelationLoader::new(MockEngine);
311 let spec = RelationSpec::one_to_many("posts", "Post", "posts").references(["author_id"]);
312 let include = IncludeSpec::new("posts");
313 let parent_ids = vec![FilterValue::Int(1), FilterValue::Int(2)];
314
315 let (sql, params) = loader.build_one_to_many_query(&spec, &include, &parent_ids);
316
317 assert!(sql.contains("SELECT * FROM posts"));
318 assert!(sql.contains("WHERE author_id IN"));
319 assert_eq!(params.len(), 2);
320 }
321
322 #[test]
323 fn test_many_to_one_query() {
324 let loader = RelationLoader::new(MockEngine);
325 let spec = RelationSpec::many_to_one("author", "User", "users").references(["id"]);
326 let foreign_keys = vec![FilterValue::Int(1), FilterValue::Int(2)];
327
328 let (sql, params) = loader.build_many_to_one_query(&spec, &foreign_keys);
329
330 assert!(sql.contains("SELECT * FROM users"));
331 assert!(sql.contains("WHERE id IN"));
332 assert_eq!(params.len(), 2);
333 }
334}