Skip to main content

shaperail_runtime/db/
pagination.rs

1use serde::{Deserialize, Serialize};
2
3/// Request parameters for pagination.
4#[derive(Debug, Clone)]
5pub enum PageRequest {
6    /// Cursor-based: fetch `limit` rows after the given cursor (base64-encoded id).
7    Cursor { after: Option<String>, limit: i64 },
8    /// Offset-based: skip `offset` rows, fetch `limit` rows.
9    Offset { offset: i64, limit: i64 },
10}
11
12impl PageRequest {
13    /// Default page size.
14    pub const DEFAULT_LIMIT: i64 = 25;
15    /// Maximum allowed page size.
16    pub const MAX_LIMIT: i64 = 100;
17
18    /// Clamps the limit to the allowed range [1, MAX_LIMIT].
19    pub fn clamped_limit(limit: Option<i64>) -> i64 {
20        limit
21            .unwrap_or(Self::DEFAULT_LIMIT)
22            .clamp(1, Self::MAX_LIMIT)
23    }
24}
25
26/// Response metadata for cursor-based pagination.
27#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct CursorPage {
29    /// Opaque cursor for the next page (base64-encoded id of last row).
30    pub cursor: Option<String>,
31    /// Whether there are more rows after this page.
32    pub has_more: bool,
33}
34
35/// Response metadata for offset-based pagination.
36#[derive(Debug, Clone, Serialize, Deserialize)]
37pub struct OffsetPage {
38    /// Current offset.
39    pub offset: i64,
40    /// Page size used.
41    pub limit: i64,
42    /// Total number of matching rows.
43    pub total: i64,
44}
45
46impl PageRequest {
47    /// Appends cursor pagination clauses to the SQL string.
48    ///
49    /// For cursor pagination, adds `WHERE "id" > $N ORDER BY "id" ASC LIMIT N+1`
50    /// (fetches one extra row to determine `has_more`).
51    /// Returns the new parameter offset.
52    pub fn apply_cursor_to_sql(
53        &self,
54        sql: &mut String,
55        has_where: bool,
56        param_offset: usize,
57    ) -> usize {
58        match self {
59            PageRequest::Cursor { after, limit } => {
60                let mut offset = param_offset;
61                if after.is_some() {
62                    if has_where {
63                        sql.push_str(" AND ");
64                    } else {
65                        sql.push_str(" WHERE ");
66                    }
67                    sql.push_str(&format!("\"id\" > ${offset}"));
68                    offset += 1;
69                }
70                sql.push_str(" ORDER BY \"id\" ASC");
71                // Fetch limit+1 to detect has_more
72                sql.push_str(&format!(" LIMIT {}", limit + 1));
73                offset
74            }
75            PageRequest::Offset { offset: off, limit } => {
76                sql.push_str(&format!(" LIMIT {limit} OFFSET {off}"));
77                param_offset
78            }
79        }
80    }
81}
82
83/// Decodes a cursor string (base64-encoded UUID) to a UUID string.
84pub fn decode_cursor(cursor: &str) -> Result<String, shaperail_core::ShaperailError> {
85    use std::str;
86    // We use simple base64 encoding of the UUID string
87    let bytes = base64_decode(cursor).map_err(|_| {
88        shaperail_core::ShaperailError::Validation(vec![shaperail_core::FieldError {
89            field: "cursor".to_string(),
90            message: "Invalid cursor format".to_string(),
91            code: "invalid_cursor".to_string(),
92        }])
93    })?;
94    let id = str::from_utf8(&bytes).map_err(|_| {
95        shaperail_core::ShaperailError::Validation(vec![shaperail_core::FieldError {
96            field: "cursor".to_string(),
97            message: "Invalid cursor encoding".to_string(),
98            code: "invalid_cursor".to_string(),
99        }])
100    })?;
101    Ok(id.to_string())
102}
103
104/// Encodes a UUID string as a base64 cursor.
105pub fn encode_cursor(id: &str) -> String {
106    base64_encode(id.as_bytes())
107}
108
109// Simple base64 encode/decode (no external dep needed for this)
110fn base64_encode(data: &[u8]) -> String {
111    const CHARS: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
112    let mut result = String::with_capacity(data.len().div_ceil(3) * 4);
113    for chunk in data.chunks(3) {
114        let b0 = chunk[0] as u32;
115        let b1 = if chunk.len() > 1 { chunk[1] as u32 } else { 0 };
116        let b2 = if chunk.len() > 2 { chunk[2] as u32 } else { 0 };
117        let triple = (b0 << 16) | (b1 << 8) | b2;
118        result.push(CHARS[((triple >> 18) & 0x3F) as usize] as char);
119        result.push(CHARS[((triple >> 12) & 0x3F) as usize] as char);
120        if chunk.len() > 1 {
121            result.push(CHARS[((triple >> 6) & 0x3F) as usize] as char);
122        } else {
123            result.push('=');
124        }
125        if chunk.len() > 2 {
126            result.push(CHARS[(triple & 0x3F) as usize] as char);
127        } else {
128            result.push('=');
129        }
130    }
131    result
132}
133
134fn base64_decode(input: &str) -> Result<Vec<u8>, &'static str> {
135    fn char_to_val(c: u8) -> Result<u32, &'static str> {
136        match c {
137            b'A'..=b'Z' => Ok((c - b'A') as u32),
138            b'a'..=b'z' => Ok((c - b'a' + 26) as u32),
139            b'0'..=b'9' => Ok((c - b'0' + 52) as u32),
140            b'+' => Ok(62),
141            b'/' => Ok(63),
142            b'=' => Ok(0),
143            _ => Err("invalid base64 character"),
144        }
145    }
146
147    let bytes = input.as_bytes();
148    if !bytes.len().is_multiple_of(4) {
149        return Err("invalid base64 length");
150    }
151
152    let mut result = Vec::with_capacity(bytes.len() / 4 * 3);
153    for chunk in bytes.chunks(4) {
154        let a = char_to_val(chunk[0])?;
155        let b = char_to_val(chunk[1])?;
156        let c = char_to_val(chunk[2])?;
157        let d = char_to_val(chunk[3])?;
158        let triple = (a << 18) | (b << 12) | (c << 6) | d;
159        result.push((triple >> 16) as u8);
160        if chunk[2] != b'=' {
161            result.push((triple >> 8) as u8);
162        }
163        if chunk[3] != b'=' {
164            result.push(triple as u8);
165        }
166    }
167    Ok(result)
168}
169
170#[cfg(test)]
171mod tests {
172    use super::*;
173
174    #[test]
175    fn clamped_limit_default() {
176        assert_eq!(PageRequest::clamped_limit(None), 25);
177    }
178
179    #[test]
180    fn clamped_limit_within_range() {
181        assert_eq!(PageRequest::clamped_limit(Some(10)), 10);
182        assert_eq!(PageRequest::clamped_limit(Some(50)), 50);
183    }
184
185    #[test]
186    fn clamped_limit_too_high() {
187        assert_eq!(PageRequest::clamped_limit(Some(500)), 100);
188    }
189
190    #[test]
191    fn clamped_limit_too_low() {
192        assert_eq!(PageRequest::clamped_limit(Some(0)), 1);
193        assert_eq!(PageRequest::clamped_limit(Some(-5)), 1);
194    }
195
196    #[test]
197    fn cursor_encode_decode_roundtrip() {
198        let id = "550e8400-e29b-41d4-a716-446655440000";
199        let encoded = encode_cursor(id);
200        let decoded = decode_cursor(&encoded).unwrap();
201        assert_eq!(decoded, id);
202    }
203
204    #[test]
205    fn invalid_cursor_returns_error() {
206        let result = decode_cursor("!!invalid!!");
207        assert!(result.is_err());
208    }
209
210    #[test]
211    fn cursor_pagination_sql_no_cursor() {
212        let page = PageRequest::Cursor {
213            after: None,
214            limit: 25,
215        };
216        let mut sql = "SELECT * FROM users".to_string();
217        let offset = page.apply_cursor_to_sql(&mut sql, false, 1);
218
219        assert_eq!(sql, "SELECT * FROM users ORDER BY \"id\" ASC LIMIT 26");
220        assert_eq!(offset, 1);
221    }
222
223    #[test]
224    fn cursor_pagination_sql_with_cursor() {
225        let page = PageRequest::Cursor {
226            after: Some("some-uuid".to_string()),
227            limit: 10,
228        };
229        let mut sql = "SELECT * FROM users".to_string();
230        let offset = page.apply_cursor_to_sql(&mut sql, false, 1);
231
232        assert_eq!(
233            sql,
234            "SELECT * FROM users WHERE \"id\" > $1 ORDER BY \"id\" ASC LIMIT 11"
235        );
236        assert_eq!(offset, 2);
237    }
238
239    #[test]
240    fn cursor_pagination_with_existing_where() {
241        let page = PageRequest::Cursor {
242            after: Some("some-uuid".to_string()),
243            limit: 10,
244        };
245        let mut sql = "SELECT * FROM users WHERE \"role\" = $1".to_string();
246        let offset = page.apply_cursor_to_sql(&mut sql, true, 2);
247
248        assert!(sql.contains("AND \"id\" > $2"));
249        assert!(sql.contains("LIMIT 11"));
250        assert_eq!(offset, 3);
251    }
252
253    #[test]
254    fn offset_pagination_sql() {
255        let page = PageRequest::Offset {
256            offset: 20,
257            limit: 10,
258        };
259        let mut sql = "SELECT * FROM users".to_string();
260        let offset = page.apply_cursor_to_sql(&mut sql, false, 1);
261
262        assert_eq!(sql, "SELECT * FROM users LIMIT 10 OFFSET 20");
263        assert_eq!(offset, 1);
264    }
265}