data_pager/
sql.rs

1use crate::{
2    error::*,
3    utils::{decode_u64, encode_u64},
4    Container, PageInfo, Pager, Paginator,
5};
6use derive_builder::Builder;
7use itertools::Itertools;
8use serde::{Deserialize, Serialize};
9use snafu::ensure;
10use std::borrow::Cow;
11
12const MAX_PAGE_SIZE: u64 = 100;
13
14#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize, Builder)]
15#[builder(build_fn(name = "private_build"), setter(into, strip_option), default)]
16pub struct SqlQuery<'a> {
17    /// source table or view
18    pub source: Cow<'a, str>,
19    /// fields to include in the result
20    pub projection: Vec<Cow<'a, str>>,
21    /// filter condition (the WHERE clause)
22    pub filter: Option<Cow<'a, str>>,
23    /// sort order (the ORDER BY clause)
24    pub order: Option<Cow<'a, str>>,
25    /// previous page cursor, in base64 (right now this is just the number of items to skip)
26    pub cursor: Option<Cow<'a, str>>,
27    /// page size
28    pub page_size: u64,
29}
30
31impl<'a> SqlQueryBuilder<'a> {
32    pub fn build(&self) -> Result<SqlQuery<'a>, Error> {
33        let mut data = self
34            .private_build()
35            .expect("failed to build SqlQuery struct");
36        data.normalize();
37        data.validate()?;
38
39        Ok(data)
40    }
41}
42
43impl<'a> SqlQuery<'a> {
44    pub fn to_sql(&self) -> String {
45        let limit = self.page_size + 1;
46        let offset = self.get_cursor().unwrap_or_default();
47
48        let where_clause = if let Some(filter) = &self.filter {
49            Cow::Owned(format!("WHERE {filter}"))
50        } else {
51            Cow::Borrowed("")
52        };
53
54        let order_clause = if let Some(order) = &self.order {
55            Cow::Owned(format!("ORDER BY {order}"))
56        } else {
57            Cow::Borrowed("")
58        };
59
60        [
61            "SELECT",
62            &self.projection(),
63            "FROM",
64            &self.source,
65            &where_clause,
66            &order_clause,
67            "LIMIT",
68            &limit.to_string(),
69            "OFFSET",
70            &offset.to_string(),
71        ]
72        .iter()
73        .filter(|s| !s.is_empty())
74        .join(" ")
75    }
76
77    pub fn get_pager<T: Container>(&self, data: &mut T) -> Pager {
78        let page_info = self.page_info();
79        page_info.get_pager(data)
80    }
81
82    pub fn get_cursor(&self) -> Option<u64> {
83        self.cursor.as_deref().and_then(|c| decode_u64(c).ok())
84    }
85
86    pub fn next_page(&self, pager: &Pager) -> Option<Self> {
87        let page_info = self.page_info();
88        let page_info = page_info.next_page(pager);
89        page_info.map(|page_info| Self {
90            source: self.source.clone(),
91            projection: self.projection.clone(),
92            filter: self.filter.clone(),
93            order: self.order.clone(),
94            cursor: page_info.cursor.map(|c| encode_u64(c).into()),
95            page_size: page_info.page_size,
96        })
97    }
98
99    pub fn validate(&self) -> Result<(), Error> {
100        ensure!(
101            self.page_size > 0 && self.page_size < MAX_PAGE_SIZE,
102            InvalidPageSizeSnafu {
103                size: self.page_size
104            }
105        );
106        ensure!(!self.source.is_empty(), InvalidSourceSnafu);
107
108        Ok(())
109    }
110
111    pub fn normalize(&mut self) {
112        if self.page_size == 0 {
113            self.page_size = 10;
114        } else if self.page_size > MAX_PAGE_SIZE {
115            self.page_size = MAX_PAGE_SIZE;
116        }
117    }
118
119    fn page_info(&self) -> PageInfo {
120        PageInfo {
121            cursor: self.get_cursor(),
122            page_size: self.page_size,
123        }
124    }
125
126    fn projection(&self) -> Cow<'a, str> {
127        if self.projection.is_empty() {
128            return "*".into();
129        }
130
131        self.projection.iter().join(", ").into()
132    }
133}
134
135#[cfg(test)]
136mod tests {
137    use super::*;
138    use crate::pager_test_utils::generate_test_ids;
139    use anyhow::{Context, Result};
140
141    #[test]
142    fn sql_query_should_generate_right_sql() -> Result<()> {
143        let query = SqlQuery {
144            source: "users".into(),
145            projection: vec!["id".into(), "name".into()],
146            filter: Some("id > 10".into()),
147            order: Some("id DESC".into()),
148            cursor: Some(encode_u64(10).into()),
149            page_size: 10,
150        };
151
152        let sql = query.to_sql();
153        assert_eq!(
154            sql,
155            "SELECT id, name FROM users WHERE id > 10 ORDER BY id DESC LIMIT 11 OFFSET 10"
156        );
157
158        Ok(())
159    }
160
161    #[test]
162    fn sql_builder_should_get_correct_page_info() -> Result<()> {
163        let query = SqlQueryBuilder::default().source("users").build()?;
164
165        let mut data = generate_test_ids(1, 11);
166        let pager = query.get_pager(&mut data);
167        assert_eq!(pager.prev, None);
168        assert_eq!(pager.next, Some(10));
169
170        let query = query.next_page(&pager).context("no next page")?;
171        let sql = query.to_sql();
172        assert_eq!(sql, "SELECT * FROM users LIMIT 11 OFFSET 10");
173
174        // second page
175        let mut data = generate_test_ids(11, 21);
176        let pager = query.get_pager(&mut data);
177        assert_eq!(pager.prev, Some(0));
178        assert_eq!(pager.next, Some(20));
179        let query = query.next_page(&pager).context("no next page")?;
180        let sql = query.to_sql();
181        assert_eq!(sql, "SELECT * FROM users LIMIT 11 OFFSET 20");
182        Ok(())
183    }
184}