1#![doc = include_str!("../README.md")]
2#![warn(missing_docs)]
3
4use std::sync::Arc;
12use std::time::Duration;
13
14use tailtriage_core::{unix_time_ms, RuntimeSnapshot, Tailtriage};
15use tokio::runtime::Handle;
16use tokio::sync::oneshot;
17use tokio::task::JoinHandle;
18
19#[must_use]
21pub const fn crate_name() -> &'static str {
22 "tailtriage-tokio"
23}
24
25#[derive(Debug, Clone, Copy, PartialEq, Eq)]
27pub enum SamplerStartError {
28 ZeroInterval,
30}
31
32impl std::fmt::Display for SamplerStartError {
33 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
34 match self {
35 Self::ZeroInterval => write!(f, "runtime sampling interval must be greater than zero"),
36 }
37 }
38}
39
40impl std::error::Error for SamplerStartError {}
41
42#[derive(Debug)]
44pub struct RuntimeSampler {
45 stop_tx: Option<oneshot::Sender<()>>,
46 task: JoinHandle<()>,
47}
48
49impl RuntimeSampler {
50 pub fn start(
60 tailtriage: Arc<Tailtriage>,
61 interval: Duration,
62 ) -> Result<Self, SamplerStartError> {
63 if interval.is_zero() {
64 return Err(SamplerStartError::ZeroInterval);
65 }
66
67 let handle = Handle::current();
68 let (stop_tx, mut stop_rx) = oneshot::channel();
69 let mut ticker = tokio::time::interval(interval);
70
71 let task = tokio::spawn(async move {
72 loop {
73 tokio::select! {
74 _ = &mut stop_rx => break,
75 _ = ticker.tick() => {
76 tailtriage.record_runtime_snapshot(capture_runtime_snapshot(&handle));
77 }
78 }
79 }
80 });
81
82 Ok(Self {
83 stop_tx: Some(stop_tx),
84 task,
85 })
86 }
87
88 pub async fn shutdown(mut self) {
90 if let Some(stop_tx) = self.stop_tx.take() {
91 let _ = stop_tx.send(());
92 }
93 let _ = self.task.await;
94 }
95}
96
97#[must_use]
99pub fn capture_runtime_snapshot(handle: &Handle) -> RuntimeSnapshot {
100 let metrics = handle.metrics();
101
102 #[cfg(tokio_unstable)]
103 let local_queue_depth = {
104 let worker_count: usize = metrics.num_workers();
105 (0..worker_count).try_fold(0_u64, |sum, worker| {
106 let worker_depth: u64 = metrics.worker_local_queue_depth(worker).try_into().ok()?;
107 sum.checked_add(worker_depth)
108 })
109 };
110
111 #[cfg(not(tokio_unstable))]
112 let local_queue_depth = None;
113
114 #[cfg(tokio_unstable)]
115 let blocking_queue_depth = u64::try_from(metrics.blocking_queue_depth()).ok();
116
117 #[cfg(not(tokio_unstable))]
118 let blocking_queue_depth = None;
119
120 #[cfg(tokio_unstable)]
121 let remote_schedule_count = Some(metrics.remote_schedule_count());
122
123 #[cfg(not(tokio_unstable))]
124 let remote_schedule_count = None;
125
126 RuntimeSnapshot {
127 at_unix_ms: unix_time_ms(),
128 alive_tasks: u64::try_from(metrics.num_alive_tasks()).ok(),
129 global_queue_depth: u64::try_from(metrics.global_queue_depth()).ok(),
130 local_queue_depth,
131 blocking_queue_depth,
132 remote_schedule_count,
133 }
134}
135
136#[cfg(test)]
137mod tests {
138 use std::sync::Arc;
139 use std::time::{Duration, SystemTime, UNIX_EPOCH};
140
141 use tailtriage_core::Tailtriage;
142
143 use super::crate_name;
144 use super::{RuntimeSampler, SamplerStartError};
145
146 #[test]
147 fn crate_name_is_stable() {
148 assert_eq!(crate_name(), "tailtriage-tokio");
149 }
150
151 #[tokio::test(flavor = "current_thread")]
152 async fn runtime_sampler_records_snapshots() {
153 let nanos = SystemTime::now()
154 .duration_since(UNIX_EPOCH)
155 .expect("system time before epoch")
156 .as_nanos();
157
158 let tailtriage = Arc::new(
159 Tailtriage::builder("runtime-test")
160 .output(std::env::temp_dir().join(format!("tailtriage_tokio_sampler_{nanos}.json")))
161 .build()
162 .expect("build should succeed"),
163 );
164 let sampler = RuntimeSampler::start(Arc::clone(&tailtriage), Duration::from_millis(5))
165 .expect("sampler should start");
166
167 tokio::time::sleep(Duration::from_millis(20)).await;
168 sampler.shutdown().await;
169
170 let snapshot = tailtriage.snapshot();
171 assert!(
172 !snapshot.runtime_snapshots.is_empty(),
173 "sampler should record runtime snapshots"
174 );
175
176 let first = &snapshot.runtime_snapshots[0];
177 assert!(first.alive_tasks.is_some());
178 assert!(first.global_queue_depth.is_some());
179 }
180
181 #[tokio::test(flavor = "current_thread")]
182 async fn runtime_sampler_rejects_zero_interval() {
183 let tailtriage = Arc::new(
184 Tailtriage::builder("runtime-test")
185 .output(std::env::temp_dir().join("tailtriage_tokio_zero_interval.json"))
186 .build()
187 .expect("build should succeed"),
188 );
189
190 let err = RuntimeSampler::start(tailtriage, Duration::ZERO)
191 .expect_err("zero interval should fail");
192 assert_eq!(err, SamplerStartError::ZeroInterval);
193 }
194
195 #[tokio::test(flavor = "current_thread")]
196 async fn unavailable_runtime_metrics_are_recorded_as_none() {
197 let snapshot = super::capture_runtime_snapshot(&tokio::runtime::Handle::current());
198
199 #[cfg(not(tokio_unstable))]
200 {
201 assert_eq!(snapshot.local_queue_depth, None);
202 assert_eq!(snapshot.blocking_queue_depth, None);
203 assert_eq!(snapshot.remote_schedule_count, None);
204 }
205 }
206}