1use std::borrow::Cow;
16
17use base64::Engine;
18use base64::engine::general_purpose::URL_SAFE_NO_PAD;
19use rmcp::schemars::{JsonSchema, Schema, SchemaGenerator, json_schema};
20use serde::de::{Error as DeError, Unexpected};
21use serde::{Deserialize, Deserializer, Serialize, Serializer};
22
23#[derive(Serialize, Deserialize)]
25struct Payload {
26 offset: u64,
27}
28
29#[derive(Debug, Clone, Copy, PartialEq, Eq)]
36pub struct Cursor {
37 pub offset: u64,
39}
40
41impl Serialize for Cursor {
42 fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
43 serializer.serialize_str(&encode_cursor(self.offset))
44 }
45}
46
47impl<'de> Deserialize<'de> for Cursor {
48 fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
49 let raw = <Cow<'de, str>>::deserialize(deserializer)?;
50 decode_cursor(&raw)
51 .map(|offset| Self { offset })
52 .map_err(|msg| D::Error::invalid_value(Unexpected::Str(&raw), &msg))
53 }
54}
55
56impl JsonSchema for Cursor {
57 fn inline_schema() -> bool {
58 true
59 }
60
61 fn schema_name() -> Cow<'static, str> {
62 "Cursor".into()
63 }
64
65 fn json_schema(_: &mut SchemaGenerator) -> Schema {
66 json_schema!({
67 "type": "string",
68 "description": "Opaque pagination cursor. Echo the `nextCursor` from a prior response; do not parse or modify."
69 })
70 }
71}
72
73#[derive(Debug, Clone, Copy)]
80pub struct Pager {
81 offset: u64,
82 size: usize,
83}
84
85impl Pager {
86 #[must_use]
88 pub fn new(cursor: Option<Cursor>, size: u16) -> Self {
89 Self {
90 offset: cursor.map_or(0, |c| c.offset),
91 size: usize::from(size),
92 }
93 }
94
95 #[must_use]
97 pub fn offset(&self) -> u64 {
98 self.offset
99 }
100
101 #[must_use]
103 pub fn limit(&self) -> usize {
104 self.size + 1
105 }
106
107 #[must_use]
113 pub fn finalize<T>(&self, mut items: Vec<T>) -> (Vec<T>, Option<Cursor>) {
114 if items.len() > self.size {
115 items.truncate(self.size);
116 let offset = self.offset + self.size as u64;
117 (items, Some(Cursor { offset }))
118 } else {
119 (items, None)
120 }
121 }
122}
123
124fn encode_cursor(offset: u64) -> String {
126 let payload = Payload { offset };
127 let json = serde_json::to_vec(&payload).expect("Payload is infallible to serialize");
128 URL_SAFE_NO_PAD.encode(&json)
129}
130
131fn decode_cursor(raw: &str) -> Result<u64, &'static str> {
133 let bytes = URL_SAFE_NO_PAD
134 .decode(raw.as_bytes())
135 .map_err(|_| "invalid pagination cursor: not valid base64")?;
136 let payload: Payload =
137 serde_json::from_slice(&bytes).map_err(|_| "invalid pagination cursor: payload is malformed")?;
138 Ok(payload.offset)
139}
140
141#[cfg(test)]
142mod tests {
143 use base64::Engine;
144 use base64::engine::general_purpose::URL_SAFE_NO_PAD;
145 use serde_json::{Value, json};
146
147 use super::{Cursor, Pager, decode_cursor, encode_cursor};
148
149 #[test]
150 fn encode_decode_round_trips_representative_offsets() {
151 for offset in [0u64, 1, 99, 100, 101, 12_345, u64::MAX / 2, u64::MAX] {
152 let cursor = encode_cursor(offset);
153 let decoded = decode_cursor(&cursor).expect("valid cursor should decode");
154 assert_eq!(decoded, offset);
155 }
156 }
157
158 #[test]
159 fn encoded_cursor_is_url_safe_base64() {
160 let cursor = encode_cursor(100);
161 assert!(!cursor.contains('+'));
162 assert!(!cursor.contains('/'));
163 assert!(!cursor.contains('='));
164 }
165
166 #[test]
167 fn cursor_serializes_as_base64_string() {
168 let cursor = Cursor { offset: 100 };
169 let value = serde_json::to_value(cursor).unwrap();
170 let Value::String(s) = value else {
171 panic!("expected string, got {value:?}");
172 };
173 assert_eq!(decode_cursor(&s).unwrap(), 100);
174 }
175
176 #[test]
177 fn cursor_deserializes_from_valid_base64() {
178 let raw = encode_cursor(42);
179 let cursor: Cursor = serde_json::from_value(Value::String(raw)).unwrap();
180 assert_eq!(cursor.offset, 42);
181 }
182
183 #[test]
184 fn cursor_deserialization_round_trips_through_serde_json() {
185 let original = Cursor { offset: 7 };
186 let json = serde_json::to_string(&original).unwrap();
187 let back: Cursor = serde_json::from_str(&json).unwrap();
188 assert_eq!(back, original);
189 }
190
191 #[test]
192 fn cursor_deserialization_rejects_non_base64() {
193 let err = serde_json::from_value::<Cursor>(json!("!!!not-base64")).expect_err("should fail");
194 assert!(err.to_string().contains("base64"), "error: {err}");
195 }
196
197 #[test]
198 fn cursor_deserialization_rejects_base64_of_non_json() {
199 let raw = URL_SAFE_NO_PAD.encode(b"not json");
200 let err = serde_json::from_value::<Cursor>(json!(raw)).expect_err("should fail");
201 assert!(err.to_string().contains("malformed"), "error: {err}");
202 }
203
204 #[test]
205 fn cursor_deserialization_rejects_payload_missing_fields() {
206 let raw = URL_SAFE_NO_PAD.encode(b"{}");
207 let err = serde_json::from_value::<Cursor>(json!(raw)).expect_err("should fail");
208 assert!(err.to_string().contains("malformed"), "error: {err}");
209 }
210
211 #[test]
212 fn cursor_deserialization_rejects_negative_offset() {
213 let raw = URL_SAFE_NO_PAD.encode(b"{\"offset\":-1}");
214 let err = serde_json::from_value::<Cursor>(json!(raw)).expect_err("should fail");
215 assert!(err.to_string().contains("malformed"), "error: {err}");
216 }
217
218 #[test]
219 fn encoded_cursor_payload_uses_offset_key() {
220 let raw = serde_json::to_value(Cursor { offset: 100 }).unwrap();
221 let Value::String(s) = raw else {
222 panic!("expected string cursor, got {raw:?}");
223 };
224 let bytes = URL_SAFE_NO_PAD.decode(s.as_bytes()).unwrap();
225 let payload: Value = serde_json::from_slice(&bytes).unwrap();
226 let obj = payload.as_object().expect("payload should be a JSON object");
227 assert_eq!(
228 obj.get("offset").and_then(Value::as_u64),
229 Some(100),
230 "payload should carry offset under the `offset` key: {obj:?}"
231 );
232 }
233
234 #[test]
235 fn page_defaults_to_offset_zero_without_cursor() {
236 let pager = Pager::new(None, 50);
237 assert_eq!(pager.offset(), 0);
238 assert_eq!(pager.limit(), 51);
239 }
240
241 #[test]
242 fn page_inherits_offset_from_cursor() {
243 let pager = Pager::new(Some(Cursor { offset: 200 }), 50);
244 assert_eq!(pager.offset(), 200);
245 assert_eq!(pager.limit(), 51);
246 }
247
248 #[test]
249 fn page_finalize_emits_next_cursor_when_over_fetched() {
250 let pager = Pager::new(None, 3);
251 let (items, next) = pager.finalize(vec!["a", "b", "c", "d"]);
252 assert_eq!(items, ["a", "b", "c"]);
253 assert_eq!(next, Some(Cursor { offset: 3 }));
254 }
255
256 #[test]
257 fn page_finalize_drops_next_cursor_on_exact_fit() {
258 let pager = Pager::new(None, 3);
259 let (items, next) = pager.finalize(vec!["a", "b", "c"]);
260 assert_eq!(items, ["a", "b", "c"]);
261 assert!(next.is_none());
262 }
263
264 #[test]
265 fn page_finalize_drops_next_cursor_on_short_page() {
266 let pager = Pager::new(None, 3);
267 let (items, next) = pager.finalize(vec!["a"]);
268 assert_eq!(items, ["a"]);
269 assert!(next.is_none());
270 }
271
272 #[test]
273 fn page_finalize_drops_next_cursor_on_empty_result() {
274 let pager = Pager::new(Some(Cursor { offset: 99 }), 3);
275 let (items, next) = pager.finalize(Vec::<&str>::new());
276 assert!(items.is_empty());
277 assert!(next.is_none());
278 }
279
280 #[test]
281 fn page_finalize_advances_offset_by_page_size() {
282 let pager = Pager::new(Some(Cursor { offset: 100 }), 50);
283 let items: Vec<u32> = (0..51).collect();
284 let (_, next) = pager.finalize(items);
285 assert_eq!(next, Some(Cursor { offset: 150 }));
286 }
287}