potato_workflow/workflow/
events.rs

1use chrono::{DateTime, Duration, Utc};
2use potato_agent::agents::{task::TaskStatus, types::ChatResponse};
3use potato_agent::AgentResponse;
4use potato_prompt::Prompt;
5use potato_util::create_uuid7;
6use potato_util::PyHelperFuncs;
7use pyo3::prelude::*;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::sync::Arc;
11use std::sync::RwLock;
12
13#[pyclass]
14#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
15pub struct TaskEvent {
16    #[pyo3(get)]
17    pub id: String,
18    #[pyo3(get)]
19    pub workflow_id: String,
20    #[pyo3(get)]
21    pub task_id: String,
22    #[pyo3(get)]
23    pub status: TaskStatus,
24    #[pyo3(get)]
25    pub timestamp: DateTime<Utc>,
26    #[pyo3(get)]
27    pub updated_at: DateTime<Utc>,
28    #[pyo3(get)]
29    pub details: EventDetails,
30}
31
32#[pymethods]
33impl TaskEvent {
34    pub fn __str__(&self) -> String {
35        PyHelperFuncs::__str__(self)
36    }
37}
38
39#[pyclass]
40#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)]
41pub struct EventDetails {
42    #[serde(skip_serializing_if = "Option::is_none")]
43    #[pyo3(get)]
44    pub prompt: Option<Prompt>,
45    #[serde(skip_serializing_if = "Option::is_none")]
46    #[pyo3(get)]
47    pub response: Option<ChatResponse>,
48    #[serde(skip_serializing_if = "Option::is_none")]
49    #[pyo3(get)]
50    pub duration: Option<Duration>,
51    #[serde(skip_serializing_if = "Option::is_none")]
52    #[pyo3(get)]
53    pub start_time: Option<DateTime<Utc>>,
54    #[serde(skip_serializing_if = "Option::is_none")]
55    #[pyo3(get)]
56    pub end_time: Option<DateTime<Utc>>,
57    #[serde(skip_serializing_if = "Option::is_none")]
58    #[pyo3(get)]
59    pub error: Option<String>,
60}
61
62#[pymethods]
63impl EventDetails {
64    pub fn __str__(&self) -> String {
65        PyHelperFuncs::__str__(self)
66    }
67}
68
69#[derive(Debug, Clone, Default)]
70pub struct EventTracker {
71    workflow_id: String,
72    pub events: Arc<RwLock<Vec<TaskEvent>>>,
73    task_start_times: Arc<RwLock<HashMap<String, DateTime<Utc>>>>,
74}
75
76impl PartialEq for EventTracker {
77    fn eq(&self, other: &Self) -> bool {
78        // Compare workflow_id and events
79        self.workflow_id == other.workflow_id
80    }
81}
82
83impl EventTracker {
84    pub fn new(workflow_id: String) -> Self {
85        Self {
86            workflow_id,
87            events: Arc::new(RwLock::new(Vec::new())),
88            task_start_times: Arc::new(RwLock::new(HashMap::new())),
89        }
90    }
91
92    pub fn is_empty(&self) -> bool {
93        let events = self.events.read().unwrap();
94        events.is_empty()
95    }
96
97    pub fn reset(&self) {
98        let mut events = self.events.write().unwrap();
99        events.clear();
100        let mut task_start_times = self.task_start_times.write().unwrap();
101        task_start_times.clear();
102    }
103
104    /// Creates an event for a task when it is started.
105    /// # Arguments
106    /// * `workflow_id` - The ID of the workflow to which the task belongs.
107    /// * `task_id` - The ID of the task that was started.
108    /// # Returns
109    /// None
110    pub fn record_task_started(&self, task_id: &str) {
111        let now = Utc::now();
112
113        let mut start_times = self.task_start_times.write().unwrap();
114        start_times.insert(task_id.to_string(), now);
115
116        let event = TaskEvent {
117            id: create_uuid7(),
118            workflow_id: self.workflow_id.clone(),
119            task_id: task_id.to_string(),
120            status: TaskStatus::Running,
121            timestamp: now,
122            updated_at: now,
123            details: EventDetails {
124                start_time: Some(now),
125                ..Default::default()
126            },
127        };
128
129        let mut events = self.events.write().unwrap();
130        events.push(event);
131    }
132
133    /// Updates the event for a given task ID when it is completed.
134    /// # Arguments
135    /// * `task_id` - The ID of the task that was completed.
136    /// * `prompt` - The prompt used for the task.
137    /// * `response` - The response received from the task.
138    /// # Returns
139    /// None
140    pub fn record_task_completed(&self, task_id: &str, prompt: &Prompt, response: AgentResponse) {
141        let now = Utc::now();
142        let duration = {
143            let start_times = self.task_start_times.read().unwrap();
144            start_times
145                .get(task_id)
146                .map(|start_time| now.signed_duration_since(*start_time))
147        };
148
149        // update the event details
150        // Update the event
151        let mut events = self.events.write().unwrap();
152
153        // filter to find the event with the matching task_id
154        // and update it
155        let _ = events
156            .iter_mut()
157            .filter_map(|event| {
158                if event.task_id == task_id {
159                    event.status = TaskStatus::Completed;
160                    event.updated_at = now;
161                    event.details.response = Some(response.response.clone());
162                    event.details.duration = duration;
163                    event.details.end_time = Some(now);
164                    event.details.prompt = Some(prompt.clone());
165                    Some(event)
166                } else {
167                    None
168                }
169            })
170            .collect::<Vec<_>>();
171    }
172
173    /// Records a task failure event.
174    /// # Arguments
175    /// * `task_id` - The ID of the task that failed.
176    /// * `error_msg` - The error message associated with the failure.
177    /// * `prompt` - The prompt used for the task.
178    /// # Returns
179    /// None
180    pub fn record_task_failed(&self, task_id: &str, error_msg: &str, prompt: &Prompt) {
181        let now = Utc::now();
182        let duration = {
183            let start_times = self.task_start_times.read().unwrap();
184            start_times
185                .get(task_id)
186                .map(|start_time| now.signed_duration_since(*start_time))
187        };
188
189        // update the event details
190        // Update the event
191        let mut events = self.events.write().unwrap();
192
193        // filter to find the event with the matching task_id
194        // and update it
195        let _ = events
196            .iter_mut()
197            .filter_map(|event| {
198                if event.task_id == task_id {
199                    event.status = TaskStatus::Failed;
200                    event.updated_at = now;
201                    event.details.duration = duration;
202                    event.details.end_time = Some(now);
203                    event.details.prompt = Some(prompt.clone());
204                    event.details.error = Some(error_msg.to_string());
205                    Some(event)
206                } else {
207                    None
208                }
209            })
210            .collect::<Vec<_>>();
211    }
212}