Skip to main content

karbon_framework/db/
paginated_query.rs

1use std::marker::PhantomData;
2
3use super::{Db, DbArguments, DbPool, DbRow, placeholder};
4use super::pagination::PaginationParams;
5use super::repository::WhereValue;
6use crate::error::AppResult;
7
8/// Builder pour les requêtes paginées avec tri, recherche et filtres WHERE.
9///
10/// # Exemples
11///
12/// ```ignore
13/// // Simple
14/// PaginatedQuery::<MediaFile>::new("SELECT * FROM media_file")
15///     .allowed_sorts(&["id", "filename", "created", "size"])
16///     .default_sort("id")
17///     .execute(pool, params)
18///     .await
19///
20/// // Avec WHERE
21/// PaginatedQuery::<CommentListItem>::new(
22///     "SELECT c.id, c.comment, u.username FROM comment c LEFT JOIN users u ON c.user_id = u.id"
23/// )
24///     .where_eq("c.content_id", content_id.into())
25///     .default_sort("c.created")
26///     .default_order("ASC")
27///     .execute(pool, params)
28///     .await
29///
30/// // Avec recherche
31/// PaginatedQuery::<UserListItem>::new("SELECT u.id, u.username FROM users u")
32///     .allowed_sorts(&["id", "username", "email", "created"])
33///     .sort_prefix("u")
34///     .search_columns(&["u.username", "u.email"])
35///     .execute(pool, params)
36///     .await
37/// ```
38pub struct PaginatedQuery<T> {
39    base_sql: String,
40    count_from: Option<String>,
41    allowed_sorts: Vec<String>,
42    default_sort: String,
43    default_order: String,
44    sort_prefix: Option<String>,
45    search_columns: Vec<String>,
46    where_conditions: Vec<(String, WhereValue)>,
47    _phantom: PhantomData<T>,
48}
49
50impl<T> PaginatedQuery<T>
51where
52    T: Send + Unpin + for<'r> sqlx::FromRow<'r, DbRow>,
53{
54    /// Crée un nouveau builder à partir d'une requête SELECT de base (sans ORDER BY/LIMIT).
55    pub fn new(base_sql: &str) -> Self {
56        Self {
57            base_sql: base_sql.to_string(),
58            count_from: None,
59            allowed_sorts: Vec::new(),
60            default_sort: "id".to_string(),
61            default_order: "DESC".to_string(),
62            sort_prefix: None,
63            search_columns: Vec::new(),
64            where_conditions: Vec::new(),
65            _phantom: PhantomData,
66        }
67    }
68
69    /// Colonnes de tri autorisées (whitelist pour éviter l'injection SQL).
70    pub fn allowed_sorts(mut self, sorts: &[&str]) -> Self {
71        self.allowed_sorts = sorts.iter().map(|s| s.to_string()).collect();
72        self
73    }
74
75    /// Colonne de tri par défaut.
76    pub fn default_sort(mut self, sort: &str) -> Self {
77        self.default_sort = sort.to_string();
78        self
79    }
80
81    /// Direction de tri par défaut ("ASC" ou "DESC").
82    pub fn default_order(mut self, order: &str) -> Self {
83        self.default_order = order.to_uppercase();
84        self
85    }
86
87    /// Préfixe de table pour le ORDER BY (ex: "u" → `ORDER BY u.id`).
88    pub fn sort_prefix(mut self, prefix: &str) -> Self {
89        self.sort_prefix = Some(prefix.to_string());
90        self
91    }
92
93    /// Colonnes à rechercher avec LIKE (activé si `params.search` est renseigné).
94    pub fn search_columns(mut self, columns: &[&str]) -> Self {
95        self.search_columns = columns.iter().map(|s| s.to_string()).collect();
96        self
97    }
98
99    /// Ajoute une condition WHERE (col = ?).
100    pub fn where_eq(mut self, column: &str, value: WhereValue) -> Self {
101        self.where_conditions.push((column.to_string(), value));
102        self
103    }
104
105    /// Surcharge la requête COUNT (utile si le FROM est complexe).
106    /// Par défaut, le COUNT est déduit automatiquement de la requête de base.
107    pub fn count_from(mut self, sql: &str) -> Self {
108        self.count_from = Some(sql.to_string());
109        self
110    }
111
112    /// Exécute la requête paginée et retourne `(Vec<T>, total)`.
113    pub async fn execute(
114        self,
115        pool: &DbPool,
116        params: &PaginationParams,
117    ) -> AppResult<(Vec<T>, u64)> {
118        // ─── Résoudre le tri ───
119        let sort_col = if self.allowed_sorts.is_empty() {
120            self.default_sort.clone()
121        } else {
122            let allowed_refs: Vec<&str> = self.allowed_sorts.iter().map(|s| s.as_str()).collect();
123            params.sort_column(&allowed_refs, &self.default_sort).to_string()
124        };
125
126        let order = if params.order.to_lowercase() == "asc" {
127            "ASC"
128        } else if params.order.to_lowercase() == "desc" {
129            "DESC"
130        } else {
131            &self.default_order
132        };
133
134        let sort_expr = match &self.sort_prefix {
135            Some(prefix) if !sort_col.contains('.') => format!("{}.{}", prefix, sort_col),
136            _ => sort_col,
137        };
138
139        // ─── Construire WHERE avec placeholder tracking ───
140        let mut where_clauses: Vec<String> = Vec::new();
141        let mut bind_values: Vec<WhereValue> = Vec::new();
142        let mut ph_idx: usize = 0;
143
144        for (col, val) in &self.where_conditions {
145            ph_idx += 1;
146            where_clauses.push(format!("{} = {}", col, placeholder(ph_idx)));
147            bind_values.push(val.clone());
148        }
149
150        // Search (LIKE)
151        let has_search = !self.search_columns.is_empty() && params.search.is_some();
152        if has_search {
153            let like_clauses: Vec<String> = self.search_columns
154                .iter()
155                .map(|col| {
156                    ph_idx += 1;
157                    format!("{} LIKE {}", col, placeholder(ph_idx))
158                })
159                .collect();
160            where_clauses.push(format!("({})", like_clauses.join(" OR ")));
161        }
162
163        let where_sql = if where_clauses.is_empty() {
164            String::new()
165        } else {
166            format!(" WHERE {}", where_clauses.join(" AND "))
167        };
168
169        // ─── Requête data ───
170        ph_idx += 1;
171        let limit_ph = placeholder(ph_idx);
172        ph_idx += 1;
173        let offset_ph = placeholder(ph_idx);
174
175        let data_sql = format!(
176            "{}{} ORDER BY {} {} LIMIT {} OFFSET {}",
177            self.base_sql, where_sql, sort_expr, order, limit_ph, offset_ph
178        );
179
180        // ─── Requête count (réutilise les mêmes index de placeholder que le WHERE) ───
181        let count_sql = if let Some(ref custom) = self.count_from {
182            format!("{}{}", custom, where_sql)
183        } else {
184            let from_part = extract_from_clause(&self.base_sql);
185            format!("SELECT COUNT(*) {}{}", from_part, where_sql)
186        };
187
188        // ─── Bind et exécution ───
189        let like_value = params.search.as_ref().map(|s| format!("%{}%", s));
190        let search_col_count = self.search_columns.len();
191
192        // Data query
193        let mut data_query = sqlx::query_as::<Db, T>(&data_sql);
194        for val in &bind_values {
195            data_query = bind_where_value(data_query, val);
196        }
197        if has_search {
198            if let Some(ref like) = like_value {
199                for _ in 0..search_col_count {
200                    data_query = data_query.bind(like.as_str());
201                }
202            }
203        }
204        data_query = data_query.bind(params.safe_per_page() as i64).bind(params.offset() as i64);
205        let items = data_query.fetch_all(pool).await?;
206
207        // Count query
208        let mut count_query = sqlx::query_as::<Db, (i64,)>(&count_sql);
209        for val in &bind_values {
210            count_query = bind_where_value_tuple(count_query, val);
211        }
212        if has_search {
213            if let Some(ref like) = like_value {
214                for _ in 0..search_col_count {
215                    count_query = count_query.bind(like.as_str());
216                }
217            }
218        }
219        let (total,) = count_query.fetch_one(pool).await?;
220
221        Ok((items, total as u64))
222    }
223}
224
225/// Extrait la clause FROM d'un SELECT (tout après le premier FROM/from).
226fn extract_from_clause(sql: &str) -> &str {
227    let upper = sql.to_uppercase();
228    if let Some(pos) = upper.find(" FROM ") {
229        &sql[pos..]
230    } else {
231        sql
232    }
233}
234
235/// Bind une WhereValue sur un query_as<T>
236fn bind_where_value<'q, T>(
237    query: sqlx::query::QueryAs<'q, Db, T, DbArguments>,
238    value: &'q WhereValue,
239) -> sqlx::query::QueryAs<'q, Db, T, DbArguments>
240where
241    T: Send + Unpin + for<'r> sqlx::FromRow<'r, DbRow>,
242{
243    match value {
244        WhereValue::Int(v) => query.bind(*v),
245        WhereValue::Float(v) => query.bind(*v),
246        WhereValue::String(v) => query.bind(v.as_str()),
247        WhereValue::Bool(v) => query.bind(*v),
248        WhereValue::DateTime(v) => query.bind(*v),
249    }
250}
251
252/// Bind une WhereValue sur un query_as<(i64,)> (pour les COUNT)
253fn bind_where_value_tuple<'q>(
254    query: sqlx::query::QueryAs<'q, Db, (i64,), DbArguments>,
255    value: &'q WhereValue,
256) -> sqlx::query::QueryAs<'q, Db, (i64,), DbArguments> {
257    match value {
258        WhereValue::Int(v) => query.bind(*v),
259        WhereValue::Float(v) => query.bind(*v),
260        WhereValue::String(v) => query.bind(v.as_str()),
261        WhereValue::Bool(v) => query.bind(*v),
262        WhereValue::DateTime(v) => query.bind(*v),
263    }
264}