klauthed_data/pagination/
cursor.rs1use std::fmt;
5use std::str::FromStr;
6
7use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
8use serde::{Deserialize, Serialize, de::DeserializeOwned};
9
10use crate::error::DataError;
11
12use super::{DEFAULT_PAGE_SIZE, MAX_PAGE_SIZE, SortKey};
13
14#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
19#[serde(transparent)]
20pub struct Cursor(String);
21
22impl Cursor {
23 pub fn encode<T: Serialize>(value: &T) -> Result<Cursor, DataError> {
25 let json =
26 serde_json::to_string(value).map_err(|e| DataError::InvalidCursor(e.to_string()))?;
27 Ok(Cursor(URL_SAFE_NO_PAD.encode(json.as_bytes())))
28 }
29
30 pub fn decode<T: DeserializeOwned>(&self) -> Result<T, DataError> {
32 let bytes = URL_SAFE_NO_PAD
33 .decode(self.0.as_bytes())
34 .map_err(|e| DataError::InvalidCursor(format!("base64 decode: {e}")))?;
35 serde_json::from_slice(&bytes)
36 .map_err(|e| DataError::InvalidCursor(format!("json decode: {e}")))
37 }
38
39 pub fn as_str(&self) -> &str {
41 &self.0
42 }
43}
44
45impl fmt::Display for Cursor {
46 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
47 f.write_str(&self.0)
48 }
49}
50
51impl FromStr for Cursor {
52 type Err = std::convert::Infallible;
53
54 fn from_str(s: &str) -> Result<Self, Self::Err> {
55 Ok(Cursor(s.to_owned()))
56 }
57}
58
59#[derive(Debug, Clone, Serialize, Deserialize)]
61pub struct CursorPageRequest {
62 pub after: Option<Cursor>,
64 pub before: Option<Cursor>,
66 pub limit: u32,
68 pub sort: Vec<SortKey>,
70}
71
72impl CursorPageRequest {
73 pub fn new(limit: u32) -> Result<Self, DataError> {
76 if limit < 1 {
77 return Err(DataError::InvalidPage("cursor limit must be >= 1".into()));
78 }
79 let limit = limit.min(MAX_PAGE_SIZE);
80 Ok(CursorPageRequest { after: None, before: None, limit, sort: Vec::new() })
81 }
82
83 #[must_use]
85 pub fn after(mut self, cursor: Cursor) -> Self {
86 self.after = Some(cursor);
87 self
88 }
89
90 #[must_use]
92 pub fn before(mut self, cursor: Cursor) -> Self {
93 self.before = Some(cursor);
94 self
95 }
96
97 #[must_use]
99 pub fn sort(mut self, keys: Vec<SortKey>) -> Self {
100 self.sort = keys;
101 self
102 }
103}
104
105impl Default for CursorPageRequest {
106 fn default() -> Self {
107 CursorPageRequest { after: None, before: None, limit: DEFAULT_PAGE_SIZE, sort: Vec::new() }
108 }
109}
110
111#[derive(Debug, Clone, Serialize, Deserialize)]
113pub struct CursorPage<T> {
114 pub items: Vec<T>,
116 pub start_cursor: Option<Cursor>,
118 pub end_cursor: Option<Cursor>,
120 pub has_next_page: bool,
122 pub has_prev_page: bool,
124}
125
126impl<T> CursorPage<T> {
127 pub fn from_items<C, F>(
132 items: Vec<T>,
133 encode: F,
134 has_prev: bool,
135 has_next: bool,
136 ) -> Result<Self, DataError>
137 where
138 F: Fn(&T) -> C,
139 C: Serialize,
140 {
141 let start_cursor = items.first().map(|item| Cursor::encode(&encode(item))).transpose()?;
142 let end_cursor = items.last().map(|item| Cursor::encode(&encode(item))).transpose()?;
143 Ok(CursorPage {
144 items,
145 start_cursor,
146 end_cursor,
147 has_next_page: has_next,
148 has_prev_page: has_prev,
149 })
150 }
151
152 pub fn map<U, F: FnMut(T) -> U>(self, f: F) -> CursorPage<U> {
154 CursorPage {
155 items: self.items.into_iter().map(f).collect(),
156 start_cursor: self.start_cursor,
157 end_cursor: self.end_cursor,
158 has_next_page: self.has_next_page,
159 has_prev_page: self.has_prev_page,
160 }
161 }
162}
163
164#[cfg(test)]
165mod tests {
166 use super::*;
167
168 #[test]
169 fn cursor_encode_decode_round_trip() {
170 #[derive(Debug, PartialEq, Serialize, Deserialize)]
171 struct Pos {
172 id: u64,
173 ts: i64,
174 }
175
176 let pos = Pos { id: 42, ts: 1_700_000_000 };
177 let cursor = Cursor::encode(&pos).unwrap();
178 let decoded: Pos = cursor.decode().unwrap();
179 assert_eq!(decoded, pos);
180 }
181
182 #[test]
183 fn cursor_as_str_and_display() {
184 let cursor = Cursor::encode(&42u32).unwrap();
185 assert_eq!(cursor.as_str(), cursor.to_string());
186 }
187
188 #[test]
189 fn cursor_from_str() {
190 let c: Cursor = "abc123".parse().unwrap();
191 assert_eq!(c.as_str(), "abc123");
192 }
193
194 #[test]
195 fn cursor_decode_garbage_returns_invalid_cursor() {
196 let bad: Cursor = "!!!!not-valid-base64!!!!".parse().unwrap();
197 let err = bad.decode::<u32>().unwrap_err();
198 match err {
199 DataError::InvalidCursor(_) => {}
200 other => panic!("expected InvalidCursor, got {other:?}"),
201 }
202 }
203
204 #[test]
205 fn cursor_decode_valid_base64_but_bad_json_returns_invalid_cursor() {
206 let c: Cursor = URL_SAFE_NO_PAD.encode(b"not-json").parse().unwrap();
207 let err = c.decode::<u32>().unwrap_err();
208 match err {
209 DataError::InvalidCursor(_) => {}
210 other => panic!("expected InvalidCursor, got {other:?}"),
211 }
212 }
213
214 #[test]
217 fn cursor_page_from_items_sets_cursors_and_flags() {
218 let items = vec![10u32, 20, 30];
219 let page = CursorPage::from_items(items, |x| *x, false, true).unwrap();
220 assert_eq!(page.items, vec![10, 20, 30]);
221 assert!(page.start_cursor.is_some());
222 assert!(page.end_cursor.is_some());
223 assert!(page.has_next_page);
224 assert!(!page.has_prev_page);
225
226 let start: u32 = page.start_cursor.unwrap().decode().unwrap();
227 let end: u32 = page.end_cursor.unwrap().decode().unwrap();
228 assert_eq!(start, 10);
229 assert_eq!(end, 30);
230 }
231
232 #[test]
233 fn cursor_page_empty_has_no_cursors() {
234 let page: CursorPage<u32> = CursorPage::from_items(vec![], |x| *x, false, false).unwrap();
235 assert!(page.start_cursor.is_none());
236 assert!(page.end_cursor.is_none());
237 }
238
239 #[test]
240 fn cursor_page_map() {
241 let items = vec![1u32, 2, 3];
242 let page = CursorPage::from_items(items, |x| *x, true, false).unwrap();
243 let mapped = page.map(|x| x.to_string());
244 assert_eq!(mapped.items, vec!["1", "2", "3"]);
245 assert!(mapped.has_prev_page);
246 assert!(!mapped.has_next_page);
247 }
248}
249
250#[cfg(test)]
251mod proptests {
252 use super::*;
253 use proptest::prelude::*;
254
255 proptest! {
256 #[test]
258 fn encode_decode_round_trips(id in any::<u64>(), ts in any::<i64>(), name in "[ -~]{0,32}") {
259 let value = (id, ts, name);
260 let cursor = Cursor::encode(&value).unwrap();
261 let decoded: (u64, i64, String) = cursor.decode().unwrap();
262 prop_assert_eq!(decoded, value);
263 }
264
265 #[test]
267 fn string_forms_agree(id in any::<u64>()) {
268 let cursor = Cursor::encode(&id).unwrap();
269 let shown = cursor.to_string();
270 prop_assert_eq!(cursor.as_str(), shown.as_str());
271 let reparsed: Cursor = cursor.as_str().parse().unwrap(); prop_assert_eq!(reparsed.as_str(), cursor.as_str());
273 }
274
275 #[test]
277 fn decode_arbitrary_text_never_panics(s in ".*") {
278 let cursor: Cursor = s.parse().unwrap(); let _ = cursor.decode::<(u64, i64, String)>();
280 }
281 }
282}