Skip to main content

roboticus_agent/
typestate.rs

1use roboticus_core::RiskLevel;
2use serde_json::Value;
3use std::marker::PhantomData;
4
5/// Marker types for typestate pattern.
6pub mod states {
7    /// Tool call has not been evaluated by the policy engine.
8    #[derive(Debug)]
9    pub struct Unevaluated;
10
11    /// Tool call has been approved by the policy engine.
12    #[derive(Debug)]
13    pub struct Approved;
14
15    /// Tool call has been denied by the policy engine.
16    #[derive(Debug)]
17    pub struct Denied;
18
19    /// Tool call has been executed.
20    #[derive(Debug)]
21    pub struct Executed;
22}
23
24/// A tool call request with compile-time state tracking.
25/// Only `ToolCallRequest<Approved>` can be executed.
26#[derive(Debug)]
27pub struct ToolCallRequest<State> {
28    pub tool_name: String,
29    pub parameters: Value,
30    pub risk_level: RiskLevel,
31    _state: PhantomData<State>,
32}
33
34impl ToolCallRequest<states::Unevaluated> {
35    /// Create a new unevaluated tool call request.
36    pub fn new(tool_name: String, parameters: Value, risk_level: RiskLevel) -> Self {
37        Self {
38            tool_name,
39            parameters,
40            risk_level,
41            _state: PhantomData,
42        }
43    }
44
45    /// Approve the tool call (transitions Unevaluated -> Approved).
46    pub fn approve(self) -> ToolCallRequest<states::Approved> {
47        ToolCallRequest {
48            tool_name: self.tool_name,
49            parameters: self.parameters,
50            risk_level: self.risk_level,
51            _state: PhantomData,
52        }
53    }
54
55    /// Deny the tool call (transitions Unevaluated -> Denied).
56    pub fn deny(self) -> ToolCallRequest<states::Denied> {
57        ToolCallRequest {
58            tool_name: self.tool_name,
59            parameters: self.parameters,
60            risk_level: self.risk_level,
61            _state: PhantomData,
62        }
63    }
64}
65
66impl ToolCallRequest<states::Approved> {
67    /// Mark as executed (transitions Approved -> Executed).
68    pub fn mark_executed(self) -> ToolCallRequest<states::Executed> {
69        ToolCallRequest {
70            tool_name: self.tool_name,
71            parameters: self.parameters,
72            risk_level: self.risk_level,
73            _state: PhantomData,
74        }
75    }
76}
77
78impl<S> ToolCallRequest<S> {
79    pub fn tool_name(&self) -> &str {
80        &self.tool_name
81    }
82
83    pub fn parameters(&self) -> &Value {
84        &self.parameters
85    }
86
87    pub fn risk_level(&self) -> &RiskLevel {
88        &self.risk_level
89    }
90}
91
92/// Agent lifecycle states as type-level markers.
93pub mod lifecycle {
94    #[derive(Debug)]
95    pub struct Setup;
96
97    #[derive(Debug)]
98    pub struct Waking;
99
100    #[derive(Debug)]
101    pub struct Running;
102
103    #[derive(Debug)]
104    pub struct Sleeping;
105
106    #[derive(Debug)]
107    pub struct Dead;
108}
109
110/// Agent handle with compile-time lifecycle state tracking.
111#[derive(Debug)]
112pub struct AgentHandle<State> {
113    pub agent_id: String,
114    _state: PhantomData<State>,
115}
116
117impl AgentHandle<lifecycle::Setup> {
118    pub fn new(agent_id: String) -> Self {
119        Self {
120            agent_id,
121            _state: PhantomData,
122        }
123    }
124
125    pub fn wake(self) -> AgentHandle<lifecycle::Waking> {
126        AgentHandle {
127            agent_id: self.agent_id,
128            _state: PhantomData,
129        }
130    }
131}
132
133impl AgentHandle<lifecycle::Waking> {
134    pub fn start(self) -> AgentHandle<lifecycle::Running> {
135        AgentHandle {
136            agent_id: self.agent_id,
137            _state: PhantomData,
138        }
139    }
140}
141
142impl AgentHandle<lifecycle::Running> {
143    pub fn sleep(self) -> AgentHandle<lifecycle::Sleeping> {
144        AgentHandle {
145            agent_id: self.agent_id,
146            _state: PhantomData,
147        }
148    }
149
150    pub fn terminate(self) -> AgentHandle<lifecycle::Dead> {
151        AgentHandle {
152            agent_id: self.agent_id,
153            _state: PhantomData,
154        }
155    }
156}
157
158impl AgentHandle<lifecycle::Sleeping> {
159    pub fn wake(self) -> AgentHandle<lifecycle::Waking> {
160        AgentHandle {
161            agent_id: self.agent_id,
162            _state: PhantomData,
163        }
164    }
165
166    pub fn terminate(self) -> AgentHandle<lifecycle::Dead> {
167        AgentHandle {
168            agent_id: self.agent_id,
169            _state: PhantomData,
170        }
171    }
172}
173
174impl<S> AgentHandle<S> {
175    pub fn agent_id(&self) -> &str {
176        &self.agent_id
177    }
178}
179
180/// Treasury with const-generic spending limits.
181#[derive(Debug)]
182pub struct BoundedTreasury<const MAX_PER_TX: u64, const MAX_DAILY: u64> {
183    pub balance: u64,
184    pub spent_today: u64,
185}
186
187impl<const MAX_PER_TX: u64, const MAX_DAILY: u64> BoundedTreasury<MAX_PER_TX, MAX_DAILY> {
188    pub fn new(balance: u64) -> Self {
189        Self {
190            balance,
191            spent_today: 0,
192        }
193    }
194
195    pub fn can_spend(&self, amount: u64) -> bool {
196        amount <= MAX_PER_TX && self.spent_today + amount <= MAX_DAILY && amount <= self.balance
197    }
198
199    pub fn spend(&mut self, amount: u64) -> Result<(), &'static str> {
200        if !self.can_spend(amount) {
201            return Err("spending limit exceeded or insufficient balance");
202        }
203        self.balance -= amount;
204        self.spent_today += amount;
205        Ok(())
206    }
207
208    pub fn reset_daily(&mut self) {
209        self.spent_today = 0;
210    }
211
212    pub fn max_per_tx() -> u64 {
213        MAX_PER_TX
214    }
215
216    pub fn max_daily() -> u64 {
217        MAX_DAILY
218    }
219}
220
221#[cfg(test)]
222mod tests {
223    use super::*;
224
225    #[test]
226    fn tool_call_lifecycle() {
227        let request = ToolCallRequest::new(
228            "memory_search".into(),
229            serde_json::json!({"query": "test"}),
230            RiskLevel::Safe,
231        );
232        assert_eq!(request.tool_name(), "memory_search");
233
234        let approved = request.approve();
235        assert_eq!(approved.tool_name(), "memory_search");
236
237        let executed = approved.mark_executed();
238        assert_eq!(executed.tool_name(), "memory_search");
239    }
240
241    #[test]
242    fn tool_call_deny() {
243        let request = ToolCallRequest::new(
244            "dangerous_tool".into(),
245            serde_json::json!({}),
246            RiskLevel::Dangerous,
247        );
248        let denied = request.deny();
249        assert_eq!(denied.tool_name(), "dangerous_tool");
250    }
251
252    #[test]
253    fn agent_lifecycle() {
254        let agent = AgentHandle::<lifecycle::Setup>::new("agent-1".into());
255        assert_eq!(agent.agent_id(), "agent-1");
256
257        let waking = agent.wake();
258        let running = waking.start();
259        let sleeping = running.sleep();
260        let waking_again = sleeping.wake();
261        let running_again = waking_again.start();
262        let _dead = running_again.terminate();
263    }
264
265    #[test]
266    fn agent_setup_to_dead_via_sleep() {
267        let agent = AgentHandle::<lifecycle::Setup>::new("a".into());
268        let _dead = agent.wake().start().sleep().terminate();
269    }
270
271    #[test]
272    fn bounded_treasury_limits() {
273        let mut treasury = BoundedTreasury::<100, 500>::new(1000);
274        assert!(treasury.can_spend(100));
275        assert!(!treasury.can_spend(101));
276
277        treasury.spend(100).unwrap();
278        treasury.spend(100).unwrap();
279        treasury.spend(100).unwrap();
280        treasury.spend(100).unwrap();
281        treasury.spend(100).unwrap();
282        assert!(!treasury.can_spend(1));
283
284        treasury.reset_daily();
285        assert!(treasury.can_spend(100));
286    }
287
288    #[test]
289    fn bounded_treasury_insufficient_balance() {
290        let mut treasury = BoundedTreasury::<1000, 10000>::new(50);
291        assert!(!treasury.can_spend(51));
292        assert!(treasury.can_spend(50));
293        treasury.spend(50).unwrap();
294        assert!(!treasury.can_spend(1));
295    }
296
297    #[test]
298    fn bounded_treasury_const_accessors() {
299        assert_eq!(BoundedTreasury::<42, 999>::max_per_tx(), 42);
300        assert_eq!(BoundedTreasury::<42, 999>::max_daily(), 999);
301    }
302}