use serde::{Deserialize, Serialize};
use super::CursorEncoder;
use crate::exception::Result;
use std::sync::Arc;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct Edge<T> {
pub node: T,
pub cursor: String,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct PageInfo {
pub has_next_page: bool,
pub has_previous_page: bool,
pub start_cursor: Option<String>,
pub end_cursor: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct Connection<T> {
pub edges: Vec<Edge<T>>,
pub page_info: PageInfo,
pub total_count: Option<usize>,
}
#[derive(Clone)]
pub struct RelayPagination {
pub default_page_size: usize,
pub max_page_size: Option<usize>,
pub include_total_count: bool,
encoder: Arc<dyn CursorEncoder>,
}
impl RelayPagination {
pub fn new() -> Self {
Self {
default_page_size: 10,
max_page_size: Some(100),
include_total_count: true,
encoder: Arc::new(super::Base64CursorEncoder::new()),
}
}
pub fn default_page_size(mut self, size: usize) -> Self {
self.default_page_size = size;
self
}
pub fn max_page_size(mut self, size: usize) -> Self {
self.max_page_size = Some(size);
self
}
pub fn include_total_count(mut self, include: bool) -> Self {
self.include_total_count = include;
self
}
pub fn with_encoder<E: CursorEncoder + 'static>(mut self, encoder: E) -> Self {
self.encoder = Arc::new(encoder);
self
}
pub fn paginate<T: Clone + Send + Sync>(
&self,
items: &[T],
first: Option<usize>,
after: Option<&str>,
last: Option<usize>,
before: Option<&str>,
) -> Result<Connection<T>> {
let total_count = items.len();
let (page_size, is_forward) = if let Some(f) = first {
let size = if let Some(max) = self.max_page_size {
std::cmp::min(f, max)
} else {
f
};
(size, true)
} else if let Some(l) = last {
let size = if let Some(max) = self.max_page_size {
std::cmp::min(l, max)
} else {
l
};
(size, false)
} else {
(self.default_page_size, true)
};
let start = if let Some(after_cursor) = after {
self.encoder.decode(after_cursor)? + 1
} else if let Some(before_cursor) = before {
let before_pos = self.encoder.decode(before_cursor)?;
before_pos.saturating_sub(page_size)
} else if is_forward {
0
} else {
total_count.saturating_sub(page_size)
};
let start = std::cmp::min(start, total_count);
let end = std::cmp::min(start + page_size, total_count);
let slice = &items[start..end];
let edges: Result<Vec<Edge<T>>> = slice
.iter()
.enumerate()
.map(|(i, item)| {
let position = start + i;
let cursor = self.encoder.encode(position)?;
Ok(Edge {
node: item.clone(),
cursor,
})
})
.collect();
let edges = edges?;
let has_previous_page = start > 0;
let has_next_page = end < total_count;
let start_cursor = edges.first().map(|e| e.cursor.clone());
let end_cursor = edges.last().map(|e| e.cursor.clone());
let page_info = PageInfo {
has_next_page,
has_previous_page,
start_cursor,
end_cursor,
};
Ok(Connection {
edges,
page_info,
total_count: if self.include_total_count {
Some(total_count)
} else {
None
},
})
}
}
impl Default for RelayPagination {
fn default() -> Self {
Self::new()
}
}
impl std::fmt::Debug for RelayPagination {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RelayPagination")
.field("default_page_size", &self.default_page_size)
.field("max_page_size", &self.max_page_size)
.field("include_total_count", &self.include_total_count)
.finish()
}
}
#[cfg(test)]
mod tests {
use rstest::rstest;
use super::*;
#[rstest]
fn test_relay_pagination_forward() {
let items: Vec<i32> = (1..=100).collect();
let paginator = RelayPagination::new().default_page_size(10);
let connection = paginator
.paginate(&items, Some(10), None, None, None)
.unwrap();
assert_eq!(connection.edges.len(), 10);
assert_eq!(connection.edges[0].node, 1);
assert_eq!(connection.edges[9].node, 10);
assert!(connection.page_info.has_next_page);
assert!(!connection.page_info.has_previous_page);
assert_eq!(connection.total_count, Some(100));
}
#[rstest]
fn test_relay_pagination_forward_with_after() {
let items: Vec<i32> = (1..=100).collect();
let paginator = RelayPagination::new();
let page1 = paginator
.paginate(&items, Some(10), None, None, None)
.unwrap();
let after_cursor = page1.page_info.end_cursor.unwrap();
let page2 = paginator
.paginate(&items, Some(10), Some(&after_cursor), None, None)
.unwrap();
assert_eq!(page2.edges.len(), 10);
assert_eq!(page2.edges[0].node, 11);
assert_eq!(page2.edges[9].node, 20);
assert!(page2.page_info.has_previous_page);
assert!(page2.page_info.has_next_page);
}
#[rstest]
fn test_relay_pagination_backward() {
let items: Vec<i32> = (1..=100).collect();
let paginator = RelayPagination::new();
let connection = paginator
.paginate(&items, None, None, Some(10), None)
.unwrap();
assert_eq!(connection.edges.len(), 10);
assert_eq!(connection.edges[0].node, 91);
assert_eq!(connection.edges[9].node, 100);
assert!(!connection.page_info.has_next_page);
assert!(connection.page_info.has_previous_page);
}
#[rstest]
fn test_relay_pagination_edge_structure() {
let items = vec!["a", "b", "c"];
let paginator = RelayPagination::new();
let connection = paginator
.paginate(&items, Some(2), None, None, None)
.unwrap();
assert_eq!(connection.edges.len(), 2);
assert_eq!(connection.edges[0].node, "a");
assert!(!connection.edges[0].cursor.is_empty());
assert_eq!(connection.edges[1].node, "b");
assert!(!connection.edges[1].cursor.is_empty());
}
#[rstest]
fn test_relay_pagination_page_info() {
let items: Vec<i32> = (1..=5).collect();
let paginator = RelayPagination::new();
let connection = paginator
.paginate(&items, Some(3), None, None, None)
.unwrap();
assert!(connection.page_info.start_cursor.is_some());
assert!(connection.page_info.end_cursor.is_some());
assert!(connection.page_info.has_next_page);
assert!(!connection.page_info.has_previous_page);
}
#[rstest]
fn test_relay_pagination_max_page_size() {
let items: Vec<i32> = (1..=100).collect();
let paginator = RelayPagination::new().max_page_size(20);
let connection = paginator
.paginate(&items, Some(50), None, None, None)
.unwrap();
assert_eq!(connection.edges.len(), 20);
}
#[rstest]
fn test_relay_pagination_without_total_count() {
let items: Vec<i32> = (1..=100).collect();
let paginator = RelayPagination::new().include_total_count(false);
let connection = paginator
.paginate(&items, Some(10), None, None, None)
.unwrap();
assert_eq!(connection.total_count, None);
}
#[rstest]
fn test_relay_pagination_empty_list() {
let items: Vec<i32> = vec![];
let paginator = RelayPagination::new();
let connection = paginator
.paginate(&items, Some(10), None, None, None)
.unwrap();
assert_eq!(connection.edges.len(), 0);
assert!(!connection.page_info.has_next_page);
assert!(!connection.page_info.has_previous_page);
assert!(connection.page_info.start_cursor.is_none());
assert!(connection.page_info.end_cursor.is_none());
}
#[rstest]
fn relay_pagination_does_not_panic_on_out_of_range_after_cursor() {
let encoder = crate::pagination::cursor::Base64CursorEncoder::with_secret_key(
b"test-secret-key-for-unit-tests!!",
);
let items: Vec<i32> = (1..=10).collect();
let out_of_range_cursor = encoder.encode(100).unwrap();
let paginator = RelayPagination::new().with_encoder(encoder);
let result = paginator.paginate(&items, Some(5), Some(&out_of_range_cursor), None, None);
assert!(result.is_ok());
let connection = result.unwrap();
assert!(connection.edges.is_empty());
assert!(!connection.page_info.has_next_page);
}
#[rstest]
fn relay_pagination_does_not_panic_on_out_of_range_before_cursor() {
let encoder = crate::pagination::cursor::Base64CursorEncoder::with_secret_key(
b"test-secret-key-for-unit-tests!!",
);
let items: Vec<i32> = (1..=10).collect();
let out_of_range_cursor = encoder.encode(100).unwrap();
let paginator = RelayPagination::new().with_encoder(encoder);
let result = paginator.paginate(&items, None, None, Some(5), Some(&out_of_range_cursor));
assert!(result.is_ok());
}
#[rstest]
fn relay_pagination_empty_dataset_with_after_cursor() {
let encoder = crate::pagination::cursor::Base64CursorEncoder::with_secret_key(
b"test-secret-key-for-unit-tests!!",
);
let items: Vec<i32> = vec![];
let cursor = encoder.encode(0).unwrap();
let paginator = RelayPagination::new().with_encoder(encoder);
let result = paginator.paginate(&items, Some(10), Some(&cursor), None, None);
assert!(result.is_ok());
let connection = result.unwrap();
assert!(connection.edges.is_empty());
}
}