1use std::borrow::Cow;
16
17use base64::Engine;
18use base64::engine::general_purpose::URL_SAFE_NO_PAD;
19use rmcp::schemars::{JsonSchema, Schema, SchemaGenerator, json_schema};
20use serde::de::{Error as DeError, Unexpected};
21use serde::{Deserialize, Deserializer, Serialize, Serializer};
22
23#[derive(Serialize, Deserialize)]
25struct Payload {
26 offset: u64,
27}
28
29#[derive(Debug, Clone, Copy, PartialEq, Eq)]
36pub struct Cursor {
37 pub offset: u64,
39}
40
41impl Serialize for Cursor {
42 fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
43 serializer.serialize_str(&encode_cursor(self.offset))
44 }
45}
46
47impl<'de> Deserialize<'de> for Cursor {
48 fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
49 let raw = <Cow<'de, str>>::deserialize(deserializer)?;
50 decode_cursor(&raw)
51 .map(|offset| Self { offset })
52 .map_err(|msg| D::Error::invalid_value(Unexpected::Str(&raw), &msg))
53 }
54}
55
56impl JsonSchema for Cursor {
57 fn inline_schema() -> bool {
58 true
59 }
60
61 fn schema_name() -> Cow<'static, str> {
62 "Cursor".into()
63 }
64
65 fn json_schema(_: &mut SchemaGenerator) -> Schema {
66 json_schema!({
67 "type": "string",
68 "description": "Opaque pagination cursor. Echo the `nextCursor` from a prior response; do not parse or modify."
69 })
70 }
71}
72
73#[derive(Debug, Clone, Copy)]
80pub struct Pager {
81 offset: u64,
82 size: u16,
83}
84
85impl Pager {
86 #[must_use]
88 pub fn new(cursor: Option<Cursor>, size: u16) -> Self {
89 Self {
90 offset: cursor.map_or(0, |c| c.offset),
91 size,
92 }
93 }
94
95 #[must_use]
101 pub fn offset(&self) -> i64 {
102 i64::try_from(self.offset).unwrap_or(i64::MAX)
103 }
104
105 #[must_use]
110 pub fn limit(&self) -> i64 {
111 i64::from(self.size) + 1
112 }
113
114 #[must_use]
120 pub fn paginate<T>(&self, mut items: Vec<T>) -> (Vec<T>, Option<Cursor>) {
121 let size = usize::from(self.size);
122 if items.len() > size {
123 items.truncate(size);
124 let offset = self.offset + u64::from(self.size);
125 (items, Some(Cursor { offset }))
126 } else {
127 (items, None)
128 }
129 }
130}
131
132fn encode_cursor(offset: u64) -> String {
134 let payload = Payload { offset };
135 let json = serde_json::to_vec(&payload).expect("Payload is infallible to serialize");
136 URL_SAFE_NO_PAD.encode(&json)
137}
138
139fn decode_cursor(raw: &str) -> Result<u64, &'static str> {
141 let bytes = URL_SAFE_NO_PAD
142 .decode(raw.as_bytes())
143 .map_err(|_| "invalid pagination cursor: not valid base64")?;
144 let payload: Payload =
145 serde_json::from_slice(&bytes).map_err(|_| "invalid pagination cursor: payload is malformed")?;
146 Ok(payload.offset)
147}
148
149#[cfg(test)]
150mod tests {
151 use base64::Engine;
152 use base64::engine::general_purpose::URL_SAFE_NO_PAD;
153 use serde_json::{Value, json};
154
155 use super::{Cursor, Pager, decode_cursor, encode_cursor};
156
157 #[test]
158 fn encode_decode_round_trips_representative_offsets() {
159 for offset in [0u64, 1, 99, 100, 101, 12_345, u64::MAX / 2, u64::MAX] {
160 let cursor = encode_cursor(offset);
161 let decoded = decode_cursor(&cursor).expect("valid cursor should decode");
162 assert_eq!(decoded, offset);
163 }
164 }
165
166 #[test]
167 fn encoded_cursor_is_url_safe_base64() {
168 let cursor = encode_cursor(100);
169 assert!(!cursor.contains('+'));
170 assert!(!cursor.contains('/'));
171 assert!(!cursor.contains('='));
172 }
173
174 #[test]
175 fn cursor_serializes_as_base64_string() {
176 let cursor = Cursor { offset: 100 };
177 let value = serde_json::to_value(cursor).unwrap();
178 let Value::String(s) = value else {
179 panic!("expected string, got {value:?}");
180 };
181 assert_eq!(decode_cursor(&s).unwrap(), 100);
182 }
183
184 #[test]
185 fn cursor_deserializes_from_valid_base64() {
186 let raw = encode_cursor(42);
187 let cursor: Cursor = serde_json::from_value(Value::String(raw)).unwrap();
188 assert_eq!(cursor.offset, 42);
189 }
190
191 #[test]
192 fn cursor_deserialization_round_trips_through_serde_json() {
193 let original = Cursor { offset: 7 };
194 let json = serde_json::to_string(&original).unwrap();
195 let back: Cursor = serde_json::from_str(&json).unwrap();
196 assert_eq!(back, original);
197 }
198
199 #[test]
200 fn cursor_deserialization_rejects_non_base64() {
201 let err = serde_json::from_value::<Cursor>(json!("!!!not-base64")).expect_err("should fail");
202 assert!(err.to_string().contains("base64"), "error: {err}");
203 }
204
205 #[test]
206 fn cursor_deserialization_rejects_base64_of_non_json() {
207 let raw = URL_SAFE_NO_PAD.encode(b"not json");
208 let err = serde_json::from_value::<Cursor>(json!(raw)).expect_err("should fail");
209 assert!(err.to_string().contains("malformed"), "error: {err}");
210 }
211
212 #[test]
213 fn cursor_deserialization_rejects_payload_missing_fields() {
214 let raw = URL_SAFE_NO_PAD.encode(b"{}");
215 let err = serde_json::from_value::<Cursor>(json!(raw)).expect_err("should fail");
216 assert!(err.to_string().contains("malformed"), "error: {err}");
217 }
218
219 #[test]
220 fn cursor_deserialization_rejects_negative_offset() {
221 let raw = URL_SAFE_NO_PAD.encode(b"{\"offset\":-1}");
222 let err = serde_json::from_value::<Cursor>(json!(raw)).expect_err("should fail");
223 assert!(err.to_string().contains("malformed"), "error: {err}");
224 }
225
226 #[test]
227 fn encoded_cursor_payload_uses_offset_key() {
228 let raw = serde_json::to_value(Cursor { offset: 100 }).unwrap();
229 let Value::String(s) = raw else {
230 panic!("expected string cursor, got {raw:?}");
231 };
232 let bytes = URL_SAFE_NO_PAD.decode(s.as_bytes()).unwrap();
233 let payload: Value = serde_json::from_slice(&bytes).unwrap();
234 let obj = payload.as_object().expect("payload should be a JSON object");
235 assert_eq!(
236 obj.get("offset").and_then(Value::as_u64),
237 Some(100),
238 "payload should carry offset under the `offset` key: {obj:?}"
239 );
240 }
241
242 #[test]
243 fn page_defaults_to_offset_zero_without_cursor() {
244 let pager = Pager::new(None, 50);
245 assert_eq!(pager.offset(), 0);
246 assert_eq!(pager.limit(), 51);
247 }
248
249 #[test]
250 fn page_inherits_offset_from_cursor() {
251 let pager = Pager::new(Some(Cursor { offset: 200 }), 50);
252 assert_eq!(pager.offset(), 200);
253 assert_eq!(pager.limit(), 51);
254 }
255
256 #[test]
257 fn page_paginate_emits_next_cursor_when_over_fetched() {
258 let pager = Pager::new(None, 3);
259 let (items, next) = pager.paginate(vec!["a", "b", "c", "d"]);
260 assert_eq!(items, ["a", "b", "c"]);
261 assert_eq!(next, Some(Cursor { offset: 3 }));
262 }
263
264 #[test]
265 fn page_paginate_drops_next_cursor_on_exact_fit() {
266 let pager = Pager::new(None, 3);
267 let (items, next) = pager.paginate(vec!["a", "b", "c"]);
268 assert_eq!(items, ["a", "b", "c"]);
269 assert!(next.is_none());
270 }
271
272 #[test]
273 fn page_paginate_drops_next_cursor_on_short_page() {
274 let pager = Pager::new(None, 3);
275 let (items, next) = pager.paginate(vec!["a"]);
276 assert_eq!(items, ["a"]);
277 assert!(next.is_none());
278 }
279
280 #[test]
281 fn page_paginate_drops_next_cursor_on_empty_result() {
282 let pager = Pager::new(Some(Cursor { offset: 99 }), 3);
283 let (items, next) = pager.paginate(Vec::<&str>::new());
284 assert!(items.is_empty());
285 assert!(next.is_none());
286 }
287
288 #[test]
289 fn page_paginate_advances_offset_by_page_size() {
290 let pager = Pager::new(Some(Cursor { offset: 100 }), 50);
291 let items: Vec<u32> = (0..51).collect();
292 let (_, next) = pager.paginate(items);
293 assert_eq!(next, Some(Cursor { offset: 150 }));
294 }
295}