use serde::{Deserialize, Serialize};
pub const TOOL_CALL_PROGRESS_ACTIVITY_TYPE: &str = "tool-call-progress";
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCallProgressState {
#[serde(default = "default_schema")]
pub schema: String,
pub node_id: String,
pub call_id: String,
pub tool_name: String,
pub status: ProgressStatus,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub progress: Option<f64>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub loaded: Option<u64>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub total: Option<u64>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub message: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub parent_node_id: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub parent_call_id: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub run_id: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub parent_run_id: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub thread_id: Option<String>,
}
fn default_schema() -> String {
"tool-call-progress.v1".into()
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ProgressStatus {
Pending,
Running,
Done,
Failed,
Cancelled,
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn progress_state_serde_roundtrip() {
let state = ToolCallProgressState {
schema: "tool-call-progress.v1".into(),
node_id: "call-1".into(),
call_id: "call-1".into(),
tool_name: "search".into(),
status: ProgressStatus::Running,
progress: Some(0.5),
loaded: Some(50),
total: Some(100),
message: Some("Searching...".into()),
parent_node_id: None,
parent_call_id: None,
run_id: None,
parent_run_id: None,
thread_id: None,
};
let json = serde_json::to_string(&state).unwrap();
let parsed: ToolCallProgressState = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.node_id, "call-1");
assert_eq!(parsed.status, ProgressStatus::Running);
assert_eq!(parsed.progress, Some(0.5));
assert_eq!(parsed.loaded, Some(50));
assert_eq!(parsed.total, Some(100));
assert_eq!(parsed.message.as_deref(), Some("Searching..."));
}
#[test]
fn progress_state_default_schema() {
let json_str = r#"{
"node_id": "n1",
"call_id": "c1",
"tool_name": "t1",
"status": "pending"
}"#;
let parsed: ToolCallProgressState = serde_json::from_str(json_str).unwrap();
assert_eq!(parsed.schema, "tool-call-progress.v1");
}
#[test]
fn progress_state_omits_none_fields() {
let state = ToolCallProgressState {
schema: "tool-call-progress.v1".into(),
node_id: "n1".into(),
call_id: "c1".into(),
tool_name: "t1".into(),
status: ProgressStatus::Pending,
progress: None,
loaded: None,
total: None,
message: None,
parent_node_id: None,
parent_call_id: None,
run_id: None,
parent_run_id: None,
thread_id: None,
};
let value: serde_json::Value = serde_json::to_value(&state).unwrap();
let obj = value.as_object().unwrap();
assert!(!obj.contains_key("progress"));
assert!(!obj.contains_key("loaded"));
assert!(!obj.contains_key("total"));
assert!(!obj.contains_key("message"));
assert!(!obj.contains_key("parent_node_id"));
assert!(!obj.contains_key("parent_call_id"));
assert!(!obj.contains_key("run_id"));
assert!(!obj.contains_key("parent_run_id"));
assert!(!obj.contains_key("thread_id"));
}
#[test]
fn progress_status_all_variants_roundtrip() {
for status in [
ProgressStatus::Pending,
ProgressStatus::Running,
ProgressStatus::Done,
ProgressStatus::Failed,
ProgressStatus::Cancelled,
] {
let json = serde_json::to_value(status).unwrap();
let parsed: ProgressStatus = serde_json::from_value(json).unwrap();
assert_eq!(parsed, status);
}
}
#[test]
fn progress_status_snake_case_serialization() {
assert_eq!(
serde_json::to_value(ProgressStatus::Pending).unwrap(),
json!("pending")
);
assert_eq!(
serde_json::to_value(ProgressStatus::Running).unwrap(),
json!("running")
);
assert_eq!(
serde_json::to_value(ProgressStatus::Done).unwrap(),
json!("done")
);
assert_eq!(
serde_json::to_value(ProgressStatus::Failed).unwrap(),
json!("failed")
);
assert_eq!(
serde_json::to_value(ProgressStatus::Cancelled).unwrap(),
json!("cancelled")
);
}
#[test]
fn progress_state_with_parent_fields() {
let state = ToolCallProgressState {
schema: "tool-call-progress.v1".into(),
node_id: "child-1".into(),
call_id: "child-1".into(),
tool_name: "sub_tool".into(),
status: ProgressStatus::Running,
progress: None,
loaded: None,
total: None,
message: None,
parent_node_id: Some("parent-1".into()),
parent_call_id: Some("parent-1".into()),
run_id: None,
parent_run_id: None,
thread_id: None,
};
let json = serde_json::to_string(&state).unwrap();
assert!(json.contains("parent_node_id"));
assert!(json.contains("parent_call_id"));
let parsed: ToolCallProgressState = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.parent_node_id.as_deref(), Some("parent-1"));
assert_eq!(parsed.parent_call_id.as_deref(), Some("parent-1"));
}
#[test]
fn progress_state_lineage_fields_roundtrip() {
let state = ToolCallProgressState {
schema: "tool-call-progress.v1".into(),
node_id: "tool_call:call-42".into(),
call_id: "call-42".into(),
tool_name: "search".into(),
status: ProgressStatus::Running,
progress: None,
loaded: None,
total: None,
message: None,
parent_node_id: Some("run:run-1".into()),
parent_call_id: None,
run_id: Some("run-1".into()),
parent_run_id: Some("run-0".into()),
thread_id: Some("thread-abc".into()),
};
let json = serde_json::to_string(&state).unwrap();
assert!(json.contains("run_id"));
assert!(json.contains("parent_run_id"));
assert!(json.contains("thread_id"));
let parsed: ToolCallProgressState = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.run_id.as_deref(), Some("run-1"));
assert_eq!(parsed.parent_run_id.as_deref(), Some("run-0"));
assert_eq!(parsed.thread_id.as_deref(), Some("thread-abc"));
}
#[test]
fn activity_type_constants() {
assert_eq!(TOOL_CALL_PROGRESS_ACTIVITY_TYPE, "tool-call-progress");
}
#[test]
fn progress_status_all_variants_have_distinct_serialization() {
use std::collections::HashSet;
let variants = [
ProgressStatus::Pending,
ProgressStatus::Running,
ProgressStatus::Done,
ProgressStatus::Failed,
ProgressStatus::Cancelled,
];
let mut seen = HashSet::new();
for variant in &variants {
let serialized = serde_json::to_string(variant).unwrap();
assert!(
seen.insert(serialized.clone()),
"Duplicate serialization: {serialized} for {variant:?}"
);
let parsed: ProgressStatus = serde_json::from_str(&serialized).unwrap();
assert_eq!(&parsed, variant, "Roundtrip failed for {variant:?}");
}
assert_eq!(seen.len(), 5, "Expected 5 distinct serialized strings");
}
#[test]
fn progress_state_with_all_fields_populated() {
let state = ToolCallProgressState {
schema: "tool-call-progress.v1".into(),
node_id: "tool_call:call-99".into(),
call_id: "call-99".into(),
tool_name: "complex_tool".into(),
status: ProgressStatus::Running,
progress: Some(0.75),
loaded: Some(750),
total: Some(1000),
message: Some("Processing batch 3 of 4".into()),
parent_node_id: Some("run:parent-run".into()),
parent_call_id: Some("parent-call-1".into()),
run_id: Some("run-42".into()),
parent_run_id: Some("run-41".into()),
thread_id: Some("thread-xyz".into()),
};
let value: serde_json::Value = serde_json::to_value(&state).unwrap();
let obj = value.as_object().unwrap();
let expected_keys = [
"schema",
"node_id",
"call_id",
"tool_name",
"status",
"progress",
"loaded",
"total",
"message",
"parent_node_id",
"parent_call_id",
"run_id",
"parent_run_id",
"thread_id",
];
for key in &expected_keys {
assert!(obj.contains_key(*key), "Missing key: {key}");
}
let json = serde_json::to_string(&state).unwrap();
let parsed: ToolCallProgressState = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.schema, "tool-call-progress.v1");
assert_eq!(parsed.node_id, "tool_call:call-99");
assert_eq!(parsed.call_id, "call-99");
assert_eq!(parsed.tool_name, "complex_tool");
assert_eq!(parsed.status, ProgressStatus::Running);
assert_eq!(parsed.progress, Some(0.75));
assert_eq!(parsed.loaded, Some(750));
assert_eq!(parsed.total, Some(1000));
assert_eq!(parsed.message.as_deref(), Some("Processing batch 3 of 4"));
assert_eq!(parsed.parent_node_id.as_deref(), Some("run:parent-run"));
assert_eq!(parsed.parent_call_id.as_deref(), Some("parent-call-1"));
assert_eq!(parsed.run_id.as_deref(), Some("run-42"));
assert_eq!(parsed.parent_run_id.as_deref(), Some("run-41"));
assert_eq!(parsed.thread_id.as_deref(), Some("thread-xyz"));
}
#[test]
fn progress_state_validates_progress_range() {
let finite_cases: &[f64] = &[0.0, 0.5, 1.0, -1.0, 2.0];
for &val in finite_cases {
let state = ToolCallProgressState {
schema: "tool-call-progress.v1".into(),
node_id: "n".into(),
call_id: "c".into(),
tool_name: "t".into(),
status: ProgressStatus::Pending,
progress: Some(val),
loaded: None,
total: None,
message: None,
parent_node_id: None,
parent_call_id: None,
run_id: None,
parent_run_id: None,
thread_id: None,
};
let json = serde_json::to_string(&state).unwrap();
let parsed: ToolCallProgressState = serde_json::from_str(&json).unwrap();
assert_eq!(
parsed.progress,
Some(val),
"Roundtrip failed for progress={val}"
);
}
for &non_finite in &[f64::NAN, f64::INFINITY, f64::NEG_INFINITY] {
let state = ToolCallProgressState {
schema: "tool-call-progress.v1".into(),
node_id: "n".into(),
call_id: "c".into(),
tool_name: "t".into(),
status: ProgressStatus::Pending,
progress: Some(non_finite),
loaded: None,
total: None,
message: None,
parent_node_id: None,
parent_call_id: None,
run_id: None,
parent_run_id: None,
thread_id: None,
};
match serde_json::to_string(&state) {
Ok(json) => {
let _parsed: ToolCallProgressState = serde_json::from_str(&json).unwrap();
}
Err(_) => {
}
}
}
}
}