Skip to main content

modo/db/
select.rs

1use crate::error::{Error, Result};
2
3use super::conn::ConnExt;
4use super::filter::ValidatedFilter;
5use super::from_row::FromRow;
6use super::page::{CursorPage, CursorRequest, Page, PageRequest};
7
8/// Composable query builder combining filters, sorting, and pagination.
9///
10/// Created via [`ConnExt::select`] with a base SQL query. Chain
11/// [`filter`](Self::filter), [`order_by`](Self::order_by), and
12/// [`cursor_column`](Self::cursor_column) before executing with
13/// [`page`](Self::page), [`cursor`](Self::cursor), [`fetch_all`](Self::fetch_all),
14/// [`fetch_one`](Self::fetch_one), or [`fetch_optional`](Self::fetch_optional).
15pub struct SelectBuilder<'a, C: ConnExt> {
16    conn: &'a C,
17    base_sql: String,
18    filter: Option<ValidatedFilter>,
19    order_by: Option<String>,
20    cursor_column: String,
21    cursor_desc: bool,
22}
23
24impl<'a, C: ConnExt> SelectBuilder<'a, C> {
25    pub(crate) fn new(conn: &'a C, sql: &str) -> Self {
26        Self {
27            conn,
28            base_sql: sql.to_string(),
29            filter: None,
30            order_by: None,
31            cursor_column: "id".to_string(),
32            cursor_desc: true,
33        }
34    }
35
36    /// Apply a validated filter (WHERE clauses).
37    pub fn filter(mut self, filter: ValidatedFilter) -> Self {
38        self.filter = Some(filter);
39        self
40    }
41
42    /// Set ORDER BY clause. This is raw SQL — not user input.
43    /// If a filter has a sort_clause, it takes precedence over this.
44    pub fn order_by(mut self, order: &str) -> Self {
45        self.order_by = Some(order.to_string());
46        self
47    }
48
49    /// Set the column used for cursor pagination (default: `"id"`).
50    ///
51    /// The column must appear in the SELECT list and be sortable (e.g., ULID,
52    /// timestamp, auto-increment). Cursor pagination will ORDER BY this column
53    /// ascending and use it for the `WHERE col > ?` condition.
54    pub fn cursor_column(mut self, col: &str) -> Self {
55        self.cursor_column = col.to_string();
56        self
57    }
58
59    /// Use ascending (oldest-first) cursor ordering instead of the default DESC.
60    ///
61    /// By default cursor pagination orders descending (newest-first), which is the
62    /// common pattern for feeds and timelines. Call this to switch to ascending
63    /// order when you need chronological (oldest-first) traversal.
64    pub fn oldest_first(mut self) -> Self {
65        self.cursor_desc = false;
66        self
67    }
68
69    /// Build WHERE clause and params from filter.
70    fn build_where(&self) -> (String, Vec<libsql::Value>) {
71        match &self.filter {
72            Some(f) if !f.clauses.is_empty() => {
73                let where_sql = format!(" WHERE {}", f.clauses.join(" AND "));
74                (where_sql, f.params.clone())
75            }
76            _ => (String::new(), Vec::new()),
77        }
78    }
79
80    /// Resolve ORDER BY — filter sort takes precedence, then explicit order_by.
81    fn resolve_order(&self) -> Option<String> {
82        self.filter
83            .as_ref()
84            .and_then(|f| f.sort_clause.clone())
85            .or_else(|| self.order_by.clone())
86    }
87
88    /// Execute with offset pagination, returning a [`Page<T>`].
89    ///
90    /// Runs a `COUNT(*)` subquery for the total, then fetches the
91    /// requested page with `LIMIT`/`OFFSET`.
92    ///
93    /// # Errors
94    ///
95    /// Returns an error if the query or row conversion fails.
96    pub async fn page<T: FromRow + serde::Serialize>(self, req: PageRequest) -> Result<Page<T>> {
97        let (where_sql, mut params) = self.build_where();
98        let order = self.resolve_order();
99
100        // Count query
101        let count_sql = format!(
102            "SELECT COUNT(*) FROM ({}{}) AS _count",
103            self.base_sql, where_sql
104        );
105        let mut rows = self
106            .conn
107            .query_raw(&count_sql, params.clone())
108            .await
109            .map_err(Error::from)?;
110        let total: i64 = rows
111            .next()
112            .await
113            .map_err(Error::from)?
114            .ok_or_else(|| Error::internal("count query returned no rows"))?
115            .get(0)
116            .map_err(Error::from)?;
117
118        // Data query
119        let order_sql = order.map(|o| format!(" ORDER BY {o}")).unwrap_or_default();
120        let data_sql = format!(
121            "{}{}{} LIMIT ? OFFSET ?",
122            self.base_sql, where_sql, order_sql
123        );
124        params.push(libsql::Value::from(req.per_page));
125        params.push(libsql::Value::from(req.offset()));
126
127        let mut rows = self
128            .conn
129            .query_raw(&data_sql, params)
130            .await
131            .map_err(Error::from)?;
132        let mut items = Vec::new();
133        while let Some(row) = rows.next().await.map_err(Error::from)? {
134            items.push(T::from_row(&row)?);
135        }
136
137        Ok(Page::new(items, total, req.page, req.per_page))
138    }
139
140    /// Execute with cursor pagination. Returns [`CursorPage<T>`].
141    ///
142    /// By default orders descending (newest-first). Use
143    /// [`oldest_first`](Self::oldest_first) to switch to ascending order.
144    /// The cursor column can be changed with [`cursor_column`](Self::cursor_column).
145    ///
146    /// # Errors
147    ///
148    /// Returns an error if the query or row conversion fails.
149    pub async fn cursor<T: FromRow + serde::Serialize>(
150        self,
151        req: CursorRequest,
152    ) -> Result<CursorPage<T>> {
153        let (where_sql, mut params) = self.build_where();
154        let col = &self.cursor_column;
155
156        let (op, dir) = if self.cursor_desc {
157            ("<", "DESC")
158        } else {
159            (">", "ASC")
160        };
161
162        // Add cursor condition
163        let cursor_condition = if let Some(ref after) = req.after {
164            params.push(libsql::Value::from(after.clone()));
165            if where_sql.is_empty() {
166                format!(" WHERE \"{col}\" {op} ?")
167            } else {
168                format!(" AND \"{col}\" {op} ?")
169            }
170        } else {
171            String::new()
172        };
173
174        // Fetch one extra to determine has_more
175        let limit = req.per_page + 1;
176        let sql = format!(
177            "{}{}{} ORDER BY \"{col}\" {dir} LIMIT ?",
178            self.base_sql, where_sql, cursor_condition
179        );
180        params.push(libsql::Value::from(limit));
181
182        let mut rows = self
183            .conn
184            .query_raw(&sql, params)
185            .await
186            .map_err(Error::from)?;
187
188        // Track cursor values alongside items for cursor extraction.
189        // Find the cursor column index dynamically on the first row.
190        let mut items = Vec::new();
191        let mut cursor_values: Vec<Option<String>> = Vec::new();
192        let mut cursor_col_idx: Option<i32> = None;
193        while let Some(row) = rows.next().await.map_err(Error::from)? {
194            if cursor_col_idx.is_none() {
195                cursor_col_idx = Some(
196                    (0..row.column_count())
197                        .find(|&i| row.column_name(i) == Some(col))
198                        .ok_or_else(|| {
199                            Error::internal(format!(
200                                "cursor column '{col}' not found in result set"
201                            ))
202                        })?,
203                );
204            }
205            let idx = cursor_col_idx.expect("cursor column index was set on first row");
206            let cursor_val = match row.get_value(idx) {
207                Ok(libsql::Value::Text(s)) => Some(s),
208                Ok(libsql::Value::Integer(n)) => Some(n.to_string()),
209                Ok(libsql::Value::Real(f)) => Some(f.to_string()),
210                _ => None,
211            };
212            cursor_values.push(cursor_val);
213            items.push(T::from_row(&row)?);
214        }
215
216        let has_more = items.len() as i64 > req.per_page;
217        if has_more {
218            items.pop();
219            cursor_values.pop();
220        }
221
222        let next_cursor = if has_more {
223            cursor_values.last().cloned().flatten()
224        } else {
225            None
226        };
227
228        Ok(CursorPage {
229            items,
230            next_cursor,
231            has_more,
232            per_page: req.per_page,
233        })
234    }
235
236    /// Execute without pagination, returning all matching rows.
237    ///
238    /// # Errors
239    ///
240    /// Returns an error if the query or row conversion fails.
241    pub async fn fetch_all<T: FromRow>(self) -> Result<Vec<T>> {
242        let (where_sql, params) = self.build_where();
243        let order = self.resolve_order();
244        let order_sql = order.map(|o| format!(" ORDER BY {o}")).unwrap_or_default();
245        let sql = format!("{}{}{}", self.base_sql, where_sql, order_sql);
246
247        let mut rows = self
248            .conn
249            .query_raw(&sql, params)
250            .await
251            .map_err(Error::from)?;
252        let mut items = Vec::new();
253        while let Some(row) = rows.next().await.map_err(Error::from)? {
254            items.push(T::from_row(&row)?);
255        }
256        Ok(items)
257    }
258
259    /// Execute without pagination, returning the first row.
260    ///
261    /// # Errors
262    ///
263    /// Returns [`Error::not_found`](crate::Error::not_found) if no rows match.
264    pub async fn fetch_one<T: FromRow>(self) -> Result<T> {
265        let (where_sql, params) = self.build_where();
266        let sql = format!("{}{} LIMIT 1", self.base_sql, where_sql);
267
268        let mut rows = self
269            .conn
270            .query_raw(&sql, params)
271            .await
272            .map_err(Error::from)?;
273        let row = rows
274            .next()
275            .await
276            .map_err(Error::from)?
277            .ok_or_else(|| Error::not_found("record not found"))?;
278        T::from_row(&row)
279    }
280
281    /// Execute without pagination, returning the first row or `None`.
282    ///
283    /// # Errors
284    ///
285    /// Returns an error if the query or row conversion fails.
286    pub async fn fetch_optional<T: FromRow>(self) -> Result<Option<T>> {
287        let (where_sql, params) = self.build_where();
288        let sql = format!("{}{} LIMIT 1", self.base_sql, where_sql);
289
290        let mut rows = self
291            .conn
292            .query_raw(&sql, params)
293            .await
294            .map_err(Error::from)?;
295        match rows.next().await.map_err(Error::from)? {
296            Some(row) => Ok(Some(T::from_row(&row)?)),
297            None => Ok(None),
298        }
299    }
300}