1use sqlmodel_core::{Model, RelationshipInfo, RelationshipKind, Value};
8use std::marker::PhantomData;
9
10#[derive(Debug, Clone)]
21pub struct EagerLoader<T: Model> {
22 includes: Vec<IncludePath>,
24 _marker: PhantomData<T>,
26}
27
28#[derive(Debug, Clone)]
30pub struct IncludePath {
31 pub relationship: &'static str,
33 pub nested: Vec<IncludePath>,
35}
36
37impl IncludePath {
38 #[must_use]
40 pub fn new(relationship: &'static str) -> Self {
41 Self {
42 relationship,
43 nested: Vec::new(),
44 }
45 }
46
47 #[must_use]
49 pub fn nest(mut self, path: IncludePath) -> Self {
50 self.nested.push(path);
51 self
52 }
53}
54
55impl<T: Model> EagerLoader<T> {
56 #[must_use]
58 pub fn new() -> Self {
59 Self {
60 includes: Vec::new(),
61 _marker: PhantomData,
62 }
63 }
64
65 #[must_use]
73 pub fn include(mut self, relationship: &'static str) -> Self {
74 self.includes.push(IncludePath::new(relationship));
75 self
76 }
77
78 #[must_use]
86 pub fn include_nested(mut self, path: &'static str) -> Self {
87 let path = path.trim();
89 if path.is_empty() {
90 return self;
91 }
92
93 let parts: Vec<&'static str> = path.split('.').collect();
94 if parts.iter().all(|p| p.is_empty()) {
97 return self;
98 }
99
100 let parts: Vec<&'static str> = parts.into_iter().filter(|p| !p.is_empty()).collect();
102 if parts.is_empty() {
103 return self;
104 }
105
106 let include = Self::build_nested_path(&parts);
108 self.includes.push(include);
109 self
110 }
111
112 fn build_nested_path(parts: &[&'static str]) -> IncludePath {
114 if parts.len() == 1 {
115 IncludePath::new(parts[0])
116 } else {
117 let mut path = IncludePath::new(parts[0]);
118 path.nested.push(Self::build_nested_path(&parts[1..]));
119 path
120 }
121 }
122
123 #[must_use]
125 pub fn includes(&self) -> &[IncludePath] {
126 &self.includes
127 }
128
129 #[must_use]
131 pub fn has_includes(&self) -> bool {
132 !self.includes.is_empty()
133 }
134}
135
136impl<T: Model> Default for EagerLoader<T> {
137 fn default() -> Self {
138 Self::new()
139 }
140}
141
142#[must_use]
144pub fn find_relationship<M: Model>(name: &str) -> Option<&'static RelationshipInfo> {
145 M::RELATIONSHIPS.iter().find(|r| r.name == name)
146}
147
148#[must_use]
150pub fn build_join_clause(
151 parent_table: &str,
152 rel: &RelationshipInfo,
153 _param_offset: usize,
154) -> (String, Vec<Value>) {
155 let params = Vec::new();
156
157 let remote_pk = rel.remote_key.unwrap_or("id");
159
160 let sql = match rel.kind {
161 RelationshipKind::ManyToOne | RelationshipKind::OneToOne => {
162 let local_key = rel.local_key.unwrap_or("id");
164 format!(
165 " LEFT JOIN {} ON {}.{} = {}.{}",
166 rel.related_table, parent_table, local_key, rel.related_table, remote_pk
167 )
168 }
169 RelationshipKind::OneToMany => {
170 let fk_on_related = rel.remote_key.unwrap_or("id");
173 let local_pk = rel.local_key.unwrap_or("id");
175 format!(
176 " LEFT JOIN {} ON {}.{} = {}.{}",
177 rel.related_table, rel.related_table, fk_on_related, parent_table, local_pk
178 )
179 }
180 RelationshipKind::ManyToMany => {
181 if let Some(link) = &rel.link_table {
184 let local_pk = rel.local_key.unwrap_or("id");
185 format!(
186 " LEFT JOIN {} ON {}.{} = {}.{} LEFT JOIN {} ON {}.{} = {}.{}",
187 link.table_name,
188 parent_table,
189 local_pk,
190 link.table_name,
191 link.local_column,
192 rel.related_table,
193 link.table_name,
194 link.remote_column,
195 rel.related_table,
196 remote_pk
197 )
198 } else {
199 String::new()
200 }
201 }
202 };
203
204 (sql, params)
205}
206
207#[must_use]
211pub fn build_aliased_columns(table_name: &str, columns: &[&str]) -> String {
212 columns
213 .iter()
214 .map(|col| format!("{}.{} AS {}__{}", table_name, col, table_name, col))
215 .collect::<Vec<_>>()
216 .join(", ")
217}
218
219#[cfg(test)]
220mod tests {
221 use super::*;
222 use sqlmodel_core::{Error, FieldInfo, Model, Result, Row, Value};
223
224 #[derive(Debug, Clone)]
225 struct TestHero;
226
227 impl Model for TestHero {
228 const TABLE_NAME: &'static str = "heroes";
229 const PRIMARY_KEY: &'static [&'static str] = &["id"];
230 const RELATIONSHIPS: &'static [RelationshipInfo] =
231 &[
232 RelationshipInfo::new("team", "teams", RelationshipKind::ManyToOne)
233 .local_key("team_id"),
234 ];
235
236 fn fields() -> &'static [FieldInfo] {
237 &[]
238 }
239
240 fn to_row(&self) -> Vec<(&'static str, Value)> {
241 vec![]
242 }
243
244 fn from_row(_row: &Row) -> Result<Self> {
245 Err(Error::Custom("not used".to_string()))
246 }
247
248 fn primary_key_value(&self) -> Vec<Value> {
249 vec![]
250 }
251
252 fn is_new(&self) -> bool {
253 true
254 }
255 }
256
257 #[test]
258 fn test_eager_loader_new() {
259 let loader = EagerLoader::<TestHero>::new();
260 assert!(!loader.has_includes());
261 assert!(loader.includes().is_empty());
262 }
263
264 #[test]
265 fn test_eager_loader_include() {
266 let loader = EagerLoader::<TestHero>::new().include("team");
267 assert!(loader.has_includes());
268 assert_eq!(loader.includes().len(), 1);
269 assert_eq!(loader.includes()[0].relationship, "team");
270 }
271
272 #[test]
273 fn test_eager_loader_multiple_includes() {
274 let loader = EagerLoader::<TestHero>::new()
275 .include("team")
276 .include("powers");
277 assert_eq!(loader.includes().len(), 2);
278 }
279
280 #[test]
281 fn test_eager_loader_include_nested() {
282 let loader = EagerLoader::<TestHero>::new().include_nested("team.headquarters");
283 assert_eq!(loader.includes().len(), 1);
284 assert_eq!(loader.includes()[0].relationship, "team");
285 assert_eq!(loader.includes()[0].nested.len(), 1);
286 assert_eq!(loader.includes()[0].nested[0].relationship, "headquarters");
287 }
288
289 #[test]
290 fn test_eager_loader_include_deeply_nested() {
291 let loader =
292 EagerLoader::<TestHero>::new().include_nested("team.headquarters.city.country");
293 assert_eq!(loader.includes().len(), 1);
294 assert_eq!(loader.includes()[0].relationship, "team");
295 assert_eq!(loader.includes()[0].nested[0].relationship, "headquarters");
296 assert_eq!(
297 loader.includes()[0].nested[0].nested[0].relationship,
298 "city"
299 );
300 assert_eq!(
301 loader.includes()[0].nested[0].nested[0].nested[0].relationship,
302 "country"
303 );
304 }
305
306 #[test]
307 fn test_find_relationship() {
308 let rel = find_relationship::<TestHero>("team");
309 assert!(rel.is_some());
310 assert_eq!(rel.unwrap().name, "team");
311 assert_eq!(rel.unwrap().related_table, "teams");
312 }
313
314 #[test]
315 fn test_find_relationship_not_found() {
316 let rel = find_relationship::<TestHero>("nonexistent");
317 assert!(rel.is_none());
318 }
319
320 #[test]
321 fn test_build_join_many_to_one() {
322 let rel = RelationshipInfo::new("team", "teams", RelationshipKind::ManyToOne)
323 .local_key("team_id");
324
325 let (sql, params) = build_join_clause("heroes", &rel, 0);
326
327 assert_eq!(sql, " LEFT JOIN teams ON heroes.team_id = teams.id");
328 assert!(params.is_empty());
329 }
330
331 #[test]
332 fn test_build_join_one_to_many() {
333 let rel = RelationshipInfo::new("heroes", "heroes", RelationshipKind::OneToMany)
334 .remote_key("team_id");
335
336 let (sql, params) = build_join_clause("teams", &rel, 0);
337
338 assert_eq!(sql, " LEFT JOIN heroes ON heroes.team_id = teams.id");
339 assert!(params.is_empty());
340 }
341
342 #[test]
343 fn test_build_join_many_to_many() {
344 let rel =
345 RelationshipInfo::new("powers", "powers", RelationshipKind::ManyToMany).link_table(
346 sqlmodel_core::LinkTableInfo::new("hero_powers", "hero_id", "power_id"),
347 );
348
349 let (sql, params) = build_join_clause("heroes", &rel, 0);
350
351 assert!(sql.contains("LEFT JOIN hero_powers"));
352 assert!(sql.contains("LEFT JOIN powers"));
353 assert!(params.is_empty());
354 }
355
356 #[test]
357 fn test_build_aliased_columns() {
358 let result = build_aliased_columns("heroes", &["id", "name", "team_id"]);
359 assert!(result.contains("heroes.id AS heroes__id"));
360 assert!(result.contains("heroes.name AS heroes__name"));
361 assert!(result.contains("heroes.team_id AS heroes__team_id"));
362 }
363
364 #[test]
365 fn test_eager_loader_default() {
366 let loader: EagerLoader<TestHero> = EagerLoader::default();
367 assert!(!loader.has_includes());
368 }
369
370 #[test]
371 fn test_include_path_new() {
372 let path = IncludePath::new("team");
373 assert_eq!(path.relationship, "team");
374 assert!(path.nested.is_empty());
375 }
376
377 #[test]
378 fn test_include_path_nest() {
379 let path = IncludePath::new("team").nest(IncludePath::new("headquarters"));
380 assert_eq!(path.nested.len(), 1);
381 assert_eq!(path.nested[0].relationship, "headquarters");
382 }
383}