Skip to main content

mlua_swarm/
types.rs

1//! Fundamental types: Role / Verb / RoleVerbGate / CapToken / IDs.
2
3use hmac::{Hmac, Mac};
4use serde::{Deserialize, Serialize};
5use sha2::Sha256;
6use std::collections::{HashMap, HashSet};
7use std::time::{Duration, SystemTime, UNIX_EPOCH};
8
9// ─── ID newtypes ───────────────────────────────────────────────────────────
10
11/// Opaque task identifier, e.g. `T-<hex>`. Newtype over `String` so task,
12/// session, and worker ids can't be swapped by accident at call sites.
13#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
14pub struct TaskId(pub String);
15
16impl TaskId {
17    /// Mint a fresh id with the `T-` prefix and a process-unique nonce.
18    pub fn new() -> Self {
19        Self(format!("T-{}", uid_hex(8)))
20    }
21}
22
23impl Default for TaskId {
24    fn default() -> Self {
25        Self::new()
26    }
27}
28
29impl std::fmt::Display for TaskId {
30    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31        write!(f, "{}", self.0)
32    }
33}
34
35/// Opaque session identifier, e.g. `S-<hex>`. See [`TaskId`] for the newtype
36/// rationale.
37#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
38pub struct SessionId(pub String);
39
40impl SessionId {
41    /// Mint a fresh id with the `S-` prefix and a process-unique nonce.
42    pub fn new() -> Self {
43        Self(format!("S-{}", uid_hex(8)))
44    }
45}
46
47impl Default for SessionId {
48    fn default() -> Self {
49        Self::new()
50    }
51}
52
53/// Opaque worker identifier, e.g. `W-<hex>`. See [`TaskId`] for the newtype
54/// rationale.
55#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
56pub struct WorkerId(pub String);
57
58impl WorkerId {
59    /// Mint a fresh id with the `W-` prefix and a process-unique nonce.
60    pub fn new() -> Self {
61        Self(format!("W-{}", uid_hex(8)))
62    }
63}
64
65impl Default for WorkerId {
66    fn default() -> Self {
67        Self::new()
68    }
69}
70
71// ─── Role × Verb ───────────────────────────────────────────────────────────
72
73/// The four participant roles in the swarm. Every [`Verb`] a caller wants to
74/// invoke must be allow-listed for its role in a [`RoleVerbGate`].
75#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
76#[serde(rename_all = "snake_case")]
77pub enum Role {
78    /// Drives task lifecycle: starts tasks, dispatches attempts, reads
79    /// state, manages sessions.
80    Operator,
81    /// Executes a dispatched attempt: fetches its prompt/data, posts a
82    /// result, verifies its own token.
83    Worker,
84    /// Read-only: subscribes to events and reads trace/state without
85    /// mutating anything.
86    Observer,
87    /// Human/oversight role: answers queries, overrides verdicts, and can
88    /// pause/resume the loop or inject a directive.
89    Senior,
90}
91
92/// Every action a participant can request. Grouped by the [`Role`] that
93/// typically performs it (see the `// operator` / `// worker` / ... section
94/// comments below); the grouping is documentation only — actual
95/// authorization is decided by [`RoleVerbGate::is_allowed`].
96#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
97#[serde(rename_all = "snake_case")]
98pub enum Verb {
99    // operator
100    /// Create a new task.
101    StartTask,
102    /// Dispatch (or re-dispatch) an attempt for a task.
103    DispatchAttempt,
104    /// Mint a [`CapToken`] for a worker.
105    MintWorkerToken,
106    /// Read the current state of a task.
107    ReadTaskState,
108    /// Cancel a task.
109    CancelTask,
110    /// Ask a [`Role::Senior`] a question about a task.
111    QuerySenior,
112    /// Mark a task/attempt as passed.
113    MarkPass,
114    /// Mark a task/attempt as blocked.
115    MarkBlocked,
116    /// Attach a session to a task.
117    AttachSession,
118    /// Detach a session from a task.
119    DetachSession,
120    /// Emit a liveness heartbeat.
121    Heartbeat,
122    /// Poll for task progress/completion.
123    PollTask,
124    // worker
125    /// Fetch the rendered prompt for the current attempt.
126    FetchPrompt,
127    /// Fetch task input data.
128    FetchData,
129    /// Post the result of an attempt.
130    PostResult,
131    /// Verify a presented [`CapToken`].
132    VerifyToken,
133    /// Emit intermediate output for observers.
134    EmitOutput,
135    // observer
136    /// Subscribe to the task's event stream.
137    SubscribeEvents,
138    /// Read the accumulated trace of a task.
139    ReadTrace,
140    // senior
141    /// Answer a query raised via [`Verb::QuerySenior`].
142    AnswerQuery,
143    /// Override a previously recorded verdict.
144    OverrideVerdict,
145    /// Pause the dispatch loop.
146    PauseLoop,
147    /// Resume a paused dispatch loop.
148    ResumeLoop,
149    /// Inject a directive into a running task.
150    InjectDirective,
151}
152
153/// Role × Verb gate table. Const-style storage.
154#[derive(Debug, Clone)]
155pub struct RoleVerbGate {
156    table: HashMap<Role, HashSet<Verb>>,
157}
158
159impl RoleVerbGate {
160    /// Build an empty gate (nothing allowed until [`Self::allow`] is called).
161    pub fn new() -> Self {
162        Self {
163            table: HashMap::new(),
164        }
165    }
166
167    /// Allow-list `verbs` for `role`, merging with any existing entries.
168    /// Returns `self` for chained construction (see
169    /// [`default_role_verb_table`]).
170    pub fn allow(mut self, role: Role, verbs: &[Verb]) -> Self {
171        let set = self.table.entry(role).or_default();
172        for v in verbs {
173            set.insert(*v);
174        }
175        self
176    }
177
178    /// Whether `role` is allow-listed to invoke `verb`.
179    pub fn is_allowed(&self, role: Role, verb: Verb) -> bool {
180        self.table
181            .get(&role)
182            .map(|s| s.contains(&verb))
183            .unwrap_or(false)
184    }
185}
186
187impl Default for RoleVerbGate {
188    fn default() -> Self {
189        default_role_verb_table()
190    }
191}
192
193// ─── Verb tables (const slices, swap-out points for future Role splits) ──
194
195/// Verbs an Operator may invoke — covers task lifecycle, session, and
196/// senior interactions.
197pub const OPERATOR_VERBS: &[Verb] = &[
198    Verb::StartTask,
199    Verb::DispatchAttempt,
200    Verb::MintWorkerToken,
201    Verb::ReadTaskState,
202    Verb::CancelTask,
203    Verb::QuerySenior,
204    Verb::MarkPass,
205    Verb::MarkBlocked,
206    Verb::AttachSession,
207    Verb::DetachSession,
208    Verb::Heartbeat,
209    Verb::PollTask,
210];
211
212/// The Worker verbs shared across all workers — the minimum a leaf
213/// needs, with no sub-task spawning. If we introduce
214/// `Role::WorkerLeaf` in the future, that role gets allowed against
215/// this slice.
216pub const WORKER_LEAF_VERBS: &[Verb] = &[
217    Verb::FetchPrompt,
218    Verb::FetchData,
219    Verb::PostResult,
220    Verb::VerifyToken,
221    Verb::EmitOutput,
222];
223
224/// Worker verbs for recursive swarming: sub-task spawn and
225/// observation. When `Role::WorkerSwarm` splits out in the future,
226/// that role gets allowed against `WORKER_LEAF_VERBS` plus this
227/// slice. The safety valves are `EngineCfg.max_spawn_depth` today,
228/// and a task-ownership gate down the line.
229pub const WORKER_SWARM_VERBS: &[Verb] = &[
230    Verb::StartTask,
231    Verb::DispatchAttempt,
232    Verb::ReadTaskState,
233    Verb::PollTask,
234    Verb::CancelTask,
235];
236
237/// Verbs an Observer may invoke — strictly read-only (event subscription
238/// and trace/state reads, no mutation).
239pub const OBSERVER_VERBS: &[Verb] = &[Verb::SubscribeEvents, Verb::ReadTrace, Verb::ReadTaskState];
240
241/// Verbs a Senior may invoke — human/oversight actions: answering
242/// queries, overriding verdicts, and pausing/resuming/injecting into the
243/// dispatch loop.
244pub const SENIOR_VERBS: &[Verb] = &[
245    Verb::AnswerQuery,
246    Verb::OverrideVerdict,
247    Verb::PauseLoop,
248    Verb::ResumeLoop,
249    Verb::InjectDirective,
250];
251
252/// The default Role × Verb table.
253///
254/// Today `Role::Worker` holds both leaf and swarm capabilities. When
255/// we split it into `WorkerLeaf` / `WorkerSwarm` in the future, the
256/// only change needed is swapping the `allow(Role::Worker, ...)` line
257/// here for two lines — the verb slices themselves stay `const` and
258/// get reused as-is.
259pub fn default_role_verb_table() -> RoleVerbGate {
260    RoleVerbGate::new()
261        .allow(Role::Operator, OPERATOR_VERBS)
262        .allow(Role::Worker, WORKER_LEAF_VERBS)
263        .allow(Role::Worker, WORKER_SWARM_VERBS)
264        .allow(Role::Observer, OBSERVER_VERBS)
265        .allow(Role::Senior, SENIOR_VERBS)
266}
267
268// ─── CapToken ──────────────────────────────────────────────────────────────
269
270/// Capability token. `max_uses` picks between OneTime / Session /
271/// Limited.
272///
273/// The `uses_left` counter is **server-side, on `EngineState`**: the
274/// token stays immutable, and the record holds the counter.
275#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
276pub struct CapToken {
277    /// Identifier of the agent this token was minted for.
278    pub agent_id: String,
279    /// The [`Role`] the bearer is authorized to act as.
280    pub role: Role,
281    /// Free-form scope strings (interpretation is caller-defined; `"*"`
282    /// conventionally means unrestricted).
283    pub scopes: Vec<String>,
284    /// Unix timestamp (seconds) when the token was minted.
285    pub issued_at: u64,
286    /// Unix timestamp (seconds) after which the token is expired.
287    pub expire_at: u64,
288    /// Remaining-use budget: `None` = unlimited (session token), `Some(n)`
289    /// = at most `n` uses (one-time when `n == 1`).
290    pub max_uses: Option<u32>,
291    /// Random per-mint value; also serves as the token's server-side
292    /// lookup key (see [`CapToken::id`]).
293    pub nonce: String,
294    /// Hex-encoded HMAC-SHA256 signature over [`CapToken::signing_input`].
295    pub sig_hex: String,
296}
297
298impl CapToken {
299    /// Use the `nonce` as the token identifier — the server-side
300    /// record key.
301    pub fn id(&self) -> &str {
302        &self.nonce
303    }
304
305    /// Input for the HMAC signature — the concatenation of every field
306    /// except `sig` itself.
307    pub fn signing_input(&self) -> Vec<u8> {
308        let s = format!(
309            "{}|{:?}|{}|{}|{}|{:?}|{}",
310            self.agent_id,
311            self.role,
312            self.scopes.join(","),
313            self.issued_at,
314            self.expire_at,
315            self.max_uses,
316            self.nonce,
317        );
318        s.into_bytes()
319    }
320
321    /// Whether `now_unix` is at or past [`CapToken::expire_at`].
322    pub fn is_expired(&self, now_unix: u64) -> bool {
323        now_unix >= self.expire_at
324    }
325
326    /// Transport-safe string encoding — URL-safe base64 of the
327    /// `serde_json` representation. Used when SubAgents put the token
328    /// on the HTTP path via `Authorization: Bearer <encode()>`. The
329    /// HMAC signature covers every field, so the server verifies with
330    /// `verify_sig` after decoding.
331    pub fn encode(&self) -> String {
332        use base64::Engine as _;
333        let json = serde_json::to_vec(self).expect("CapToken is always JSON-serializable");
334        base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(json)
335    }
336
337    /// The inverse of `encode()`: base64 decode followed by JSON
338    /// parse. Either failure returns `CapTokenDecodeError` — this is
339    /// the input-validation step when the server receives a `Bearer`
340    /// token.
341    pub fn decode(s: &str) -> Result<Self, CapTokenDecodeError> {
342        use base64::Engine as _;
343        let bytes = base64::engine::general_purpose::URL_SAFE_NO_PAD
344            .decode(s)
345            .map_err(|e| CapTokenDecodeError::Base64(e.to_string()))?;
346        serde_json::from_slice(&bytes).map_err(|e| CapTokenDecodeError::Json(e.to_string()))
347    }
348}
349
350/// Response body for `HTTP /v1/worker/prompt` — the shape that lets a
351/// SubAgent pull its task input in a single round-trip.
352///
353/// - `system`: the rendered `AgentDef.profile.system_prompt` (`None`
354///   when the profile is absent).
355/// - `prompt`: `TaskSpec.initial_directive` — the value baked into the
356///   prompts table during dispatch preparation.
357/// - `agent`: `TaskSpec.agent` — the agent name this dispatch is
358///   targeting.
359/// - `attempt`: the 1-based attempt number, matching the current
360///   `task.attempt`.
361#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
362pub struct WorkerPayload {
363    /// The task this payload was fetched for.
364    pub task_id: String,
365    /// 1-based attempt number, matching the current `task.attempt`.
366    pub attempt: u32,
367    /// Name of the agent this dispatch is targeting.
368    pub agent: String,
369    /// Rendered system prompt, if the agent profile defines one.
370    #[serde(skip_serializing_if = "Option::is_none")]
371    pub system: Option<String>,
372    /// The task's initial directive, baked in at dispatch preparation.
373    pub prompt: String,
374}
375
376/// Error returned when `CapToken::decode` fails.
377#[derive(Debug, thiserror::Error)]
378pub enum CapTokenDecodeError {
379    /// The input was not valid URL-safe base64.
380    #[error("base64 decode failed: {0}")]
381    Base64(String),
382    /// The decoded bytes were not valid `CapToken` JSON.
383    #[error("json parse failed: {0}")]
384    Json(String),
385}
386
387/// Server-side machinery for minting and verifying tokens.
388#[derive(Debug, Clone)]
389pub struct TokenSigner {
390    secret: Vec<u8>,
391}
392
393impl TokenSigner {
394    /// Build a signer from a raw HMAC secret (any length; HMAC accepts it).
395    pub fn new(secret: impl AsRef<[u8]>) -> Self {
396        Self {
397            secret: secret.as_ref().to_vec(),
398        }
399    }
400
401    /// Mint and sign a [`CapToken`] with an explicit `max_uses` policy.
402    /// Prefer [`Self::one_time`] / [`Self::session`] / [`Self::limited`]
403    /// for the common cases.
404    pub fn mint(
405        &self,
406        agent_id: impl Into<String>,
407        role: Role,
408        scopes: Vec<String>,
409        ttl: Duration,
410        max_uses: Option<u32>,
411    ) -> CapToken {
412        let now = now_unix();
413        let mut token = CapToken {
414            agent_id: agent_id.into(),
415            role,
416            scopes,
417            issued_at: now,
418            expire_at: now + ttl.as_secs(),
419            max_uses,
420            nonce: secure_hex(16),
421            sig_hex: String::new(),
422        };
423        let mut mac =
424            Hmac::<Sha256>::new_from_slice(&self.secret).expect("HMAC accepts any key length");
425        mac.update(&token.signing_input());
426        let sig = mac.finalize().into_bytes();
427        token.sig_hex = hex::encode(sig);
428        token
429    }
430
431    /// HMAC sig verify (constant-time eq for timing side-channel resistance).
432    pub fn verify_sig(&self, token: &CapToken) -> bool {
433        let mut mac =
434            Hmac::<Sha256>::new_from_slice(&self.secret).expect("HMAC accepts any key length");
435        mac.update(&token.signing_input());
436        let expected = mac.finalize().into_bytes();
437        let Ok(provided) = hex::decode(&token.sig_hex) else {
438            return false;
439        };
440        ct_eq(&expected, &provided)
441    }
442
443    /// Builder convenience: one-time token.
444    pub fn one_time(
445        &self,
446        agent_id: impl Into<String>,
447        role: Role,
448        scopes: Vec<String>,
449        ttl: Duration,
450    ) -> CapToken {
451        self.mint(agent_id, role, scopes, ttl, Some(1))
452    }
453
454    /// Builder convenience: session token (unlimited uses until expire).
455    pub fn session(
456        &self,
457        agent_id: impl Into<String>,
458        role: Role,
459        scopes: Vec<String>,
460        ttl: Duration,
461    ) -> CapToken {
462        self.mint(agent_id, role, scopes, ttl, None)
463    }
464
465    /// Builder convenience: limited (N uses).
466    pub fn limited(
467        &self,
468        agent_id: impl Into<String>,
469        role: Role,
470        scopes: Vec<String>,
471        ttl: Duration,
472        max_uses: u32,
473    ) -> CapToken {
474        self.mint(agent_id, role, scopes, ttl, Some(max_uses))
475    }
476}
477
478// ─── helpers ───────────────────────────────────────────────────────────────
479
480pub(crate) fn now_unix() -> u64 {
481    // A clock reporting before the epoch means the host clock is broken in a
482    // way that would otherwise silently mint `issued_at: 0` / `expire_at: 0`
483    // tokens (indistinguishable from "already expired" *and* from "minted at
484    // the epoch") — fail loud instead of laundering that into a bogus
485    // timestamp.
486    SystemTime::now()
487        .duration_since(UNIX_EPOCH)
488        .expect("system clock is before UNIX_EPOCH")
489        .as_secs()
490}
491
492/// In-process-unique, restart-decorrelated hex id.
493///
494/// Combines a monotonic per-process counter (bijective — guarantees no two
495/// calls in the same process ever collide) with a random per-process salt
496/// drawn once from the OS RNG (decorrelates ids across restarts, so a
497/// long-lived id from a previous process run can't be mistaken for one
498/// minted by the current process). The high bits of the 128-bit XOR are
499/// dominated by the salt (a process fingerprint); the low bits change on
500/// every call.
501///
502/// **Not unguessable.** The counter is a public, low-entropy sequence once
503/// the salt leaks (e.g. via any single id from this process) — never use
504/// this for bearer credentials, signing nonces, or anything else that must
505/// resist an adversary who can observe some ids and guess others. Use
506/// [`secure_hex`] for that.
507pub fn uid_hex(bytes: usize) -> String {
508    use std::sync::atomic::{AtomicU64, Ordering};
509    use std::sync::OnceLock;
510    static COUNTER: AtomicU64 = AtomicU64::new(0);
511    static SALT: OnceLock<u128> = OnceLock::new();
512    let salt = *SALT.get_or_init(|| {
513        let mut b = [0u8; 16];
514        getrandom::fill(&mut b).expect("OS RNG unavailable");
515        u128::from_le_bytes(b)
516    });
517    let c = COUNTER.fetch_add(1, Ordering::Relaxed) as u128;
518    // XOR keeps the counter's in-process uniqueness (bijection) while the
519    // per-process random salt decorrelates restarts. High 64 bits are pure
520    // salt (a process fingerprint); low bits change every call.
521    let v = salt ^ c;
522    let raw = format!("{:032x}", v);
523    let n = (bytes * 2).min(32);
524    raw[32 - n..].to_string()
525}
526
527/// OS-RNG hex, safe for bearer credentials.
528///
529/// Every byte comes from the OS random source (`getrandom`) on every call —
530/// unpredictable across calls *and* across process restarts, unlike
531/// [`uid_hex`]. Use this whenever the value itself is the secret: the
532/// [`CapToken`] nonce (its server-side lookup key and part of the signed
533/// material) and worker/session bearer handles.
534pub fn secure_hex(bytes: usize) -> String {
535    let mut buf = vec![0u8; bytes];
536    getrandom::fill(&mut buf).expect("OS RNG unavailable");
537    hex::encode(buf)
538}
539
540fn ct_eq(a: &[u8], b: &[u8]) -> bool {
541    if a.len() != b.len() {
542        return false;
543    }
544    let mut diff: u8 = 0;
545    for (x, y) in a.iter().zip(b.iter()) {
546        diff |= x ^ y;
547    }
548    diff == 0
549}
550
551#[cfg(test)]
552mod cap_token_transport_tests {
553    use super::*;
554    use std::time::Duration;
555
556    #[test]
557    fn encode_decode_round_trips() {
558        let signer = TokenSigner::new("test-secret");
559        let token = signer.session(
560            "worker-of-task-x",
561            Role::Worker,
562            vec!["*".into()],
563            Duration::from_secs(600),
564        );
565        let s = token.encode();
566        // URL-safe base64 should not contain `+` `/` `=`
567        assert!(!s.contains('+'));
568        assert!(!s.contains('/'));
569        assert!(!s.contains('='));
570
571        let decoded = CapToken::decode(&s).expect("decode ok");
572        assert_eq!(decoded, token);
573        assert!(
574            signer.verify_sig(&decoded),
575            "HMAC sig still verifies after round-trip"
576        );
577    }
578
579    #[test]
580    fn decode_rejects_garbage() {
581        let err = CapToken::decode("not-base64!!!").expect_err("should fail");
582        assert!(matches!(err, CapTokenDecodeError::Base64(_)));
583    }
584
585    #[test]
586    fn decode_rejects_non_token_json() {
587        use base64::Engine as _;
588        let bogus = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(b"{\"oops\":1}");
589        let err = CapToken::decode(&bogus).expect_err("should fail json shape");
590        assert!(matches!(err, CapTokenDecodeError::Json(_)));
591    }
592}