dojo_orm/
pagination.rs

1mod async_graphql;
2
3use std::error::Error;
4use std::fmt::{Debug, Display};
5use std::marker::PhantomData;
6
7use crate::order_by::Direction;
8use crate::types::{accepts, to_sql_checked, IsNull, ToSql, Type};
9use crate::Model;
10use anyhow::Result;
11use base64ct::{Base64, Encoding};
12use bytes::BytesMut;
13use chrono::NaiveDateTime;
14use serde::{Deserialize, Serialize};
15use tracing::debug;
16
17pub use async_graphql::*;
18
19pub trait DefaultSortKeys {
20    fn keys() -> Vec<String>;
21
22    fn order_by_stmt(direction: Direction) -> String {
23        let mut stmt = "".to_string();
24        for (i, order) in Self::keys().iter().enumerate() {
25            if i > 0 {
26                stmt.push_str(", ");
27            }
28            stmt.push_str(order);
29            if i == 0 {
30                if direction == Direction::Asc {
31                    stmt.push_str(" ASC");
32                } else {
33                    stmt.push_str(" DESC");
34                }
35            } else {
36                stmt.push_str(" ASC");
37            }
38        }
39
40        stmt
41    }
42}
43
44#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct Row {
46    pub column: String,
47    pub value: crate::model::Value,
48}
49
50impl Row {
51    pub fn new(column: String, value: crate::model::Value) -> Self {
52        Self { column, value }
53    }
54}
55
56#[derive(Debug, Clone)]
57pub struct Cursor {
58    pub values: Vec<Row>,
59}
60
61impl Cursor {
62    pub fn to_where_stmt(&self, direction: Direction) -> (String, Vec<&(dyn ToSql + Sync)>) {
63        let mut columns = vec![];
64        let mut params: Vec<&(dyn ToSql + Sync)> = vec![];
65        for value in &self.values {
66            columns.push(value.column.clone());
67            params.push(&value.value);
68        }
69        let mut stmt = "(".to_string();
70        stmt.push_str(&columns.join(", "));
71        stmt.push_str(") ");
72
73        if direction == Direction::Asc {
74            stmt.push('>');
75        } else {
76            stmt.push('<');
77        }
78        stmt.push_str(" (");
79        stmt.push_str(
80            &params
81                .iter()
82                .enumerate()
83                .map(|(i, _)| format!("${}", i + 1))
84                .collect::<Vec<_>>()
85                .join(", "),
86        );
87        stmt.push(')');
88
89        (stmt, params)
90    }
91
92    pub fn to_order_by_stmt(&self, direction: Direction) -> String {
93        let keys = self
94            .values
95            .iter()
96            .map(|v| v.column.clone())
97            .collect::<Vec<_>>();
98        Self::order_by_stmt_by_keys(&keys, direction)
99    }
100
101    pub fn order_by_stmt_by_keys(keys: &[String], direction: Direction) -> String {
102        let mut stmt = "".to_string();
103        if let Some(value) = keys.first() {
104            stmt.push_str(value);
105            if direction == Direction::Asc {
106                stmt.push_str(" ASC");
107            } else {
108                stmt.push_str(" DESC");
109            }
110        }
111
112        for value in keys.iter().skip(1) {
113            stmt.push_str(", ");
114            stmt.push_str(value);
115            stmt.push_str(" ASC");
116        }
117
118        stmt
119    }
120}
121
122impl Cursor {
123    pub fn new(values: Vec<Row>) -> Self {
124        Self { values }
125    }
126
127    pub fn encode(&self) -> String {
128        // it's safe, trust me bro.
129        let buf = bincode::serialize(&self.values).unwrap();
130        Base64::encode_string(&buf)
131    }
132
133    pub fn decode(encoded: &str) -> Result<Self> {
134        let decoded = Base64::decode_vec(encoded).unwrap();
135        let values: Vec<Row> = bincode::deserialize(&decoded[..])?;
136        Ok(Self { values })
137    }
138}
139
140#[derive(Debug)]
141pub struct Pagination<T>
142where
143    T: Model + Debug,
144{
145    pub items: Vec<T>,
146    pub before: Option<Cursor>,
147    pub after: Option<Cursor>,
148    pub first: Option<i64>,
149    pub last: Option<i64>,
150    pub total_nodes: i64,
151    pub has_next: bool,
152    pub has_previous: bool,
153}
154
155impl<T> Pagination<T>
156where
157    T: Model + Debug,
158{
159    pub fn new(
160        items: Vec<T>,
161        first: Option<i64>,
162        after: Option<Cursor>,
163        last: Option<i64>,
164        before: Option<Cursor>,
165        total_nodes: i64,
166    ) -> Self {
167        let limit = first.or(last).unwrap_or(0);
168        let has_next_or_previous = items.len() as i64 == limit + 1;
169        let has_next = first.is_some() && has_next_or_previous;
170        let has_previous = last.is_some() && has_next_or_previous;
171
172        let mut items = items;
173        if has_next_or_previous {
174            items.pop();
175        }
176
177        Self {
178            items,
179            before,
180            after,
181            first,
182            last,
183            total_nodes,
184            has_next,
185            has_previous,
186        }
187    }
188
189    pub fn end_cursor(&self) -> Option<Cursor> {
190        self.items.last().map(|item| item.cursor())
191    }
192}
193
194#[cfg(test)]
195mod tests {
196    use super::*;
197    use chrono::NaiveDateTime;
198    use dojo_macros::Model;
199    use googletest::prelude::*;
200    use uuid::Uuid;
201
202    #[test]
203    fn test_cursor_to_sql_with_1_key() -> anyhow::Result<()> {
204        let created_at = NaiveDateTime::parse_from_str("2024-01-07 12:34:56", "%Y-%m-%d %H:%M:%S")?;
205        let cursor_value = Row {
206            column: "created_at".to_string(),
207            value: crate::model::Value::NaiveDateTime(created_at),
208        };
209        let cursor = Cursor::new(vec![cursor_value]);
210        let (sql, params) = cursor.to_where_stmt(Direction::Asc);
211        println!("sql: {}", sql);
212        println!("params: {:?}", params);
213
214        Ok(())
215    }
216
217    #[test]
218    fn test_cursor_to_sql_with_2_key() -> anyhow::Result<()> {
219        let created_at = NaiveDateTime::parse_from_str("2024-01-07 12:34:56", "%Y-%m-%d %H:%M:%S")?;
220        let uuid = Uuid::parse_str("ce2087a7-bdbc-4453-9fb8-d4dff3584f3e")?;
221        let cursor = Cursor::new(vec![
222            Row {
223                column: "created_at".to_string(),
224                value: crate::model::Value::NaiveDateTime(created_at),
225            },
226            Row {
227                column: "id".to_string(),
228                value: crate::model::Value::Uuid(uuid),
229            },
230        ]);
231        let (sql, params) = cursor.to_where_stmt(Direction::Asc);
232        println!("sql: {}", sql);
233        println!("params: {:?}", params);
234
235        Ok(())
236    }
237
238    #[test]
239    fn test_decode_cursor() -> anyhow::Result<()> {
240        let created_at = NaiveDateTime::parse_from_str("2024-01-07 12:34:56", "%Y-%m-%d %H:%M:%S")?;
241        let cursor_value = Row {
242            column: "created_at".to_string(),
243            value: crate::model::Value::NaiveDateTime(created_at),
244        };
245        let cursor = Cursor::new(vec![cursor_value]);
246        let encoded = cursor.encode();
247
248        let decoded = Cursor::decode(&encoded).unwrap();
249        assert_that!(
250            decoded,
251            pat!(Cursor {
252                values: contains_each![pat!(Row {
253                    column: eq("created_at".to_string()),
254                    value: eq(crate::model::Value::NaiveDateTime(created_at)),
255                }),],
256            })
257        );
258
259        Ok(())
260    }
261}