potato_workflow/workflow/
events.rs1use chrono::{DateTime, Duration, Utc};
2use potato_agent::agents::task::TaskStatus;
3use potato_agent::AgentResponse;
4use potato_prompt::Prompt;
5use potato_provider::ChatResponse;
6use potato_util::create_uuid7;
7use potato_util::PyHelperFuncs;
8use pyo3::prelude::*;
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::sync::Arc;
12use std::sync::RwLock;
13
14#[pyclass]
15#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
16pub struct TaskEvent {
17 #[pyo3(get)]
18 pub id: String,
19 #[pyo3(get)]
20 pub workflow_id: String,
21 #[pyo3(get)]
22 pub task_id: String,
23 #[pyo3(get)]
24 pub status: TaskStatus,
25 #[pyo3(get)]
26 pub timestamp: DateTime<Utc>,
27 #[pyo3(get)]
28 pub updated_at: DateTime<Utc>,
29 #[pyo3(get)]
30 pub details: EventDetails,
31}
32
33#[pymethods]
34impl TaskEvent {
35 pub fn __str__(&self) -> String {
36 PyHelperFuncs::__str__(self)
37 }
38}
39
40#[pyclass]
41#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)]
42pub struct EventDetails {
43 #[serde(skip_serializing_if = "Option::is_none")]
44 #[pyo3(get)]
45 pub prompt: Option<Prompt>,
46 #[serde(skip_serializing_if = "Option::is_none")]
47 #[pyo3(get)]
48 pub response: Option<ChatResponse>,
49 #[serde(skip_serializing_if = "Option::is_none")]
50 #[pyo3(get)]
51 pub duration: Option<Duration>,
52 #[serde(skip_serializing_if = "Option::is_none")]
53 #[pyo3(get)]
54 pub start_time: Option<DateTime<Utc>>,
55 #[serde(skip_serializing_if = "Option::is_none")]
56 #[pyo3(get)]
57 pub end_time: Option<DateTime<Utc>>,
58 #[serde(skip_serializing_if = "Option::is_none")]
59 #[pyo3(get)]
60 pub error: Option<String>,
61}
62
63#[pymethods]
64impl EventDetails {
65 pub fn __str__(&self) -> String {
66 PyHelperFuncs::__str__(self)
67 }
68}
69
70#[derive(Debug, Clone, Default)]
71pub struct EventTracker {
72 workflow_id: String,
73 pub events: Arc<RwLock<Vec<TaskEvent>>>,
74 task_start_times: Arc<RwLock<HashMap<String, DateTime<Utc>>>>,
75}
76
77impl PartialEq for EventTracker {
78 fn eq(&self, other: &Self) -> bool {
79 self.workflow_id == other.workflow_id
81 }
82}
83
84impl EventTracker {
85 pub fn new(workflow_id: String) -> Self {
86 Self {
87 workflow_id,
88 events: Arc::new(RwLock::new(Vec::new())),
89 task_start_times: Arc::new(RwLock::new(HashMap::new())),
90 }
91 }
92
93 pub fn is_empty(&self) -> bool {
94 let events = self.events.read().unwrap();
95 events.is_empty()
96 }
97
98 pub fn reset(&self) {
99 let mut events = self.events.write().unwrap();
100 events.clear();
101 let mut task_start_times = self.task_start_times.write().unwrap();
102 task_start_times.clear();
103 }
104
105 pub fn record_task_started(&self, task_id: &str) {
112 let now = Utc::now();
113
114 let mut start_times = self.task_start_times.write().unwrap();
115 start_times.insert(task_id.to_string(), now);
116
117 let event = TaskEvent {
118 id: create_uuid7(),
119 workflow_id: self.workflow_id.clone(),
120 task_id: task_id.to_string(),
121 status: TaskStatus::Running,
122 timestamp: now,
123 updated_at: now,
124 details: EventDetails {
125 start_time: Some(now),
126 ..Default::default()
127 },
128 };
129
130 let mut events = self.events.write().unwrap();
131 events.push(event);
132 }
133
134 pub fn record_task_completed(&self, task_id: &str, prompt: &Prompt, response: AgentResponse) {
142 let now = Utc::now();
143 let duration = {
144 let start_times = self.task_start_times.read().unwrap();
145 start_times
146 .get(task_id)
147 .map(|start_time| now.signed_duration_since(*start_time))
148 };
149
150 let mut events = self.events.write().unwrap();
153
154 let _ = events
157 .iter_mut()
158 .filter_map(|event| {
159 if event.task_id == task_id {
160 event.status = TaskStatus::Completed;
161 event.updated_at = now;
162 event.details.response = Some(response.response.clone());
163 event.details.duration = duration;
164 event.details.end_time = Some(now);
165 event.details.prompt = Some(prompt.clone());
166 Some(event)
167 } else {
168 None
169 }
170 })
171 .collect::<Vec<_>>();
172 }
173
174 pub fn record_task_failed(&self, task_id: &str, error_msg: &str, prompt: &Prompt) {
182 let now = Utc::now();
183 let duration = {
184 let start_times = self.task_start_times.read().unwrap();
185 start_times
186 .get(task_id)
187 .map(|start_time| now.signed_duration_since(*start_time))
188 };
189
190 let mut events = self.events.write().unwrap();
193
194 let _ = events
197 .iter_mut()
198 .filter_map(|event| {
199 if event.task_id == task_id {
200 event.status = TaskStatus::Failed;
201 event.updated_at = now;
202 event.details.duration = duration;
203 event.details.end_time = Some(now);
204 event.details.prompt = Some(prompt.clone());
205 event.details.error = Some(error_msg.to_string());
206 Some(event)
207 } else {
208 None
209 }
210 })
211 .collect::<Vec<_>>();
212 }
213}