1use std::marker::PhantomData;
2
3use super::{Db, DbArguments, DbPool, DbRow, placeholder};
4use super::pagination::PaginationParams;
5use super::repository::WhereValue;
6use crate::error::AppResult;
7
8pub struct PaginatedQuery<T> {
39 base_sql: String,
40 count_from: Option<String>,
41 allowed_sorts: Vec<String>,
42 default_sort: String,
43 default_order: String,
44 sort_prefix: Option<String>,
45 search_columns: Vec<String>,
46 where_conditions: Vec<(String, &'static str, WhereValue)>,
47 _phantom: PhantomData<T>,
48}
49
50impl<T> PaginatedQuery<T>
51where
52 T: Send + Unpin + for<'r> sqlx::FromRow<'r, DbRow>,
53{
54 pub fn new(base_sql: &str) -> Self {
56 Self {
57 base_sql: base_sql.to_string(),
58 count_from: None,
59 allowed_sorts: Vec::new(),
60 default_sort: "id".to_string(),
61 default_order: "DESC".to_string(),
62 sort_prefix: None,
63 search_columns: Vec::new(),
64 where_conditions: Vec::new(),
65 _phantom: PhantomData,
66 }
67 }
68
69 pub fn allowed_sorts(mut self, sorts: &[&str]) -> Self {
71 self.allowed_sorts = sorts.iter().map(|s| s.to_string()).collect();
72 self
73 }
74
75 pub fn default_sort(mut self, sort: &str) -> Self {
77 self.default_sort = sort.to_string();
78 self
79 }
80
81 pub fn default_order(mut self, order: &str) -> Self {
83 self.default_order = order.to_uppercase();
84 self
85 }
86
87 pub fn sort_prefix(mut self, prefix: &str) -> Self {
89 self.sort_prefix = Some(prefix.to_string());
90 self
91 }
92
93 pub fn search_columns(mut self, columns: &[&str]) -> Self {
95 self.search_columns = columns.iter().map(|s| s.to_string()).collect();
96 self
97 }
98
99 pub fn where_eq(mut self, column: &str, value: WhereValue) -> Self {
101 self.where_conditions.push((column.to_string(), "=", value));
102 self
103 }
104
105 pub fn where_gte(mut self, column: &str, value: WhereValue) -> Self {
107 self.where_conditions.push((column.to_string(), ">=", value));
108 self
109 }
110
111 pub fn where_lte(mut self, column: &str, value: WhereValue) -> Self {
113 self.where_conditions.push((column.to_string(), "<=", value));
114 self
115 }
116
117 pub fn count_from(mut self, sql: &str) -> Self {
120 self.count_from = Some(sql.to_string());
121 self
122 }
123
124 pub async fn execute(
126 self,
127 pool: &DbPool,
128 params: &PaginationParams,
129 ) -> AppResult<(Vec<T>, u64)> {
130 let sort_col = if self.allowed_sorts.is_empty() {
132 self.default_sort.clone()
133 } else {
134 let allowed_refs: Vec<&str> = self.allowed_sorts.iter().map(|s| s.as_str()).collect();
135 params.sort_column(&allowed_refs, &self.default_sort).to_string()
136 };
137
138 let order = if params.order.to_lowercase() == "asc" {
139 "ASC"
140 } else if params.order.to_lowercase() == "desc" {
141 "DESC"
142 } else {
143 &self.default_order
144 };
145
146 let sort_expr = match &self.sort_prefix {
147 Some(prefix) if !sort_col.contains('.') => format!("{}.{}", prefix, sort_col),
148 _ => sort_col,
149 };
150
151 let mut where_clauses: Vec<String> = Vec::new();
153 let mut bind_values: Vec<WhereValue> = Vec::new();
154 let mut ph_idx: usize = 0;
155
156 for (col, op, val) in &self.where_conditions {
157 ph_idx += 1;
158 where_clauses.push(format!("{} {} {}", col, op, placeholder(ph_idx)));
159 bind_values.push(val.clone());
160 }
161
162 let has_search = !self.search_columns.is_empty() && params.search.is_some();
164 if has_search {
165 let like_clauses: Vec<String> = self.search_columns
166 .iter()
167 .map(|col| {
168 ph_idx += 1;
169 format!("{} LIKE {}", col, placeholder(ph_idx))
170 })
171 .collect();
172 where_clauses.push(format!("({})", like_clauses.join(" OR ")));
173 }
174
175 let where_sql = if where_clauses.is_empty() {
176 String::new()
177 } else {
178 format!(" WHERE {}", where_clauses.join(" AND "))
179 };
180
181 ph_idx += 1;
183 let limit_ph = placeholder(ph_idx);
184 ph_idx += 1;
185 let offset_ph = placeholder(ph_idx);
186
187 let data_sql = format!(
188 "{}{} ORDER BY {} {} LIMIT {} OFFSET {}",
189 self.base_sql, where_sql, sort_expr, order, limit_ph, offset_ph
190 );
191
192 let count_sql = if let Some(ref custom) = self.count_from {
194 format!("{}{}", custom, where_sql)
195 } else {
196 let from_part = extract_from_clause(&self.base_sql);
197 format!("SELECT COUNT(*) {}{}", from_part, where_sql)
198 };
199
200 let like_value = params.search.as_ref().map(|s| format!("%{}%", s));
202 let search_col_count = self.search_columns.len();
203
204 let mut data_query = sqlx::query_as::<Db, T>(&data_sql);
206 for val in &bind_values {
207 data_query = bind_where_value(data_query, val);
208 }
209 if has_search {
210 if let Some(ref like) = like_value {
211 for _ in 0..search_col_count {
212 data_query = data_query.bind(like.as_str());
213 }
214 }
215 }
216 data_query = data_query.bind(params.safe_per_page() as i64).bind(params.offset() as i64);
217 let items = data_query.fetch_all(pool).await?;
218
219 let mut count_query = sqlx::query_as::<Db, (i64,)>(&count_sql);
221 for val in &bind_values {
222 count_query = bind_where_value_tuple(count_query, val);
223 }
224 if has_search {
225 if let Some(ref like) = like_value {
226 for _ in 0..search_col_count {
227 count_query = count_query.bind(like.as_str());
228 }
229 }
230 }
231 let (total,) = count_query.fetch_one(pool).await?;
232
233 Ok((items, total as u64))
234 }
235}
236
237fn extract_from_clause(sql: &str) -> &str {
239 let upper = sql.to_uppercase();
240 if let Some(pos) = upper.find(" FROM ") {
241 &sql[pos..]
242 } else {
243 sql
244 }
245}
246
247fn bind_where_value<'q, T>(
249 query: sqlx::query::QueryAs<'q, Db, T, DbArguments>,
250 value: &'q WhereValue,
251) -> sqlx::query::QueryAs<'q, Db, T, DbArguments>
252where
253 T: Send + Unpin + for<'r> sqlx::FromRow<'r, DbRow>,
254{
255 match value {
256 WhereValue::Int(v) => query.bind(*v),
257 WhereValue::Float(v) => query.bind(*v),
258 WhereValue::String(v) => query.bind(v.as_str()),
259 WhereValue::Bool(v) => query.bind(*v),
260 WhereValue::DateTime(v) => query.bind(*v),
261 }
262}
263
264fn bind_where_value_tuple<'q>(
266 query: sqlx::query::QueryAs<'q, Db, (i64,), DbArguments>,
267 value: &'q WhereValue,
268) -> sqlx::query::QueryAs<'q, Db, (i64,), DbArguments> {
269 match value {
270 WhereValue::Int(v) => query.bind(*v),
271 WhereValue::Float(v) => query.bind(*v),
272 WhereValue::String(v) => query.bind(v.as_str()),
273 WhereValue::Bool(v) => query.bind(*v),
274 WhereValue::DateTime(v) => query.bind(*v),
275 }
276}