pyroscope_rbspy_oncpu/sampler/
mod.rs

1use anyhow::{Context, Error, Result};
2use std::collections::HashSet;
3use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
4use std::sync::mpsc::{Sender, SyncSender};
5use std::sync::Arc;
6use std::time::{Duration, Instant};
7#[cfg(windows)]
8use winapi::um::timeapi;
9
10use crate::core::process::{Pid, Process, ProcessRetry};
11use crate::core::types::{MemoryCopyError, StackTrace};
12
13#[derive(Debug)]
14pub struct Sampler {
15    done: Arc<AtomicBool>,
16    lock_process: bool,
17    root_pid: Pid,
18    sample_rate: u32,
19    time_limit: Option<Duration>,
20    timing_error_traces: Arc<AtomicUsize>,
21    total_traces: Arc<AtomicUsize>,
22    with_subprocesses: bool,
23    force_version: Option<String>,
24    on_cpu: bool,
25}
26
27impl Sampler {
28    pub fn new(
29        pid: Pid,
30        sample_rate: u32,
31        lock_process: bool,
32        time_limit: Option<Duration>,
33        with_subprocesses: bool,
34        force_version: Option<String>,
35        on_cpu: bool,
36    ) -> Self {
37        Sampler {
38            done: Arc::new(AtomicBool::new(false)),
39            lock_process,
40            root_pid: pid,
41            sample_rate,
42            time_limit,
43            timing_error_traces: Arc::new(AtomicUsize::new(0)),
44            total_traces: Arc::new(AtomicUsize::new(0)),
45            with_subprocesses,
46            force_version,
47            on_cpu,
48        }
49    }
50
51    pub fn total_traces(&self) -> usize {
52        self.total_traces.load(Ordering::Relaxed)
53    }
54
55    pub fn timing_error_traces(&self) -> usize {
56        self.timing_error_traces.load(Ordering::Relaxed)
57    }
58
59    /// Start thread(s) recording a PID and possibly its children. Tracks new processes
60    /// Returns a pair of Receivers from which you can consume recorded stacktraces and errors
61    pub fn start(
62        &self,
63        trace_sender: SyncSender<StackTrace>,
64        result_sender: Sender<Result<(), Error>>,
65    ) -> Result<(), Error> {
66        let done = self.done.clone();
67        let root_pid = self.root_pid.clone();
68        let sample_rate = self.sample_rate.clone();
69        let maybe_stop_time = match self.time_limit {
70            Some(duration) => Some(std::time::Instant::now() + duration),
71            None => None,
72        };
73        let lock_process = self.lock_process.clone();
74        let force_version = self.force_version.clone();
75        let on_cpu = self.on_cpu.clone();
76        let result_sender = result_sender.clone();
77        let timing_error_traces = self.timing_error_traces.clone();
78        let total_traces = self.total_traces.clone();
79
80        if self.with_subprocesses {
81            // Start a thread which watches for new descendents and starts new recorders when they
82            // appear
83            let done_clone = self.done.clone();
84            std::thread::spawn(move || {
85                let process = Process::new_with_retry(root_pid)
86                    .expect("couldn't attach to process (is it running?)");
87                let mut pids: HashSet<Pid> = HashSet::new();
88                // we need to exit this loop when the process we're monitoring exits, otherwise the
89                // sender channels won't get closed and rbspy will hang. So we check the done
90                // mutex.
91                while !done_clone.load(Ordering::Relaxed) {
92                    let mut descendents: Vec<Pid> = process
93                        .child_processes()
94                        .expect("Error finding descendents of pid")
95                        .into_iter()
96                        .map(|tuple| tuple.0)
97                        .collect();
98                    descendents.push(root_pid);
99
100                    for pid in descendents {
101                        if pids.contains(&pid) {
102                            // already recording it, no need to start a new recording thread
103                            continue;
104                        }
105                        pids.insert(pid);
106                        let done_root = done.clone();
107                        let done_thread = done.clone();
108                        let result_sender = result_sender.clone();
109                        let timing_error_traces = timing_error_traces.clone();
110                        let total_traces = total_traces.clone();
111                        let trace_sender_clone = trace_sender.clone();
112                        let force_version = force_version.clone();
113                        let on_cpu = on_cpu.clone();
114
115                        std::thread::spawn(move || {
116                            let result = sample(
117                                pid,
118                                sample_rate,
119                                maybe_stop_time,
120                                done_thread,
121                                timing_error_traces,
122                                total_traces,
123                                trace_sender_clone,
124                                lock_process,
125                                force_version,
126                                on_cpu,
127                            );
128                            result_sender.send(result).expect("couldn't send error");
129                            drop(result_sender);
130
131                            if pid == root_pid {
132                                debug!("Root process {} ended", pid);
133                                // we need to store done = true here to signal the other threads here that we
134                                // should stop profiling
135                                done_root.store(true, Ordering::Relaxed);
136                            }
137                        });
138                    }
139                    // TODO: Parameterize subprocess check interval
140                    std::thread::sleep(Duration::from_secs(1));
141                }
142            });
143        } else {
144            // Start a single recorder thread
145            std::thread::spawn(move || {
146                let result = sample(
147                    root_pid,
148                    sample_rate,
149                    maybe_stop_time,
150                    done,
151                    timing_error_traces,
152                    total_traces,
153                    trace_sender,
154                    lock_process,
155                    force_version,
156                    on_cpu,
157                );
158                result_sender.send(result).unwrap();
159                drop(result_sender);
160            });
161        }
162
163        return Ok(());
164    }
165
166    pub fn stop(&self) {
167        self.done.store(true, Ordering::Relaxed);
168    }
169}
170
171/// Samples stack traces and sends them to a channel in another thread where they can be aggregated
172fn sample(
173    pid: Pid,
174    sample_rate: u32,
175    maybe_stop_time: Option<Instant>,
176    done: Arc<AtomicBool>,
177    timing_error_traces: Arc<AtomicUsize>,
178    total_traces: Arc<AtomicUsize>,
179    sender: SyncSender<StackTrace>,
180    lock_process: bool,
181    force_version: Option<String>,
182    on_cpu: bool,
183) -> Result<(), Error> {
184    let mut process =
185        crate::core::ruby_spy::RubySpy::retry_new(pid, 10, force_version).context("new spy")?;
186
187    let mut total = 0;
188    let mut errors = 0;
189
190    let mut sample_time = SampleTime::new(sample_rate);
191    #[cfg(windows)]
192    {
193        // This changes a system-wide setting on Windows so that the OS wakes up every 1ms
194        // instead of the default 15.6ms. This is required to have a sleep call
195        // take less than 15ms, which we need since we usually profile at more than 64hz.
196        // The downside is that this will increase power usage: good discussions are:
197        // https://randomascii.wordpress.com/2013/07/08/windows-timer-resolution-megawatts-wasted/
198        // and http://www.belshe.com/2010/06/04/chrome-cranking-up-the-clock/
199        unsafe {
200            timeapi::timeBeginPeriod(1);
201        }
202    }
203
204    while !done.load(Ordering::Relaxed) {
205        total += 1;
206        let trace = process.get_stack_trace(lock_process, on_cpu);
207        match trace {
208            Ok(Some(ok_trace)) => {
209                sender.send(ok_trace).context("send trace")?;
210            }
211            Ok(None) => {}
212            Err(e) => {
213                if let Some(MemoryCopyError::ProcessEnded) = e.downcast_ref() {
214                    debug!("Process {} ended", pid);
215                    return Ok(());
216                }
217
218                errors += 1;
219                if errors > 20 && (errors as f64) / (total as f64) > 0.5 {
220                    // TODO: Return error type instead of printing here
221                    print_errors(errors, total);
222                    return Err(e);
223                }
224            }
225        }
226        if let Some(stop_time) = maybe_stop_time {
227            if std::time::Instant::now() > stop_time {
228                // need to store done for same reason as above
229                done.store(true, Ordering::Relaxed);
230                break;
231            }
232        }
233        // Sleep until the next expected sample time
234        total_traces.fetch_add(1, Ordering::Relaxed);
235        match sample_time.get_sleep_time() {
236            Ok(sleep_time) => {
237                std::thread::sleep(std::time::Duration::new(0, sleep_time));
238            }
239            Err(_) => {
240                timing_error_traces.fetch_add(1, Ordering::Relaxed);
241            }
242        }
243    }
244
245    // reset time period calls
246    #[cfg(windows)]
247    {
248        unsafe {
249            timeapi::timeEndPeriod(1);
250        }
251    }
252    Ok(())
253}
254
255fn print_errors(errors: usize, total: usize) {
256    if errors > 0 {
257        eprintln!(
258            "Dropped {}/{} stack traces because of errors. Please consider reporting a GitHub issue -- this isn't normal.",
259            errors,
260            total
261        );
262    }
263}
264
265// This SampleTime struct helps us sample on a regular schedule ("exactly" 100 times per second, if
266// the sample rate is 100).
267// What we do is -- when doing the 1234th sample, we calculate the exact time the 1234th sample
268// should happen at, which is (start time + nanos_between_samples * 1234) and then sleep until that
269// time
270struct SampleTime {
271    start_time: Instant,
272    nanos_between_samples: u64,
273    num_samples: u64,
274}
275
276const BILLION: u64 = 1000 * 1000 * 1000; // for nanosleep
277
278impl SampleTime {
279    pub fn new(rate: u32) -> SampleTime {
280        SampleTime {
281            start_time: Instant::now(),
282            nanos_between_samples: BILLION / u64::from(rate),
283            num_samples: 0,
284        }
285    }
286
287    pub fn get_sleep_time(&mut self) -> Result<u32, u32> {
288        // Returns either the amount of time to sleep (Ok(x)) until next sample time or an error of
289        // how far we're behind if we're behind the expected next sample time
290        self.num_samples += 1;
291        let elapsed = self.start_time.elapsed();
292        let nanos_elapsed = elapsed.as_secs() * BILLION + u64::from(elapsed.subsec_nanos());
293        let target_elapsed = self.num_samples * self.nanos_between_samples;
294        if target_elapsed < nanos_elapsed {
295            Err((nanos_elapsed - target_elapsed) as u32)
296        } else {
297            Ok((target_elapsed - nanos_elapsed) as u32)
298        }
299    }
300}
301
302#[cfg(test)]
303mod tests {
304    #[cfg(not(target_os = "windows"))]
305    use std::collections::HashSet;
306    #[cfg(unix)]
307    use std::process::Command;
308
309    use crate::core::process::{tests::RubyScript, Pid};
310    use crate::sampler::Sampler;
311
312    #[test]
313    fn test_sample_single_process() {
314        #[cfg(target_os = "macos")]
315        if !nix::unistd::Uid::effective().is_root() {
316            println!("Skipping test because we're not running as root");
317            return;
318        }
319
320        let mut process = RubyScript::new("ci/ruby-programs/infinite.rb");
321        let pid = process.id() as Pid;
322
323        let sampler = Sampler::new(pid, 100, true, None, false, None);
324        let (trace_sender, trace_receiver) = std::sync::mpsc::sync_channel(100);
325        let (result_sender, result_receiver) = std::sync::mpsc::channel();
326        sampler
327            .start(trace_sender, result_sender)
328            .expect("sampler failed to start");
329
330        let trace = trace_receiver.recv().expect("failed to receive trace");
331        assert_eq!(trace.pid.unwrap(), pid);
332
333        process.kill().expect("failed to kill process");
334
335        let result = result_receiver.recv().expect("failed to receive result");
336        result.expect("unexpected error");
337    }
338
339    #[test]
340    fn test_sample_single_process_with_time_limit() {
341        #[cfg(target_os = "macos")]
342        if !nix::unistd::Uid::effective().is_root() {
343            println!("Skipping test because we're not running as root");
344            return;
345        }
346
347        let mut process = RubyScript::new("ci/ruby-programs/infinite.rb");
348        let pid = process.id() as Pid;
349
350        let sampler = Sampler::new(
351            pid,
352            100,
353            true,
354            Some(std::time::Duration::from_millis(500)),
355            false,
356            None,
357        );
358        let (trace_sender, trace_receiver) = std::sync::mpsc::sync_channel(100);
359        let (result_sender, result_receiver) = std::sync::mpsc::channel();
360        sampler
361            .start(trace_sender, result_sender)
362            .expect("sampler failed to start");
363
364        for trace in trace_receiver {
365            assert_eq!(trace.pid.unwrap(), pid);
366        }
367
368        // At this point the sampler has halted, so we can kill the process
369        process.kill().expect("failed to kill process");
370
371        let result = result_receiver.recv().expect("failed to receive result");
372        result.expect("unexpected error");
373    }
374
375    // TODO: Find a more reliable way to test this on Windows hosts
376    #[cfg(not(target_os = "windows"))]
377    #[test]
378    fn test_sample_subprocesses() {
379        #[cfg(target_os = "macos")]
380        if !nix::unistd::Uid::effective().is_root() {
381            println!("Skipping test because we're not running as root");
382            return;
383        }
384
385        let which = if cfg!(target_os = "windows") {
386            "C:\\Windows\\System32\\WHERE.exe"
387        } else {
388            "/usr/bin/which"
389        };
390
391        let output = Command::new(which)
392            .arg("ruby")
393            .output()
394            .expect("failed to execute process");
395
396        let ruby_binary_path = String::from_utf8(output.stdout).unwrap();
397
398        let ruby_binary_path_str = ruby_binary_path
399            .lines()
400            .next()
401            .expect("failed to execute ruby process");
402
403        let coordination_dir = tempfile::tempdir().unwrap();
404        let coordination_dir_name = coordination_dir.path().to_str().unwrap();
405
406        let mut process = Command::new(ruby_binary_path_str)
407            .arg("ci/ruby-programs/ruby_forks.rb")
408            .arg(coordination_dir_name)
409            .spawn()
410            .unwrap();
411        let pid = process.id() as Pid;
412
413        let sampler = Sampler::new(pid, 5, true, None, true, None);
414        let (trace_sender, trace_receiver) = std::sync::mpsc::sync_channel(100);
415        let (result_sender, result_receiver) = std::sync::mpsc::channel();
416        sampler
417            .start(trace_sender, result_sender)
418            .expect("sampler failed to start");
419
420        let mut pids = HashSet::<Pid>::new();
421        for trace in trace_receiver {
422            let pid = trace.pid.unwrap();
423            if !pids.contains(&pid) {
424                // Now that we have a stack trace for this PID, signal to the corresponding
425                // ruby process that it can exit
426                let coordination_filename = format!("rbspy_ack.{}", pid);
427                std::fs::File::create(coordination_dir.path().join(coordination_filename.clone()))
428                    .expect("couldn't create coordination file");
429                pids.insert(pid);
430            }
431        }
432
433        let _ = process.wait();
434
435        let results: Vec<_> = result_receiver.iter().take(4).collect();
436        for r in results {
437            r.expect("unexpected error");
438        }
439
440        assert_eq!(pids.len(), 4);
441    }
442}