use bitcode::{Decode, Encode};
use serde::{Deserialize, Serialize};
pub type CursorId = String;
#[derive(Debug, Clone, Copy, Serialize, Deserialize, Encode, Decode, PartialEq, Eq, Hash)]
pub enum CursorResultType {
Rows,
Nodes,
Edges,
Similar,
Unified,
PatternMatch,
}
impl std::fmt::Display for CursorResultType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Rows => write!(f, "rows"),
Self::Nodes => write!(f, "nodes"),
Self::Edges => write!(f, "edges"),
Self::Similar => write!(f, "similar"),
Self::Unified => write!(f, "unified"),
Self::PatternMatch => write!(f, "pattern_match"),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Encode, Decode, PartialEq, Eq)]
pub struct CursorState {
pub id: CursorId,
pub query: String,
pub result_type: CursorResultType,
pub offset: usize,
pub page_size: usize,
pub total_count: Option<usize>,
pub created_at: i64,
pub last_accessed_at: i64,
pub ttl_secs: u32,
}
impl CursorState {
pub const DEFAULT_PAGE_SIZE: usize = 100;
pub const DEFAULT_TTL_SECS: u32 = 300;
pub const MAX_TTL_SECS: u32 = 1800;
#[must_use]
pub fn new(
id: CursorId,
query: String,
result_type: CursorResultType,
page_size: usize,
total_count: Option<usize>,
ttl_secs: u32,
) -> Self {
let now = current_timestamp();
Self {
id,
query,
result_type,
offset: 0,
page_size,
total_count,
created_at: now,
last_accessed_at: now,
ttl_secs,
}
}
#[must_use]
pub fn next_page(&self) -> Self {
let mut next = self.clone();
next.offset += self.page_size;
next.last_accessed_at = current_timestamp();
next
}
#[must_use]
pub fn prev_page(&self) -> Option<Self> {
if self.offset == 0 {
return None;
}
let mut prev = self.clone();
prev.offset = self.offset.saturating_sub(self.page_size);
prev.last_accessed_at = current_timestamp();
Some(prev)
}
#[must_use]
pub const fn has_more(&self) -> bool {
match self.total_count {
Some(total) => self.offset + self.page_size < total,
None => true, }
}
#[must_use]
pub fn is_expired(&self) -> bool {
let now = current_timestamp();
let elapsed = now - self.last_accessed_at;
elapsed > i64::from(self.ttl_secs)
}
pub fn touch(&mut self) {
self.last_accessed_at = current_timestamp();
}
pub fn encode(&self) -> Result<String, CursorError> {
use base64::Engine;
let encoded = bitcode::encode(self);
Ok(base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(&encoded))
}
pub fn decode(token: &str) -> Result<Self, CursorError> {
use base64::Engine;
let bytes = base64::engine::general_purpose::URL_SAFE_NO_PAD
.decode(token)
.map_err(|e| CursorError::InvalidToken(format!("base64 decode failed: {e}")))?;
let state: Self = bitcode::decode(&bytes)
.map_err(|e| CursorError::InvalidToken(format!("bitcode decode failed: {e}")))?;
if state.is_expired() {
return Err(CursorError::Expired(state.id));
}
Ok(state)
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum CursorError {
InvalidToken(String),
Expired(CursorId),
NotFound(CursorId),
CapacityExceeded,
}
impl std::fmt::Display for CursorError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::InvalidToken(msg) => write!(f, "Invalid cursor token: {msg}"),
Self::Expired(id) => write!(f, "Cursor expired: {id}"),
Self::NotFound(id) => write!(f, "Cursor not found: {id}"),
Self::CapacityExceeded => write!(f, "Maximum cursor capacity exceeded"),
}
}
}
impl std::error::Error for CursorError {}
fn current_timestamp() -> i64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| i64::try_from(d.as_secs()).unwrap_or(i64::MAX))
.unwrap_or(0)
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_cursor() -> CursorState {
CursorState::new(
"test-cursor-id".to_string(),
"SELECT users".to_string(),
CursorResultType::Rows,
100,
Some(500),
300,
)
}
#[test]
fn test_cursor_state_new() {
let cursor = create_test_cursor();
assert_eq!(cursor.id, "test-cursor-id");
assert_eq!(cursor.query, "SELECT users");
assert_eq!(cursor.result_type, CursorResultType::Rows);
assert_eq!(cursor.offset, 0);
assert_eq!(cursor.page_size, 100);
assert_eq!(cursor.total_count, Some(500));
assert_eq!(cursor.ttl_secs, 300);
}
#[test]
fn test_cursor_next_page() {
let cursor = create_test_cursor();
let next = cursor.next_page();
assert_eq!(next.offset, 100);
assert_eq!(next.page_size, 100);
assert_eq!(next.id, cursor.id);
}
#[test]
fn test_cursor_prev_page_at_start() {
let cursor = create_test_cursor();
assert!(cursor.prev_page().is_none());
}
#[test]
fn test_cursor_prev_page() {
let mut cursor = create_test_cursor();
cursor.offset = 200;
let prev = cursor.prev_page().unwrap();
assert_eq!(prev.offset, 100);
}
#[test]
fn test_cursor_has_more() {
let cursor = create_test_cursor();
assert!(cursor.has_more());
let mut last_page = cursor.clone();
last_page.offset = 400;
assert!(!last_page.has_more());
}
#[test]
fn test_cursor_has_more_unknown_total() {
let cursor = CursorState::new(
"test".to_string(),
"SELECT users".to_string(),
CursorResultType::Rows,
100,
None, 300,
);
assert!(cursor.has_more());
}
#[test]
fn test_cursor_is_expired() {
let mut cursor = create_test_cursor();
assert!(!cursor.is_expired());
cursor.last_accessed_at = current_timestamp() - 400;
assert!(cursor.is_expired());
}
#[test]
fn test_cursor_touch() {
let mut cursor = create_test_cursor();
let original_time = cursor.last_accessed_at;
std::thread::sleep(std::time::Duration::from_millis(10));
cursor.touch();
assert!(cursor.last_accessed_at >= original_time);
}
#[test]
fn test_cursor_encode_decode_roundtrip() {
let cursor = create_test_cursor();
let token = cursor.encode().unwrap();
let decoded = CursorState::decode(&token).unwrap();
assert_eq!(cursor.id, decoded.id);
assert_eq!(cursor.query, decoded.query);
assert_eq!(cursor.result_type, decoded.result_type);
assert_eq!(cursor.offset, decoded.offset);
assert_eq!(cursor.page_size, decoded.page_size);
assert_eq!(cursor.total_count, decoded.total_count);
}
#[test]
fn test_cursor_decode_invalid_token() {
let result = CursorState::decode("not-valid-base64!!!");
assert!(matches!(result, Err(CursorError::InvalidToken(_))));
}
#[test]
fn test_cursor_decode_expired() {
let mut cursor = create_test_cursor();
cursor.last_accessed_at = current_timestamp() - 400; let token = cursor.encode().unwrap();
let result = CursorState::decode(&token);
assert!(matches!(result, Err(CursorError::Expired(_))));
}
#[test]
fn test_cursor_result_type_display() {
assert_eq!(CursorResultType::Rows.to_string(), "rows");
assert_eq!(CursorResultType::Nodes.to_string(), "nodes");
assert_eq!(CursorResultType::Edges.to_string(), "edges");
assert_eq!(CursorResultType::Similar.to_string(), "similar");
assert_eq!(CursorResultType::Unified.to_string(), "unified");
assert_eq!(CursorResultType::PatternMatch.to_string(), "pattern_match");
}
#[test]
fn test_cursor_error_display() {
let err = CursorError::InvalidToken("bad".to_string());
assert!(err.to_string().contains("Invalid cursor token"));
let err = CursorError::Expired("cursor-1".to_string());
assert!(err.to_string().contains("expired"));
let err = CursorError::NotFound("cursor-2".to_string());
assert!(err.to_string().contains("not found"));
let err = CursorError::CapacityExceeded;
assert!(err.to_string().contains("capacity"));
}
#[test]
fn test_cursor_constants() {
assert_eq!(CursorState::DEFAULT_PAGE_SIZE, 100);
assert_eq!(CursorState::DEFAULT_TTL_SECS, 300);
assert_eq!(CursorState::MAX_TTL_SECS, 1800);
}
#[test]
fn test_cursor_state_partial_eq() {
let cursor1 = create_test_cursor();
let cursor2 = create_test_cursor();
assert_eq!(cursor1.id, cursor2.id);
assert_eq!(cursor1.query, cursor2.query);
}
#[test]
fn test_cursor_result_type_equality() {
assert_eq!(CursorResultType::Rows, CursorResultType::Rows);
assert_ne!(CursorResultType::Rows, CursorResultType::Nodes);
}
#[test]
fn test_cursor_clone() {
let cursor = create_test_cursor();
let cloned = cursor.clone();
assert_eq!(cursor.id, cloned.id);
assert_eq!(cursor.offset, cloned.offset);
}
#[test]
fn test_cursor_debug() {
let cursor = create_test_cursor();
let debug = format!("{cursor:?}");
assert!(debug.contains("CursorState"));
assert!(debug.contains("test-cursor-id"));
}
#[test]
fn test_prev_page_saturating_sub() {
let mut cursor = create_test_cursor();
cursor.offset = 50; let prev = cursor.prev_page().unwrap();
assert_eq!(prev.offset, 0);
}
#[test]
fn test_cursor_next_page_chain() {
let cursor = create_test_cursor();
let page2 = cursor.next_page();
let page3 = page2.next_page();
let page4 = page3.next_page();
assert_eq!(page4.offset, 300);
}
#[test]
fn test_has_more_exact_boundary() {
let mut cursor = create_test_cursor();
cursor.total_count = Some(200);
cursor.offset = 100;
assert!(!cursor.has_more());
}
#[test]
fn test_is_expired_boundary() {
let mut cursor = create_test_cursor();
cursor.ttl_secs = 10;
cursor.last_accessed_at = current_timestamp() - 10;
assert!(!cursor.is_expired());
cursor.last_accessed_at = current_timestamp() - 11;
assert!(cursor.is_expired());
}
}