use a2a_protocol_types::message::{Part, PartContent};
use a2a_protocol_types::task::{ContextId, TaskId, TaskState};
use proptest::prelude::*;
fn arb_task_state() -> impl Strategy<Value = TaskState> {
prop_oneof![
Just(TaskState::Unspecified),
Just(TaskState::Submitted),
Just(TaskState::Working),
Just(TaskState::InputRequired),
Just(TaskState::AuthRequired),
Just(TaskState::Completed),
Just(TaskState::Failed),
Just(TaskState::Canceled),
Just(TaskState::Rejected),
]
}
proptest! {
#[test]
fn task_state_roundtrip(state in arb_task_state()) {
let json = serde_json::to_string(&state).unwrap();
let back: TaskState = serde_json::from_str(&json).unwrap();
prop_assert_eq!(state, back);
}
#[test]
fn task_state_terminal_classification(state in arb_task_state()) {
let expected_terminal = matches!(
state,
TaskState::Completed | TaskState::Failed | TaskState::Canceled | TaskState::Rejected
);
prop_assert_eq!(state.is_terminal(), expected_terminal);
}
#[test]
fn task_state_wire_format(state in arb_task_state()) {
let json = serde_json::to_string(&state).unwrap();
let inner = json.trim_matches('"');
let valid = ["unspecified", "submitted", "working", "input-required",
"auth-required", "completed", "failed", "canceled", "rejected"];
prop_assert!(valid.contains(&inner), "got: {}", inner);
}
}
fn arb_text_part() -> impl Strategy<Value = Part> {
".*".prop_map(Part::text)
}
fn arb_raw_part() -> impl Strategy<Value = Part> {
"[a-zA-Z0-9+/=]{0,100}".prop_map(Part::raw)
}
fn arb_url_part() -> impl Strategy<Value = Part> {
"https?://[a-z]{1,20}\\.[a-z]{2,4}/[a-z]{0,20}".prop_map(Part::url)
}
fn arb_part() -> impl Strategy<Value = Part> {
prop_oneof![arb_text_part(), arb_raw_part(), arb_url_part(),]
}
proptest! {
#[test]
fn part_roundtrip(part in arb_part()) {
let json = serde_json::to_string(&part).unwrap();
let back: Part = serde_json::from_str(&json).unwrap();
match (&part.content, &back.content) {
(PartContent::Text { text: a }, PartContent::Text { text: b }) => {
prop_assert_eq!(a, b);
}
(PartContent::File { file: a }, PartContent::File { file: b }) => {
prop_assert_eq!(&a.bytes, &b.bytes);
prop_assert_eq!(&a.uri, &b.uri);
}
_ => prop_assert!(false, "content type mismatch"),
}
}
}
proptest! {
#[test]
fn task_id_display(s in "[a-zA-Z0-9_-]{1,50}") {
let id = TaskId::new(&s);
prop_assert_eq!(id.to_string(), s);
}
#[test]
fn context_id_display(s in "[a-zA-Z0-9_-]{1,50}") {
let id = ContextId::new(&s);
prop_assert_eq!(id.to_string(), s);
}
#[test]
fn task_id_equality(s in "[a-zA-Z0-9_-]{1,50}") {
let a = TaskId::new(&s);
let b = TaskId::new(&s);
prop_assert_eq!(a, b);
}
#[test]
fn task_id_inequality(a in "[a-z]{1,10}", b in "[A-Z]{1,10}") {
let id_a = TaskId::new(&a);
let id_b = TaskId::new(&b);
prop_assert_ne!(id_a, id_b);
}
}