Skip to main content

awaken_runtime/agent/state/
tool_call_lifecycle.rs

1use crate::state::{MergeStrategy, StateKey};
2use awaken_contract::contract::suspension::{ToolCallResume, ToolCallResumeMode, ToolCallStatus};
3use serde::{Deserialize, Serialize};
4use serde_json::Value;
5use std::collections::HashMap;
6
7/// Per-tool-call lifecycle state.
8#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
9pub struct ToolCallState {
10    pub call_id: String,
11    pub tool_name: String,
12    pub arguments: Value,
13    pub status: ToolCallStatus,
14    pub updated_at: u64,
15    /// Resume mode from the `SuspendTicket` (set when status becomes Suspended).
16    #[serde(default)]
17    pub resume_mode: ToolCallResumeMode,
18    /// External-facing suspension id used by protocols that distinguish
19    /// approval/interrupt ids from the underlying tool call id.
20    #[serde(default, skip_serializing_if = "Option::is_none")]
21    pub suspension_id: Option<String>,
22    /// Suspension reason/action from the active `SuspendTicket`.
23    #[serde(default, skip_serializing_if = "Option::is_none")]
24    pub suspension_reason: Option<String>,
25    /// Most recent external resume input applied to this suspended tool call.
26    #[serde(default, skip_serializing_if = "Option::is_none")]
27    pub resume_input: Option<ToolCallResume>,
28}
29
30impl ToolCallState {
31    pub fn new(
32        call_id: impl Into<String>,
33        tool_name: impl Into<String>,
34        arguments: Value,
35        status: ToolCallStatus,
36        updated_at: u64,
37    ) -> Self {
38        Self {
39            call_id: call_id.into(),
40            tool_name: tool_name.into(),
41            arguments,
42            status,
43            updated_at,
44            resume_mode: ToolCallResumeMode::default(),
45            suspension_id: None,
46            suspension_reason: None,
47            resume_input: None,
48        }
49    }
50
51    #[must_use]
52    pub fn with_resume_mode(mut self, resume_mode: ToolCallResumeMode) -> Self {
53        self.resume_mode = resume_mode;
54        self
55    }
56
57    #[must_use]
58    pub fn with_suspension(
59        mut self,
60        suspension_id: Option<String>,
61        suspension_reason: Option<String>,
62    ) -> Self {
63        self.suspension_id = normalize_optional_string(suspension_id);
64        self.suspension_reason = normalize_optional_string(suspension_reason);
65        self
66    }
67
68    #[must_use]
69    pub fn with_resume_input(mut self, resume_input: Option<ToolCallResume>) -> Self {
70        self.resume_input = resume_input;
71        self
72    }
73}
74
75/// Keyed collection of tool call states for the current step.
76#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
77pub struct ToolCallStateMap {
78    pub calls: HashMap<String, ToolCallState>,
79}
80
81fn normalize_optional_string(value: Option<String>) -> Option<String> {
82    value.and_then(|value| {
83        let trimmed = value.trim();
84        (!trimmed.is_empty()).then(|| trimmed.to_string())
85    })
86}
87
88pub enum ToolCallStatesUpdate {
89    /// Replace a tool call's lifecycle state (validates transition).
90    Put(Box<ToolCallState>),
91    /// Clear all tool call states (at step boundary).
92    Clear,
93}
94
95impl ToolCallStatesUpdate {
96    #[must_use]
97    pub fn put(state: ToolCallState) -> Self {
98        Self::Put(Box::new(state))
99    }
100}
101
102/// State key for tool call lifecycle tracking within a step.
103pub struct ToolCallStates;
104
105impl StateKey for ToolCallStates {
106    const KEY: &'static str = "__runtime.tool_call_states";
107    const MERGE: MergeStrategy = MergeStrategy::Commutative;
108
109    type Value = ToolCallStateMap;
110    type Update = ToolCallStatesUpdate;
111
112    fn apply(value: &mut Self::Value, update: Self::Update) {
113        match update {
114            ToolCallStatesUpdate::Put(state) => {
115                let call_id = state.call_id.clone();
116                let existing = value.calls.get(&call_id);
117                let current_status = existing.map(|s| s.status).unwrap_or(ToolCallStatus::New);
118                let next_status = state.status;
119
120                if !current_status.can_transition_to(next_status) {
121                    tracing::error!(
122                        from = ?current_status,
123                        to = ?next_status,
124                        call_id = %call_id,
125                        "invalid tool call transition — skipping update"
126                    );
127                    return;
128                }
129
130                let mut state = state;
131                state.suspension_id = normalize_optional_string(state.suspension_id);
132                state.suspension_reason = normalize_optional_string(state.suspension_reason);
133                value.calls.insert(call_id, *state);
134            }
135            ToolCallStatesUpdate::Clear => {
136                value.calls.clear();
137            }
138        }
139    }
140}
141
142#[cfg(test)]
143mod tests {
144    use super::*;
145
146    fn upsert(
147        states: &mut ToolCallStateMap,
148        call_id: &str,
149        tool: &str,
150        status: ToolCallStatus,
151        ts: u64,
152    ) {
153        ToolCallStates::apply(
154            states,
155            ToolCallStatesUpdate::put(ToolCallState::new(
156                call_id,
157                tool,
158                serde_json::json!({}),
159                status,
160                ts,
161            )),
162        );
163    }
164
165    #[test]
166    fn tool_call_new_to_running() {
167        let mut states = ToolCallStateMap::default();
168        upsert(&mut states, "c1", "echo", ToolCallStatus::Running, 100);
169        assert_eq!(states.calls["c1"].status, ToolCallStatus::Running);
170    }
171
172    #[test]
173    fn tool_call_running_to_succeeded() {
174        let mut states = ToolCallStateMap::default();
175        upsert(&mut states, "c1", "echo", ToolCallStatus::Running, 100);
176        upsert(&mut states, "c1", "echo", ToolCallStatus::Succeeded, 200);
177        assert_eq!(states.calls["c1"].status, ToolCallStatus::Succeeded);
178    }
179
180    #[test]
181    fn tool_call_running_to_failed() {
182        let mut states = ToolCallStateMap::default();
183        upsert(&mut states, "c1", "echo", ToolCallStatus::Running, 100);
184        upsert(&mut states, "c1", "echo", ToolCallStatus::Failed, 200);
185        assert_eq!(states.calls["c1"].status, ToolCallStatus::Failed);
186    }
187
188    #[test]
189    fn tool_call_running_to_suspended_to_resuming() {
190        let mut states = ToolCallStateMap::default();
191        upsert(&mut states, "c1", "echo", ToolCallStatus::Running, 100);
192        upsert(&mut states, "c1", "echo", ToolCallStatus::Suspended, 200);
193        upsert(&mut states, "c1", "echo", ToolCallStatus::Resuming, 300);
194        assert_eq!(states.calls["c1"].status, ToolCallStatus::Resuming);
195    }
196
197    #[test]
198    fn tool_call_suspended_to_cancelled() {
199        let mut states = ToolCallStateMap::default();
200        upsert(&mut states, "c1", "echo", ToolCallStatus::Running, 100);
201        upsert(&mut states, "c1", "echo", ToolCallStatus::Suspended, 200);
202        upsert(&mut states, "c1", "echo", ToolCallStatus::Cancelled, 300);
203        assert_eq!(states.calls["c1"].status, ToolCallStatus::Cancelled);
204        assert!(states.calls["c1"].status.is_terminal());
205    }
206
207    #[test]
208    fn tool_call_rejects_succeeded_to_running() {
209        let mut states = ToolCallStateMap::default();
210        upsert(&mut states, "c1", "echo", ToolCallStatus::Running, 100);
211        upsert(&mut states, "c1", "echo", ToolCallStatus::Succeeded, 200);
212        upsert(&mut states, "c1", "echo", ToolCallStatus::Running, 300);
213        assert_eq!(states.calls["c1"].status, ToolCallStatus::Succeeded);
214        assert_eq!(states.calls["c1"].updated_at, 200);
215    }
216
217    #[test]
218    fn tool_call_rejects_failed_to_running() {
219        let mut states = ToolCallStateMap::default();
220        upsert(&mut states, "c1", "echo", ToolCallStatus::Running, 100);
221        upsert(&mut states, "c1", "echo", ToolCallStatus::Failed, 200);
222        upsert(&mut states, "c1", "echo", ToolCallStatus::Running, 300);
223        assert_eq!(states.calls["c1"].status, ToolCallStatus::Failed);
224        assert_eq!(states.calls["c1"].updated_at, 200);
225    }
226
227    #[test]
228    fn tool_call_multiple_calls_independent() {
229        let mut states = ToolCallStateMap::default();
230        upsert(&mut states, "c1", "echo", ToolCallStatus::Running, 100);
231        upsert(&mut states, "c2", "calc", ToolCallStatus::Running, 100);
232        upsert(&mut states, "c1", "echo", ToolCallStatus::Succeeded, 200);
233        upsert(&mut states, "c2", "calc", ToolCallStatus::Failed, 200);
234
235        assert_eq!(states.calls["c1"].status, ToolCallStatus::Succeeded);
236        assert_eq!(states.calls["c2"].status, ToolCallStatus::Failed);
237    }
238
239    #[test]
240    fn tool_call_clear_removes_all() {
241        let mut states = ToolCallStateMap::default();
242        upsert(&mut states, "c1", "echo", ToolCallStatus::Running, 100);
243        upsert(&mut states, "c2", "calc", ToolCallStatus::Running, 100);
244        ToolCallStates::apply(&mut states, ToolCallStatesUpdate::Clear);
245        assert!(states.calls.is_empty());
246    }
247
248    #[test]
249    fn tool_call_state_serde_roundtrip() {
250        let mut states = ToolCallStateMap::default();
251        upsert(&mut states, "c1", "echo", ToolCallStatus::Running, 100);
252        upsert(&mut states, "c1", "echo", ToolCallStatus::Succeeded, 200);
253        let json = serde_json::to_string(&states).unwrap();
254        let parsed: ToolCallStateMap = serde_json::from_str(&json).unwrap();
255        assert_eq!(parsed, states);
256    }
257
258    #[test]
259    fn tool_call_full_lifecycle_suspend_resume_succeed() {
260        let mut states = ToolCallStateMap::default();
261        upsert(&mut states, "c1", "dangerous", ToolCallStatus::Running, 100);
262        upsert(
263            &mut states,
264            "c1",
265            "dangerous",
266            ToolCallStatus::Suspended,
267            200,
268        );
269        upsert(
270            &mut states,
271            "c1",
272            "dangerous",
273            ToolCallStatus::Resuming,
274            300,
275        );
276        upsert(&mut states, "c1", "dangerous", ToolCallStatus::Running, 400);
277        upsert(
278            &mut states,
279            "c1",
280            "dangerous",
281            ToolCallStatus::Succeeded,
282            500,
283        );
284        assert_eq!(states.calls["c1"].status, ToolCallStatus::Succeeded);
285        assert_eq!(states.calls["c1"].updated_at, 500);
286    }
287
288    // -----------------------------------------------------------------------
289    // Migrated from uncarve: additional tool call lifecycle tests
290    // -----------------------------------------------------------------------
291
292    #[test]
293    fn tool_call_new_can_transition_to_any() {
294        let mut states = ToolCallStateMap::default();
295        upsert(&mut states, "c1", "echo", ToolCallStatus::Succeeded, 100);
296        assert_eq!(states.calls["c1"].status, ToolCallStatus::Succeeded);
297    }
298
299    #[test]
300    fn tool_call_new_to_running_is_typical_path() {
301        let mut states = ToolCallStateMap::default();
302        upsert(&mut states, "c1", "echo", ToolCallStatus::Running, 100);
303        assert_eq!(states.calls["c1"].status, ToolCallStatus::Running);
304    }
305
306    #[test]
307    fn tool_call_suspended_to_succeeded_not_allowed() {
308        let mut states = ToolCallStateMap::default();
309        upsert(&mut states, "c1", "echo", ToolCallStatus::Running, 100);
310        upsert(&mut states, "c1", "echo", ToolCallStatus::Suspended, 200);
311        upsert(&mut states, "c1", "echo", ToolCallStatus::Succeeded, 300);
312        assert_eq!(states.calls["c1"].status, ToolCallStatus::Suspended);
313        assert_eq!(states.calls["c1"].updated_at, 200);
314    }
315
316    #[test]
317    fn tool_call_map_default_is_empty() {
318        let states = ToolCallStateMap::default();
319        assert!(states.calls.is_empty());
320    }
321
322    #[test]
323    fn tool_call_preserves_tool_name_and_arguments() {
324        let mut states = ToolCallStateMap::default();
325        ToolCallStates::apply(
326            &mut states,
327            ToolCallStatesUpdate::put(ToolCallState::new(
328                "c1",
329                "search",
330                serde_json::json!({"query": "test"}),
331                ToolCallStatus::Running,
332                100,
333            )),
334        );
335        let call = &states.calls["c1"];
336        assert_eq!(call.tool_name, "search");
337        assert_eq!(call.arguments["query"], "test");
338    }
339
340    #[test]
341    fn tool_call_suspension_context_roundtrip() {
342        let mut states = ToolCallStateMap::default();
343        ToolCallStates::apply(
344            &mut states,
345            ToolCallStatesUpdate::put(
346                ToolCallState::new(
347                    "c1",
348                    "dangerous",
349                    serde_json::json!({"cmd": "rm"}),
350                    ToolCallStatus::Suspended,
351                    100,
352                )
353                .with_resume_mode(ToolCallResumeMode::ReplayToolCall)
354                .with_suspension(
355                    Some("perm_c1".into()),
356                    Some("tool:PermissionConfirm".into()),
357                ),
358            ),
359        );
360        ToolCallStates::apply(
361            &mut states,
362            ToolCallStatesUpdate::put(
363                ToolCallState::new(
364                    "c1",
365                    "dangerous",
366                    serde_json::json!({"cmd": "rm"}),
367                    ToolCallStatus::Cancelled,
368                    200,
369                )
370                .with_resume_mode(ToolCallResumeMode::ReplayToolCall)
371                .with_suspension(
372                    Some("perm_c1".into()),
373                    Some("tool:PermissionConfirm".into()),
374                )
375                .with_resume_input(Some(ToolCallResume {
376                    decision_id: "d1".into(),
377                    action: awaken_contract::contract::suspension::ResumeDecisionAction::Cancel,
378                    result: serde_json::json!({"approved": false}),
379                    reason: Some("user denied".into()),
380                    updated_at: 200,
381                })),
382            ),
383        );
384        let call = &states.calls["c1"];
385        assert_eq!(call.suspension_id.as_deref(), Some("perm_c1"));
386        assert_eq!(
387            call.suspension_reason.as_deref(),
388            Some("tool:PermissionConfirm")
389        );
390        assert_eq!(
391            call.resume_input.as_ref().map(|resume| &resume.result),
392            Some(&serde_json::json!({"approved": false}))
393        );
394    }
395
396    #[test]
397    fn tool_call_clear_then_reuse() {
398        let mut states = ToolCallStateMap::default();
399        upsert(&mut states, "c1", "echo", ToolCallStatus::Running, 100);
400        upsert(&mut states, "c1", "echo", ToolCallStatus::Succeeded, 200);
401
402        ToolCallStates::apply(&mut states, ToolCallStatesUpdate::Clear);
403        assert!(states.calls.is_empty());
404
405        upsert(&mut states, "c1", "echo", ToolCallStatus::Running, 300);
406        assert_eq!(states.calls["c1"].status, ToolCallStatus::Running);
407    }
408
409    #[test]
410    fn tool_call_cancelled_is_terminal() {
411        let mut states = ToolCallStateMap::default();
412        upsert(&mut states, "c1", "echo", ToolCallStatus::Running, 100);
413        upsert(&mut states, "c1", "echo", ToolCallStatus::Suspended, 200);
414        upsert(&mut states, "c1", "echo", ToolCallStatus::Cancelled, 300);
415        assert!(states.calls["c1"].status.is_terminal());
416    }
417
418    #[test]
419    fn tool_call_succeeded_is_terminal() {
420        let mut states = ToolCallStateMap::default();
421        upsert(&mut states, "c1", "echo", ToolCallStatus::Running, 100);
422        upsert(&mut states, "c1", "echo", ToolCallStatus::Succeeded, 200);
423        assert!(states.calls["c1"].status.is_terminal());
424    }
425
426    #[test]
427    fn tool_call_failed_is_terminal() {
428        let mut states = ToolCallStateMap::default();
429        upsert(&mut states, "c1", "echo", ToolCallStatus::Running, 100);
430        upsert(&mut states, "c1", "echo", ToolCallStatus::Failed, 200);
431        assert!(states.calls["c1"].status.is_terminal());
432    }
433
434    #[test]
435    fn tool_call_running_is_not_terminal() {
436        let mut states = ToolCallStateMap::default();
437        upsert(&mut states, "c1", "echo", ToolCallStatus::Running, 100);
438        assert!(!states.calls["c1"].status.is_terminal());
439    }
440
441    #[test]
442    fn tool_call_many_calls_independent_lifecycle() {
443        let mut states = ToolCallStateMap::default();
444
445        upsert(&mut states, "c1", "echo", ToolCallStatus::Running, 100);
446        upsert(&mut states, "c1", "echo", ToolCallStatus::Succeeded, 200);
447
448        upsert(&mut states, "c2", "calc", ToolCallStatus::Running, 100);
449        upsert(&mut states, "c2", "calc", ToolCallStatus::Failed, 200);
450
451        upsert(&mut states, "c3", "search", ToolCallStatus::Running, 100);
452        upsert(&mut states, "c3", "search", ToolCallStatus::Suspended, 200);
453        upsert(&mut states, "c3", "search", ToolCallStatus::Resuming, 300);
454        upsert(&mut states, "c3", "search", ToolCallStatus::Running, 400);
455        upsert(&mut states, "c3", "search", ToolCallStatus::Succeeded, 500);
456
457        assert_eq!(states.calls.len(), 3);
458        assert_eq!(states.calls["c1"].status, ToolCallStatus::Succeeded);
459        assert_eq!(states.calls["c2"].status, ToolCallStatus::Failed);
460        assert_eq!(states.calls["c3"].status, ToolCallStatus::Succeeded);
461    }
462}