use crate::runtime::phase::SuspendTicket;
use crate::thread::ToolCall;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashMap;
use tirea_state::State;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum ResumeDecisionAction {
Resume,
Cancel,
}
#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum ToolCallResumeMode {
#[default]
ReplayToolCall,
UseDecisionAsToolResult,
PassDecisionToTool,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
pub struct PendingToolCall {
pub id: String,
pub name: String,
pub arguments: Value,
}
impl PendingToolCall {
pub fn new(id: impl Into<String>, name: impl Into<String>, arguments: Value) -> Self {
Self {
id: id.into(),
name: name.into(),
arguments,
}
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
pub struct SuspendedCall {
#[serde(default)]
pub call_id: String,
#[serde(default)]
pub tool_name: String,
#[serde(default)]
pub arguments: Value,
#[serde(flatten)]
pub ticket: SuspendTicket,
}
impl SuspendedCall {
pub fn new(call: &ToolCall, ticket: SuspendTicket) -> Self {
Self {
call_id: call.id.clone(),
tool_name: call.name.clone(),
arguments: call.arguments.clone(),
ticket,
}
}
pub fn into_state_action(self) -> crate::runtime::state::AnyStateAction {
let call_id = self.call_id.clone();
crate::runtime::state::AnyStateAction::new_for_call::<SuspendedCallState>(
SuspendedCallAction::Set(self),
call_id,
)
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize, State)]
#[tirea(
path = "suspended_call",
action = "SuspendedCallAction",
scope = "tool_call"
)]
pub struct SuspendedCallState {
#[serde(flatten)]
pub call: SuspendedCall,
}
#[derive(Serialize, Deserialize)]
pub enum SuspendedCallAction {
Set(SuspendedCall),
}
impl SuspendedCallState {
fn reduce(&mut self, action: SuspendedCallAction) {
match action {
SuspendedCallAction::Set(call) => {
self.call = call;
}
}
}
}
#[derive(Serialize, Deserialize)]
pub enum ToolCallStateAction {
Set(ToolCallState),
}
#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum ToolCallStatus {
#[default]
New,
Running,
Suspended,
Resuming,
Succeeded,
Failed,
Cancelled,
}
impl ToolCallStatus {
pub const ASCII_STATE_MACHINE: &str = r#"new ------------> running
| |
| v
+------------> suspended -----> resuming
| |
+---------------+
running/resuming ---> succeeded
running/resuming ---> failed
running/suspended/resuming ---> cancelled"#;
pub fn is_terminal(self) -> bool {
matches!(
self,
ToolCallStatus::Succeeded | ToolCallStatus::Failed | ToolCallStatus::Cancelled
)
}
pub fn can_transition_to(self, next: Self) -> bool {
if self == next {
return true;
}
match self {
ToolCallStatus::New => true,
ToolCallStatus::Running => matches!(
next,
ToolCallStatus::Suspended
| ToolCallStatus::Succeeded
| ToolCallStatus::Failed
| ToolCallStatus::Cancelled
),
ToolCallStatus::Suspended => {
matches!(next, ToolCallStatus::Resuming | ToolCallStatus::Cancelled)
}
ToolCallStatus::Resuming => matches!(
next,
ToolCallStatus::Running
| ToolCallStatus::Suspended
| ToolCallStatus::Succeeded
| ToolCallStatus::Failed
| ToolCallStatus::Cancelled
),
ToolCallStatus::Succeeded | ToolCallStatus::Failed | ToolCallStatus::Cancelled => false,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct ToolCallResume {
#[serde(default)]
pub decision_id: String,
pub action: ResumeDecisionAction,
#[serde(default, skip_serializing_if = "Value::is_null")]
pub result: Value,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub reason: Option<String>,
#[serde(default)]
pub updated_at: u64,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, State)]
#[tirea(
path = "tool_call_state",
action = "ToolCallStateAction",
scope = "tool_call"
)]
pub struct ToolCallState {
#[serde(default, skip_serializing_if = "String::is_empty")]
pub call_id: String,
#[serde(default, skip_serializing_if = "String::is_empty")]
pub tool_name: String,
#[serde(default, skip_serializing_if = "Value::is_null")]
pub arguments: Value,
#[serde(default)]
pub status: ToolCallStatus,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub resume_token: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub resume: Option<ToolCallResume>,
#[serde(default, skip_serializing_if = "Value::is_null")]
pub scratch: Value,
#[serde(default)]
pub updated_at: u64,
}
impl ToolCallState {
pub fn into_state_action(self) -> crate::runtime::state::AnyStateAction {
let call_id = self.call_id.clone();
crate::runtime::state::AnyStateAction::new_for_call::<ToolCallState>(
ToolCallStateAction::Set(self),
call_id,
)
}
}
impl ToolCallState {
fn reduce(&mut self, action: ToolCallStateAction) {
match action {
ToolCallStateAction::Set(s) => *self = s,
}
}
}
pub fn suspended_calls_from_state(state: &Value) -> HashMap<String, SuspendedCall> {
let Some(Value::Object(scopes)) = state.get("__tool_call_scope") else {
return HashMap::new();
};
scopes
.iter()
.filter_map(|(call_id, scope_val)| {
scope_val
.get("suspended_call")
.and_then(|v| SuspendedCallState::from_value(v).ok())
.map(|s| (call_id.clone(), s.call))
})
.collect()
}
pub fn tool_call_states_from_state(state: &Value) -> HashMap<String, ToolCallState> {
let Some(Value::Object(scopes)) = state.get("__tool_call_scope") else {
return HashMap::new();
};
scopes
.iter()
.filter_map(|(call_id, scope_val)| {
scope_val
.get("tool_call_state")
.and_then(|v| ToolCallState::from_value(v).ok())
.map(|s| (call_id.clone(), s))
})
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn suspended_call_state_default() {
let suspended = SuspendedCallState::default();
assert_eq!(suspended.call.call_id, "");
assert_eq!(suspended.call.tool_name, "");
}
#[test]
fn tool_call_status_transitions_match_lifecycle() {
assert!(ToolCallStatus::New.can_transition_to(ToolCallStatus::Running));
assert!(ToolCallStatus::Running.can_transition_to(ToolCallStatus::Suspended));
assert!(ToolCallStatus::Suspended.can_transition_to(ToolCallStatus::Resuming));
assert!(ToolCallStatus::Resuming.can_transition_to(ToolCallStatus::Running));
assert!(ToolCallStatus::Resuming.can_transition_to(ToolCallStatus::Failed));
assert!(ToolCallStatus::Running.can_transition_to(ToolCallStatus::Succeeded));
assert!(ToolCallStatus::Running.can_transition_to(ToolCallStatus::Failed));
assert!(ToolCallStatus::Suspended.can_transition_to(ToolCallStatus::Cancelled));
}
#[test]
fn tool_call_status_rejects_terminal_reopen_transitions() {
assert!(!ToolCallStatus::Succeeded.can_transition_to(ToolCallStatus::Running));
assert!(!ToolCallStatus::Failed.can_transition_to(ToolCallStatus::Resuming));
assert!(!ToolCallStatus::Cancelled.can_transition_to(ToolCallStatus::Suspended));
}
#[test]
fn suspended_call_serde_flatten_roundtrip() {
use crate::runtime::tool_call::Suspension;
let call = SuspendedCall {
call_id: "call_1".into(),
tool_name: "my_tool".into(),
arguments: serde_json::json!({"key": "val"}),
ticket: SuspendTicket::new(
Suspension::new("susp_1", "confirm"),
PendingToolCall::new("pending_1", "my_tool", serde_json::json!({"key": "val"})),
ToolCallResumeMode::UseDecisionAsToolResult,
),
};
let json = serde_json::to_value(&call).unwrap();
assert!(json.get("ticket").is_none(), "ticket should be flattened");
assert!(
json.get("suspension").is_some(),
"suspension should be at top level"
);
assert!(
json.get("pending").is_some(),
"pending should be at top level"
);
assert!(
json.get("resume_mode").is_some(),
"resume_mode should be at top level"
);
assert_eq!(json["call_id"], "call_1");
assert_eq!(json["suspension"]["id"], "susp_1");
assert_eq!(json["pending"]["id"], "pending_1");
let deserialized: SuspendedCall = serde_json::from_value(json).unwrap();
assert_eq!(deserialized, call);
}
#[test]
fn tool_call_ascii_state_machine_contains_all_states() {
let diagram = ToolCallStatus::ASCII_STATE_MACHINE;
assert!(diagram.contains("new"));
assert!(diagram.contains("running"));
assert!(diagram.contains("suspended"));
assert!(diagram.contains("resuming"));
assert!(diagram.contains("succeeded"));
assert!(diagram.contains("failed"));
assert!(diagram.contains("cancelled"));
}
}