use hmac::{Hmac, Mac};
use serde::{Deserialize, Serialize};
use sha2::Sha256;
use std::collections::{HashMap, HashSet};
use std::time::{Duration, SystemTime, UNIX_EPOCH};
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct TaskId(pub String);
impl TaskId {
pub fn new() -> Self {
Self(format!("T-{}", uid_hex(8)))
}
}
impl Default for TaskId {
fn default() -> Self {
Self::new()
}
}
impl std::fmt::Display for TaskId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct SessionId(pub String);
impl SessionId {
pub fn new() -> Self {
Self(format!("S-{}", uid_hex(8)))
}
}
impl Default for SessionId {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct WorkerId(pub String);
impl WorkerId {
pub fn new() -> Self {
Self(format!("W-{}", uid_hex(8)))
}
}
impl Default for WorkerId {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum Role {
Operator,
Worker,
Observer,
Senior,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum Verb {
StartTask,
DispatchAttempt,
MintWorkerToken,
ReadTaskState,
CancelTask,
QuerySenior,
MarkPass,
MarkBlocked,
AttachSession,
DetachSession,
Heartbeat,
PollTask,
FetchPrompt,
FetchData,
PostResult,
VerifyToken,
EmitOutput,
SubscribeEvents,
ReadTrace,
AnswerQuery,
OverrideVerdict,
PauseLoop,
ResumeLoop,
InjectDirective,
}
#[derive(Debug, Clone)]
pub struct RoleVerbGate {
table: HashMap<Role, HashSet<Verb>>,
}
impl RoleVerbGate {
pub fn new() -> Self {
Self {
table: HashMap::new(),
}
}
pub fn allow(mut self, role: Role, verbs: &[Verb]) -> Self {
let set = self.table.entry(role).or_default();
for v in verbs {
set.insert(*v);
}
self
}
pub fn is_allowed(&self, role: Role, verb: Verb) -> bool {
self.table
.get(&role)
.map(|s| s.contains(&verb))
.unwrap_or(false)
}
}
impl Default for RoleVerbGate {
fn default() -> Self {
default_role_verb_table()
}
}
pub const OPERATOR_VERBS: &[Verb] = &[
Verb::StartTask,
Verb::DispatchAttempt,
Verb::MintWorkerToken,
Verb::ReadTaskState,
Verb::CancelTask,
Verb::QuerySenior,
Verb::MarkPass,
Verb::MarkBlocked,
Verb::AttachSession,
Verb::DetachSession,
Verb::Heartbeat,
Verb::PollTask,
];
pub const WORKER_LEAF_VERBS: &[Verb] = &[
Verb::FetchPrompt,
Verb::FetchData,
Verb::PostResult,
Verb::VerifyToken,
Verb::EmitOutput,
];
pub const WORKER_SWARM_VERBS: &[Verb] = &[
Verb::StartTask,
Verb::DispatchAttempt,
Verb::ReadTaskState,
Verb::PollTask,
Verb::CancelTask,
];
pub const OBSERVER_VERBS: &[Verb] = &[Verb::SubscribeEvents, Verb::ReadTrace, Verb::ReadTaskState];
pub const SENIOR_VERBS: &[Verb] = &[
Verb::AnswerQuery,
Verb::OverrideVerdict,
Verb::PauseLoop,
Verb::ResumeLoop,
Verb::InjectDirective,
];
pub fn default_role_verb_table() -> RoleVerbGate {
RoleVerbGate::new()
.allow(Role::Operator, OPERATOR_VERBS)
.allow(Role::Worker, WORKER_LEAF_VERBS)
.allow(Role::Worker, WORKER_SWARM_VERBS)
.allow(Role::Observer, OBSERVER_VERBS)
.allow(Role::Senior, SENIOR_VERBS)
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct CapToken {
pub agent_id: String,
pub role: Role,
pub scopes: Vec<String>,
pub issued_at: u64,
pub expire_at: u64,
pub max_uses: Option<u32>,
pub nonce: String,
pub sig_hex: String,
}
impl CapToken {
pub fn id(&self) -> &str {
&self.nonce
}
pub fn signing_input(&self) -> Vec<u8> {
let s = format!(
"{}|{:?}|{}|{}|{}|{:?}|{}",
self.agent_id,
self.role,
self.scopes.join(","),
self.issued_at,
self.expire_at,
self.max_uses,
self.nonce,
);
s.into_bytes()
}
pub fn is_expired(&self, now_unix: u64) -> bool {
now_unix >= self.expire_at
}
pub fn encode(&self) -> String {
use base64::Engine as _;
let json = serde_json::to_vec(self).expect("CapToken is always JSON-serializable");
base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(json)
}
pub fn decode(s: &str) -> Result<Self, CapTokenDecodeError> {
use base64::Engine as _;
let bytes = base64::engine::general_purpose::URL_SAFE_NO_PAD
.decode(s)
.map_err(|e| CapTokenDecodeError::Base64(e.to_string()))?;
serde_json::from_slice(&bytes).map_err(|e| CapTokenDecodeError::Json(e.to_string()))
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct WorkerPayload {
pub task_id: String,
pub attempt: u32,
pub agent: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub system: Option<String>,
pub prompt: String,
}
#[derive(Debug, thiserror::Error)]
pub enum CapTokenDecodeError {
#[error("base64 decode failed: {0}")]
Base64(String),
#[error("json parse failed: {0}")]
Json(String),
}
#[derive(Debug, Clone)]
pub struct TokenSigner {
secret: Vec<u8>,
}
impl TokenSigner {
pub fn new(secret: impl AsRef<[u8]>) -> Self {
Self {
secret: secret.as_ref().to_vec(),
}
}
pub fn mint(
&self,
agent_id: impl Into<String>,
role: Role,
scopes: Vec<String>,
ttl: Duration,
max_uses: Option<u32>,
) -> CapToken {
let now = now_unix();
let mut token = CapToken {
agent_id: agent_id.into(),
role,
scopes,
issued_at: now,
expire_at: now + ttl.as_secs(),
max_uses,
nonce: secure_hex(16),
sig_hex: String::new(),
};
let mut mac =
Hmac::<Sha256>::new_from_slice(&self.secret).expect("HMAC accepts any key length");
mac.update(&token.signing_input());
let sig = mac.finalize().into_bytes();
token.sig_hex = hex::encode(sig);
token
}
pub fn verify_sig(&self, token: &CapToken) -> bool {
let mut mac =
Hmac::<Sha256>::new_from_slice(&self.secret).expect("HMAC accepts any key length");
mac.update(&token.signing_input());
let expected = mac.finalize().into_bytes();
let Ok(provided) = hex::decode(&token.sig_hex) else {
return false;
};
ct_eq(&expected, &provided)
}
pub fn one_time(
&self,
agent_id: impl Into<String>,
role: Role,
scopes: Vec<String>,
ttl: Duration,
) -> CapToken {
self.mint(agent_id, role, scopes, ttl, Some(1))
}
pub fn session(
&self,
agent_id: impl Into<String>,
role: Role,
scopes: Vec<String>,
ttl: Duration,
) -> CapToken {
self.mint(agent_id, role, scopes, ttl, None)
}
pub fn limited(
&self,
agent_id: impl Into<String>,
role: Role,
scopes: Vec<String>,
ttl: Duration,
max_uses: u32,
) -> CapToken {
self.mint(agent_id, role, scopes, ttl, Some(max_uses))
}
}
pub(crate) fn now_unix() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("system clock is before UNIX_EPOCH")
.as_secs()
}
pub fn uid_hex(bytes: usize) -> String {
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::OnceLock;
static COUNTER: AtomicU64 = AtomicU64::new(0);
static SALT: OnceLock<u128> = OnceLock::new();
let salt = *SALT.get_or_init(|| {
let mut b = [0u8; 16];
getrandom::fill(&mut b).expect("OS RNG unavailable");
u128::from_le_bytes(b)
});
let c = COUNTER.fetch_add(1, Ordering::Relaxed) as u128;
let v = salt ^ c;
let raw = format!("{:032x}", v);
let n = (bytes * 2).min(32);
raw[32 - n..].to_string()
}
pub fn secure_hex(bytes: usize) -> String {
let mut buf = vec![0u8; bytes];
getrandom::fill(&mut buf).expect("OS RNG unavailable");
hex::encode(buf)
}
fn ct_eq(a: &[u8], b: &[u8]) -> bool {
if a.len() != b.len() {
return false;
}
let mut diff: u8 = 0;
for (x, y) in a.iter().zip(b.iter()) {
diff |= x ^ y;
}
diff == 0
}
#[cfg(test)]
mod cap_token_transport_tests {
use super::*;
use std::time::Duration;
#[test]
fn encode_decode_round_trips() {
let signer = TokenSigner::new("test-secret");
let token = signer.session(
"worker-of-task-x",
Role::Worker,
vec!["*".into()],
Duration::from_secs(600),
);
let s = token.encode();
assert!(!s.contains('+'));
assert!(!s.contains('/'));
assert!(!s.contains('='));
let decoded = CapToken::decode(&s).expect("decode ok");
assert_eq!(decoded, token);
assert!(
signer.verify_sig(&decoded),
"HMAC sig still verifies after round-trip"
);
}
#[test]
fn decode_rejects_garbage() {
let err = CapToken::decode("not-base64!!!").expect_err("should fail");
assert!(matches!(err, CapTokenDecodeError::Base64(_)));
}
#[test]
fn decode_rejects_non_token_json() {
use base64::Engine as _;
let bogus = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(b"{\"oops\":1}");
let err = CapToken::decode(&bogus).expect_err("should fail json shape");
assert!(matches!(err, CapTokenDecodeError::Json(_)));
}
}