1mod async_graphql;
2
3use std::error::Error;
4use std::fmt::{Debug, Display};
5use std::marker::PhantomData;
6
7use crate::order_by::Direction;
8use crate::types::{accepts, to_sql_checked, IsNull, ToSql, Type};
9use crate::Model;
10use anyhow::Result;
11use base64ct::{Base64, Encoding};
12use bytes::BytesMut;
13use chrono::NaiveDateTime;
14use serde::{Deserialize, Serialize};
15use tracing::debug;
16
17pub use async_graphql::*;
18
19pub trait DefaultSortKeys {
20 fn keys() -> Vec<String>;
21
22 fn order_by_stmt(direction: Direction) -> String {
23 let mut stmt = "".to_string();
24 for (i, order) in Self::keys().iter().enumerate() {
25 if i > 0 {
26 stmt.push_str(", ");
27 }
28 stmt.push_str(order);
29 if i == 0 {
30 if direction == Direction::Asc {
31 stmt.push_str(" ASC");
32 } else {
33 stmt.push_str(" DESC");
34 }
35 } else {
36 stmt.push_str(" ASC");
37 }
38 }
39
40 stmt
41 }
42}
43
44#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct Row {
46 pub column: String,
47 pub value: crate::model::Value,
48}
49
50impl Row {
51 pub fn new(column: String, value: crate::model::Value) -> Self {
52 Self { column, value }
53 }
54}
55
56#[derive(Debug, Clone)]
57pub struct Cursor {
58 pub values: Vec<Row>,
59}
60
61impl Cursor {
62 pub fn to_where_stmt(&self, direction: Direction) -> (String, Vec<&(dyn ToSql + Sync)>) {
63 let mut columns = vec![];
64 let mut params: Vec<&(dyn ToSql + Sync)> = vec![];
65 for value in &self.values {
66 columns.push(value.column.clone());
67 params.push(&value.value);
68 }
69 let mut stmt = "(".to_string();
70 stmt.push_str(&columns.join(", "));
71 stmt.push_str(") ");
72
73 if direction == Direction::Asc {
74 stmt.push('>');
75 } else {
76 stmt.push('<');
77 }
78 stmt.push_str(" (");
79 stmt.push_str(
80 ¶ms
81 .iter()
82 .enumerate()
83 .map(|(i, _)| format!("${}", i + 1))
84 .collect::<Vec<_>>()
85 .join(", "),
86 );
87 stmt.push(')');
88
89 (stmt, params)
90 }
91
92 pub fn to_order_by_stmt(&self, direction: Direction) -> String {
93 let keys = self
94 .values
95 .iter()
96 .map(|v| v.column.clone())
97 .collect::<Vec<_>>();
98 Self::order_by_stmt_by_keys(&keys, direction)
99 }
100
101 pub fn order_by_stmt_by_keys(keys: &[String], direction: Direction) -> String {
102 let mut stmt = "".to_string();
103 if let Some(value) = keys.first() {
104 stmt.push_str(value);
105 if direction == Direction::Asc {
106 stmt.push_str(" ASC");
107 } else {
108 stmt.push_str(" DESC");
109 }
110 }
111
112 for value in keys.iter().skip(1) {
113 stmt.push_str(", ");
114 stmt.push_str(value);
115 stmt.push_str(" ASC");
116 }
117
118 stmt
119 }
120}
121
122impl Cursor {
123 pub fn new(values: Vec<Row>) -> Self {
124 Self { values }
125 }
126
127 pub fn encode(&self) -> String {
128 let buf = bincode::serialize(&self.values).unwrap();
130 Base64::encode_string(&buf)
131 }
132
133 pub fn decode(encoded: &str) -> Result<Self> {
134 let decoded = Base64::decode_vec(encoded).unwrap();
135 let values: Vec<Row> = bincode::deserialize(&decoded[..])?;
136 Ok(Self { values })
137 }
138}
139
140#[derive(Debug)]
141pub struct Pagination<T>
142where
143 T: Model + Debug,
144{
145 pub items: Vec<T>,
146 pub before: Option<Cursor>,
147 pub after: Option<Cursor>,
148 pub first: Option<i64>,
149 pub last: Option<i64>,
150 pub total_nodes: i64,
151 pub has_next: bool,
152 pub has_previous: bool,
153}
154
155impl<T> Pagination<T>
156where
157 T: Model + Debug,
158{
159 pub fn new(
160 items: Vec<T>,
161 first: Option<i64>,
162 after: Option<Cursor>,
163 last: Option<i64>,
164 before: Option<Cursor>,
165 total_nodes: i64,
166 ) -> Self {
167 let limit = first.or(last).unwrap_or(0);
168 let has_next_or_previous = items.len() as i64 == limit + 1;
169 let has_next = first.is_some() && has_next_or_previous;
170 let has_previous = last.is_some() && has_next_or_previous;
171
172 let mut items = items;
173 if has_next_or_previous {
174 items.pop();
175 }
176
177 Self {
178 items,
179 before,
180 after,
181 first,
182 last,
183 total_nodes,
184 has_next,
185 has_previous,
186 }
187 }
188
189 pub fn end_cursor(&self) -> Option<Cursor> {
190 self.items.last().map(|item| item.cursor())
191 }
192}
193
194#[cfg(test)]
195mod tests {
196 use super::*;
197 use chrono::NaiveDateTime;
198 use dojo_macros::Model;
199 use googletest::prelude::*;
200 use uuid::Uuid;
201
202 #[test]
203 fn test_cursor_to_sql_with_1_key() -> anyhow::Result<()> {
204 let created_at = NaiveDateTime::parse_from_str("2024-01-07 12:34:56", "%Y-%m-%d %H:%M:%S")?;
205 let cursor_value = Row {
206 column: "created_at".to_string(),
207 value: crate::model::Value::NaiveDateTime(created_at),
208 };
209 let cursor = Cursor::new(vec![cursor_value]);
210 let (sql, params) = cursor.to_where_stmt(Direction::Asc);
211 println!("sql: {}", sql);
212 println!("params: {:?}", params);
213
214 Ok(())
215 }
216
217 #[test]
218 fn test_cursor_to_sql_with_2_key() -> anyhow::Result<()> {
219 let created_at = NaiveDateTime::parse_from_str("2024-01-07 12:34:56", "%Y-%m-%d %H:%M:%S")?;
220 let uuid = Uuid::parse_str("ce2087a7-bdbc-4453-9fb8-d4dff3584f3e")?;
221 let cursor = Cursor::new(vec![
222 Row {
223 column: "created_at".to_string(),
224 value: crate::model::Value::NaiveDateTime(created_at),
225 },
226 Row {
227 column: "id".to_string(),
228 value: crate::model::Value::Uuid(uuid),
229 },
230 ]);
231 let (sql, params) = cursor.to_where_stmt(Direction::Asc);
232 println!("sql: {}", sql);
233 println!("params: {:?}", params);
234
235 Ok(())
236 }
237
238 #[test]
239 fn test_decode_cursor() -> anyhow::Result<()> {
240 let created_at = NaiveDateTime::parse_from_str("2024-01-07 12:34:56", "%Y-%m-%d %H:%M:%S")?;
241 let cursor_value = Row {
242 column: "created_at".to_string(),
243 value: crate::model::Value::NaiveDateTime(created_at),
244 };
245 let cursor = Cursor::new(vec![cursor_value]);
246 let encoded = cursor.encode();
247
248 let decoded = Cursor::decode(&encoded).unwrap();
249 assert_that!(
250 decoded,
251 pat!(Cursor {
252 values: contains_each![pat!(Row {
253 column: eq("created_at".to_string()),
254 value: eq(crate::model::Value::NaiveDateTime(created_at)),
255 }),],
256 })
257 );
258
259 Ok(())
260 }
261}