pipeworks_tasks/
task_state.rs

1use std::{
2    collections::HashMap,
3    fmt,
4    future::Future,
5    sync::{Arc, RwLock},
6    time::{Duration, Instant},
7};
8
9use tokio::sync::broadcast;
10use tracing::{span, Instrument, Level, Subscriber};
11use tracing_subscriber::{layer::Context, registry::LookupSpan, Layer};
12
13use crate::task_tree::TaskId;
14
15tokio::task_local! {
16    /// All states, indexed by their Task ID in the FULL transitive task tree (including all
17    /// subtrees). The values are the same Arcs stored thread-local in TASK_STATE
18    pub static ALL_TASK_STATES: Arc<RwLock<HashMap<TaskId, Arc<RwLock<TaskState>>>>>;
19
20    /// Shared context with all tasks in the full trasitive tree
21    pub static CURRENT_TASK_STATE: Arc<RwLock<TaskState>>;
22
23    /// An optional MPSC channel to stream all tracing events to.
24    pub static EVENT_CHANNEL: broadcast::Sender<TaskEvent>;
25}
26
27#[derive(Debug)]
28pub struct TaskState {
29    pub name: Arc<String>,
30    pub id: TaskId,
31    pub parent_id: Option<TaskId>,
32    pub started_at: Option<Instant>,
33    pub closed_at: Option<Instant>,
34    pub active_for: Duration,
35    active_start_instant: Option<Instant>,
36    is_registered: bool,
37}
38
39#[derive(Clone, PartialEq, Eq, Hash, Debug)]
40pub enum TaskEvent {
41    Enter {
42        name: Arc<String>,
43        id: TaskId,
44        parent_id: Option<TaskId>,
45    },
46    Event {
47        name: Arc<String>,
48        id: TaskId,
49        parent_id: Option<TaskId>,
50        started_at: Instant,
51        active_for: Duration,
52        message: Option<String>,
53    },
54    Closed {
55        name: Arc<String>,
56        id: TaskId,
57        parent_id: Option<TaskId>,
58        started_at: Instant,
59        closed_at: Instant,
60        active_for: Duration,
61    },
62}
63
64pub struct TaskStateTrackingLayer;
65
66pub fn subscribe_task_events() -> broadcast::Receiver<TaskEvent> {
67    EVENT_CHANNEL.with(|channel| channel.subscribe())
68}
69
70impl TaskState {
71    pub fn scope<F>(
72        name: impl Into<String>,
73        id: TaskId,
74        parent_id: Option<TaskId>,
75        future: F,
76    ) -> impl Future<Output = F::Output>
77    where
78        F: Future<Output = ()> + Send + Sync + 'static,
79    {
80        let name: Arc<String> = Arc::new(name.into());
81        let this = Arc::new(RwLock::new(Self {
82            name: name.clone(),
83            id,
84            parent_id,
85            started_at: None,
86            closed_at: None,
87            active_for: Duration::default(),
88            active_start_instant: None,
89            is_registered: false,
90        }));
91
92        // Instrument the future
93        let span = tracing::span!(
94            Level::INFO,
95            "pipeworks-task",
96            name = name.as_ref(),
97            task_id = id.0,
98            parent_task_id = parent_id.map(|id| id.0)
99        );
100        let future = future.instrument(span);
101
102        // Provide the current and all-task states.
103        let all_task_states = ALL_TASK_STATES.try_with(Clone::clone).unwrap_or_default();
104        let future = ALL_TASK_STATES.scope(all_task_states, future);
105        let future = CURRENT_TASK_STATE.scope(this.clone(), future);
106        let future = EVENT_CHANNEL.scope(
107            EVENT_CHANNEL
108                .try_with(Clone::clone)
109                .ok()
110                .unwrap_or_else(|| broadcast::channel(1024).0),
111            future,
112        );
113
114        future
115    }
116}
117
118impl<S> Layer<S> for TaskStateTrackingLayer
119where
120    S: Subscriber + for<'a> LookupSpan<'a>, // `LookupSpan` is needed to get parent IDs
121{
122    fn on_enter(&self, _id: &span::Id, _ctx: Context<'_, S>) {
123        let now = Instant::now();
124        let _ = CURRENT_TASK_STATE.try_with(|state_rw| {
125            let mut state = state_rw.write().unwrap();
126            state.started_at.get_or_insert(now);
127            state.active_start_instant = Some(now);
128
129            if !state.is_registered {
130                state.is_registered = true;
131                ALL_TASK_STATES.with(|all_tasks| {
132                    all_tasks
133                        .write()
134                        .unwrap()
135                        .insert(TaskId::current(), state_rw.clone())
136                });
137            }
138
139            let _ = EVENT_CHANNEL.try_with(|channel| {
140                let _ = channel.send(TaskEvent::Enter {
141                    name: state.name.clone(),
142                    id: state.id,
143                    parent_id: state.parent_id,
144                });
145            });
146        });
147    }
148
149    fn on_event(&self, event: &tracing::Event, _ctx: Context<'_, S>) {
150        #[derive(Default)]
151        struct EventMessageVisitor {
152            message: Option<String>,
153        }
154
155        impl tracing::field::Visit for EventMessageVisitor {
156            fn record_debug(&mut self, field: &tracing::field::Field, value: &dyn fmt::Debug) {
157                if field.name() == "message" {
158                    self.message = Some(format!("{:?}", value));
159                }
160            }
161        }
162
163        let _ = CURRENT_TASK_STATE.try_with(|state| {
164            let state = state.read().unwrap();
165
166            let _ = EVENT_CHANNEL.try_with(|channel| {
167                let mut visitor = EventMessageVisitor::default();
168                event.record(&mut visitor);
169
170                let _ = channel.send(TaskEvent::Event {
171                    name: state.name.clone(),
172                    id: state.id,
173                    parent_id: state.parent_id,
174                    started_at: state.started_at.unwrap(),
175                    active_for: state.active_for,
176                    message: visitor.message,
177                });
178            });
179        });
180    }
181
182    fn on_exit(&self, _id: &span::Id, _ctx: Context<'_, S>) {
183        let _ = CURRENT_TASK_STATE.try_with(|state| {
184            let mut state = state.write().unwrap();
185            if let Some(active_start) = state.active_start_instant.take() {
186                state.active_for += Instant::now().duration_since(active_start);
187            }
188        });
189    }
190
191    fn on_close(&self, _id: tracing::Id, _ctx: Context<'_, S>) {
192        let _ = CURRENT_TASK_STATE.try_with(|state| {
193            let mut state = state.write().unwrap();
194            let now = Instant::now();
195            state.closed_at = Some(now);
196            if let Some(active_start) = state.active_start_instant.take() {
197                state.active_for += now.duration_since(active_start);
198            }
199
200            if state.is_registered {
201                ALL_TASK_STATES.with(|all_tasks| {
202                    all_tasks.write().unwrap().remove(&TaskId::current());
203                });
204            }
205
206            let _ = EVENT_CHANNEL.try_with(|channel| {
207                let _ = channel.send(TaskEvent::Closed {
208                    name: state.name.clone(),
209                    id: state.id,
210                    parent_id: state.parent_id,
211                    started_at: state.started_at.unwrap(),
212                    closed_at: state.closed_at.unwrap(),
213                    active_for: state.active_for,
214                });
215            });
216        });
217    }
218}
219
220impl std::fmt::Display for TaskEvent {
221    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
222        match self {
223            TaskEvent::Enter {
224                name,
225                id,
226                parent_id,
227            } => {
228                write!(
229                    f,
230                    "Enter(name=\"{}\", id={:?}, parent_id={:?})",
231                    name, id, parent_id
232                )
233            }
234            TaskEvent::Event {
235                name,
236                id,
237                parent_id,
238                started_at,
239                active_for,
240                message,
241            } => {
242                let wall_time = std::time::Instant::now()
243                    .duration_since(*started_at)
244                    .as_secs_f64();
245                write!(
246                    f,
247                    "Event(name=\"{}\", id={:?}, parent_id={:?}, started_at={:.3}s ago, active_for={:.3}s): {}",
248                    name,
249                    id,
250                    parent_id,
251                    wall_time,
252                    active_for.as_secs_f64(),
253                    message.as_ref().cloned().unwrap_or_default()
254                )
255            }
256            TaskEvent::Closed {
257                name,
258                id,
259                parent_id,
260                started_at,
261                closed_at,
262                active_for,
263            } => {
264                let wall_time = closed_at.duration_since(*started_at).as_secs_f64();
265                write!(
266                    f,
267                    "Closed(name=\"{}\", id={:?}, parent_id={:?}, wall_time={:.3} active_for={:.3}s)",
268                    name,
269                    id,
270                    parent_id,
271                    wall_time,
272                    active_for.as_secs_f64(),
273                )
274            }
275        }
276    }
277}