Skip to main content

karbon_framework/db/
paginated_query.rs

1use std::marker::PhantomData;
2
3use super::{Db, DbArguments, DbPool, DbRow, placeholder};
4use super::pagination::PaginationParams;
5use super::repository::WhereValue;
6use crate::error::AppResult;
7
8/// Builder pour les requêtes paginées avec tri, recherche et filtres WHERE.
9///
10/// # Exemples
11///
12/// ```ignore
13/// // Simple
14/// PaginatedQuery::<MediaFile>::new("SELECT * FROM media_file")
15///     .allowed_sorts(&["id", "filename", "created", "size"])
16///     .default_sort("id")
17///     .execute(pool, params)
18///     .await
19///
20/// // Avec WHERE
21/// PaginatedQuery::<CommentListItem>::new(
22///     "SELECT c.id, c.comment, u.username FROM comment c LEFT JOIN users u ON c.user_id = u.id"
23/// )
24///     .where_eq("c.content_id", content_id.into())
25///     .default_sort("c.created")
26///     .default_order("ASC")
27///     .execute(pool, params)
28///     .await
29///
30/// // Avec recherche
31/// PaginatedQuery::<UserListItem>::new("SELECT u.id, u.username FROM users u")
32///     .allowed_sorts(&["id", "username", "email", "created"])
33///     .sort_prefix("u")
34///     .search_columns(&["u.username", "u.email"])
35///     .execute(pool, params)
36///     .await
37/// ```
38pub struct PaginatedQuery<T> {
39    base_sql: String,
40    count_from: Option<String>,
41    allowed_sorts: Vec<String>,
42    default_sort: String,
43    default_order: String,
44    sort_prefix: Option<String>,
45    search_columns: Vec<String>,
46    where_conditions: Vec<(String, &'static str, WhereValue)>,
47    _phantom: PhantomData<T>,
48}
49
50impl<T> PaginatedQuery<T>
51where
52    T: Send + Unpin + for<'r> sqlx::FromRow<'r, DbRow>,
53{
54    /// Crée un nouveau builder à partir d'une requête SELECT de base (sans ORDER BY/LIMIT).
55    pub fn new(base_sql: &str) -> Self {
56        Self {
57            base_sql: base_sql.to_string(),
58            count_from: None,
59            allowed_sorts: Vec::new(),
60            default_sort: "id".to_string(),
61            default_order: "DESC".to_string(),
62            sort_prefix: None,
63            search_columns: Vec::new(),
64            where_conditions: Vec::new(),
65            _phantom: PhantomData,
66        }
67    }
68
69    /// Colonnes de tri autorisées (whitelist pour éviter l'injection SQL).
70    pub fn allowed_sorts(mut self, sorts: &[&str]) -> Self {
71        self.allowed_sorts = sorts.iter().map(|s| s.to_string()).collect();
72        self
73    }
74
75    /// Colonne de tri par défaut.
76    pub fn default_sort(mut self, sort: &str) -> Self {
77        self.default_sort = sort.to_string();
78        self
79    }
80
81    /// Direction de tri par défaut ("ASC" ou "DESC").
82    pub fn default_order(mut self, order: &str) -> Self {
83        self.default_order = order.to_uppercase();
84        self
85    }
86
87    /// Préfixe de table pour le ORDER BY (ex: "u" → `ORDER BY u.id`).
88    pub fn sort_prefix(mut self, prefix: &str) -> Self {
89        self.sort_prefix = Some(prefix.to_string());
90        self
91    }
92
93    /// Colonnes à rechercher avec LIKE (activé si `params.search` est renseigné).
94    pub fn search_columns(mut self, columns: &[&str]) -> Self {
95        self.search_columns = columns.iter().map(|s| s.to_string()).collect();
96        self
97    }
98
99    /// Ajoute une condition WHERE (col = ?).
100    pub fn where_eq(mut self, column: &str, value: WhereValue) -> Self {
101        self.where_conditions.push((column.to_string(), "=", value));
102        self
103    }
104
105    /// Ajoute une condition WHERE (col >= ?).
106    pub fn where_gte(mut self, column: &str, value: WhereValue) -> Self {
107        self.where_conditions.push((column.to_string(), ">=", value));
108        self
109    }
110
111    /// Ajoute une condition WHERE (col <= ?).
112    pub fn where_lte(mut self, column: &str, value: WhereValue) -> Self {
113        self.where_conditions.push((column.to_string(), "<=", value));
114        self
115    }
116
117    /// Surcharge la requête COUNT (utile si le FROM est complexe).
118    /// Par défaut, le COUNT est déduit automatiquement de la requête de base.
119    pub fn count_from(mut self, sql: &str) -> Self {
120        self.count_from = Some(sql.to_string());
121        self
122    }
123
124    /// Exécute la requête paginée et retourne `(Vec<T>, total)`.
125    pub async fn execute(
126        self,
127        pool: &DbPool,
128        params: &PaginationParams,
129    ) -> AppResult<(Vec<T>, u64)> {
130        // ─── Résoudre le tri ───
131        let sort_col = if self.allowed_sorts.is_empty() {
132            self.default_sort.clone()
133        } else {
134            let allowed_refs: Vec<&str> = self.allowed_sorts.iter().map(|s| s.as_str()).collect();
135            params.sort_column(&allowed_refs, &self.default_sort).to_string()
136        };
137
138        let order = if params.order.to_lowercase() == "asc" {
139            "ASC"
140        } else if params.order.to_lowercase() == "desc" {
141            "DESC"
142        } else {
143            &self.default_order
144        };
145
146        let sort_expr = match &self.sort_prefix {
147            Some(prefix) if !sort_col.contains('.') => format!("{}.{}", prefix, sort_col),
148            _ => sort_col,
149        };
150
151        // ─── Construire WHERE avec placeholder tracking ───
152        let mut where_clauses: Vec<String> = Vec::new();
153        let mut bind_values: Vec<WhereValue> = Vec::new();
154        let mut ph_idx: usize = 0;
155
156        for (col, op, val) in &self.where_conditions {
157            ph_idx += 1;
158            where_clauses.push(format!("{} {} {}", col, op, placeholder(ph_idx)));
159            bind_values.push(val.clone());
160        }
161
162        // Search (LIKE)
163        let has_search = !self.search_columns.is_empty() && params.search.is_some();
164        if has_search {
165            let like_clauses: Vec<String> = self.search_columns
166                .iter()
167                .map(|col| {
168                    ph_idx += 1;
169                    format!("{} LIKE {}", col, placeholder(ph_idx))
170                })
171                .collect();
172            where_clauses.push(format!("({})", like_clauses.join(" OR ")));
173        }
174
175        let where_sql = if where_clauses.is_empty() {
176            String::new()
177        } else {
178            format!(" WHERE {}", where_clauses.join(" AND "))
179        };
180
181        // ─── Requête data ───
182        ph_idx += 1;
183        let limit_ph = placeholder(ph_idx);
184        ph_idx += 1;
185        let offset_ph = placeholder(ph_idx);
186
187        let data_sql = format!(
188            "{}{} ORDER BY {} {} LIMIT {} OFFSET {}",
189            self.base_sql, where_sql, sort_expr, order, limit_ph, offset_ph
190        );
191
192        // ─── Requête count (réutilise les mêmes index de placeholder que le WHERE) ───
193        let count_sql = if let Some(ref custom) = self.count_from {
194            format!("{}{}", custom, where_sql)
195        } else {
196            let from_part = extract_from_clause(&self.base_sql);
197            format!("SELECT COUNT(*) {}{}", from_part, where_sql)
198        };
199
200        // ─── Bind et exécution ───
201        let like_value = params.search.as_ref().map(|s| format!("%{}%", s));
202        let search_col_count = self.search_columns.len();
203
204        // Data query
205        let mut data_query = sqlx::query_as::<Db, T>(&data_sql);
206        for val in &bind_values {
207            data_query = bind_where_value(data_query, val);
208        }
209        if has_search {
210            if let Some(ref like) = like_value {
211                for _ in 0..search_col_count {
212                    data_query = data_query.bind(like.as_str());
213                }
214            }
215        }
216        data_query = data_query.bind(params.safe_per_page() as i64).bind(params.offset() as i64);
217        let items = data_query.fetch_all(pool).await?;
218
219        // Count query
220        let mut count_query = sqlx::query_as::<Db, (i64,)>(&count_sql);
221        for val in &bind_values {
222            count_query = bind_where_value_tuple(count_query, val);
223        }
224        if has_search {
225            if let Some(ref like) = like_value {
226                for _ in 0..search_col_count {
227                    count_query = count_query.bind(like.as_str());
228                }
229            }
230        }
231        let (total,) = count_query.fetch_one(pool).await?;
232
233        Ok((items, total as u64))
234    }
235}
236
237/// Extrait la clause FROM d'un SELECT (tout après le premier FROM/from).
238fn extract_from_clause(sql: &str) -> &str {
239    let upper = sql.to_uppercase();
240    if let Some(pos) = upper.find(" FROM ") {
241        &sql[pos..]
242    } else {
243        sql
244    }
245}
246
247/// Bind une WhereValue sur un query_as<T>
248fn bind_where_value<'q, T>(
249    query: sqlx::query::QueryAs<'q, Db, T, DbArguments>,
250    value: &'q WhereValue,
251) -> sqlx::query::QueryAs<'q, Db, T, DbArguments>
252where
253    T: Send + Unpin + for<'r> sqlx::FromRow<'r, DbRow>,
254{
255    match value {
256        WhereValue::Int(v) => query.bind(*v),
257        WhereValue::Float(v) => query.bind(*v),
258        WhereValue::String(v) => query.bind(v.as_str()),
259        WhereValue::Bool(v) => query.bind(*v),
260        WhereValue::DateTime(v) => query.bind(*v),
261    }
262}
263
264/// Bind une WhereValue sur un query_as<(i64,)> (pour les COUNT)
265fn bind_where_value_tuple<'q>(
266    query: sqlx::query::QueryAs<'q, Db, (i64,), DbArguments>,
267    value: &'q WhereValue,
268) -> sqlx::query::QueryAs<'q, Db, (i64,), DbArguments> {
269    match value {
270        WhereValue::Int(v) => query.bind(*v),
271        WhereValue::Float(v) => query.bind(*v),
272        WhereValue::String(v) => query.bind(v.as_str()),
273        WhereValue::Bool(v) => query.bind(*v),
274        WhereValue::DateTime(v) => query.bind(*v),
275    }
276}