prax_query/relations/
include.rs

1//! Include specifications for eager loading relations.
2
3use std::collections::HashMap;
4
5use crate::filter::Filter;
6use crate::pagination::Pagination;
7use crate::types::OrderBy;
8
9/// Specification for including a relation in a query.
10#[derive(Debug, Clone)]
11pub struct IncludeSpec {
12    /// Name of the relation to include.
13    pub relation_name: String,
14    /// Filter to apply to the related records.
15    pub filter: Option<Filter>,
16    /// Ordering for the related records.
17    pub order_by: Option<OrderBy>,
18    /// Pagination for the related records.
19    pub pagination: Option<Pagination>,
20    /// Nested includes.
21    pub nested: HashMap<String, IncludeSpec>,
22    /// Whether to include the count of related records.
23    pub include_count: bool,
24}
25
26impl IncludeSpec {
27    /// Create a new include spec for a relation.
28    pub fn new(relation_name: impl Into<String>) -> Self {
29        Self {
30            relation_name: relation_name.into(),
31            filter: None,
32            order_by: None,
33            pagination: None,
34            nested: HashMap::new(),
35            include_count: false,
36        }
37    }
38
39    /// Add a filter to the included relation.
40    pub fn r#where(mut self, filter: impl Into<Filter>) -> Self {
41        self.filter = Some(filter.into());
42        self
43    }
44
45    /// Set ordering for the included relation.
46    pub fn order_by(mut self, order: impl Into<OrderBy>) -> Self {
47        self.order_by = Some(order.into());
48        self
49    }
50
51    /// Skip records in the included relation.
52    pub fn skip(mut self, n: u64) -> Self {
53        self.pagination = Some(
54            self.pagination
55                .unwrap_or_default()
56                .skip(n),
57        );
58        self
59    }
60
61    /// Take a limited number of records from the included relation.
62    pub fn take(mut self, n: u64) -> Self {
63        self.pagination = Some(
64            self.pagination
65                .unwrap_or_default()
66                .take(n),
67        );
68        self
69    }
70
71    /// Include a nested relation.
72    pub fn include(mut self, nested: IncludeSpec) -> Self {
73        self.nested.insert(nested.relation_name.clone(), nested);
74        self
75    }
76
77    /// Include the count of related records.
78    pub fn with_count(mut self) -> Self {
79        self.include_count = true;
80        self
81    }
82
83    /// Check if there are nested includes.
84    pub fn has_nested(&self) -> bool {
85        !self.nested.is_empty()
86    }
87
88    /// Get all nested include specs.
89    pub fn nested_specs(&self) -> impl Iterator<Item = &IncludeSpec> {
90        self.nested.values()
91    }
92}
93
94/// Builder for include specifications.
95///
96/// This is typically used by the generated code to provide a fluent API.
97#[derive(Debug, Clone, Default)]
98pub struct Include {
99    specs: HashMap<String, IncludeSpec>,
100}
101
102impl Include {
103    /// Create a new empty include builder.
104    pub fn new() -> Self {
105        Self::default()
106    }
107
108    /// Add a relation to include.
109    pub fn add(mut self, spec: IncludeSpec) -> Self {
110        self.specs.insert(spec.relation_name.clone(), spec);
111        self
112    }
113
114    /// Add multiple relations to include.
115    pub fn add_many(mut self, specs: impl IntoIterator<Item = IncludeSpec>) -> Self {
116        for spec in specs {
117            self.specs.insert(spec.relation_name.clone(), spec);
118        }
119        self
120    }
121
122    /// Get an include spec by relation name.
123    pub fn get(&self, relation: &str) -> Option<&IncludeSpec> {
124        self.specs.get(relation)
125    }
126
127    /// Check if a relation is included.
128    pub fn contains(&self, relation: &str) -> bool {
129        self.specs.contains_key(relation)
130    }
131
132    /// Get all include specs.
133    pub fn specs(&self) -> impl Iterator<Item = &IncludeSpec> {
134        self.specs.values()
135    }
136
137    /// Check if there are any includes.
138    pub fn is_empty(&self) -> bool {
139        self.specs.is_empty()
140    }
141
142    /// Get the number of includes.
143    pub fn len(&self) -> usize {
144        self.specs.len()
145    }
146
147    /// Merge another include into this one.
148    pub fn merge(mut self, other: Include) -> Self {
149        self.specs.extend(other.specs);
150        self
151    }
152}
153
154impl From<IncludeSpec> for Include {
155    fn from(spec: IncludeSpec) -> Self {
156        Self::new().add(spec)
157    }
158}
159
160impl FromIterator<IncludeSpec> for Include {
161    fn from_iter<T: IntoIterator<Item = IncludeSpec>>(iter: T) -> Self {
162        Self::new().add_many(iter)
163    }
164}
165
166/// Helper function to create an include spec.
167pub fn include(relation: impl Into<String>) -> IncludeSpec {
168    IncludeSpec::new(relation)
169}
170
171#[cfg(test)]
172mod tests {
173    use super::*;
174    use crate::types::OrderByField;
175
176    #[test]
177    fn test_include_spec_basic() {
178        let spec = IncludeSpec::new("posts");
179        assert_eq!(spec.relation_name, "posts");
180        assert!(spec.filter.is_none());
181        assert!(spec.order_by.is_none());
182    }
183
184    #[test]
185    fn test_include_spec_with_options() {
186        let spec = IncludeSpec::new("posts")
187            .order_by(OrderByField::desc("created_at"))
188            .take(5)
189            .with_count();
190
191        assert!(spec.order_by.is_some());
192        assert!(spec.pagination.is_some());
193        assert!(spec.include_count);
194    }
195
196    #[test]
197    fn test_include_spec_nested() {
198        let spec = IncludeSpec::new("posts")
199            .include(IncludeSpec::new("comments").take(10));
200
201        assert!(spec.has_nested());
202        assert!(spec.nested.contains_key("comments"));
203    }
204
205    #[test]
206    fn test_include_builder() {
207        let includes = Include::new()
208            .add(IncludeSpec::new("posts"))
209            .add(IncludeSpec::new("profile"));
210
211        assert_eq!(includes.len(), 2);
212        assert!(includes.contains("posts"));
213        assert!(includes.contains("profile"));
214    }
215
216    #[test]
217    fn test_include_from_iter() {
218        let includes: Include = vec![
219            IncludeSpec::new("posts"),
220            IncludeSpec::new("comments"),
221        ]
222        .into_iter()
223        .collect();
224
225        assert_eq!(includes.len(), 2);
226    }
227}
228