Skip to main content

aster/agents/subagent_execution_tool/
notification_events.rs

1use serde::{Deserialize, Serialize};
2use serde_json::Value;
3
4#[derive(Debug, Clone, Serialize, Deserialize)]
5pub enum TaskStatus {
6    Pending,
7    Running,
8    Completed,
9    Failed,
10}
11
12impl std::fmt::Display for TaskStatus {
13    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
14        match self {
15            TaskStatus::Pending => write!(f, "Pending"),
16            TaskStatus::Running => write!(f, "Running"),
17            TaskStatus::Completed => write!(f, "Completed"),
18            TaskStatus::Failed => write!(f, "Failed"),
19        }
20    }
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
24#[serde(tag = "subtype")]
25pub enum TaskExecutionNotificationEvent {
26    #[serde(rename = "line_output")]
27    LineOutput { task_id: String, output: String },
28    #[serde(rename = "tasks_update")]
29    TasksUpdate {
30        stats: TaskExecutionStats,
31        tasks: Vec<TaskInfo>,
32    },
33    #[serde(rename = "tasks_complete")]
34    TasksComplete {
35        stats: TaskCompletionStats,
36        failed_tasks: Vec<FailedTaskInfo>,
37    },
38}
39
40#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct TaskExecutionStats {
42    pub total: usize,
43    pub pending: usize,
44    pub running: usize,
45    pub completed: usize,
46    pub failed: usize,
47}
48
49#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct TaskCompletionStats {
51    pub total: usize,
52    pub completed: usize,
53    pub failed: usize,
54    pub success_rate: f64,
55}
56
57#[derive(Debug, Clone, Serialize, Deserialize)]
58pub struct TaskInfo {
59    pub id: String,
60    pub status: TaskStatus,
61    pub duration_secs: Option<f64>,
62    pub current_output: String,
63    pub task_type: String,
64    pub task_name: String,
65    pub task_metadata: String,
66    pub error: Option<String>,
67    pub result_data: Option<Value>,
68}
69
70#[derive(Debug, Clone, Serialize, Deserialize)]
71pub struct FailedTaskInfo {
72    pub id: String,
73    pub name: String,
74    pub error: Option<String>,
75}
76
77impl TaskExecutionNotificationEvent {
78    pub fn line_output(task_id: String, output: String) -> Self {
79        Self::LineOutput { task_id, output }
80    }
81
82    pub fn tasks_update(stats: TaskExecutionStats, tasks: Vec<TaskInfo>) -> Self {
83        Self::TasksUpdate { stats, tasks }
84    }
85
86    pub fn tasks_complete(stats: TaskCompletionStats, failed_tasks: Vec<FailedTaskInfo>) -> Self {
87        Self::TasksComplete {
88            stats,
89            failed_tasks,
90        }
91    }
92
93    /// Convert event to JSON format for MCP notification
94    pub fn to_notification_data(&self) -> serde_json::Value {
95        let mut event_data = serde_json::to_value(self).expect("Failed to serialize event");
96
97        // Add the type field at the root level
98        if let serde_json::Value::Object(ref mut map) = event_data {
99            map.insert(
100                "type".to_string(),
101                serde_json::Value::String("task_execution".to_string()),
102            );
103        }
104
105        event_data
106    }
107}
108
109impl TaskExecutionStats {
110    pub fn new(
111        total: usize,
112        pending: usize,
113        running: usize,
114        completed: usize,
115        failed: usize,
116    ) -> Self {
117        Self {
118            total,
119            pending,
120            running,
121            completed,
122            failed,
123        }
124    }
125}
126
127impl TaskCompletionStats {
128    pub fn new(total: usize, completed: usize, failed: usize) -> Self {
129        let success_rate = if total > 0 {
130            (completed as f64 / total as f64) * 100.0
131        } else {
132            0.0
133        };
134
135        Self {
136            total,
137            completed,
138            failed,
139            success_rate,
140        }
141    }
142}
143
144#[cfg(test)]
145mod tests {
146    use super::*;
147
148    #[test]
149    fn test_line_output_event_serialization() {
150        let event = TaskExecutionNotificationEvent::line_output(
151            "task-1".to_string(),
152            "Hello World".to_string(),
153        );
154
155        let notification_data = event.to_notification_data();
156        assert_eq!(notification_data["type"], "task_execution");
157        assert_eq!(notification_data["subtype"], "line_output");
158        assert_eq!(notification_data["task_id"], "task-1");
159        assert_eq!(notification_data["output"], "Hello World");
160    }
161
162    #[test]
163    fn test_tasks_update_event_serialization() {
164        let stats = TaskExecutionStats::new(5, 2, 1, 1, 1);
165        let tasks = vec![TaskInfo {
166            id: "task-1".to_string(),
167            status: TaskStatus::Running,
168            duration_secs: Some(1.5),
169            current_output: "Processing...".to_string(),
170            task_type: "sub_recipe".to_string(),
171            task_name: "test-task".to_string(),
172            task_metadata: "param=value".to_string(),
173            error: None,
174            result_data: None,
175        }];
176
177        let event = TaskExecutionNotificationEvent::tasks_update(stats, tasks);
178        let notification_data = event.to_notification_data();
179
180        assert_eq!(notification_data["type"], "task_execution");
181        assert_eq!(notification_data["subtype"], "tasks_update");
182        assert_eq!(notification_data["stats"]["total"], 5);
183        assert_eq!(notification_data["tasks"].as_array().unwrap().len(), 1);
184    }
185
186    #[test]
187    fn test_event_roundtrip_serialization() {
188        let original_event = TaskExecutionNotificationEvent::line_output(
189            "task-1".to_string(),
190            "Test output".to_string(),
191        );
192
193        // Serialize to JSON
194        let json_data = original_event.to_notification_data();
195
196        // Deserialize back to event (excluding the type field)
197        let mut event_data = json_data.clone();
198        if let serde_json::Value::Object(ref mut map) = event_data {
199            map.remove("type");
200        }
201
202        let deserialized_event: TaskExecutionNotificationEvent =
203            serde_json::from_value(event_data).expect("Failed to deserialize");
204
205        match (original_event, deserialized_event) {
206            (
207                TaskExecutionNotificationEvent::LineOutput {
208                    task_id: id1,
209                    output: out1,
210                },
211                TaskExecutionNotificationEvent::LineOutput {
212                    task_id: id2,
213                    output: out2,
214                },
215            ) => {
216                assert_eq!(id1, id2);
217                assert_eq!(out1, out2);
218            }
219            _ => panic!("Event types don't match after roundtrip"),
220        }
221    }
222}