1use std::{
2 collections::HashMap,
3 fmt,
4 future::Future,
5 sync::{Arc, RwLock},
6 time::{Duration, Instant},
7};
8
9use tokio::sync::broadcast;
10use tracing::{span, Instrument, Level, Subscriber};
11use tracing_subscriber::{layer::Context, registry::LookupSpan, Layer};
12
13use crate::task_tree::TaskId;
14
15tokio::task_local! {
16 pub static ALL_TASK_STATES: Arc<RwLock<HashMap<TaskId, Arc<RwLock<TaskState>>>>>;
19
20 pub static CURRENT_TASK_STATE: Arc<RwLock<TaskState>>;
22
23 pub static EVENT_CHANNEL: broadcast::Sender<TaskEvent>;
25}
26
27#[derive(Debug)]
28pub struct TaskState {
29 pub name: Arc<String>,
30 pub id: TaskId,
31 pub parent_id: Option<TaskId>,
32 pub started_at: Option<Instant>,
33 pub closed_at: Option<Instant>,
34 pub active_for: Duration,
35 active_start_instant: Option<Instant>,
36 is_registered: bool,
37}
38
39#[derive(Clone, PartialEq, Eq, Hash, Debug)]
40pub enum TaskEvent {
41 Enter {
42 name: Arc<String>,
43 id: TaskId,
44 parent_id: Option<TaskId>,
45 },
46 Event {
47 name: Arc<String>,
48 id: TaskId,
49 parent_id: Option<TaskId>,
50 started_at: Instant,
51 active_for: Duration,
52 message: Option<String>,
53 },
54 Closed {
55 name: Arc<String>,
56 id: TaskId,
57 parent_id: Option<TaskId>,
58 started_at: Instant,
59 closed_at: Instant,
60 active_for: Duration,
61 },
62}
63
64pub struct TaskStateTrackingLayer;
65
66pub fn subscribe_task_events() -> broadcast::Receiver<TaskEvent> {
67 EVENT_CHANNEL.with(|channel| channel.subscribe())
68}
69
70impl TaskState {
71 pub fn scope<F>(
72 name: impl Into<String>,
73 id: TaskId,
74 parent_id: Option<TaskId>,
75 future: F,
76 ) -> impl Future<Output = F::Output>
77 where
78 F: Future<Output = ()> + Send + Sync + 'static,
79 {
80 let name: Arc<String> = Arc::new(name.into());
81 let this = Arc::new(RwLock::new(Self {
82 name: name.clone(),
83 id,
84 parent_id,
85 started_at: None,
86 closed_at: None,
87 active_for: Duration::default(),
88 active_start_instant: None,
89 is_registered: false,
90 }));
91
92 let span = tracing::span!(
94 Level::INFO,
95 "pipeworks-task",
96 name = name.as_ref(),
97 task_id = id.0,
98 parent_task_id = parent_id.map(|id| id.0)
99 );
100 let future = future.instrument(span);
101
102 let all_task_states = ALL_TASK_STATES.try_with(Clone::clone).unwrap_or_default();
104 let future = ALL_TASK_STATES.scope(all_task_states, future);
105 let future = CURRENT_TASK_STATE.scope(this.clone(), future);
106 let future = EVENT_CHANNEL.scope(
107 EVENT_CHANNEL
108 .try_with(Clone::clone)
109 .ok()
110 .unwrap_or_else(|| broadcast::channel(1024).0),
111 future,
112 );
113
114 future
115 }
116}
117
118impl<S> Layer<S> for TaskStateTrackingLayer
119where
120 S: Subscriber + for<'a> LookupSpan<'a>, {
122 fn on_enter(&self, _id: &span::Id, _ctx: Context<'_, S>) {
123 let now = Instant::now();
124 let _ = CURRENT_TASK_STATE.try_with(|state_rw| {
125 let mut state = state_rw.write().unwrap();
126 state.started_at.get_or_insert(now);
127 state.active_start_instant = Some(now);
128
129 if !state.is_registered {
130 state.is_registered = true;
131 ALL_TASK_STATES.with(|all_tasks| {
132 all_tasks
133 .write()
134 .unwrap()
135 .insert(TaskId::current(), state_rw.clone())
136 });
137 }
138
139 let _ = EVENT_CHANNEL.try_with(|channel| {
140 let _ = channel.send(TaskEvent::Enter {
141 name: state.name.clone(),
142 id: state.id,
143 parent_id: state.parent_id,
144 });
145 });
146 });
147 }
148
149 fn on_event(&self, event: &tracing::Event, _ctx: Context<'_, S>) {
150 #[derive(Default)]
151 struct EventMessageVisitor {
152 message: Option<String>,
153 }
154
155 impl tracing::field::Visit for EventMessageVisitor {
156 fn record_debug(&mut self, field: &tracing::field::Field, value: &dyn fmt::Debug) {
157 if field.name() == "message" {
158 self.message = Some(format!("{:?}", value));
159 }
160 }
161 }
162
163 let _ = CURRENT_TASK_STATE.try_with(|state| {
164 let state = state.read().unwrap();
165
166 let _ = EVENT_CHANNEL.try_with(|channel| {
167 let mut visitor = EventMessageVisitor::default();
168 event.record(&mut visitor);
169
170 let _ = channel.send(TaskEvent::Event {
171 name: state.name.clone(),
172 id: state.id,
173 parent_id: state.parent_id,
174 started_at: state.started_at.unwrap(),
175 active_for: state.active_for,
176 message: visitor.message,
177 });
178 });
179 });
180 }
181
182 fn on_exit(&self, _id: &span::Id, _ctx: Context<'_, S>) {
183 let _ = CURRENT_TASK_STATE.try_with(|state| {
184 let mut state = state.write().unwrap();
185 if let Some(active_start) = state.active_start_instant.take() {
186 state.active_for += Instant::now().duration_since(active_start);
187 }
188 });
189 }
190
191 fn on_close(&self, _id: tracing::Id, _ctx: Context<'_, S>) {
192 let _ = CURRENT_TASK_STATE.try_with(|state| {
193 let mut state = state.write().unwrap();
194 let now = Instant::now();
195 state.closed_at = Some(now);
196 if let Some(active_start) = state.active_start_instant.take() {
197 state.active_for += now.duration_since(active_start);
198 }
199
200 if state.is_registered {
201 ALL_TASK_STATES.with(|all_tasks| {
202 all_tasks.write().unwrap().remove(&TaskId::current());
203 });
204 }
205
206 let _ = EVENT_CHANNEL.try_with(|channel| {
207 let _ = channel.send(TaskEvent::Closed {
208 name: state.name.clone(),
209 id: state.id,
210 parent_id: state.parent_id,
211 started_at: state.started_at.unwrap(),
212 closed_at: state.closed_at.unwrap(),
213 active_for: state.active_for,
214 });
215 });
216 });
217 }
218}
219
220impl std::fmt::Display for TaskEvent {
221 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
222 match self {
223 TaskEvent::Enter {
224 name,
225 id,
226 parent_id,
227 } => {
228 write!(
229 f,
230 "Enter(name=\"{}\", id={:?}, parent_id={:?})",
231 name, id, parent_id
232 )
233 }
234 TaskEvent::Event {
235 name,
236 id,
237 parent_id,
238 started_at,
239 active_for,
240 message,
241 } => {
242 let wall_time = std::time::Instant::now()
243 .duration_since(*started_at)
244 .as_secs_f64();
245 write!(
246 f,
247 "Event(name=\"{}\", id={:?}, parent_id={:?}, started_at={:.3}s ago, active_for={:.3}s): {}",
248 name,
249 id,
250 parent_id,
251 wall_time,
252 active_for.as_secs_f64(),
253 message.as_ref().cloned().unwrap_or_default()
254 )
255 }
256 TaskEvent::Closed {
257 name,
258 id,
259 parent_id,
260 started_at,
261 closed_at,
262 active_for,
263 } => {
264 let wall_time = closed_at.duration_since(*started_at).as_secs_f64();
265 write!(
266 f,
267 "Closed(name=\"{}\", id={:?}, parent_id={:?}, wall_time={:.3} active_for={:.3}s)",
268 name,
269 id,
270 parent_id,
271 wall_time,
272 active_for.as_secs_f64(),
273 )
274 }
275 }
276 }
277}