use crate::TaskState;
use crate::error::A2aTypeError;
pub fn validate_transition(from: TaskState, to: TaskState) -> Result<(), A2aTypeError> {
match from {
TaskState::Submitted => match to {
TaskState::Working | TaskState::Rejected | TaskState::Failed | TaskState::Canceled => {
Ok(())
}
_ => Err(A2aTypeError::InvalidTransition {
current: from,
requested: to,
}),
},
TaskState::Working => match to {
TaskState::Completed
| TaskState::Failed
| TaskState::Canceled
| TaskState::Rejected
| TaskState::InputRequired
| TaskState::AuthRequired => Ok(()),
_ => Err(A2aTypeError::InvalidTransition {
current: from,
requested: to,
}),
},
TaskState::InputRequired => match to {
TaskState::Working
| TaskState::Completed
| TaskState::Failed
| TaskState::Canceled
| TaskState::Rejected => Ok(()),
_ => Err(A2aTypeError::InvalidTransition {
current: from,
requested: to,
}),
},
TaskState::AuthRequired => match to {
TaskState::Working | TaskState::Failed | TaskState::Canceled | TaskState::Rejected => {
Ok(())
}
_ => Err(A2aTypeError::InvalidTransition {
current: from,
requested: to,
}),
},
TaskState::Completed | TaskState::Failed | TaskState::Canceled | TaskState::Rejected => {
Err(A2aTypeError::TerminalState(from))
}
}
}
pub fn is_terminal(state: TaskState) -> bool {
matches!(
state,
TaskState::Completed | TaskState::Failed | TaskState::Canceled | TaskState::Rejected
)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn valid_submitted_transitions() {
assert!(validate_transition(TaskState::Submitted, TaskState::Working).is_ok());
assert!(validate_transition(TaskState::Submitted, TaskState::Rejected).is_ok());
assert!(validate_transition(TaskState::Submitted, TaskState::Failed).is_ok());
assert!(validate_transition(TaskState::Submitted, TaskState::Canceled).is_ok());
}
#[test]
fn invalid_submitted_transitions() {
assert!(validate_transition(TaskState::Submitted, TaskState::Completed).is_err());
assert!(validate_transition(TaskState::Submitted, TaskState::InputRequired).is_err());
assert!(validate_transition(TaskState::Submitted, TaskState::AuthRequired).is_err());
assert!(validate_transition(TaskState::Submitted, TaskState::Submitted).is_err());
}
#[test]
fn valid_working_transitions() {
assert!(validate_transition(TaskState::Working, TaskState::Completed).is_ok());
assert!(validate_transition(TaskState::Working, TaskState::Failed).is_ok());
assert!(validate_transition(TaskState::Working, TaskState::Canceled).is_ok());
assert!(validate_transition(TaskState::Working, TaskState::Rejected).is_ok());
assert!(validate_transition(TaskState::Working, TaskState::InputRequired).is_ok());
assert!(validate_transition(TaskState::Working, TaskState::AuthRequired).is_ok());
}
#[test]
fn invalid_working_transitions() {
assert!(validate_transition(TaskState::Working, TaskState::Working).is_err());
assert!(validate_transition(TaskState::Working, TaskState::Submitted).is_err());
}
#[test]
fn valid_input_required_transitions() {
assert!(validate_transition(TaskState::InputRequired, TaskState::Working).is_ok());
assert!(validate_transition(TaskState::InputRequired, TaskState::Completed).is_ok());
assert!(validate_transition(TaskState::InputRequired, TaskState::Failed).is_ok());
assert!(validate_transition(TaskState::InputRequired, TaskState::Canceled).is_ok());
assert!(validate_transition(TaskState::InputRequired, TaskState::Rejected).is_ok());
}
#[test]
fn invalid_input_required_transitions() {
assert!(validate_transition(TaskState::InputRequired, TaskState::InputRequired).is_err());
assert!(validate_transition(TaskState::InputRequired, TaskState::AuthRequired).is_err());
assert!(validate_transition(TaskState::InputRequired, TaskState::Submitted).is_err());
}
#[test]
fn valid_auth_required_transitions() {
assert!(validate_transition(TaskState::AuthRequired, TaskState::Working).is_ok());
assert!(validate_transition(TaskState::AuthRequired, TaskState::Failed).is_ok());
assert!(validate_transition(TaskState::AuthRequired, TaskState::Canceled).is_ok());
assert!(validate_transition(TaskState::AuthRequired, TaskState::Rejected).is_ok());
}
#[test]
fn invalid_auth_required_transitions() {
assert!(validate_transition(TaskState::AuthRequired, TaskState::Completed).is_err());
assert!(validate_transition(TaskState::AuthRequired, TaskState::AuthRequired).is_err());
assert!(validate_transition(TaskState::AuthRequired, TaskState::InputRequired).is_err());
assert!(validate_transition(TaskState::AuthRequired, TaskState::Submitted).is_err());
}
#[test]
fn terminal_states_reject_all_transitions() {
for terminal in [
TaskState::Completed,
TaskState::Failed,
TaskState::Canceled,
TaskState::Rejected,
] {
for target in [
TaskState::Submitted,
TaskState::Working,
TaskState::Completed,
TaskState::Failed,
TaskState::Canceled,
TaskState::InputRequired,
TaskState::AuthRequired,
TaskState::Rejected,
] {
let result = validate_transition(terminal, target);
assert!(
result.is_err(),
"Expected error for {terminal:?} -> {target:?}"
);
match result.unwrap_err() {
A2aTypeError::TerminalState(s) => assert_eq!(s, terminal),
other => panic!("Expected TerminalState, got: {other:?}"),
}
}
}
}
#[test]
fn is_terminal_correct() {
assert!(!is_terminal(TaskState::Submitted));
assert!(!is_terminal(TaskState::Working));
assert!(!is_terminal(TaskState::InputRequired));
assert!(!is_terminal(TaskState::AuthRequired));
assert!(is_terminal(TaskState::Completed));
assert!(is_terminal(TaskState::Failed));
assert!(is_terminal(TaskState::Canceled));
assert!(is_terminal(TaskState::Rejected));
}
}