1use std::collections::HashMap;
4
5use crate::filter::FilterValue;
6use crate::traits::QueryEngine;
7
8use super::include::IncludeSpec;
9use super::spec::RelationSpec;
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
13pub enum RelationLoadStrategy {
14 #[default]
16 Separate,
17 Join,
19 Lazy,
21}
22
23impl RelationLoadStrategy {
24 pub fn is_separate(&self) -> bool {
26 matches!(self, Self::Separate)
27 }
28
29 pub fn is_join(&self) -> bool {
31 matches!(self, Self::Join)
32 }
33
34 pub fn is_lazy(&self) -> bool {
36 matches!(self, Self::Lazy)
37 }
38}
39
40pub struct RelationLoader<E: QueryEngine> {
42 engine: E,
43 strategy: RelationLoadStrategy,
44 batch_size: usize,
45}
46
47impl<E: QueryEngine> RelationLoader<E> {
48 pub fn new(engine: E) -> Self {
50 Self {
51 engine,
52 strategy: RelationLoadStrategy::Separate,
53 batch_size: 100,
54 }
55 }
56
57 pub fn with_strategy(mut self, strategy: RelationLoadStrategy) -> Self {
59 self.strategy = strategy;
60 self
61 }
62
63 pub fn with_batch_size(mut self, size: usize) -> Self {
65 self.batch_size = size;
66 self
67 }
68
69 pub fn engine(&self) -> &E {
71 &self.engine
72 }
73
74 pub fn build_one_to_many_query(
76 &self,
77 spec: &RelationSpec,
78 include: &IncludeSpec,
79 parent_ids: &[FilterValue],
80 ) -> (String, Vec<FilterValue>) {
81 let mut sql = format!(
82 "SELECT * FROM {} WHERE {} IN (",
83 spec.related_table,
84 spec.references.first().unwrap_or(&"id".to_string())
85 );
86
87 let placeholders: Vec<_> = (1..=parent_ids.len()).map(|i| format!("${}", i)).collect();
88 sql.push_str(&placeholders.join(", "));
89 sql.push(')');
90
91 if let Some(ref filter) = include.filter {
93 let (filter_sql, filter_params) = filter.to_sql(parent_ids.len());
94 sql.push_str(" AND ");
95 sql.push_str(&filter_sql);
96
97 let mut params = parent_ids.to_vec();
98 params.extend(filter_params);
99 return (sql, params);
100 }
101
102 if let Some(ref order) = include.order_by {
104 sql.push_str(" ORDER BY ");
105 sql.push_str(&order.to_sql());
106 }
107
108 if let Some(ref pagination) = include.pagination {
110 let pagination_sql = pagination.to_sql();
111 if !pagination_sql.is_empty() {
112 sql.push(' ');
113 sql.push_str(&pagination_sql);
114 }
115 }
116
117 (sql, parent_ids.to_vec())
118 }
119
120 pub fn build_many_to_one_query(
122 &self,
123 spec: &RelationSpec,
124 child_foreign_keys: &[FilterValue],
125 ) -> (String, Vec<FilterValue>) {
126 let default_pk = "id".to_string();
127 let pk = spec.references.first().unwrap_or(&default_pk);
128
129 let mut sql = format!("SELECT * FROM {} WHERE {} IN (", spec.related_table, pk);
130
131 let placeholders: Vec<_> = (1..=child_foreign_keys.len())
132 .map(|i| format!("${}", i))
133 .collect();
134 sql.push_str(&placeholders.join(", "));
135 sql.push(')');
136
137 (sql, child_foreign_keys.to_vec())
138 }
139
140 pub fn build_many_to_many_query(
142 &self,
143 spec: &RelationSpec,
144 include: &IncludeSpec,
145 parent_ids: &[FilterValue],
146 ) -> (String, Vec<FilterValue>) {
147 let jt = spec
148 .join_table
149 .as_ref()
150 .expect("many-to-many requires join table");
151
152 let mut sql = format!(
153 "SELECT t.*, jt.{} as _parent_id FROM {} t \
154 INNER JOIN {} jt ON t.{} = jt.{} \
155 WHERE jt.{} IN (",
156 jt.source_column,
157 spec.related_table,
158 jt.table_name,
159 spec.references.first().unwrap_or(&"id".to_string()),
160 jt.target_column,
161 jt.source_column
162 );
163
164 let placeholders: Vec<_> = (1..=parent_ids.len()).map(|i| format!("${}", i)).collect();
165 sql.push_str(&placeholders.join(", "));
166 sql.push(')');
167
168 if let Some(ref order) = include.order_by {
170 sql.push_str(" ORDER BY ");
171 sql.push_str(&order.to_sql());
172 }
173
174 (sql, parent_ids.to_vec())
175 }
176}
177
178impl<E: QueryEngine + Clone> Clone for RelationLoader<E> {
179 fn clone(&self) -> Self {
180 Self {
181 engine: self.engine.clone(),
182 strategy: self.strategy,
183 batch_size: self.batch_size,
184 }
185 }
186}
187
188pub type RelationLoadResult<T> = HashMap<String, Vec<T>>;
190
191#[derive(Debug)]
193pub struct BatchLoadContext {
194 pub parent_ids: Vec<FilterValue>,
196 pub group_by_field: String,
198}
199
200impl BatchLoadContext {
201 pub fn new(parent_ids: Vec<FilterValue>, group_by_field: impl Into<String>) -> Self {
203 Self {
204 parent_ids,
205 group_by_field: group_by_field.into(),
206 }
207 }
208}
209
210#[cfg(test)]
211mod tests {
212 use super::*;
213 use crate::error::{QueryError, QueryResult};
214 use crate::traits::{BoxFuture, Model};
215
216 struct TestModel;
217
218 impl Model for TestModel {
219 const MODEL_NAME: &'static str = "TestModel";
220 const TABLE_NAME: &'static str = "test_models";
221 const PRIMARY_KEY: &'static [&'static str] = &["id"];
222 const COLUMNS: &'static [&'static str] = &["id", "name"];
223 }
224
225 #[derive(Clone)]
226 struct MockEngine;
227
228 impl QueryEngine for MockEngine {
229 fn query_many<T: Model + Send + 'static>(
230 &self,
231 _sql: &str,
232 _params: Vec<FilterValue>,
233 ) -> BoxFuture<'_, QueryResult<Vec<T>>> {
234 Box::pin(async { Ok(Vec::new()) })
235 }
236
237 fn query_one<T: Model + Send + 'static>(
238 &self,
239 _sql: &str,
240 _params: Vec<FilterValue>,
241 ) -> BoxFuture<'_, QueryResult<T>> {
242 Box::pin(async { Err(QueryError::not_found("test")) })
243 }
244
245 fn query_optional<T: Model + Send + 'static>(
246 &self,
247 _sql: &str,
248 _params: Vec<FilterValue>,
249 ) -> BoxFuture<'_, QueryResult<Option<T>>> {
250 Box::pin(async { Ok(None) })
251 }
252
253 fn execute_insert<T: Model + Send + 'static>(
254 &self,
255 _sql: &str,
256 _params: Vec<FilterValue>,
257 ) -> BoxFuture<'_, QueryResult<T>> {
258 Box::pin(async { Err(QueryError::not_found("test")) })
259 }
260
261 fn execute_update<T: Model + Send + 'static>(
262 &self,
263 _sql: &str,
264 _params: Vec<FilterValue>,
265 ) -> BoxFuture<'_, QueryResult<Vec<T>>> {
266 Box::pin(async { Ok(Vec::new()) })
267 }
268
269 fn execute_delete(
270 &self,
271 _sql: &str,
272 _params: Vec<FilterValue>,
273 ) -> BoxFuture<'_, QueryResult<u64>> {
274 Box::pin(async { Ok(0) })
275 }
276
277 fn execute_raw(
278 &self,
279 _sql: &str,
280 _params: Vec<FilterValue>,
281 ) -> BoxFuture<'_, QueryResult<u64>> {
282 Box::pin(async { Ok(0) })
283 }
284
285 fn count(&self, _sql: &str, _params: Vec<FilterValue>) -> BoxFuture<'_, QueryResult<u64>> {
286 Box::pin(async { Ok(0) })
287 }
288 }
289
290 #[test]
291 fn test_relation_load_strategy() {
292 assert!(RelationLoadStrategy::Separate.is_separate());
293 assert!(RelationLoadStrategy::Join.is_join());
294 assert!(RelationLoadStrategy::Lazy.is_lazy());
295 }
296
297 #[test]
298 fn test_one_to_many_query() {
299 let loader = RelationLoader::new(MockEngine);
300 let spec = RelationSpec::one_to_many("posts", "Post", "posts").references(["author_id"]);
301 let include = IncludeSpec::new("posts");
302 let parent_ids = vec![FilterValue::Int(1), FilterValue::Int(2)];
303
304 let (sql, params) = loader.build_one_to_many_query(&spec, &include, &parent_ids);
305
306 assert!(sql.contains("SELECT * FROM posts"));
307 assert!(sql.contains("WHERE author_id IN"));
308 assert_eq!(params.len(), 2);
309 }
310
311 #[test]
312 fn test_many_to_one_query() {
313 let loader = RelationLoader::new(MockEngine);
314 let spec = RelationSpec::many_to_one("author", "User", "users").references(["id"]);
315 let foreign_keys = vec![FilterValue::Int(1), FilterValue::Int(2)];
316
317 let (sql, params) = loader.build_many_to_one_query(&spec, &foreign_keys);
318
319 assert!(sql.contains("SELECT * FROM users"));
320 assert!(sql.contains("WHERE id IN"));
321 assert_eq!(params.len(), 2);
322 }
323}