Skip to main content

code_analyze_mcp/
pagination.rs

1use base64::engine::general_purpose::STANDARD;
2use base64::{DecodeError, engine::Engine};
3use serde::{Deserialize, Serialize};
4use thiserror::Error;
5
6pub const DEFAULT_PAGE_SIZE: usize = 100;
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
9#[serde(rename_all = "lowercase")]
10pub enum PaginationMode {
11    Default,
12    Callers,
13    Callees,
14}
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct CursorData {
18    pub mode: PaginationMode,
19    pub offset: usize,
20}
21
22#[derive(Debug, Error)]
23pub enum PaginationError {
24    #[error("Invalid cursor: {0}")]
25    InvalidCursor(String),
26}
27
28impl From<DecodeError> for PaginationError {
29    fn from(err: DecodeError) -> Self {
30        PaginationError::InvalidCursor(format!("Base64 decode error: {}", err))
31    }
32}
33
34impl From<serde_json::Error> for PaginationError {
35    fn from(err: serde_json::Error) -> Self {
36        PaginationError::InvalidCursor(format!("JSON parse error: {}", err))
37    }
38}
39
40pub fn encode_cursor(data: &CursorData) -> Result<String, PaginationError> {
41    let json = serde_json::to_string(data)?;
42    Ok(STANDARD.encode(json))
43}
44
45pub fn decode_cursor(cursor: &str) -> Result<CursorData, PaginationError> {
46    let decoded = STANDARD.decode(cursor)?;
47    let json_str = String::from_utf8(decoded)
48        .map_err(|e| PaginationError::InvalidCursor(format!("UTF-8 decode error: {}", e)))?;
49    Ok(serde_json::from_str(&json_str)?)
50}
51
52#[derive(Debug, Clone)]
53pub struct PaginationResult<T> {
54    pub items: Vec<T>,
55    pub next_cursor: Option<String>,
56    pub total: usize,
57}
58
59pub fn paginate_slice<T: Clone>(
60    items: &[T],
61    offset: usize,
62    page_size: usize,
63    mode: PaginationMode,
64) -> Result<PaginationResult<T>, PaginationError> {
65    let total = items.len();
66
67    if offset >= total {
68        return Ok(PaginationResult {
69            items: vec![],
70            next_cursor: None,
71            total,
72        });
73    }
74
75    let end = std::cmp::min(offset + page_size, total);
76    let page_items = items[offset..end].to_vec();
77
78    let next_cursor = if end < total {
79        let cursor_data = CursorData { mode, offset: end };
80        Some(encode_cursor(&cursor_data)?)
81    } else {
82        None
83    };
84
85    Ok(PaginationResult {
86        items: page_items,
87        next_cursor,
88        total,
89    })
90}
91
92#[cfg(test)]
93mod tests {
94    use super::*;
95
96    #[test]
97    fn test_cursor_encode_decode_roundtrip() {
98        let original = CursorData {
99            mode: PaginationMode::Default,
100            offset: 42,
101        };
102
103        let encoded = encode_cursor(&original).expect("encode failed");
104        let decoded = decode_cursor(&encoded).expect("decode failed");
105
106        assert_eq!(decoded.mode, original.mode);
107        assert_eq!(decoded.offset, original.offset);
108    }
109
110    #[test]
111    fn test_pagination_mode_wire_format() {
112        let cursor_data = CursorData {
113            mode: PaginationMode::Callers,
114            offset: 0,
115        };
116
117        let encoded = encode_cursor(&cursor_data).expect("encode failed");
118        let decoded = decode_cursor(&encoded).expect("decode failed");
119
120        assert_eq!(decoded.mode, PaginationMode::Callers);
121
122        let json_str = serde_json::to_string(&cursor_data).expect("serialize failed");
123        assert!(
124            json_str.contains("\"mode\":\"callers\""),
125            "expected lowercase 'callers' in JSON, got: {}",
126            json_str
127        );
128    }
129
130    #[test]
131    fn test_paginate_slice_middle_page() {
132        let items: Vec<i32> = (0..250).collect();
133        let result =
134            paginate_slice(&items, 100, 100, PaginationMode::Default).expect("paginate failed");
135
136        assert_eq!(result.items.len(), 100);
137        assert_eq!(result.items[0], 100);
138        assert_eq!(result.items[99], 199);
139        assert!(result.next_cursor.is_some());
140        assert_eq!(result.total, 250);
141    }
142
143    #[test]
144    fn test_paginate_slice_empty_and_beyond() {
145        let empty: Vec<i32> = vec![];
146        let result =
147            paginate_slice(&empty, 0, 100, PaginationMode::Default).expect("paginate failed");
148        assert_eq!(result.items.len(), 0);
149        assert!(result.next_cursor.is_none());
150        assert_eq!(result.total, 0);
151
152        let items: Vec<i32> = (0..50).collect();
153        let result =
154            paginate_slice(&items, 100, 100, PaginationMode::Default).expect("paginate failed");
155        assert_eq!(result.items.len(), 0);
156        assert!(result.next_cursor.is_none());
157        assert_eq!(result.total, 50);
158    }
159
160    #[test]
161    fn test_paginate_slice_exact_boundary() {
162        let items: Vec<i32> = (0..200).collect();
163        let result =
164            paginate_slice(&items, 100, 100, PaginationMode::Default).expect("paginate failed");
165
166        assert_eq!(result.items.len(), 100);
167        assert_eq!(result.items[0], 100);
168        assert!(result.next_cursor.is_none());
169        assert_eq!(result.total, 200);
170    }
171
172    #[test]
173    fn test_invalid_cursor_error() {
174        let result = decode_cursor("not-valid-base64!!!");
175        assert!(result.is_err());
176        match result {
177            Err(PaginationError::InvalidCursor(msg)) => {
178                assert!(msg.contains("Base64") || msg.contains("decode"));
179            }
180            _ => panic!("Expected InvalidCursor error"),
181        }
182    }
183}