1#![forbid(unsafe_code)]
2
3use base64::{
4 alphabet,
5 engine::{general_purpose, GeneralPurpose},
6 Engine,
7};
8use chrono::{DateTime, NaiveDateTime, Utc};
9use harsh::Harsh;
10use serde::{Deserialize, Serialize};
11#[cfg(feature = "pg")]
12use sqlx::{postgres::PgArguments, query::QueryAs, FromRow, Postgres};
13use std::{fmt::Debug, str::FromStr};
14
15use crate::error::QueryError;
16
17#[derive(Debug, PartialEq, Deserialize, Serialize, Clone)]
18pub struct CursorType(pub String);
19
20impl From<String> for CursorType {
21 fn from(val: String) -> Self {
22 CursorType(val)
23 }
24}
25
26impl AsRef<[u8]> for CursorType {
27 fn as_ref(&self) -> &[u8] {
28 self.0.as_bytes()
29 }
30}
31
32#[derive(Debug, Clone)]
33pub enum CursorOrder {
34 Asc,
35 Desc,
36}
37
38pub trait Cursor: Sized {
39 fn keys() -> Vec<&'static str>;
40 #[cfg(feature = "pg")]
41 fn bind<'q, O>(
42 self,
43 query: QueryAs<Postgres, O, PgArguments>,
44 ) -> QueryAs<Postgres, O, PgArguments>
45 where
46 O: for<'r> FromRow<'r, <sqlx::Postgres as sqlx::Database>::Row>,
47 O: 'q + std::marker::Send,
48 O: 'q + Unpin,
49 O: 'q + Cursor;
50 fn serialize(&self) -> Vec<String>;
51 fn deserialize(values: Vec<&str>) -> Result<Self, QueryError>;
52
53 fn serialize_utc(value: DateTime<Utc>) -> String {
54 Harsh::default().encode(&[value.timestamp_micros() as u64])
55 }
56
57 fn deserialize_as<F: Into<String>, D: FromStr>(
58 field: F,
59 value: Option<&&str>,
60 ) -> Result<D, QueryError> {
61 let field = field.into();
62 value
63 .ok_or(QueryError::MissingField(field.to_owned()))
64 .and_then(|v| {
65 v.to_string().parse::<D>().map_err(|_| {
66 QueryError::Unknown(
67 field,
68 v.to_string(),
69 "failed to deserialize_as_string".to_owned(),
70 )
71 })
72 })
73 }
74
75 fn deserialize_as_utc<F: Into<String>>(
76 field: F,
77 value: Option<&&str>,
78 ) -> Result<DateTime<Utc>, QueryError> {
79 let field = field.into();
80 value
81 .ok_or(QueryError::MissingField(field))
82 .and_then(|v| {
83 Harsh::default()
84 .decode(v)
85 .map(|v| v[0])
86 .map_err(QueryError::Harsh)
87 })
88 .and_then(|timestamp| {
89 NaiveDateTime::from_timestamp_micros(timestamp as i64).ok_or(QueryError::Unknown(
90 "field".to_owned(),
91 "NaiveDateTime::from_timestamp_opt".to_owned(),
92 "none".to_owned(),
93 ))
94 })
95 .map(|datetime| DateTime::from_naive_utc_and_offset(datetime, Utc))
96 }
97
98 fn to_cursor(&self) -> CursorType {
99 let data = self.serialize().join("|");
100 let engine = GeneralPurpose::new(&alphabet::URL_SAFE, general_purpose::PAD);
101
102 CursorType(engine.encode(data))
103 }
104
105 fn from_cursor(cursor: &CursorType) -> Result<Self, QueryError> {
106 let engine = GeneralPurpose::new(&alphabet::URL_SAFE, general_purpose::PAD);
107 let decoded = engine.decode(cursor)?;
108 let data = std::str::from_utf8(&decoded)?;
109
110 Self::deserialize(data.split('|').collect())
111 }
112
113 fn to_pg_filter_opts(
114 order: &CursorOrder,
115 backward: bool,
116 keys: Option<Vec<&str>>,
117 pos: Option<usize>,
118 ) -> String {
119 let pos = pos.unwrap_or(1);
120 let with_braket = keys.is_some();
121 let mut keys = keys.unwrap_or(Self::keys());
122 let key = keys.remove(0);
123
124 let sign = match (order, backward) {
125 (CursorOrder::Asc, true) | (CursorOrder::Desc, false) => "<",
126 (CursorOrder::Asc, false) | (CursorOrder::Desc, true) => ">",
127 };
128 let filter = format!("{key} {sign} ${pos}");
129
130 if keys.is_empty() {
131 return filter;
132 }
133
134 let filter = format!(
135 "{filter} OR ({key} = ${pos} AND {})",
136 Self::to_pg_filter_opts(order, backward, Some(keys), Some(pos + 1))
137 );
138
139 if with_braket {
140 format!("({filter})")
141 } else {
142 filter
143 }
144 }
145
146 fn to_pg_order(order: &CursorOrder, backward: bool) -> String {
147 let order = match (order, backward) {
148 (CursorOrder::Asc, true) | (CursorOrder::Desc, false) => "DESC",
149 (CursorOrder::Asc, false) | (CursorOrder::Desc, true) => "ASC",
150 };
151
152 Self::keys()
153 .iter()
154 .map(|key| format!("{key} {order}"))
155 .collect::<Vec<_>>()
156 .join(", ")
157 }
158}