1use crate::{
2 error::*,
3 utils::{decode_u64, encode_u64},
4 Container, PageInfo, Pager, Paginator,
5};
6use derive_builder::Builder;
7use itertools::Itertools;
8use serde::{Deserialize, Serialize};
9use snafu::ensure;
10use std::borrow::Cow;
11
12const MAX_PAGE_SIZE: u64 = 100;
13
14#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize, Builder)]
15#[builder(build_fn(name = "private_build"), setter(into, strip_option), default)]
16pub struct SqlQuery<'a> {
17 pub source: Cow<'a, str>,
19 pub projection: Vec<Cow<'a, str>>,
21 pub filter: Option<Cow<'a, str>>,
23 pub order: Option<Cow<'a, str>>,
25 pub cursor: Option<Cow<'a, str>>,
27 pub page_size: u64,
29}
30
31impl<'a> SqlQueryBuilder<'a> {
32 pub fn build(&self) -> Result<SqlQuery<'a>, Error> {
33 let mut data = self
34 .private_build()
35 .expect("failed to build SqlQuery struct");
36 data.normalize();
37 data.validate()?;
38
39 Ok(data)
40 }
41}
42
43impl<'a> SqlQuery<'a> {
44 pub fn to_sql(&self) -> String {
45 let limit = self.page_size + 1;
46 let offset = self.get_cursor().unwrap_or_default();
47
48 let where_clause = if let Some(filter) = &self.filter {
49 Cow::Owned(format!("WHERE {filter}"))
50 } else {
51 Cow::Borrowed("")
52 };
53
54 let order_clause = if let Some(order) = &self.order {
55 Cow::Owned(format!("ORDER BY {order}"))
56 } else {
57 Cow::Borrowed("")
58 };
59
60 [
61 "SELECT",
62 &self.projection(),
63 "FROM",
64 &self.source,
65 &where_clause,
66 &order_clause,
67 "LIMIT",
68 &limit.to_string(),
69 "OFFSET",
70 &offset.to_string(),
71 ]
72 .iter()
73 .filter(|s| !s.is_empty())
74 .join(" ")
75 }
76
77 pub fn get_pager<T: Container>(&self, data: &mut T) -> Pager {
78 let page_info = self.page_info();
79 page_info.get_pager(data)
80 }
81
82 pub fn get_cursor(&self) -> Option<u64> {
83 self.cursor.as_deref().and_then(|c| decode_u64(c).ok())
84 }
85
86 pub fn next_page(&self, pager: &Pager) -> Option<Self> {
87 let page_info = self.page_info();
88 let page_info = page_info.next_page(pager);
89 page_info.map(|page_info| Self {
90 source: self.source.clone(),
91 projection: self.projection.clone(),
92 filter: self.filter.clone(),
93 order: self.order.clone(),
94 cursor: page_info.cursor.map(|c| encode_u64(c).into()),
95 page_size: page_info.page_size,
96 })
97 }
98
99 pub fn validate(&self) -> Result<(), Error> {
100 ensure!(
101 self.page_size > 0 && self.page_size < MAX_PAGE_SIZE,
102 InvalidPageSizeSnafu {
103 size: self.page_size
104 }
105 );
106 ensure!(!self.source.is_empty(), InvalidSourceSnafu);
107
108 Ok(())
109 }
110
111 pub fn normalize(&mut self) {
112 if self.page_size == 0 {
113 self.page_size = 10;
114 } else if self.page_size > MAX_PAGE_SIZE {
115 self.page_size = MAX_PAGE_SIZE;
116 }
117 }
118
119 fn page_info(&self) -> PageInfo {
120 PageInfo {
121 cursor: self.get_cursor(),
122 page_size: self.page_size,
123 }
124 }
125
126 fn projection(&self) -> Cow<'a, str> {
127 if self.projection.is_empty() {
128 return "*".into();
129 }
130
131 self.projection.iter().join(", ").into()
132 }
133}
134
135#[cfg(test)]
136mod tests {
137 use super::*;
138 use crate::pager_test_utils::generate_test_ids;
139 use anyhow::{Context, Result};
140
141 #[test]
142 fn sql_query_should_generate_right_sql() -> Result<()> {
143 let query = SqlQuery {
144 source: "users".into(),
145 projection: vec!["id".into(), "name".into()],
146 filter: Some("id > 10".into()),
147 order: Some("id DESC".into()),
148 cursor: Some(encode_u64(10).into()),
149 page_size: 10,
150 };
151
152 let sql = query.to_sql();
153 assert_eq!(
154 sql,
155 "SELECT id, name FROM users WHERE id > 10 ORDER BY id DESC LIMIT 11 OFFSET 10"
156 );
157
158 Ok(())
159 }
160
161 #[test]
162 fn sql_builder_should_get_correct_page_info() -> Result<()> {
163 let query = SqlQueryBuilder::default().source("users").build()?;
164
165 let mut data = generate_test_ids(1, 11);
166 let pager = query.get_pager(&mut data);
167 assert_eq!(pager.prev, None);
168 assert_eq!(pager.next, Some(10));
169
170 let query = query.next_page(&pager).context("no next page")?;
171 let sql = query.to_sql();
172 assert_eq!(sql, "SELECT * FROM users LIMIT 11 OFFSET 10");
173
174 let mut data = generate_test_ids(11, 21);
176 let pager = query.get_pager(&mut data);
177 assert_eq!(pager.prev, Some(0));
178 assert_eq!(pager.next, Some(20));
179 let query = query.next_page(&pager).context("no next page")?;
180 let sql = query.to_sql();
181 assert_eq!(sql, "SELECT * FROM users LIMIT 11 OFFSET 20");
182 Ok(())
183 }
184}