use turul_a2a_types::TaskState;
pub const DEBUG_TERMINAL_STATES: &[&str] = &["Completed", "Failed", "Canceled", "Rejected"];
pub fn task_state_wire_name(state: TaskState) -> &'static str {
match state {
TaskState::Submitted => "TASK_STATE_SUBMITTED",
TaskState::Working => "TASK_STATE_WORKING",
TaskState::Completed => "TASK_STATE_COMPLETED",
TaskState::Failed => "TASK_STATE_FAILED",
TaskState::Canceled => "TASK_STATE_CANCELED",
TaskState::InputRequired => "TASK_STATE_INPUT_REQUIRED",
TaskState::AuthRequired => "TASK_STATE_AUTH_REQUIRED",
TaskState::Rejected => "TASK_STATE_REJECTED",
_ => "TASK_STATE_UNKNOWN",
}
}
pub fn debug_state_to_wire_name(debug_state: &str) -> String {
match debug_state {
"Submitted" => "TASK_STATE_SUBMITTED".to_string(),
"Working" => "TASK_STATE_WORKING".to_string(),
"Completed" => "TASK_STATE_COMPLETED".to_string(),
"Failed" => "TASK_STATE_FAILED".to_string(),
"Canceled" => "TASK_STATE_CANCELED".to_string(),
"InputRequired" => "TASK_STATE_INPUT_REQUIRED".to_string(),
"AuthRequired" => "TASK_STATE_AUTH_REQUIRED".to_string(),
"Rejected" => "TASK_STATE_REJECTED".to_string(),
other => other.to_string(),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn all_terminal_states_have_wire_names() {
for state in [
TaskState::Completed,
TaskState::Failed,
TaskState::Canceled,
TaskState::Rejected,
] {
let wire = task_state_wire_name(state);
assert!(wire.starts_with("TASK_STATE_"));
}
}
#[test]
fn debug_to_wire_roundtrips_all_states() {
let cases = [
(TaskState::Submitted, "Submitted", "TASK_STATE_SUBMITTED"),
(TaskState::Working, "Working", "TASK_STATE_WORKING"),
(TaskState::Completed, "Completed", "TASK_STATE_COMPLETED"),
(TaskState::Failed, "Failed", "TASK_STATE_FAILED"),
(TaskState::Canceled, "Canceled", "TASK_STATE_CANCELED"),
(
TaskState::InputRequired,
"InputRequired",
"TASK_STATE_INPUT_REQUIRED",
),
(
TaskState::AuthRequired,
"AuthRequired",
"TASK_STATE_AUTH_REQUIRED",
),
(TaskState::Rejected, "Rejected", "TASK_STATE_REJECTED"),
];
for (state, debug, wire) in cases {
assert_eq!(format!("{state:?}"), debug, "Debug format stability");
assert_eq!(task_state_wire_name(state), wire);
assert_eq!(debug_state_to_wire_name(debug), wire);
}
}
#[test]
fn debug_terminal_states_covers_all_terminals() {
let expected = DEBUG_TERMINAL_STATES.to_vec();
let actual: Vec<String> = [
TaskState::Completed,
TaskState::Failed,
TaskState::Canceled,
TaskState::Rejected,
]
.iter()
.map(|s| format!("{s:?}"))
.collect();
for name in &actual {
assert!(
expected.contains(&name.as_str()),
"missing from DEBUG_TERMINAL_STATES: {name}"
);
}
assert_eq!(expected.len(), actual.len());
}
}