Skip to main content

sqlmodel_query/
eager.rs

1//! Eager loading infrastructure for relationships.
2//!
3//! This module provides the `EagerLoader` builder for configuring which
4//! relationships to load with a query. Eager loading fetches related
5//! objects in the same query using SQL JOINs.
6
7use sqlmodel_core::{Model, RelationshipInfo, RelationshipKind, Value};
8use std::marker::PhantomData;
9
10/// Builder for eager loading configuration.
11///
12/// # Example
13///
14/// ```ignore
15/// let heroes = select!(Hero)
16///     .eager(EagerLoader::new().include("team"))
17///     .all_eager(&conn)
18///     .await?;
19/// ```
20#[derive(Debug, Clone)]
21pub struct EagerLoader<T: Model> {
22    /// Relationships to eager-load.
23    includes: Vec<IncludePath>,
24    /// Model type marker.
25    _marker: PhantomData<T>,
26}
27
28/// A path to a relationship to include.
29#[derive(Debug, Clone)]
30pub struct IncludePath {
31    /// Relationship name on parent.
32    pub relationship: &'static str,
33    /// Nested relationships to load.
34    pub nested: Vec<IncludePath>,
35}
36
37impl IncludePath {
38    /// Create a new include path for a single relationship.
39    #[must_use]
40    pub fn new(relationship: &'static str) -> Self {
41        Self {
42            relationship,
43            nested: Vec::new(),
44        }
45    }
46
47    /// Add a nested relationship to load.
48    #[must_use]
49    pub fn nest(mut self, path: IncludePath) -> Self {
50        self.nested.push(path);
51        self
52    }
53}
54
55impl<T: Model> EagerLoader<T> {
56    /// Create a new empty eager loader.
57    #[must_use]
58    pub fn new() -> Self {
59        Self {
60            includes: Vec::new(),
61            _marker: PhantomData,
62        }
63    }
64
65    /// Include a relationship in eager loading.
66    ///
67    /// # Example
68    ///
69    /// ```ignore
70    /// EagerLoader::<Hero>::new().include("team")
71    /// ```
72    #[must_use]
73    pub fn include(mut self, relationship: &'static str) -> Self {
74        self.includes.push(IncludePath::new(relationship));
75        self
76    }
77
78    /// Include a nested relationship (e.g., "team.headquarters").
79    ///
80    /// # Example
81    ///
82    /// ```ignore
83    /// EagerLoader::<Hero>::new().include_nested("team.headquarters")
84    /// ```
85    #[must_use]
86    pub fn include_nested(mut self, path: &'static str) -> Self {
87        // Handle empty or whitespace-only paths
88        let path = path.trim();
89        if path.is_empty() {
90            return self;
91        }
92
93        let parts: Vec<&'static str> = path.split('.').collect();
94        // split('.') on non-empty string always returns at least one element
95        // but we should still guard against [""] from paths like "."
96        if parts.iter().all(|p| p.is_empty()) {
97            return self;
98        }
99
100        // Filter out empty parts (handles cases like "team..headquarters")
101        let parts: Vec<&'static str> = parts.into_iter().filter(|p| !p.is_empty()).collect();
102        if parts.is_empty() {
103            return self;
104        }
105
106        // Build nested IncludePath structure
107        let include = Self::build_nested_path(&parts);
108        self.includes.push(include);
109        self
110    }
111
112    /// Build a nested IncludePath from path parts.
113    fn build_nested_path(parts: &[&'static str]) -> IncludePath {
114        if parts.len() == 1 {
115            IncludePath::new(parts[0])
116        } else {
117            let mut path = IncludePath::new(parts[0]);
118            path.nested.push(Self::build_nested_path(&parts[1..]));
119            path
120        }
121    }
122
123    /// Get the include paths.
124    #[must_use]
125    pub fn includes(&self) -> &[IncludePath] {
126        &self.includes
127    }
128
129    /// Check if any relationships are included.
130    #[must_use]
131    pub fn has_includes(&self) -> bool {
132        !self.includes.is_empty()
133    }
134}
135
136impl<T: Model> Default for EagerLoader<T> {
137    fn default() -> Self {
138        Self::new()
139    }
140}
141
142/// Find a relationship by name in a model's RELATIONSHIPS.
143#[must_use]
144pub fn find_relationship<M: Model>(name: &str) -> Option<&'static RelationshipInfo> {
145    M::RELATIONSHIPS.iter().find(|r| r.name == name)
146}
147
148/// Generate a JOIN clause for a relationship.
149#[must_use]
150pub fn build_join_clause(
151    parent_table: &str,
152    rel: &RelationshipInfo,
153    _param_offset: usize,
154) -> (String, Vec<Value>) {
155    let params = Vec::new();
156
157    // Get the primary key column name from the relationship, defaulting to "id"
158    let remote_pk = rel.remote_key.unwrap_or("id");
159
160    let sql = match rel.kind {
161        RelationshipKind::ManyToOne | RelationshipKind::OneToOne => {
162            // LEFT JOIN related_table ON parent.fk = related.pk
163            let local_key = rel.local_key.unwrap_or("id");
164            format!(
165                " LEFT JOIN {} ON {}.{} = {}.{}",
166                rel.related_table, parent_table, local_key, rel.related_table, remote_pk
167            )
168        }
169        RelationshipKind::OneToMany => {
170            // LEFT JOIN related_table ON related.fk = parent.pk
171            // For OneToMany, remote_key is the FK on the related table pointing to us
172            let fk_on_related = rel.remote_key.unwrap_or("id");
173            // And we need local_key as our PK (default "id")
174            let local_pk = rel.local_key.unwrap_or("id");
175            format!(
176                " LEFT JOIN {} ON {}.{} = {}.{}",
177                rel.related_table, rel.related_table, fk_on_related, parent_table, local_pk
178            )
179        }
180        RelationshipKind::ManyToMany => {
181            // LEFT JOIN link_table ON parent.pk = link.local_col
182            // LEFT JOIN related_table ON link.remote_col = related.pk
183            if let Some(link) = &rel.link_table {
184                let local_pk = rel.local_key.unwrap_or("id");
185                format!(
186                    " LEFT JOIN {} ON {}.{} = {}.{} LEFT JOIN {} ON {}.{} = {}.{}",
187                    link.table_name,
188                    parent_table,
189                    local_pk,
190                    link.table_name,
191                    link.local_column,
192                    rel.related_table,
193                    link.table_name,
194                    link.remote_column,
195                    rel.related_table,
196                    remote_pk
197                )
198            } else {
199                String::new()
200            }
201        }
202    };
203
204    (sql, params)
205}
206
207/// Generate aliased column names for eager loading.
208///
209/// Prefixes each column with the table name to avoid conflicts.
210#[must_use]
211pub fn build_aliased_columns(table_name: &str, columns: &[&str]) -> String {
212    columns
213        .iter()
214        .map(|col| format!("{}.{} AS {}__{}", table_name, col, table_name, col))
215        .collect::<Vec<_>>()
216        .join(", ")
217}
218
219#[cfg(test)]
220mod tests {
221    use super::*;
222    use sqlmodel_core::{Error, FieldInfo, Model, Result, Row, Value};
223
224    #[derive(Debug, Clone)]
225    struct TestHero;
226
227    impl Model for TestHero {
228        const TABLE_NAME: &'static str = "heroes";
229        const PRIMARY_KEY: &'static [&'static str] = &["id"];
230        const RELATIONSHIPS: &'static [RelationshipInfo] =
231            &[
232                RelationshipInfo::new("team", "teams", RelationshipKind::ManyToOne)
233                    .local_key("team_id"),
234            ];
235
236        fn fields() -> &'static [FieldInfo] {
237            &[]
238        }
239
240        fn to_row(&self) -> Vec<(&'static str, Value)> {
241            vec![]
242        }
243
244        fn from_row(_row: &Row) -> Result<Self> {
245            Err(Error::Custom("not used".to_string()))
246        }
247
248        fn primary_key_value(&self) -> Vec<Value> {
249            vec![]
250        }
251
252        fn is_new(&self) -> bool {
253            true
254        }
255    }
256
257    #[test]
258    fn test_eager_loader_new() {
259        let loader = EagerLoader::<TestHero>::new();
260        assert!(!loader.has_includes());
261        assert!(loader.includes().is_empty());
262    }
263
264    #[test]
265    fn test_eager_loader_include() {
266        let loader = EagerLoader::<TestHero>::new().include("team");
267        assert!(loader.has_includes());
268        assert_eq!(loader.includes().len(), 1);
269        assert_eq!(loader.includes()[0].relationship, "team");
270    }
271
272    #[test]
273    fn test_eager_loader_multiple_includes() {
274        let loader = EagerLoader::<TestHero>::new()
275            .include("team")
276            .include("powers");
277        assert_eq!(loader.includes().len(), 2);
278    }
279
280    #[test]
281    fn test_eager_loader_include_nested() {
282        let loader = EagerLoader::<TestHero>::new().include_nested("team.headquarters");
283        assert_eq!(loader.includes().len(), 1);
284        assert_eq!(loader.includes()[0].relationship, "team");
285        assert_eq!(loader.includes()[0].nested.len(), 1);
286        assert_eq!(loader.includes()[0].nested[0].relationship, "headquarters");
287    }
288
289    #[test]
290    fn test_eager_loader_include_deeply_nested() {
291        let loader =
292            EagerLoader::<TestHero>::new().include_nested("team.headquarters.city.country");
293        assert_eq!(loader.includes().len(), 1);
294        assert_eq!(loader.includes()[0].relationship, "team");
295        assert_eq!(loader.includes()[0].nested[0].relationship, "headquarters");
296        assert_eq!(
297            loader.includes()[0].nested[0].nested[0].relationship,
298            "city"
299        );
300        assert_eq!(
301            loader.includes()[0].nested[0].nested[0].nested[0].relationship,
302            "country"
303        );
304    }
305
306    #[test]
307    fn test_find_relationship() {
308        let rel = find_relationship::<TestHero>("team");
309        assert!(rel.is_some());
310        assert_eq!(rel.unwrap().name, "team");
311        assert_eq!(rel.unwrap().related_table, "teams");
312    }
313
314    #[test]
315    fn test_find_relationship_not_found() {
316        let rel = find_relationship::<TestHero>("nonexistent");
317        assert!(rel.is_none());
318    }
319
320    #[test]
321    fn test_build_join_many_to_one() {
322        let rel = RelationshipInfo::new("team", "teams", RelationshipKind::ManyToOne)
323            .local_key("team_id");
324
325        let (sql, params) = build_join_clause("heroes", &rel, 0);
326
327        assert_eq!(sql, " LEFT JOIN teams ON heroes.team_id = teams.id");
328        assert!(params.is_empty());
329    }
330
331    #[test]
332    fn test_build_join_one_to_many() {
333        let rel = RelationshipInfo::new("heroes", "heroes", RelationshipKind::OneToMany)
334            .remote_key("team_id");
335
336        let (sql, params) = build_join_clause("teams", &rel, 0);
337
338        assert_eq!(sql, " LEFT JOIN heroes ON heroes.team_id = teams.id");
339        assert!(params.is_empty());
340    }
341
342    #[test]
343    fn test_build_join_many_to_many() {
344        let rel =
345            RelationshipInfo::new("powers", "powers", RelationshipKind::ManyToMany).link_table(
346                sqlmodel_core::LinkTableInfo::new("hero_powers", "hero_id", "power_id"),
347            );
348
349        let (sql, params) = build_join_clause("heroes", &rel, 0);
350
351        assert!(sql.contains("LEFT JOIN hero_powers"));
352        assert!(sql.contains("LEFT JOIN powers"));
353        assert!(params.is_empty());
354    }
355
356    #[test]
357    fn test_build_aliased_columns() {
358        let result = build_aliased_columns("heroes", &["id", "name", "team_id"]);
359        assert!(result.contains("heroes.id AS heroes__id"));
360        assert!(result.contains("heroes.name AS heroes__name"));
361        assert!(result.contains("heroes.team_id AS heroes__team_id"));
362    }
363
364    #[test]
365    fn test_eager_loader_default() {
366        let loader: EagerLoader<TestHero> = EagerLoader::default();
367        assert!(!loader.has_includes());
368    }
369
370    #[test]
371    fn test_include_path_new() {
372        let path = IncludePath::new("team");
373        assert_eq!(path.relationship, "team");
374        assert!(path.nested.is_empty());
375    }
376
377    #[test]
378    fn test_include_path_nest() {
379        let path = IncludePath::new("team").nest(IncludePath::new("headquarters"));
380        assert_eq!(path.nested.len(), 1);
381        assert_eq!(path.nested[0].relationship, "headquarters");
382    }
383}