use crate::mcp::{Tool, result_error, result_ok_json, tool_def};
use crate::state::WorkflowState;
use bote::ToolDef as BoteToolDef;
use serde_json::json;
use std::pin::Pin;
fn all_workflow_states() -> &'static [(&'static str, WorkflowState)] {
&[
("created", WorkflowState::Created),
("running", WorkflowState::Running),
("paused", WorkflowState::Paused),
("completed", WorkflowState::Completed),
("failed", WorkflowState::Failed),
("rolling_back", WorkflowState::RollingBack),
("rolled_back", WorkflowState::RolledBack),
("cancelled", WorkflowState::Cancelled),
]
}
fn parse_state(s: &str) -> Option<WorkflowState> {
all_workflow_states()
.iter()
.find(|(name, _)| *name == s)
.map(|(_, state)| *state)
}
pub struct StateCheck;
impl Tool for StateCheck {
fn definition(&self) -> BoteToolDef {
tool_def(
"szal_state_check",
"Check if a workflow state is terminal and list its valid transitions",
json!({ "state": { "type": "string", "enum": ["created","running","paused","completed","failed","rolling_back","rolled_back","cancelled"] } }),
vec!["state".into()],
)
}
fn call(
&self,
args: serde_json::Value,
) -> Pin<Box<dyn std::future::Future<Output = serde_json::Value> + Send + '_>> {
Box::pin(async move {
let state_str = match args.get("state").and_then(|v| v.as_str()) {
Some(s) => s,
None => return result_error("missing required field: state"),
};
let state = match parse_state(state_str) {
Some(s) => s,
None => return result_error(format!("invalid state: {state_str}")),
};
let all = all_workflow_states();
let valid_targets: Vec<&str> = all
.iter()
.filter(|(_, s)| state.valid_transition(s))
.map(|(n, _)| *n)
.collect();
result_ok_json(&json!({
"state": state_str,
"is_terminal": state.is_terminal(),
"valid_transitions": valid_targets,
}))
})
}
}
pub struct StateTransition;
impl Tool for StateTransition {
fn definition(&self) -> BoteToolDef {
tool_def(
"szal_state_transition",
"Check if a state transition from one state to another is valid",
json!({
"from": { "type": "string" },
"to": { "type": "string" }
}),
vec!["from".into(), "to".into()],
)
}
fn call(
&self,
args: serde_json::Value,
) -> Pin<Box<dyn std::future::Future<Output = serde_json::Value> + Send + '_>> {
Box::pin(async move {
let from = match args
.get("from")
.and_then(|v| v.as_str())
.and_then(parse_state)
{
Some(s) => s,
None => return result_error("missing or invalid 'from' state"),
};
let to = match args
.get("to")
.and_then(|v| v.as_str())
.and_then(parse_state)
{
Some(s) => s,
None => return result_error("missing or invalid 'to' state"),
};
result_ok_json(&json!({
"from": args["from"],
"to": args["to"],
"valid": from.valid_transition(&to),
}))
})
}
}
pub struct StateLifecycle;
impl Tool for StateLifecycle {
fn definition(&self) -> BoteToolDef {
tool_def(
"szal_state_lifecycle",
"Show the complete workflow state machine — all states, transitions, and terminal states",
json!({}),
vec![],
)
}
fn call(
&self,
_args: serde_json::Value,
) -> Pin<Box<dyn std::future::Future<Output = serde_json::Value> + Send + '_>> {
Box::pin(async {
let all = all_workflow_states();
let states: Vec<serde_json::Value> = all.iter().map(|(name, state)| {
let targets: Vec<&str> = all.iter().filter(|(_, s)| state.valid_transition(s)).map(|(n, _)| *n).collect();
json!({"state": name, "is_terminal": state.is_terminal(), "transitions_to": targets})
}).collect();
result_ok_json(&json!(states))
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn state_check_running() {
let result = StateCheck.call(json!({"state": "running"})).await;
assert_eq!(result["isError"], false);
let text = result["content"][0]["text"].as_str().unwrap();
assert!(text.contains("\"is_terminal\": false"));
}
#[tokio::test]
async fn state_check_completed() {
let result = StateCheck.call(json!({"state": "completed"})).await;
assert_eq!(result["isError"], false);
let text = result["content"][0]["text"].as_str().unwrap();
assert!(text.contains("\"is_terminal\": true"));
}
#[tokio::test]
async fn state_check_invalid() {
let result = StateCheck.call(json!({"state": "nope"})).await;
assert_eq!(result["isError"], true);
}
#[tokio::test]
async fn state_transition_valid() {
let result = StateTransition
.call(json!({"from": "created", "to": "running"}))
.await;
let text = result["content"][0]["text"].as_str().unwrap();
assert!(text.contains("\"valid\": true"));
}
#[tokio::test]
async fn state_transition_invalid() {
let result = StateTransition
.call(json!({"from": "completed", "to": "running"}))
.await;
let text = result["content"][0]["text"].as_str().unwrap();
assert!(text.contains("\"valid\": false"));
}
#[tokio::test]
async fn state_lifecycle() {
let result = StateLifecycle.call(json!({})).await;
assert_eq!(result["isError"], false);
let text = result["content"][0]["text"].as_str().unwrap();
let states: Vec<serde_json::Value> = serde_json::from_str(text).unwrap();
assert_eq!(states.len(), 8);
}
}