use base64::Engine;
use hmac::{Hmac, Mac};
use serde::{Deserialize, Serialize};
use sha2::Sha256;
type HmacSha256 = Hmac<Sha256>;
pub const MAX_CURSOR_BYTES: usize = 512;
pub const MAX_OPAQUE_STATE_BYTES: usize = 256;
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct CursorPayload {
pub tool_id: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub caller_id: Option<String>,
pub args_fingerprint: [u8; 32],
pub page_index: u32,
pub issued_at_unix: u64,
pub server_session: [u8; 8],
#[serde(default, with = "serde_bytes")]
pub opaque_state: Vec<u8>,
}
#[derive(thiserror::Error, Debug)]
pub enum CursorError {
#[error("cursor expired")]
Expired,
#[error("cursor signature invalid")]
InvalidSignature,
#[error("cursor format invalid: {0}")]
Format(String),
#[error("cursor too large: {0} bytes (max 512)")]
TooLarge(usize),
#[error("opaque_state too large: {0} bytes (max 256)")]
OpaqueStateTooLarge(usize),
}
pub struct CursorIssuer {
key: [u8; 32],
session_nonce: [u8; 8],
}
impl CursorIssuer {
pub fn new(key: [u8; 32]) -> Self {
let mut nonce = [0u8; 8];
getrandom::getrandom(&mut nonce).expect("OS RNG");
Self {
key,
session_nonce: nonce,
}
}
pub fn session_nonce(&self) -> [u8; 8] {
self.session_nonce
}
pub fn issue(&self, payload: CursorPayload) -> Result<String, CursorError> {
if payload.opaque_state.len() > MAX_OPAQUE_STATE_BYTES {
return Err(CursorError::OpaqueStateTooLarge(payload.opaque_state.len()));
}
let mut body = Vec::with_capacity(256);
ciborium::into_writer(&payload, &mut body)
.map_err(|e| CursorError::Format(e.to_string()))?;
let mut mac =
HmacSha256::new_from_slice(&self.key).expect("HMAC accepts arbitrary key lengths");
mac.update(&body);
let tag = mac.finalize().into_bytes();
let mut combined = body;
combined.extend_from_slice(&tag);
let encoded = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(&combined);
if encoded.len() > MAX_CURSOR_BYTES {
return Err(CursorError::TooLarge(encoded.len()));
}
Ok(encoded)
}
pub fn verify(&self, cursor: &str, ttl_seconds: u64) -> Result<CursorPayload, CursorError> {
if cursor.len() > MAX_CURSOR_BYTES {
return Err(CursorError::TooLarge(cursor.len()));
}
let combined = base64::engine::general_purpose::URL_SAFE_NO_PAD
.decode(cursor)
.map_err(|e| CursorError::Format(e.to_string()))?;
if combined.len() < 32 {
return Err(CursorError::Format("missing HMAC tag".into()));
}
let (body, tag) = combined.split_at(combined.len() - 32);
let mut mac =
HmacSha256::new_from_slice(&self.key).expect("HMAC accepts arbitrary key lengths");
mac.update(body);
if mac.verify_slice(tag).is_err() {
if let Ok(probe) = ciborium::from_reader::<CursorPayload, _>(body)
&& probe.server_session != self.session_nonce
{
return Err(CursorError::Expired);
}
return Err(CursorError::InvalidSignature);
}
let payload: CursorPayload =
ciborium::from_reader(body).map_err(|e| CursorError::Format(e.to_string()))?;
if payload.server_session != self.session_nonce {
return Err(CursorError::Expired);
}
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0);
if now.saturating_sub(payload.issued_at_unix) > ttl_seconds {
return Err(CursorError::Expired);
}
Ok(payload)
}
}
pub fn random_signing_key() -> [u8; 32] {
let mut k = [0u8; 32];
getrandom::getrandom(&mut k).expect("OS RNG");
k
}
pub fn signing_key_from_env_or_random() -> [u8; 32] {
if let Ok(value) = std::env::var("ATD_CURSOR_SIGNING_KEY") {
let trimmed = value.trim();
if !trimmed.is_empty() {
let decoded = base64::engine::general_purpose::URL_SAFE_NO_PAD
.decode(trimmed)
.or_else(|_| base64::engine::general_purpose::STANDARD.decode(trimmed));
match decoded {
Ok(bytes) if bytes.len() == 32 => {
let mut k = [0u8; 32];
k.copy_from_slice(&bytes);
return k;
}
Ok(bytes) => {
eprintln!(
"atd-runtime: ATD_CURSOR_SIGNING_KEY decoded to {} bytes; \
expected 32. Falling back to random per-process key.",
bytes.len()
);
}
Err(e) => {
eprintln!(
"atd-runtime: ATD_CURSOR_SIGNING_KEY base64 decode failed: {e}. \
Falling back to random per-process key."
);
}
}
}
}
random_signing_key()
}
pub fn args_fingerprint(args: &serde_json::Value) -> [u8; 32] {
use sha2::Digest;
let bytes = serde_json::to_vec(args).unwrap_or_default();
let mut hasher = Sha256::new();
hasher.update(&bytes);
hasher.finalize().into()
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
fn now() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs()
}
fn fresh_issuer() -> CursorIssuer {
let mut key = [0u8; 32];
getrandom::getrandom(&mut key).unwrap();
CursorIssuer::new(key)
}
fn mk_payload(issuer: &CursorIssuer, page: u32) -> CursorPayload {
CursorPayload {
tool_id: "celia:fhir.list_observations".into(),
caller_id: Some("test-caller".into()),
args_fingerprint: args_fingerprint(&json!({"patient": "p1"})),
page_index: page,
issued_at_unix: now(),
server_session: issuer.session_nonce(),
opaque_state: vec![],
}
}
#[test]
fn issue_round_trips() {
let issuer = fresh_issuer();
let payload = mk_payload(&issuer, 2);
let cursor = issuer.issue(payload.clone()).expect("issue");
let back = issuer.verify(&cursor, 300).expect("verify");
assert_eq!(back.tool_id, payload.tool_id);
assert_eq!(back.page_index, 2);
assert_eq!(back.args_fingerprint, payload.args_fingerprint);
}
#[test]
fn verify_rejects_tampered_signature() {
let issuer = fresh_issuer();
let cursor = issuer.issue(mk_payload(&issuer, 1)).unwrap();
let mut bytes: Vec<u8> = cursor.bytes().collect();
let target = bytes.len() - 16;
bytes[target] = if bytes[target] == b'A' { b'B' } else { b'A' };
let tampered = String::from_utf8(bytes).unwrap();
match issuer.verify(&tampered, 300) {
Err(CursorError::InvalidSignature) => {}
other => panic!("expected InvalidSignature, got {other:?}"),
}
}
#[test]
fn verify_rejects_after_ttl() {
let issuer = fresh_issuer();
let mut payload = mk_payload(&issuer, 1);
payload.issued_at_unix = now().saturating_sub(400); let cursor = issuer.issue(payload).unwrap();
match issuer.verify(&cursor, 300) {
Err(CursorError::Expired) => {}
other => panic!("expected Expired, got {other:?}"),
}
}
#[test]
fn verify_rejects_wrong_session_nonce() {
let key = {
let mut k = [0u8; 32];
getrandom::getrandom(&mut k).unwrap();
k
};
let issuer_a = CursorIssuer::new(key);
let issuer_b = CursorIssuer::new(key);
assert_ne!(issuer_a.session_nonce(), issuer_b.session_nonce());
let cursor = issuer_a.issue(mk_payload(&issuer_a, 1)).unwrap();
match issuer_b.verify(&cursor, 300) {
Err(CursorError::Expired) => {}
other => panic!("expected Expired (cross-session), got {other:?}"),
}
}
#[test]
fn verify_treats_server_restart_as_expired_not_forgery() {
let issuer_a = fresh_issuer();
let issuer_b = fresh_issuer(); let cursor = issuer_a.issue(mk_payload(&issuer_a, 1)).unwrap();
match issuer_b.verify(&cursor, 300) {
Err(CursorError::Expired) => {}
other => panic!("expected Expired (server restart), got {other:?}"),
}
}
#[test]
fn verify_rejects_real_forgery_with_invalid_signature() {
let issuer = fresh_issuer();
let original = issuer.issue(mk_payload(&issuer, 1)).unwrap();
let combined = base64::engine::general_purpose::URL_SAFE_NO_PAD
.decode(&original)
.unwrap();
let (_orig_body, orig_tag) = combined.split_at(combined.len() - 32);
let tampered_payload = CursorPayload {
tool_id: "test:tool".into(),
caller_id: Some("test-caller".into()),
args_fingerprint: [0u8; 32],
page_index: 999, issued_at_unix: now(),
server_session: issuer.session_nonce(), opaque_state: vec![],
};
let mut tampered_body = Vec::new();
ciborium::into_writer(&tampered_payload, &mut tampered_body).unwrap();
let mut forged = tampered_body;
forged.extend_from_slice(orig_tag);
let forged_cursor = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(&forged);
match issuer.verify(&forged_cursor, 300) {
Err(CursorError::InvalidSignature) => {}
other => panic!("expected InvalidSignature (real forgery), got {other:?}"),
}
}
#[test]
fn verify_unparseable_body_after_hmac_fail_is_invalid_signature() {
let issuer = fresh_issuer();
let garbage = vec![0xffu8; 64]; let garbage_cursor = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(&garbage);
match issuer.verify(&garbage_cursor, 300) {
Err(CursorError::InvalidSignature) => {}
other => panic!("expected InvalidSignature for unparseable, got {other:?}"),
}
}
#[test]
fn issue_rejects_oversized_opaque_state() {
let issuer = fresh_issuer();
let mut payload = mk_payload(&issuer, 1);
payload.opaque_state = vec![0u8; MAX_OPAQUE_STATE_BYTES + 1];
match issuer.issue(payload) {
Err(CursorError::OpaqueStateTooLarge(n)) => {
assert_eq!(n, MAX_OPAQUE_STATE_BYTES + 1)
}
other => panic!("expected OpaqueStateTooLarge, got {other:?}"),
}
}
#[test]
fn issue_rejects_oversized_payload_via_long_tool_id() {
let issuer = fresh_issuer();
let mut payload = mk_payload(&issuer, 1);
payload.tool_id = "x".repeat(400);
payload.opaque_state = vec![0u8; MAX_OPAQUE_STATE_BYTES];
match issuer.issue(payload) {
Err(CursorError::TooLarge(n)) => {
assert!(n > MAX_CURSOR_BYTES, "got {n}");
}
other => panic!("expected TooLarge, got {other:?}"),
}
}
#[test]
fn verify_rejects_malformed_base64() {
let issuer = fresh_issuer();
match issuer.verify("not!base64!", 300) {
Err(CursorError::Format(_)) => {}
other => panic!("expected Format error, got {other:?}"),
}
}
#[test]
fn verify_rejects_too_short_combined() {
let issuer = fresh_issuer();
let cursor = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(b"abcdef");
match issuer.verify(&cursor, 300) {
Err(CursorError::Format(_)) => {}
other => panic!("expected Format (missing HMAC tag), got {other:?}"),
}
}
#[test]
fn args_fingerprint_stable_for_same_value() {
let a = args_fingerprint(&json!({"x": 1, "y": [2, 3]}));
let b = args_fingerprint(&json!({"x": 1, "y": [2, 3]}));
assert_eq!(a, b);
}
#[test]
fn args_fingerprint_differs_for_different_value() {
let a = args_fingerprint(&json!({"x": 1}));
let b = args_fingerprint(&json!({"x": 2}));
assert_ne!(a, b);
}
#[test]
fn cursor_under_cap_for_typical_payload() {
let issuer = fresh_issuer();
let payload = mk_payload(&issuer, 1);
let cursor = issuer.issue(payload).unwrap();
assert!(
cursor.len() < MAX_CURSOR_BYTES,
"typical cursor over cap: {} > {}",
cursor.len(),
MAX_CURSOR_BYTES
);
}
#[test]
fn opaque_state_round_trips() {
let issuer = fresh_issuer();
let mut payload = mk_payload(&issuer, 1);
payload.opaque_state = b"keyset:last_id=42,page=3".to_vec();
let cursor = issuer.issue(payload.clone()).unwrap();
let back = issuer.verify(&cursor, 300).unwrap();
assert_eq!(back.opaque_state, payload.opaque_state);
}
fn signing_key_env_lock() -> std::sync::MutexGuard<'static, ()> {
static LOCK: std::sync::OnceLock<std::sync::Mutex<()>> = std::sync::OnceLock::new();
LOCK.get_or_init(|| std::sync::Mutex::new(()))
.lock()
.unwrap_or_else(|p| p.into_inner())
}
#[test]
fn signing_key_from_env_reads_base64url_no_pad() {
let _g = signing_key_env_lock();
let want = [7u8; 32];
let encoded = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(want);
unsafe { std::env::set_var("ATD_CURSOR_SIGNING_KEY", &encoded) };
let got = signing_key_from_env_or_random();
unsafe { std::env::remove_var("ATD_CURSOR_SIGNING_KEY") };
assert_eq!(got, want, "env key must round-trip verbatim");
}
#[test]
fn signing_key_from_env_falls_back_on_wrong_length() {
let _g = signing_key_env_lock();
let bad = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode([1u8; 16]);
unsafe { std::env::set_var("ATD_CURSOR_SIGNING_KEY", &bad) };
let got = signing_key_from_env_or_random();
unsafe { std::env::remove_var("ATD_CURSOR_SIGNING_KEY") };
assert_ne!(got, [1u8; 32]);
}
#[test]
fn signing_key_from_env_falls_back_on_garbage() {
let _g = signing_key_env_lock();
unsafe { std::env::set_var("ATD_CURSOR_SIGNING_KEY", "not-base64!!") };
let got = signing_key_from_env_or_random();
unsafe { std::env::remove_var("ATD_CURSOR_SIGNING_KEY") };
let _: [u8; 32] = got;
}
#[test]
fn signing_key_from_env_falls_back_when_unset() {
let _g = signing_key_env_lock();
unsafe { std::env::remove_var("ATD_CURSOR_SIGNING_KEY") };
let a = signing_key_from_env_or_random();
let b = signing_key_from_env_or_random();
assert_ne!(a, b, "random fallback should yield distinct keys");
}
#[test]
fn ttl_zero_rejects_freshly_issued() {
let issuer = fresh_issuer();
let mut payload = mk_payload(&issuer, 1);
payload.issued_at_unix = now().saturating_sub(1); let cursor = issuer.issue(payload).unwrap();
std::thread::sleep(Duration::from_millis(10));
match issuer.verify(&cursor, 0) {
Err(CursorError::Expired) => {}
other => panic!("expected Expired with ttl=0, got {other:?}"),
}
}
}