Skip to main content

potato_workflow/workflow/
events.rs

1use chrono::{DateTime, Duration, Utc};
2use potato_agent::agents::task::TaskStatus;
3use potato_agent::AgentResponse;
4use potato_provider::ChatResponse;
5use potato_type::prompt::Prompt;
6use potato_util::create_uuid7;
7use potato_util::PyHelperFuncs;
8use pyo3::prelude::*;
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::sync::Arc;
12use std::sync::RwLock;
13
14#[pyclass]
15#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
16pub struct TaskEvent {
17    #[pyo3(get)]
18    pub id: String,
19    #[pyo3(get)]
20    pub workflow_id: String,
21    #[pyo3(get)]
22    pub task_id: String,
23    #[pyo3(get)]
24    pub status: TaskStatus,
25    #[pyo3(get)]
26    pub timestamp: DateTime<Utc>,
27    #[pyo3(get)]
28    pub updated_at: DateTime<Utc>,
29    #[pyo3(get)]
30    pub details: EventDetails,
31}
32
33#[pymethods]
34impl TaskEvent {
35    pub fn __str__(&self) -> String {
36        PyHelperFuncs::__str__(self)
37    }
38}
39
40#[pyclass]
41#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)]
42pub struct EventDetails {
43    #[serde(skip_serializing_if = "Option::is_none")]
44    #[pyo3(get)]
45    pub prompt: Option<Prompt>,
46    #[serde(skip_serializing_if = "Option::is_none")]
47    pub response: Option<ChatResponse>,
48    #[serde(skip_serializing_if = "Option::is_none")]
49    #[pyo3(get)]
50    pub duration: Option<Duration>,
51    #[serde(skip_serializing_if = "Option::is_none")]
52    #[pyo3(get)]
53    pub start_time: Option<DateTime<Utc>>,
54    #[serde(skip_serializing_if = "Option::is_none")]
55    #[pyo3(get)]
56    pub end_time: Option<DateTime<Utc>>,
57    #[serde(skip_serializing_if = "Option::is_none")]
58    #[pyo3(get)]
59    pub error: Option<String>,
60}
61
62#[pymethods]
63impl EventDetails {
64    pub fn __str__(&self) -> String {
65        PyHelperFuncs::__str__(self)
66    }
67
68    #[getter]
69    pub fn response<'py>(&self, py: Python<'py>) -> PyResult<Option<Bound<'py, PyAny>>> {
70        match &self.response {
71            Some(resp) => {
72                let response = resp.to_bound_py_object(py)?;
73                Ok(Some(response))
74            }
75            None => Ok(None),
76        }
77    }
78}
79
80#[derive(Debug, Clone, Default)]
81pub struct EventTracker {
82    workflow_id: String,
83    pub events: Arc<RwLock<Vec<TaskEvent>>>,
84    task_start_times: Arc<RwLock<HashMap<String, DateTime<Utc>>>>,
85}
86
87impl PartialEq for EventTracker {
88    fn eq(&self, other: &Self) -> bool {
89        // Compare workflow_id and events
90        self.workflow_id == other.workflow_id
91    }
92}
93
94impl EventTracker {
95    pub fn new(workflow_id: String) -> Self {
96        Self {
97            workflow_id,
98            events: Arc::new(RwLock::new(Vec::new())),
99            task_start_times: Arc::new(RwLock::new(HashMap::new())),
100        }
101    }
102
103    pub fn is_empty(&self) -> bool {
104        let events = self.events.read().unwrap();
105        events.is_empty()
106    }
107
108    pub fn reset(&self) {
109        let mut events = self.events.write().unwrap();
110        events.clear();
111        let mut task_start_times = self.task_start_times.write().unwrap();
112        task_start_times.clear();
113    }
114
115    /// Creates an event for a task when it is started.
116    /// # Arguments
117    /// * `workflow_id` - The ID of the workflow to which the task belongs.
118    /// * `task_id` - The ID of the task that was started.
119    /// # Returns
120    /// None
121    pub fn record_task_started(&self, task_id: &str) {
122        let now = Utc::now();
123
124        let mut start_times = self.task_start_times.write().unwrap();
125        start_times.insert(task_id.to_string(), now);
126
127        let event = TaskEvent {
128            id: create_uuid7(),
129            workflow_id: self.workflow_id.clone(),
130            task_id: task_id.to_string(),
131            status: TaskStatus::Running,
132            timestamp: now,
133            updated_at: now,
134            details: EventDetails {
135                start_time: Some(now),
136                ..Default::default()
137            },
138        };
139
140        let mut events = self.events.write().unwrap();
141        events.push(event);
142    }
143
144    /// Updates the event for a given task ID when it is completed.
145    /// # Arguments
146    /// * `task_id` - The ID of the task that was completed.
147    /// * `prompt` - The prompt used for the task.
148    /// * `response` - The response received from the task.
149    /// # Returns
150    /// None
151    pub fn record_task_completed(&self, task_id: &str, prompt: &Prompt, response: AgentResponse) {
152        let now = Utc::now();
153        let duration = {
154            let start_times = self.task_start_times.read().unwrap();
155            start_times
156                .get(task_id)
157                .map(|start_time| now.signed_duration_since(*start_time))
158        };
159
160        // update the event details
161        // Update the event
162        let mut events = self.events.write().unwrap();
163
164        // filter to find the event with the matching task_id
165        // and update it
166        let _ = events
167            .iter_mut()
168            .filter_map(|event| {
169                if event.task_id == task_id {
170                    event.status = TaskStatus::Completed;
171                    event.updated_at = now;
172                    event.details.response = Some(response.response.clone());
173                    event.details.duration = duration;
174                    event.details.end_time = Some(now);
175                    event.details.prompt = Some(prompt.clone());
176                    Some(event)
177                } else {
178                    None
179                }
180            })
181            .collect::<Vec<_>>();
182    }
183
184    /// Records a task failure event.
185    /// # Arguments
186    /// * `task_id` - The ID of the task that failed.
187    /// * `error_msg` - The error message associated with the failure.
188    /// * `prompt` - The prompt used for the task.
189    /// # Returns
190    /// None
191    pub fn record_task_failed(&self, task_id: &str, error_msg: &str, prompt: &Prompt) {
192        let now = Utc::now();
193        let duration = {
194            let start_times = self.task_start_times.read().unwrap();
195            start_times
196                .get(task_id)
197                .map(|start_time| now.signed_duration_since(*start_time))
198        };
199
200        // update the event details
201        // Update the event
202        let mut events = self.events.write().unwrap();
203
204        // filter to find the event with the matching task_id
205        // and update it
206        let _ = events
207            .iter_mut()
208            .filter_map(|event| {
209                if event.task_id == task_id {
210                    event.status = TaskStatus::Failed;
211                    event.updated_at = now;
212                    event.details.duration = duration;
213                    event.details.end_time = Some(now);
214                    event.details.prompt = Some(prompt.clone());
215                    event.details.error = Some(error_msg.to_string());
216                    Some(event)
217                } else {
218                    None
219                }
220            })
221            .collect::<Vec<_>>();
222    }
223}