1use std::marker::PhantomData;
2
3use super::{Db, DbArguments, DbPool, DbRow, placeholder};
4use super::pagination::PaginationParams;
5use super::repository::WhereValue;
6use crate::error::AppResult;
7
8pub struct PaginatedQuery<T> {
39 base_sql: String,
40 count_from: Option<String>,
41 allowed_sorts: Vec<String>,
42 default_sort: String,
43 default_order: String,
44 sort_prefix: Option<String>,
45 search_columns: Vec<String>,
46 where_conditions: Vec<(String, WhereValue)>,
47 _phantom: PhantomData<T>,
48}
49
50impl<T> PaginatedQuery<T>
51where
52 T: Send + Unpin + for<'r> sqlx::FromRow<'r, DbRow>,
53{
54 pub fn new(base_sql: &str) -> Self {
56 Self {
57 base_sql: base_sql.to_string(),
58 count_from: None,
59 allowed_sorts: Vec::new(),
60 default_sort: "id".to_string(),
61 default_order: "DESC".to_string(),
62 sort_prefix: None,
63 search_columns: Vec::new(),
64 where_conditions: Vec::new(),
65 _phantom: PhantomData,
66 }
67 }
68
69 pub fn allowed_sorts(mut self, sorts: &[&str]) -> Self {
71 self.allowed_sorts = sorts.iter().map(|s| s.to_string()).collect();
72 self
73 }
74
75 pub fn default_sort(mut self, sort: &str) -> Self {
77 self.default_sort = sort.to_string();
78 self
79 }
80
81 pub fn default_order(mut self, order: &str) -> Self {
83 self.default_order = order.to_uppercase();
84 self
85 }
86
87 pub fn sort_prefix(mut self, prefix: &str) -> Self {
89 self.sort_prefix = Some(prefix.to_string());
90 self
91 }
92
93 pub fn search_columns(mut self, columns: &[&str]) -> Self {
95 self.search_columns = columns.iter().map(|s| s.to_string()).collect();
96 self
97 }
98
99 pub fn where_eq(mut self, column: &str, value: WhereValue) -> Self {
101 self.where_conditions.push((column.to_string(), value));
102 self
103 }
104
105 pub fn count_from(mut self, sql: &str) -> Self {
108 self.count_from = Some(sql.to_string());
109 self
110 }
111
112 pub async fn execute(
114 self,
115 pool: &DbPool,
116 params: &PaginationParams,
117 ) -> AppResult<(Vec<T>, u64)> {
118 let sort_col = if self.allowed_sorts.is_empty() {
120 self.default_sort.clone()
121 } else {
122 let allowed_refs: Vec<&str> = self.allowed_sorts.iter().map(|s| s.as_str()).collect();
123 params.sort_column(&allowed_refs, &self.default_sort).to_string()
124 };
125
126 let order = if params.order.to_lowercase() == "asc" {
127 "ASC"
128 } else if params.order.to_lowercase() == "desc" {
129 "DESC"
130 } else {
131 &self.default_order
132 };
133
134 let sort_expr = match &self.sort_prefix {
135 Some(prefix) if !sort_col.contains('.') => format!("{}.{}", prefix, sort_col),
136 _ => sort_col,
137 };
138
139 let mut where_clauses: Vec<String> = Vec::new();
141 let mut bind_values: Vec<WhereValue> = Vec::new();
142 let mut ph_idx: usize = 0;
143
144 for (col, val) in &self.where_conditions {
145 ph_idx += 1;
146 where_clauses.push(format!("{} = {}", col, placeholder(ph_idx)));
147 bind_values.push(val.clone());
148 }
149
150 let has_search = !self.search_columns.is_empty() && params.search.is_some();
152 if has_search {
153 let like_clauses: Vec<String> = self.search_columns
154 .iter()
155 .map(|col| {
156 ph_idx += 1;
157 format!("{} LIKE {}", col, placeholder(ph_idx))
158 })
159 .collect();
160 where_clauses.push(format!("({})", like_clauses.join(" OR ")));
161 }
162
163 let where_sql = if where_clauses.is_empty() {
164 String::new()
165 } else {
166 format!(" WHERE {}", where_clauses.join(" AND "))
167 };
168
169 ph_idx += 1;
171 let limit_ph = placeholder(ph_idx);
172 ph_idx += 1;
173 let offset_ph = placeholder(ph_idx);
174
175 let data_sql = format!(
176 "{}{} ORDER BY {} {} LIMIT {} OFFSET {}",
177 self.base_sql, where_sql, sort_expr, order, limit_ph, offset_ph
178 );
179
180 let count_sql = if let Some(ref custom) = self.count_from {
182 format!("{}{}", custom, where_sql)
183 } else {
184 let from_part = extract_from_clause(&self.base_sql);
185 format!("SELECT COUNT(*) {}{}", from_part, where_sql)
186 };
187
188 let like_value = params.search.as_ref().map(|s| format!("%{}%", s));
190 let search_col_count = self.search_columns.len();
191
192 let mut data_query = sqlx::query_as::<Db, T>(&data_sql);
194 for val in &bind_values {
195 data_query = bind_where_value(data_query, val);
196 }
197 if has_search {
198 if let Some(ref like) = like_value {
199 for _ in 0..search_col_count {
200 data_query = data_query.bind(like.as_str());
201 }
202 }
203 }
204 data_query = data_query.bind(params.safe_per_page() as i64).bind(params.offset() as i64);
205 let items = data_query.fetch_all(pool).await?;
206
207 let mut count_query = sqlx::query_as::<Db, (i64,)>(&count_sql);
209 for val in &bind_values {
210 count_query = bind_where_value_tuple(count_query, val);
211 }
212 if has_search {
213 if let Some(ref like) = like_value {
214 for _ in 0..search_col_count {
215 count_query = count_query.bind(like.as_str());
216 }
217 }
218 }
219 let (total,) = count_query.fetch_one(pool).await?;
220
221 Ok((items, total as u64))
222 }
223}
224
225fn extract_from_clause(sql: &str) -> &str {
227 let upper = sql.to_uppercase();
228 if let Some(pos) = upper.find(" FROM ") {
229 &sql[pos..]
230 } else {
231 sql
232 }
233}
234
235fn bind_where_value<'q, T>(
237 query: sqlx::query::QueryAs<'q, Db, T, DbArguments>,
238 value: &'q WhereValue,
239) -> sqlx::query::QueryAs<'q, Db, T, DbArguments>
240where
241 T: Send + Unpin + for<'r> sqlx::FromRow<'r, DbRow>,
242{
243 match value {
244 WhereValue::Int(v) => query.bind(*v),
245 WhereValue::Float(v) => query.bind(*v),
246 WhereValue::String(v) => query.bind(v.as_str()),
247 WhereValue::Bool(v) => query.bind(*v),
248 WhereValue::DateTime(v) => query.bind(*v),
249 }
250}
251
252fn bind_where_value_tuple<'q>(
254 query: sqlx::query::QueryAs<'q, Db, (i64,), DbArguments>,
255 value: &'q WhereValue,
256) -> sqlx::query::QueryAs<'q, Db, (i64,), DbArguments> {
257 match value {
258 WhereValue::Int(v) => query.bind(*v),
259 WhereValue::Float(v) => query.bind(*v),
260 WhereValue::String(v) => query.bind(v.as_str()),
261 WhereValue::Bool(v) => query.bind(*v),
262 WhereValue::DateTime(v) => query.bind(*v),
263 }
264}