Skip to main content

awaken_runtime_contract/contract/
progress.rs

1//! Canonical progress and file activity types for tool call execution.
2
3use serde::{Deserialize, Serialize};
4
5/// Constants for activity type identification.
6pub const TOOL_CALL_PROGRESS_ACTIVITY_TYPE: &str = "tool-call-progress";
7
8/// Canonical progress state for a tool call execution.
9#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct ToolCallProgressState {
11    /// Schema identifier.
12    #[serde(default = "default_schema")]
13    pub schema: String,
14    /// Unique node ID for this progress entry (typically tool_call_id).
15    pub node_id: String,
16    /// Tool call ID.
17    pub call_id: String,
18    /// Tool name.
19    pub tool_name: String,
20    /// Current status.
21    pub status: ProgressStatus,
22    /// Normalized progress (0.0 - 1.0). None if indeterminate.
23    #[serde(default, skip_serializing_if = "Option::is_none")]
24    pub progress: Option<f64>,
25    /// Absolute progress loaded count.
26    #[serde(default, skip_serializing_if = "Option::is_none")]
27    pub loaded: Option<u64>,
28    /// Absolute progress total count.
29    #[serde(default, skip_serializing_if = "Option::is_none")]
30    pub total: Option<u64>,
31    /// Human-readable status message.
32    #[serde(default, skip_serializing_if = "Option::is_none")]
33    pub message: Option<String>,
34    /// Parent node ID (for nested tool calls).
35    #[serde(default, skip_serializing_if = "Option::is_none")]
36    pub parent_node_id: Option<String>,
37    /// Parent tool call ID.
38    #[serde(default, skip_serializing_if = "Option::is_none")]
39    pub parent_call_id: Option<String>,
40    /// Run ID of the owning agent run.
41    #[serde(default, skip_serializing_if = "Option::is_none")]
42    pub run_id: Option<String>,
43    /// Parent run ID (set when this run was spawned by another run).
44    #[serde(default, skip_serializing_if = "Option::is_none")]
45    pub parent_run_id: Option<String>,
46    /// Thread ID of the owning thread.
47    #[serde(default, skip_serializing_if = "Option::is_none")]
48    pub thread_id: Option<String>,
49}
50
51fn default_schema() -> String {
52    "tool-call-progress.v1".into()
53}
54
55#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
56#[serde(rename_all = "snake_case")]
57pub enum ProgressStatus {
58    Pending,
59    Running,
60    Done,
61    Failed,
62    Cancelled,
63}
64
65#[cfg(test)]
66mod tests {
67    use super::*;
68    use serde_json::json;
69
70    #[test]
71    fn progress_state_serde_roundtrip() {
72        let state = ToolCallProgressState {
73            schema: "tool-call-progress.v1".into(),
74            node_id: "call-1".into(),
75            call_id: "call-1".into(),
76            tool_name: "search".into(),
77            status: ProgressStatus::Running,
78            progress: Some(0.5),
79            loaded: Some(50),
80            total: Some(100),
81            message: Some("Searching...".into()),
82            parent_node_id: None,
83            parent_call_id: None,
84            run_id: None,
85            parent_run_id: None,
86            thread_id: None,
87        };
88        let json = serde_json::to_string(&state).unwrap();
89        let parsed: ToolCallProgressState = serde_json::from_str(&json).unwrap();
90        assert_eq!(parsed.node_id, "call-1");
91        assert_eq!(parsed.status, ProgressStatus::Running);
92        assert_eq!(parsed.progress, Some(0.5));
93        assert_eq!(parsed.loaded, Some(50));
94        assert_eq!(parsed.total, Some(100));
95        assert_eq!(parsed.message.as_deref(), Some("Searching..."));
96    }
97
98    #[test]
99    fn progress_state_default_schema() {
100        let json_str = r#"{
101            "node_id": "n1",
102            "call_id": "c1",
103            "tool_name": "t1",
104            "status": "pending"
105        }"#;
106        let parsed: ToolCallProgressState = serde_json::from_str(json_str).unwrap();
107        assert_eq!(parsed.schema, "tool-call-progress.v1");
108    }
109
110    #[test]
111    fn progress_state_omits_none_fields() {
112        let state = ToolCallProgressState {
113            schema: "tool-call-progress.v1".into(),
114            node_id: "n1".into(),
115            call_id: "c1".into(),
116            tool_name: "t1".into(),
117            status: ProgressStatus::Pending,
118            progress: None,
119            loaded: None,
120            total: None,
121            message: None,
122            parent_node_id: None,
123            parent_call_id: None,
124            run_id: None,
125            parent_run_id: None,
126            thread_id: None,
127        };
128        let value: serde_json::Value = serde_json::to_value(&state).unwrap();
129        let obj = value.as_object().unwrap();
130        assert!(!obj.contains_key("progress"));
131        assert!(!obj.contains_key("loaded"));
132        assert!(!obj.contains_key("total"));
133        assert!(!obj.contains_key("message"));
134        assert!(!obj.contains_key("parent_node_id"));
135        assert!(!obj.contains_key("parent_call_id"));
136        assert!(!obj.contains_key("run_id"));
137        assert!(!obj.contains_key("parent_run_id"));
138        assert!(!obj.contains_key("thread_id"));
139    }
140
141    #[test]
142    fn progress_status_all_variants_roundtrip() {
143        for status in [
144            ProgressStatus::Pending,
145            ProgressStatus::Running,
146            ProgressStatus::Done,
147            ProgressStatus::Failed,
148            ProgressStatus::Cancelled,
149        ] {
150            let json = serde_json::to_value(status).unwrap();
151            let parsed: ProgressStatus = serde_json::from_value(json).unwrap();
152            assert_eq!(parsed, status);
153        }
154    }
155
156    #[test]
157    fn progress_status_snake_case_serialization() {
158        assert_eq!(
159            serde_json::to_value(ProgressStatus::Pending).unwrap(),
160            json!("pending")
161        );
162        assert_eq!(
163            serde_json::to_value(ProgressStatus::Running).unwrap(),
164            json!("running")
165        );
166        assert_eq!(
167            serde_json::to_value(ProgressStatus::Done).unwrap(),
168            json!("done")
169        );
170        assert_eq!(
171            serde_json::to_value(ProgressStatus::Failed).unwrap(),
172            json!("failed")
173        );
174        assert_eq!(
175            serde_json::to_value(ProgressStatus::Cancelled).unwrap(),
176            json!("cancelled")
177        );
178    }
179
180    #[test]
181    fn progress_state_with_parent_fields() {
182        let state = ToolCallProgressState {
183            schema: "tool-call-progress.v1".into(),
184            node_id: "child-1".into(),
185            call_id: "child-1".into(),
186            tool_name: "sub_tool".into(),
187            status: ProgressStatus::Running,
188            progress: None,
189            loaded: None,
190            total: None,
191            message: None,
192            parent_node_id: Some("parent-1".into()),
193            parent_call_id: Some("parent-1".into()),
194            run_id: None,
195            parent_run_id: None,
196            thread_id: None,
197        };
198        let json = serde_json::to_string(&state).unwrap();
199        assert!(json.contains("parent_node_id"));
200        assert!(json.contains("parent_call_id"));
201        let parsed: ToolCallProgressState = serde_json::from_str(&json).unwrap();
202        assert_eq!(parsed.parent_node_id.as_deref(), Some("parent-1"));
203        assert_eq!(parsed.parent_call_id.as_deref(), Some("parent-1"));
204    }
205
206    #[test]
207    fn progress_state_lineage_fields_roundtrip() {
208        let state = ToolCallProgressState {
209            schema: "tool-call-progress.v1".into(),
210            node_id: "tool_call:call-42".into(),
211            call_id: "call-42".into(),
212            tool_name: "search".into(),
213            status: ProgressStatus::Running,
214            progress: None,
215            loaded: None,
216            total: None,
217            message: None,
218            parent_node_id: Some("run:run-1".into()),
219            parent_call_id: None,
220            run_id: Some("run-1".into()),
221            parent_run_id: Some("run-0".into()),
222            thread_id: Some("thread-abc".into()),
223        };
224        let json = serde_json::to_string(&state).unwrap();
225        assert!(json.contains("run_id"));
226        assert!(json.contains("parent_run_id"));
227        assert!(json.contains("thread_id"));
228        let parsed: ToolCallProgressState = serde_json::from_str(&json).unwrap();
229        assert_eq!(parsed.run_id.as_deref(), Some("run-1"));
230        assert_eq!(parsed.parent_run_id.as_deref(), Some("run-0"));
231        assert_eq!(parsed.thread_id.as_deref(), Some("thread-abc"));
232    }
233
234    #[test]
235    fn activity_type_constants() {
236        assert_eq!(TOOL_CALL_PROGRESS_ACTIVITY_TYPE, "tool-call-progress");
237    }
238
239    #[test]
240    fn progress_status_all_variants_have_distinct_serialization() {
241        use std::collections::HashSet;
242
243        let variants = [
244            ProgressStatus::Pending,
245            ProgressStatus::Running,
246            ProgressStatus::Done,
247            ProgressStatus::Failed,
248            ProgressStatus::Cancelled,
249        ];
250        let mut seen = HashSet::new();
251        for variant in &variants {
252            let serialized = serde_json::to_string(variant).unwrap();
253            assert!(
254                seen.insert(serialized.clone()),
255                "Duplicate serialization: {serialized} for {variant:?}"
256            );
257            let parsed: ProgressStatus = serde_json::from_str(&serialized).unwrap();
258            assert_eq!(&parsed, variant, "Roundtrip failed for {variant:?}");
259        }
260        assert_eq!(seen.len(), 5, "Expected 5 distinct serialized strings");
261    }
262
263    #[test]
264    fn progress_state_with_all_fields_populated() {
265        let state = ToolCallProgressState {
266            schema: "tool-call-progress.v1".into(),
267            node_id: "tool_call:call-99".into(),
268            call_id: "call-99".into(),
269            tool_name: "complex_tool".into(),
270            status: ProgressStatus::Running,
271            progress: Some(0.75),
272            loaded: Some(750),
273            total: Some(1000),
274            message: Some("Processing batch 3 of 4".into()),
275            parent_node_id: Some("run:parent-run".into()),
276            parent_call_id: Some("parent-call-1".into()),
277            run_id: Some("run-42".into()),
278            parent_run_id: Some("run-41".into()),
279            thread_id: Some("thread-xyz".into()),
280        };
281
282        let value: serde_json::Value = serde_json::to_value(&state).unwrap();
283        let obj = value.as_object().unwrap();
284
285        // Verify all fields are present in serialized output.
286        let expected_keys = [
287            "schema",
288            "node_id",
289            "call_id",
290            "tool_name",
291            "status",
292            "progress",
293            "loaded",
294            "total",
295            "message",
296            "parent_node_id",
297            "parent_call_id",
298            "run_id",
299            "parent_run_id",
300            "thread_id",
301        ];
302        for key in &expected_keys {
303            assert!(obj.contains_key(*key), "Missing key: {key}");
304        }
305
306        // Roundtrip check.
307        let json = serde_json::to_string(&state).unwrap();
308        let parsed: ToolCallProgressState = serde_json::from_str(&json).unwrap();
309        assert_eq!(parsed.schema, "tool-call-progress.v1");
310        assert_eq!(parsed.node_id, "tool_call:call-99");
311        assert_eq!(parsed.call_id, "call-99");
312        assert_eq!(parsed.tool_name, "complex_tool");
313        assert_eq!(parsed.status, ProgressStatus::Running);
314        assert_eq!(parsed.progress, Some(0.75));
315        assert_eq!(parsed.loaded, Some(750));
316        assert_eq!(parsed.total, Some(1000));
317        assert_eq!(parsed.message.as_deref(), Some("Processing batch 3 of 4"));
318        assert_eq!(parsed.parent_node_id.as_deref(), Some("run:parent-run"));
319        assert_eq!(parsed.parent_call_id.as_deref(), Some("parent-call-1"));
320        assert_eq!(parsed.run_id.as_deref(), Some("run-42"));
321        assert_eq!(parsed.parent_run_id.as_deref(), Some("run-41"));
322        assert_eq!(parsed.thread_id.as_deref(), Some("thread-xyz"));
323    }
324
325    #[test]
326    fn progress_state_validates_progress_range() {
327        // The progress field has no built-in validation; this test documents
328        // that arbitrary finite f64 values serialize and deserialize without error.
329        let finite_cases: &[f64] = &[0.0, 0.5, 1.0, -1.0, 2.0];
330        for &val in finite_cases {
331            let state = ToolCallProgressState {
332                schema: "tool-call-progress.v1".into(),
333                node_id: "n".into(),
334                call_id: "c".into(),
335                tool_name: "t".into(),
336                status: ProgressStatus::Pending,
337                progress: Some(val),
338                loaded: None,
339                total: None,
340                message: None,
341                parent_node_id: None,
342                parent_call_id: None,
343                run_id: None,
344                parent_run_id: None,
345                thread_id: None,
346            };
347            let json = serde_json::to_string(&state).unwrap();
348            let parsed: ToolCallProgressState = serde_json::from_str(&json).unwrap();
349            assert_eq!(
350                parsed.progress,
351                Some(val),
352                "Roundtrip failed for progress={val}"
353            );
354        }
355
356        // Non-finite values: serde_json may reject or serialize them depending
357        // on the version. Document observed behavior for each.
358        for &non_finite in &[f64::NAN, f64::INFINITY, f64::NEG_INFINITY] {
359            let state = ToolCallProgressState {
360                schema: "tool-call-progress.v1".into(),
361                node_id: "n".into(),
362                call_id: "c".into(),
363                tool_name: "t".into(),
364                status: ProgressStatus::Pending,
365                progress: Some(non_finite),
366                loaded: None,
367                total: None,
368                message: None,
369                parent_node_id: None,
370                parent_call_id: None,
371                run_id: None,
372                parent_run_id: None,
373                thread_id: None,
374            };
375            match serde_json::to_string(&state) {
376                Ok(json) => {
377                    // If serialization succeeds, verify it can be parsed back.
378                    // The value may lose fidelity (e.g. null for NaN).
379                    let _parsed: ToolCallProgressState = serde_json::from_str(&json).unwrap();
380                }
381                Err(_) => {
382                    // serde_json correctly rejects non-finite float.
383                }
384            }
385        }
386    }
387}