aster/agents/subagent_execution_tool/
notification_events.rs1use serde::{Deserialize, Serialize};
2use serde_json::Value;
3
4#[derive(Debug, Clone, Serialize, Deserialize)]
5pub enum TaskStatus {
6 Pending,
7 Running,
8 Completed,
9 Failed,
10}
11
12impl std::fmt::Display for TaskStatus {
13 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
14 match self {
15 TaskStatus::Pending => write!(f, "Pending"),
16 TaskStatus::Running => write!(f, "Running"),
17 TaskStatus::Completed => write!(f, "Completed"),
18 TaskStatus::Failed => write!(f, "Failed"),
19 }
20 }
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
24#[serde(tag = "subtype")]
25pub enum TaskExecutionNotificationEvent {
26 #[serde(rename = "line_output")]
27 LineOutput { task_id: String, output: String },
28 #[serde(rename = "tasks_update")]
29 TasksUpdate {
30 stats: TaskExecutionStats,
31 tasks: Vec<TaskInfo>,
32 },
33 #[serde(rename = "tasks_complete")]
34 TasksComplete {
35 stats: TaskCompletionStats,
36 failed_tasks: Vec<FailedTaskInfo>,
37 },
38}
39
40#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct TaskExecutionStats {
42 pub total: usize,
43 pub pending: usize,
44 pub running: usize,
45 pub completed: usize,
46 pub failed: usize,
47}
48
49#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct TaskCompletionStats {
51 pub total: usize,
52 pub completed: usize,
53 pub failed: usize,
54 pub success_rate: f64,
55}
56
57#[derive(Debug, Clone, Serialize, Deserialize)]
58pub struct TaskInfo {
59 pub id: String,
60 pub status: TaskStatus,
61 pub duration_secs: Option<f64>,
62 pub current_output: String,
63 pub task_type: String,
64 pub task_name: String,
65 pub task_metadata: String,
66 pub error: Option<String>,
67 pub result_data: Option<Value>,
68}
69
70#[derive(Debug, Clone, Serialize, Deserialize)]
71pub struct FailedTaskInfo {
72 pub id: String,
73 pub name: String,
74 pub error: Option<String>,
75}
76
77impl TaskExecutionNotificationEvent {
78 pub fn line_output(task_id: String, output: String) -> Self {
79 Self::LineOutput { task_id, output }
80 }
81
82 pub fn tasks_update(stats: TaskExecutionStats, tasks: Vec<TaskInfo>) -> Self {
83 Self::TasksUpdate { stats, tasks }
84 }
85
86 pub fn tasks_complete(stats: TaskCompletionStats, failed_tasks: Vec<FailedTaskInfo>) -> Self {
87 Self::TasksComplete {
88 stats,
89 failed_tasks,
90 }
91 }
92
93 pub fn to_notification_data(&self) -> serde_json::Value {
95 let mut event_data = serde_json::to_value(self).expect("Failed to serialize event");
96
97 if let serde_json::Value::Object(ref mut map) = event_data {
99 map.insert(
100 "type".to_string(),
101 serde_json::Value::String("task_execution".to_string()),
102 );
103 }
104
105 event_data
106 }
107}
108
109impl TaskExecutionStats {
110 pub fn new(
111 total: usize,
112 pending: usize,
113 running: usize,
114 completed: usize,
115 failed: usize,
116 ) -> Self {
117 Self {
118 total,
119 pending,
120 running,
121 completed,
122 failed,
123 }
124 }
125}
126
127impl TaskCompletionStats {
128 pub fn new(total: usize, completed: usize, failed: usize) -> Self {
129 let success_rate = if total > 0 {
130 (completed as f64 / total as f64) * 100.0
131 } else {
132 0.0
133 };
134
135 Self {
136 total,
137 completed,
138 failed,
139 success_rate,
140 }
141 }
142}
143
144#[cfg(test)]
145mod tests {
146 use super::*;
147
148 #[test]
149 fn test_line_output_event_serialization() {
150 let event = TaskExecutionNotificationEvent::line_output(
151 "task-1".to_string(),
152 "Hello World".to_string(),
153 );
154
155 let notification_data = event.to_notification_data();
156 assert_eq!(notification_data["type"], "task_execution");
157 assert_eq!(notification_data["subtype"], "line_output");
158 assert_eq!(notification_data["task_id"], "task-1");
159 assert_eq!(notification_data["output"], "Hello World");
160 }
161
162 #[test]
163 fn test_tasks_update_event_serialization() {
164 let stats = TaskExecutionStats::new(5, 2, 1, 1, 1);
165 let tasks = vec![TaskInfo {
166 id: "task-1".to_string(),
167 status: TaskStatus::Running,
168 duration_secs: Some(1.5),
169 current_output: "Processing...".to_string(),
170 task_type: "sub_recipe".to_string(),
171 task_name: "test-task".to_string(),
172 task_metadata: "param=value".to_string(),
173 error: None,
174 result_data: None,
175 }];
176
177 let event = TaskExecutionNotificationEvent::tasks_update(stats, tasks);
178 let notification_data = event.to_notification_data();
179
180 assert_eq!(notification_data["type"], "task_execution");
181 assert_eq!(notification_data["subtype"], "tasks_update");
182 assert_eq!(notification_data["stats"]["total"], 5);
183 assert_eq!(notification_data["tasks"].as_array().unwrap().len(), 1);
184 }
185
186 #[test]
187 fn test_event_roundtrip_serialization() {
188 let original_event = TaskExecutionNotificationEvent::line_output(
189 "task-1".to_string(),
190 "Test output".to_string(),
191 );
192
193 let json_data = original_event.to_notification_data();
195
196 let mut event_data = json_data.clone();
198 if let serde_json::Value::Object(ref mut map) = event_data {
199 map.remove("type");
200 }
201
202 let deserialized_event: TaskExecutionNotificationEvent =
203 serde_json::from_value(event_data).expect("Failed to deserialize");
204
205 match (original_event, deserialized_event) {
206 (
207 TaskExecutionNotificationEvent::LineOutput {
208 task_id: id1,
209 output: out1,
210 },
211 TaskExecutionNotificationEvent::LineOutput {
212 task_id: id2,
213 output: out2,
214 },
215 ) => {
216 assert_eq!(id1, id2);
217 assert_eq!(out1, out2);
218 }
219 _ => panic!("Event types don't match after roundtrip"),
220 }
221 }
222}