potato_workflow/workflow/
events.rs

1use chrono::{DateTime, Duration, Utc};
2use potato_agent::agents::task::TaskStatus;
3use potato_agent::AgentResponse;
4use potato_prompt::Prompt;
5use potato_provider::ChatResponse;
6use potato_util::create_uuid7;
7use potato_util::PyHelperFuncs;
8use pyo3::prelude::*;
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::sync::Arc;
12use std::sync::RwLock;
13
14#[pyclass]
15#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
16pub struct TaskEvent {
17    #[pyo3(get)]
18    pub id: String,
19    #[pyo3(get)]
20    pub workflow_id: String,
21    #[pyo3(get)]
22    pub task_id: String,
23    #[pyo3(get)]
24    pub status: TaskStatus,
25    #[pyo3(get)]
26    pub timestamp: DateTime<Utc>,
27    #[pyo3(get)]
28    pub updated_at: DateTime<Utc>,
29    #[pyo3(get)]
30    pub details: EventDetails,
31}
32
33#[pymethods]
34impl TaskEvent {
35    pub fn __str__(&self) -> String {
36        PyHelperFuncs::__str__(self)
37    }
38}
39
40#[pyclass]
41#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)]
42pub struct EventDetails {
43    #[serde(skip_serializing_if = "Option::is_none")]
44    #[pyo3(get)]
45    pub prompt: Option<Prompt>,
46    #[serde(skip_serializing_if = "Option::is_none")]
47    #[pyo3(get)]
48    pub response: Option<ChatResponse>,
49    #[serde(skip_serializing_if = "Option::is_none")]
50    #[pyo3(get)]
51    pub duration: Option<Duration>,
52    #[serde(skip_serializing_if = "Option::is_none")]
53    #[pyo3(get)]
54    pub start_time: Option<DateTime<Utc>>,
55    #[serde(skip_serializing_if = "Option::is_none")]
56    #[pyo3(get)]
57    pub end_time: Option<DateTime<Utc>>,
58    #[serde(skip_serializing_if = "Option::is_none")]
59    #[pyo3(get)]
60    pub error: Option<String>,
61}
62
63#[pymethods]
64impl EventDetails {
65    pub fn __str__(&self) -> String {
66        PyHelperFuncs::__str__(self)
67    }
68}
69
70#[derive(Debug, Clone, Default)]
71pub struct EventTracker {
72    workflow_id: String,
73    pub events: Arc<RwLock<Vec<TaskEvent>>>,
74    task_start_times: Arc<RwLock<HashMap<String, DateTime<Utc>>>>,
75}
76
77impl PartialEq for EventTracker {
78    fn eq(&self, other: &Self) -> bool {
79        // Compare workflow_id and events
80        self.workflow_id == other.workflow_id
81    }
82}
83
84impl EventTracker {
85    pub fn new(workflow_id: String) -> Self {
86        Self {
87            workflow_id,
88            events: Arc::new(RwLock::new(Vec::new())),
89            task_start_times: Arc::new(RwLock::new(HashMap::new())),
90        }
91    }
92
93    pub fn is_empty(&self) -> bool {
94        let events = self.events.read().unwrap();
95        events.is_empty()
96    }
97
98    pub fn reset(&self) {
99        let mut events = self.events.write().unwrap();
100        events.clear();
101        let mut task_start_times = self.task_start_times.write().unwrap();
102        task_start_times.clear();
103    }
104
105    /// Creates an event for a task when it is started.
106    /// # Arguments
107    /// * `workflow_id` - The ID of the workflow to which the task belongs.
108    /// * `task_id` - The ID of the task that was started.
109    /// # Returns
110    /// None
111    pub fn record_task_started(&self, task_id: &str) {
112        let now = Utc::now();
113
114        let mut start_times = self.task_start_times.write().unwrap();
115        start_times.insert(task_id.to_string(), now);
116
117        let event = TaskEvent {
118            id: create_uuid7(),
119            workflow_id: self.workflow_id.clone(),
120            task_id: task_id.to_string(),
121            status: TaskStatus::Running,
122            timestamp: now,
123            updated_at: now,
124            details: EventDetails {
125                start_time: Some(now),
126                ..Default::default()
127            },
128        };
129
130        let mut events = self.events.write().unwrap();
131        events.push(event);
132    }
133
134    /// Updates the event for a given task ID when it is completed.
135    /// # Arguments
136    /// * `task_id` - The ID of the task that was completed.
137    /// * `prompt` - The prompt used for the task.
138    /// * `response` - The response received from the task.
139    /// # Returns
140    /// None
141    pub fn record_task_completed(&self, task_id: &str, prompt: &Prompt, response: AgentResponse) {
142        let now = Utc::now();
143        let duration = {
144            let start_times = self.task_start_times.read().unwrap();
145            start_times
146                .get(task_id)
147                .map(|start_time| now.signed_duration_since(*start_time))
148        };
149
150        // update the event details
151        // Update the event
152        let mut events = self.events.write().unwrap();
153
154        // filter to find the event with the matching task_id
155        // and update it
156        let _ = events
157            .iter_mut()
158            .filter_map(|event| {
159                if event.task_id == task_id {
160                    event.status = TaskStatus::Completed;
161                    event.updated_at = now;
162                    event.details.response = Some(response.response.clone());
163                    event.details.duration = duration;
164                    event.details.end_time = Some(now);
165                    event.details.prompt = Some(prompt.clone());
166                    Some(event)
167                } else {
168                    None
169                }
170            })
171            .collect::<Vec<_>>();
172    }
173
174    /// Records a task failure event.
175    /// # Arguments
176    /// * `task_id` - The ID of the task that failed.
177    /// * `error_msg` - The error message associated with the failure.
178    /// * `prompt` - The prompt used for the task.
179    /// # Returns
180    /// None
181    pub fn record_task_failed(&self, task_id: &str, error_msg: &str, prompt: &Prompt) {
182        let now = Utc::now();
183        let duration = {
184            let start_times = self.task_start_times.read().unwrap();
185            start_times
186                .get(task_id)
187                .map(|start_time| now.signed_duration_since(*start_time))
188        };
189
190        // update the event details
191        // Update the event
192        let mut events = self.events.write().unwrap();
193
194        // filter to find the event with the matching task_id
195        // and update it
196        let _ = events
197            .iter_mut()
198            .filter_map(|event| {
199                if event.task_id == task_id {
200                    event.status = TaskStatus::Failed;
201                    event.updated_at = now;
202                    event.details.duration = duration;
203                    event.details.end_time = Some(now);
204                    event.details.prompt = Some(prompt.clone());
205                    event.details.error = Some(error_msg.to_string());
206                    Some(event)
207                } else {
208                    None
209                }
210            })
211            .collect::<Vec<_>>();
212    }
213}