1use serde::{Deserialize, Serialize};
4
5pub const TOOL_CALL_PROGRESS_ACTIVITY_TYPE: &str = "tool-call-progress";
7
8#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct ToolCallProgressState {
11 #[serde(default = "default_schema")]
13 pub schema: String,
14 pub node_id: String,
16 pub call_id: String,
18 pub tool_name: String,
20 pub status: ProgressStatus,
22 #[serde(default, skip_serializing_if = "Option::is_none")]
24 pub progress: Option<f64>,
25 #[serde(default, skip_serializing_if = "Option::is_none")]
27 pub loaded: Option<u64>,
28 #[serde(default, skip_serializing_if = "Option::is_none")]
30 pub total: Option<u64>,
31 #[serde(default, skip_serializing_if = "Option::is_none")]
33 pub message: Option<String>,
34 #[serde(default, skip_serializing_if = "Option::is_none")]
36 pub parent_node_id: Option<String>,
37 #[serde(default, skip_serializing_if = "Option::is_none")]
39 pub parent_call_id: Option<String>,
40 #[serde(default, skip_serializing_if = "Option::is_none")]
42 pub run_id: Option<String>,
43 #[serde(default, skip_serializing_if = "Option::is_none")]
45 pub parent_run_id: Option<String>,
46 #[serde(default, skip_serializing_if = "Option::is_none")]
48 pub thread_id: Option<String>,
49}
50
51fn default_schema() -> String {
52 "tool-call-progress.v1".into()
53}
54
55#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
56#[serde(rename_all = "snake_case")]
57pub enum ProgressStatus {
58 Pending,
59 Running,
60 Done,
61 Failed,
62 Cancelled,
63}
64
65#[cfg(test)]
66mod tests {
67 use super::*;
68 use serde_json::json;
69
70 #[test]
71 fn progress_state_serde_roundtrip() {
72 let state = ToolCallProgressState {
73 schema: "tool-call-progress.v1".into(),
74 node_id: "call-1".into(),
75 call_id: "call-1".into(),
76 tool_name: "search".into(),
77 status: ProgressStatus::Running,
78 progress: Some(0.5),
79 loaded: Some(50),
80 total: Some(100),
81 message: Some("Searching...".into()),
82 parent_node_id: None,
83 parent_call_id: None,
84 run_id: None,
85 parent_run_id: None,
86 thread_id: None,
87 };
88 let json = serde_json::to_string(&state).unwrap();
89 let parsed: ToolCallProgressState = serde_json::from_str(&json).unwrap();
90 assert_eq!(parsed.node_id, "call-1");
91 assert_eq!(parsed.status, ProgressStatus::Running);
92 assert_eq!(parsed.progress, Some(0.5));
93 assert_eq!(parsed.loaded, Some(50));
94 assert_eq!(parsed.total, Some(100));
95 assert_eq!(parsed.message.as_deref(), Some("Searching..."));
96 }
97
98 #[test]
99 fn progress_state_default_schema() {
100 let json_str = r#"{
101 "node_id": "n1",
102 "call_id": "c1",
103 "tool_name": "t1",
104 "status": "pending"
105 }"#;
106 let parsed: ToolCallProgressState = serde_json::from_str(json_str).unwrap();
107 assert_eq!(parsed.schema, "tool-call-progress.v1");
108 }
109
110 #[test]
111 fn progress_state_omits_none_fields() {
112 let state = ToolCallProgressState {
113 schema: "tool-call-progress.v1".into(),
114 node_id: "n1".into(),
115 call_id: "c1".into(),
116 tool_name: "t1".into(),
117 status: ProgressStatus::Pending,
118 progress: None,
119 loaded: None,
120 total: None,
121 message: None,
122 parent_node_id: None,
123 parent_call_id: None,
124 run_id: None,
125 parent_run_id: None,
126 thread_id: None,
127 };
128 let value: serde_json::Value = serde_json::to_value(&state).unwrap();
129 let obj = value.as_object().unwrap();
130 assert!(!obj.contains_key("progress"));
131 assert!(!obj.contains_key("loaded"));
132 assert!(!obj.contains_key("total"));
133 assert!(!obj.contains_key("message"));
134 assert!(!obj.contains_key("parent_node_id"));
135 assert!(!obj.contains_key("parent_call_id"));
136 assert!(!obj.contains_key("run_id"));
137 assert!(!obj.contains_key("parent_run_id"));
138 assert!(!obj.contains_key("thread_id"));
139 }
140
141 #[test]
142 fn progress_status_all_variants_roundtrip() {
143 for status in [
144 ProgressStatus::Pending,
145 ProgressStatus::Running,
146 ProgressStatus::Done,
147 ProgressStatus::Failed,
148 ProgressStatus::Cancelled,
149 ] {
150 let json = serde_json::to_value(status).unwrap();
151 let parsed: ProgressStatus = serde_json::from_value(json).unwrap();
152 assert_eq!(parsed, status);
153 }
154 }
155
156 #[test]
157 fn progress_status_snake_case_serialization() {
158 assert_eq!(
159 serde_json::to_value(ProgressStatus::Pending).unwrap(),
160 json!("pending")
161 );
162 assert_eq!(
163 serde_json::to_value(ProgressStatus::Running).unwrap(),
164 json!("running")
165 );
166 assert_eq!(
167 serde_json::to_value(ProgressStatus::Done).unwrap(),
168 json!("done")
169 );
170 assert_eq!(
171 serde_json::to_value(ProgressStatus::Failed).unwrap(),
172 json!("failed")
173 );
174 assert_eq!(
175 serde_json::to_value(ProgressStatus::Cancelled).unwrap(),
176 json!("cancelled")
177 );
178 }
179
180 #[test]
181 fn progress_state_with_parent_fields() {
182 let state = ToolCallProgressState {
183 schema: "tool-call-progress.v1".into(),
184 node_id: "child-1".into(),
185 call_id: "child-1".into(),
186 tool_name: "sub_tool".into(),
187 status: ProgressStatus::Running,
188 progress: None,
189 loaded: None,
190 total: None,
191 message: None,
192 parent_node_id: Some("parent-1".into()),
193 parent_call_id: Some("parent-1".into()),
194 run_id: None,
195 parent_run_id: None,
196 thread_id: None,
197 };
198 let json = serde_json::to_string(&state).unwrap();
199 assert!(json.contains("parent_node_id"));
200 assert!(json.contains("parent_call_id"));
201 let parsed: ToolCallProgressState = serde_json::from_str(&json).unwrap();
202 assert_eq!(parsed.parent_node_id.as_deref(), Some("parent-1"));
203 assert_eq!(parsed.parent_call_id.as_deref(), Some("parent-1"));
204 }
205
206 #[test]
207 fn progress_state_lineage_fields_roundtrip() {
208 let state = ToolCallProgressState {
209 schema: "tool-call-progress.v1".into(),
210 node_id: "tool_call:call-42".into(),
211 call_id: "call-42".into(),
212 tool_name: "search".into(),
213 status: ProgressStatus::Running,
214 progress: None,
215 loaded: None,
216 total: None,
217 message: None,
218 parent_node_id: Some("run:run-1".into()),
219 parent_call_id: None,
220 run_id: Some("run-1".into()),
221 parent_run_id: Some("run-0".into()),
222 thread_id: Some("thread-abc".into()),
223 };
224 let json = serde_json::to_string(&state).unwrap();
225 assert!(json.contains("run_id"));
226 assert!(json.contains("parent_run_id"));
227 assert!(json.contains("thread_id"));
228 let parsed: ToolCallProgressState = serde_json::from_str(&json).unwrap();
229 assert_eq!(parsed.run_id.as_deref(), Some("run-1"));
230 assert_eq!(parsed.parent_run_id.as_deref(), Some("run-0"));
231 assert_eq!(parsed.thread_id.as_deref(), Some("thread-abc"));
232 }
233
234 #[test]
235 fn activity_type_constants() {
236 assert_eq!(TOOL_CALL_PROGRESS_ACTIVITY_TYPE, "tool-call-progress");
237 }
238
239 #[test]
240 fn progress_status_all_variants_have_distinct_serialization() {
241 use std::collections::HashSet;
242
243 let variants = [
244 ProgressStatus::Pending,
245 ProgressStatus::Running,
246 ProgressStatus::Done,
247 ProgressStatus::Failed,
248 ProgressStatus::Cancelled,
249 ];
250 let mut seen = HashSet::new();
251 for variant in &variants {
252 let serialized = serde_json::to_string(variant).unwrap();
253 assert!(
254 seen.insert(serialized.clone()),
255 "Duplicate serialization: {serialized} for {variant:?}"
256 );
257 let parsed: ProgressStatus = serde_json::from_str(&serialized).unwrap();
258 assert_eq!(&parsed, variant, "Roundtrip failed for {variant:?}");
259 }
260 assert_eq!(seen.len(), 5, "Expected 5 distinct serialized strings");
261 }
262
263 #[test]
264 fn progress_state_with_all_fields_populated() {
265 let state = ToolCallProgressState {
266 schema: "tool-call-progress.v1".into(),
267 node_id: "tool_call:call-99".into(),
268 call_id: "call-99".into(),
269 tool_name: "complex_tool".into(),
270 status: ProgressStatus::Running,
271 progress: Some(0.75),
272 loaded: Some(750),
273 total: Some(1000),
274 message: Some("Processing batch 3 of 4".into()),
275 parent_node_id: Some("run:parent-run".into()),
276 parent_call_id: Some("parent-call-1".into()),
277 run_id: Some("run-42".into()),
278 parent_run_id: Some("run-41".into()),
279 thread_id: Some("thread-xyz".into()),
280 };
281
282 let value: serde_json::Value = serde_json::to_value(&state).unwrap();
283 let obj = value.as_object().unwrap();
284
285 let expected_keys = [
287 "schema",
288 "node_id",
289 "call_id",
290 "tool_name",
291 "status",
292 "progress",
293 "loaded",
294 "total",
295 "message",
296 "parent_node_id",
297 "parent_call_id",
298 "run_id",
299 "parent_run_id",
300 "thread_id",
301 ];
302 for key in &expected_keys {
303 assert!(obj.contains_key(*key), "Missing key: {key}");
304 }
305
306 let json = serde_json::to_string(&state).unwrap();
308 let parsed: ToolCallProgressState = serde_json::from_str(&json).unwrap();
309 assert_eq!(parsed.schema, "tool-call-progress.v1");
310 assert_eq!(parsed.node_id, "tool_call:call-99");
311 assert_eq!(parsed.call_id, "call-99");
312 assert_eq!(parsed.tool_name, "complex_tool");
313 assert_eq!(parsed.status, ProgressStatus::Running);
314 assert_eq!(parsed.progress, Some(0.75));
315 assert_eq!(parsed.loaded, Some(750));
316 assert_eq!(parsed.total, Some(1000));
317 assert_eq!(parsed.message.as_deref(), Some("Processing batch 3 of 4"));
318 assert_eq!(parsed.parent_node_id.as_deref(), Some("run:parent-run"));
319 assert_eq!(parsed.parent_call_id.as_deref(), Some("parent-call-1"));
320 assert_eq!(parsed.run_id.as_deref(), Some("run-42"));
321 assert_eq!(parsed.parent_run_id.as_deref(), Some("run-41"));
322 assert_eq!(parsed.thread_id.as_deref(), Some("thread-xyz"));
323 }
324
325 #[test]
326 fn progress_state_validates_progress_range() {
327 let finite_cases: &[f64] = &[0.0, 0.5, 1.0, -1.0, 2.0];
330 for &val in finite_cases {
331 let state = ToolCallProgressState {
332 schema: "tool-call-progress.v1".into(),
333 node_id: "n".into(),
334 call_id: "c".into(),
335 tool_name: "t".into(),
336 status: ProgressStatus::Pending,
337 progress: Some(val),
338 loaded: None,
339 total: None,
340 message: None,
341 parent_node_id: None,
342 parent_call_id: None,
343 run_id: None,
344 parent_run_id: None,
345 thread_id: None,
346 };
347 let json = serde_json::to_string(&state).unwrap();
348 let parsed: ToolCallProgressState = serde_json::from_str(&json).unwrap();
349 assert_eq!(
350 parsed.progress,
351 Some(val),
352 "Roundtrip failed for progress={val}"
353 );
354 }
355
356 for &non_finite in &[f64::NAN, f64::INFINITY, f64::NEG_INFINITY] {
359 let state = ToolCallProgressState {
360 schema: "tool-call-progress.v1".into(),
361 node_id: "n".into(),
362 call_id: "c".into(),
363 tool_name: "t".into(),
364 status: ProgressStatus::Pending,
365 progress: Some(non_finite),
366 loaded: None,
367 total: None,
368 message: None,
369 parent_node_id: None,
370 parent_call_id: None,
371 run_id: None,
372 parent_run_id: None,
373 thread_id: None,
374 };
375 match serde_json::to_string(&state) {
376 Ok(json) => {
377 let _parsed: ToolCallProgressState = serde_json::from_str(&json).unwrap();
380 }
381 Err(_) => {
382 }
384 }
385 }
386 }
387}