1use base64::Engine;
2use chrono::{DateTime, TimeZone, Utc};
3use serde::{Deserialize, Serialize};
4use sqlx::{database::HasValueRef, Decode, Postgres, TypeInfo, ValueRef as _};
5use uuid::Uuid;
6
7use crate::var::Var;
8
9pub trait Iterable {
10 fn cursor(&self) -> Cursor;
11}
12
13impl<T> Iterable for Option<T>
14where
15 T: Default + Iterable,
16{
17 fn cursor(&self) -> Cursor {
18 match self {
19 Some(v) => v.cursor(),
20 None => T::default().cursor(),
21 }
22 }
23}
24
25impl Iterable for i32 {
26 fn cursor(&self) -> Cursor {
27 Cursor::I32
28 }
29}
30
31impl Iterable for i64 {
32 fn cursor(&self) -> Cursor {
33 Cursor::I64
34 }
35}
36
37impl Iterable for u32 {
38 fn cursor(&self) -> Cursor {
39 Cursor::I32
40 }
41}
42
43impl Iterable for u64 {
44 fn cursor(&self) -> Cursor {
45 Cursor::I64
46 }
47}
48
49impl Iterable for String {
50 fn cursor(&self) -> Cursor {
51 Cursor::String
52 }
53}
54
55impl Iterable for Uuid {
56 fn cursor(&self) -> Cursor {
57 Cursor::Uuid
58 }
59}
60
61impl Iterable for DateTime<Utc> {
62 fn cursor(&self) -> Cursor {
63 Cursor::DateTime
64 }
65}
66
67#[derive(Clone, Copy, Debug, Deserialize, Eq, Hash, PartialEq, Serialize)]
68pub enum Cursor {
69 I32,
70 I64,
71 String,
72 Uuid,
73 DateTime,
74}
75
76impl Cursor {
77 pub fn infer(column: <Postgres as HasValueRef<'_>>::ValueRef) -> sqlx::Result<String> {
78 Ok(match column.type_info().as_ref().name() {
79 "INT" | "INTEGER" => I32Cursor::encode(
80 &<i32 as Decode<'_, Postgres>>::decode(column).map_err(sqlx::Error::Decode)?,
81 ),
82 "BIGINT" | "BIT INTEGER" => I64Cursor::encode(
83 &<i64 as Decode<'_, Postgres>>::decode(column).map_err(sqlx::Error::Decode)?,
84 ),
85 "TEXT" | "VARCHAR" => StringCursor::encode(
86 &<String as Decode<'_, Postgres>>::decode(column).map_err(sqlx::Error::Decode)?,
87 ),
88 "UUID" => UuidCursor::encode(
89 &<Uuid as Decode<'_, Postgres>>::decode(column).map_err(sqlx::Error::Decode)?,
90 ),
91 "TIMESTAMP" | "TIMESTAMPTZ" => DateTimeCursor::encode(
92 &<DateTime<Utc> as Decode<'_, Postgres>>::decode(column)
93 .map_err(sqlx::Error::Decode)?,
94 ),
95 x => {
96 return Err(sqlx::Error::Decode(
97 format!("invalid cursor type during inference: {}", x).into(),
98 ))
99 }
100 })
101 }
102
103 pub fn decode(&self, encoded: &str) -> Var {
104 match self {
105 Self::I32 => Var::I32(I32Cursor::decode(encoded)),
106 Self::I64 => Var::I64(I64Cursor::decode(encoded)),
107 Self::String => Var::String(StringCursor::decode(encoded)),
108 Self::Uuid => Var::Uuid(UuidCursor::decode(encoded)),
109 Self::DateTime => Var::DateTime(DateTimeCursor::decode(encoded)),
110 }
111 }
112
113 pub fn encode(literal: &Var) -> String {
114 match literal {
115 Var::Bool(_) => panic!("invalid cursor type: bool"),
116 Var::I32(v) => I32Cursor::encode(v),
117 Var::I64(v) => I64Cursor::encode(v),
118 Var::String(v) => StringCursor::encode(v),
119 Var::Uuid(v) => UuidCursor::encode(v),
120 Var::DateTime(v) => DateTimeCursor::encode(v),
121 }
122 }
123
124 pub fn min(self) -> Var {
125 match self {
126 Self::I32 => Var::I32(I32Cursor::min()),
127 Self::I64 => Var::I64(I64Cursor::min()),
128 Self::String => Var::String(StringCursor::min()),
129 Self::Uuid => Var::Uuid(UuidCursor::min()),
130 Self::DateTime => Var::DateTime(DateTimeCursor::min()),
131 }
132 }
133
134 pub fn max(self) -> Var {
135 match self {
136 Self::I32 => Var::I32(I32Cursor::max()),
137 Self::I64 => Var::I64(I64Cursor::max()),
138 Self::String => Var::String(StringCursor::max()),
139 Self::Uuid => Var::Uuid(UuidCursor::max()),
140 Self::DateTime => Var::DateTime(DateTimeCursor::max()),
141 }
142 }
143}
144
145impl From<I32Cursor> for Cursor {
146 fn from(_cursor: I32Cursor) -> Self {
147 Self::I32
148 }
149}
150
151impl From<I64Cursor> for Cursor {
152 fn from(_cursor: I64Cursor) -> Self {
153 Self::I64
154 }
155}
156
157impl From<StringCursor> for Cursor {
158 fn from(_cursor: StringCursor) -> Self {
159 Self::String
160 }
161}
162
163impl From<UuidCursor> for Cursor {
164 fn from(_cursor: UuidCursor) -> Self {
165 Self::Uuid
166 }
167}
168
169impl From<DateTimeCursor> for Cursor {
170 fn from(_cursor: DateTimeCursor) -> Self {
171 Self::DateTime
172 }
173}
174
175#[derive(Clone, Copy, Debug, Default, Deserialize, Eq, Hash, PartialEq, Serialize)]
176pub struct I32Cursor;
177
178impl I32Cursor {
179 pub fn new() -> Self {
180 Self
181 }
182
183 pub fn decode(encoded: &str) -> i32 {
184 base64::engine::general_purpose::STANDARD
185 .decode(encoded)
186 .ok()
187 .and_then(|buf| buf.as_slice().try_into().ok())
188 .map(i32::from_be_bytes)
189 .unwrap_or_else(|| {
190 tracing::warn!("invalid i32 cursor '{}'", encoded);
191 Self::min()
192 })
193 }
194
195 pub fn encode(decoded: &i32) -> String {
196 base64::engine::general_purpose::STANDARD.encode(decoded.to_be_bytes())
197 }
198
199 pub fn min() -> i32 {
200 i32::MIN
201 }
202
203 pub fn max() -> i32 {
204 i32::MAX
205 }
206}
207
208#[derive(Clone, Copy, Debug, Default, Deserialize, Eq, Hash, PartialEq, Serialize)]
209pub struct I64Cursor;
210
211impl I64Cursor {
212 pub fn new() -> Self {
213 Self
214 }
215
216 pub fn decode(encoded: &str) -> i64 {
217 base64::engine::general_purpose::STANDARD
218 .decode(encoded)
219 .ok()
220 .and_then(|buf| buf.as_slice().try_into().ok())
221 .map(i64::from_be_bytes)
222 .unwrap_or_else(|| {
223 tracing::warn!("invalid i64 cursor '{}'", encoded);
224 Self::min()
225 })
226 }
227
228 pub fn encode(decoded: &i64) -> String {
229 base64::engine::general_purpose::STANDARD.encode(decoded.to_be_bytes())
230 }
231
232 pub fn min() -> i64 {
233 i64::MIN
234 }
235
236 pub fn max() -> i64 {
237 i64::MAX
238 }
239}
240
241#[derive(Clone, Copy, Debug, Default, Deserialize, Eq, Hash, PartialEq, Serialize)]
242pub struct StringCursor;
243
244impl StringCursor {
245 pub fn new() -> Self {
246 Self
247 }
248
249 pub fn decode(encoded: &str) -> String {
250 base64::engine::general_purpose::STANDARD
251 .decode(encoded)
252 .ok()
253 .and_then(|buf| String::from_utf8(buf.as_slice().to_vec()).ok())
254 .unwrap_or_else(|| {
255 tracing::warn!("invalid string cursor '{}'", encoded);
256 Self::min()
257 })
258 }
259
260 pub fn encode(decoded: &String) -> String {
261 base64::engine::general_purpose::STANDARD.encode(decoded.as_bytes())
262 }
263
264 pub fn min() -> String {
265 "".to_string()
266 }
267
268 #[allow(dead_code)]
269 pub fn max() -> String {
270 "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~".to_string()
271 }
272}
273
274#[derive(Clone, Copy, Debug, Default, Deserialize, Eq, Hash, PartialEq, Serialize)]
275pub struct UuidCursor;
276
277impl UuidCursor {
278 pub fn new() -> Self {
279 Self
280 }
281
282 pub fn decode(encoded: &str) -> Uuid {
283 base64::engine::general_purpose::STANDARD
284 .decode(encoded)
285 .ok()
286 .and_then(|buf| buf.as_slice().try_into().ok())
287 .map(Uuid::from_bytes)
288 .unwrap_or_else(|| {
289 tracing::warn!("invalid uuid cursor '{}'", encoded);
290 Self::min()
291 })
292 }
293
294 pub fn encode(decoded: &Uuid) -> String {
295 base64::engine::general_purpose::STANDARD.encode(decoded.as_bytes())
296 }
297
298 pub fn min() -> Uuid {
299 Uuid::from_bytes([0; 16])
300 }
301
302 pub fn max() -> Uuid {
303 Uuid::from_bytes([255; 16])
304 }
305}
306
307#[derive(Clone, Copy, Debug, Default, Deserialize, Eq, Hash, PartialEq, Serialize)]
308pub struct DateTimeCursor;
309
310impl DateTimeCursor {
311 pub fn new() -> Self {
312 Self
313 }
314
315 pub fn decode(encoded: &str) -> DateTime<Utc> {
316 base64::engine::general_purpose::STANDARD
317 .decode(encoded)
318 .ok()
319 .and_then(|buf| buf.as_slice().try_into().ok())
320 .map(|buf| Utc.timestamp_nanos(i64::from_be_bytes(buf)))
321 .unwrap_or_else(|| {
322 tracing::warn!("invalid datetime cursor '{}'", encoded);
323 Self::min()
324 })
325 }
326
327 pub fn encode(decoded: &DateTime<Utc>) -> String {
328 base64::engine::general_purpose::STANDARD.encode(
329 decoded
330 .timestamp_nanos_opt()
331 .expect("timestamp must be valid")
332 .to_be_bytes(),
333 )
334 }
335
336 pub fn min() -> DateTime<Utc> {
337 Utc.timestamp_nanos(i64::MIN)
338 }
339
340 pub fn max() -> DateTime<Utc> {
341 Utc.timestamp_nanos(i64::MAX)
342 }
343}