prax_query/relations/
include.rs

1//! Include specifications for eager loading relations.
2
3use std::collections::HashMap;
4
5use crate::filter::Filter;
6use crate::pagination::Pagination;
7use crate::types::OrderBy;
8
9/// Specification for including a relation in a query.
10#[derive(Debug, Clone)]
11pub struct IncludeSpec {
12    /// Name of the relation to include.
13    pub relation_name: String,
14    /// Filter to apply to the related records.
15    pub filter: Option<Filter>,
16    /// Ordering for the related records.
17    pub order_by: Option<OrderBy>,
18    /// Pagination for the related records.
19    pub pagination: Option<Pagination>,
20    /// Nested includes.
21    pub nested: HashMap<String, IncludeSpec>,
22    /// Whether to include the count of related records.
23    pub include_count: bool,
24}
25
26impl IncludeSpec {
27    /// Create a new include spec for a relation.
28    pub fn new(relation_name: impl Into<String>) -> Self {
29        Self {
30            relation_name: relation_name.into(),
31            filter: None,
32            order_by: None,
33            pagination: None,
34            nested: HashMap::new(),
35            include_count: false,
36        }
37    }
38
39    /// Add a filter to the included relation.
40    pub fn r#where(mut self, filter: impl Into<Filter>) -> Self {
41        self.filter = Some(filter.into());
42        self
43    }
44
45    /// Set ordering for the included relation.
46    pub fn order_by(mut self, order: impl Into<OrderBy>) -> Self {
47        self.order_by = Some(order.into());
48        self
49    }
50
51    /// Skip records in the included relation.
52    pub fn skip(mut self, n: u64) -> Self {
53        self.pagination = Some(self.pagination.unwrap_or_default().skip(n));
54        self
55    }
56
57    /// Take a limited number of records from the included relation.
58    pub fn take(mut self, n: u64) -> Self {
59        self.pagination = Some(self.pagination.unwrap_or_default().take(n));
60        self
61    }
62
63    /// Include a nested relation.
64    pub fn include(mut self, nested: IncludeSpec) -> Self {
65        self.nested.insert(nested.relation_name.clone(), nested);
66        self
67    }
68
69    /// Include the count of related records.
70    pub fn with_count(mut self) -> Self {
71        self.include_count = true;
72        self
73    }
74
75    /// Check if there are nested includes.
76    pub fn has_nested(&self) -> bool {
77        !self.nested.is_empty()
78    }
79
80    /// Get all nested include specs.
81    pub fn nested_specs(&self) -> impl Iterator<Item = &IncludeSpec> {
82        self.nested.values()
83    }
84}
85
86/// Builder for include specifications.
87///
88/// This is typically used by the generated code to provide a fluent API.
89#[derive(Debug, Clone, Default)]
90pub struct Include {
91    specs: HashMap<String, IncludeSpec>,
92}
93
94impl Include {
95    /// Create a new empty include builder.
96    pub fn new() -> Self {
97        Self::default()
98    }
99
100    /// Add a relation to include.
101    pub fn with(mut self, spec: IncludeSpec) -> Self {
102        self.specs.insert(spec.relation_name.clone(), spec);
103        self
104    }
105
106    /// Add multiple relations to include.
107    pub fn with_many(mut self, specs: impl IntoIterator<Item = IncludeSpec>) -> Self {
108        for spec in specs {
109            self.specs.insert(spec.relation_name.clone(), spec);
110        }
111        self
112    }
113
114    /// Get an include spec by relation name.
115    pub fn get(&self, relation: &str) -> Option<&IncludeSpec> {
116        self.specs.get(relation)
117    }
118
119    /// Check if a relation is included.
120    pub fn contains(&self, relation: &str) -> bool {
121        self.specs.contains_key(relation)
122    }
123
124    /// Get all include specs.
125    pub fn specs(&self) -> impl Iterator<Item = &IncludeSpec> {
126        self.specs.values()
127    }
128
129    /// Check if there are any includes.
130    pub fn is_empty(&self) -> bool {
131        self.specs.is_empty()
132    }
133
134    /// Get the number of includes.
135    pub fn len(&self) -> usize {
136        self.specs.len()
137    }
138
139    /// Merge another include into this one.
140    pub fn merge(mut self, other: Include) -> Self {
141        self.specs.extend(other.specs);
142        self
143    }
144}
145
146impl From<IncludeSpec> for Include {
147    fn from(spec: IncludeSpec) -> Self {
148        Self::new().with(spec)
149    }
150}
151
152impl FromIterator<IncludeSpec> for Include {
153    fn from_iter<T: IntoIterator<Item = IncludeSpec>>(iter: T) -> Self {
154        Self::new().with_many(iter)
155    }
156}
157
158/// Helper function to create an include spec.
159pub fn include(relation: impl Into<String>) -> IncludeSpec {
160    IncludeSpec::new(relation)
161}
162
163#[cfg(test)]
164mod tests {
165    use super::*;
166    use crate::types::OrderByField;
167
168    #[test]
169    fn test_include_spec_basic() {
170        let spec = IncludeSpec::new("posts");
171        assert_eq!(spec.relation_name, "posts");
172        assert!(spec.filter.is_none());
173        assert!(spec.order_by.is_none());
174    }
175
176    #[test]
177    fn test_include_spec_with_options() {
178        let spec = IncludeSpec::new("posts")
179            .order_by(OrderByField::desc("created_at"))
180            .take(5)
181            .with_count();
182
183        assert!(spec.order_by.is_some());
184        assert!(spec.pagination.is_some());
185        assert!(spec.include_count);
186    }
187
188    #[test]
189    fn test_include_spec_nested() {
190        let spec = IncludeSpec::new("posts").include(IncludeSpec::new("comments").take(10));
191
192        assert!(spec.has_nested());
193        assert!(spec.nested.contains_key("comments"));
194    }
195
196    #[test]
197    fn test_include_builder() {
198        let includes = Include::new()
199            .with(IncludeSpec::new("posts"))
200            .with(IncludeSpec::new("profile"));
201
202        assert_eq!(includes.len(), 2);
203        assert!(includes.contains("posts"));
204        assert!(includes.contains("profile"));
205    }
206
207    #[test]
208    fn test_include_from_iter() {
209        let includes: Include = vec![IncludeSpec::new("posts"), IncludeSpec::new("comments")]
210            .into_iter()
211            .collect();
212
213        assert_eq!(includes.len(), 2);
214    }
215}