use rmcp::ErrorData as McpError;
use rmcp::schemars;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema, PartialEq, Eq)]
#[serde(transparent)]
pub struct Cursor(pub String);
impl Cursor {
pub(super) fn encode_fjall(last_key: &[u8]) -> Self {
Cursor(base64url_encode(last_key))
}
pub(super) fn decode_fjall(&self) -> Result<Vec<u8>, McpError> {
base64url_decode(&self.0)
.map_err(|e| McpError::invalid_params(format!("invalid cursor: {e}"), None))
}
pub(crate) fn encode_in_memory(offset: u64, snapshot_id: u32) -> Self {
let payload = InMemCursor {
o: offset,
s: snapshot_id,
};
let bytes = rmp_serde::to_vec(&payload).expect("encode in-memory cursor never fails");
Cursor(base64url_encode(&bytes))
}
pub(super) fn decode_in_memory(&self) -> Result<(u64, u32), McpError> {
let bytes = base64url_decode(&self.0)
.map_err(|e| McpError::invalid_params(format!("invalid cursor: {e}"), None))?;
let payload: InMemCursor = rmp_serde::from_slice(&bytes).map_err(|e| {
McpError::invalid_params(format!("invalid in-memory cursor payload: {e}"), None)
})?;
Ok((payload.o, payload.s))
}
}
#[derive(Serialize, Deserialize)]
struct InMemCursor {
o: u64,
s: u32,
}
pub(super) fn prefix_upper_bound(prefix: &[u8]) -> Option<Vec<u8>> {
let mut out = prefix.to_vec();
while let Some(last) = out.last_mut() {
if *last == 0xFF {
out.pop();
continue;
}
*last += 1;
return Some(out);
}
None
}
const B64URL: &[u8; 64] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_";
fn base64url_encode(bytes: &[u8]) -> String {
let mut out = String::with_capacity(bytes.len().div_ceil(3) * 4);
let mut i = 0;
while i + 3 <= bytes.len() {
let n = ((bytes[i] as u32) << 16) | ((bytes[i + 1] as u32) << 8) | (bytes[i + 2] as u32);
out.push(B64URL[((n >> 18) & 0x3F) as usize] as char);
out.push(B64URL[((n >> 12) & 0x3F) as usize] as char);
out.push(B64URL[((n >> 6) & 0x3F) as usize] as char);
out.push(B64URL[(n & 0x3F) as usize] as char);
i += 3;
}
let rem = bytes.len() - i;
if rem == 1 {
let n = (bytes[i] as u32) << 16;
out.push(B64URL[((n >> 18) & 0x3F) as usize] as char);
out.push(B64URL[((n >> 12) & 0x3F) as usize] as char);
} else if rem == 2 {
let n = ((bytes[i] as u32) << 16) | ((bytes[i + 1] as u32) << 8);
out.push(B64URL[((n >> 18) & 0x3F) as usize] as char);
out.push(B64URL[((n >> 12) & 0x3F) as usize] as char);
out.push(B64URL[((n >> 6) & 0x3F) as usize] as char);
}
out
}
fn base64url_decode(s: &str) -> Result<Vec<u8>, &'static str> {
let bytes = s.as_bytes();
let mut decoded = Vec::with_capacity(bytes.len() * 3 / 4);
let mut buf: u32 = 0;
let mut bits: u32 = 0;
for &b in bytes {
let v = match b {
b'A'..=b'Z' => b - b'A',
b'a'..=b'z' => b - b'a' + 26,
b'0'..=b'9' => b - b'0' + 52,
b'-' => 62,
b'_' => 63,
b'=' => continue, _ => return Err("non-base64url byte"),
};
buf = (buf << 6) | (v as u32);
bits += 6;
if bits >= 8 {
bits -= 8;
decoded.push((buf >> bits) as u8);
buf &= (1 << bits) - 1;
}
}
Ok(decoded)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn base64url_roundtrip_random_lengths() {
for len in 0..32usize {
let bytes: Vec<u8> = (0..len).map(|i| (i as u8).wrapping_mul(37)).collect();
let encoded = base64url_encode(&bytes);
let decoded = base64url_decode(&encoded).expect("decode roundtrips");
assert_eq!(bytes, decoded, "len={len}");
}
}
#[test]
fn base64url_rejects_invalid_chars() {
assert!(base64url_decode("**").is_err());
}
#[test]
fn fjall_cursor_roundtrip() {
let key = b"\x00\x01\x02\xFFabc";
let c = Cursor::encode_fjall(key);
let decoded = c.decode_fjall().expect("decode");
assert_eq!(decoded, key);
}
#[test]
fn in_memory_cursor_roundtrip() {
let c = Cursor::encode_in_memory(42, 7);
let (offset, snapshot_id) = c.decode_in_memory().expect("decode");
assert_eq!(offset, 42);
assert_eq!(snapshot_id, 7);
}
#[test]
fn prefix_upper_bound_increments_last_byte() {
assert_eq!(
prefix_upper_bound(b"prefix\x01"),
Some(b"prefix\x02".to_vec())
);
}
#[test]
fn prefix_upper_bound_strips_trailing_ff() {
assert_eq!(prefix_upper_bound(b"ab\xFF\xFF"), Some(b"ac".to_vec()));
}
#[test]
fn prefix_upper_bound_all_ff_returns_none() {
assert_eq!(prefix_upper_bound(&[0xFF, 0xFF]), None);
}
}