Skip to main content

py_spy/
config.rs

1use clap::builder::{styling::AnsiColor, EnumValueParser, Styles};
2use clap::{
3    crate_description, crate_name, crate_version, value_parser, Arg, ArgAction, Command, ValueEnum,
4};
5use remoteprocess::Pid;
6
7/// Options on how to collect samples from a python process
8#[derive(Debug, Clone, PartialEq)]
9pub struct Config {
10    /// Whether or not we should stop the python process when taking samples.
11    /// Setting this to false will reduce the performance impact on the target
12    /// python process, but can lead to incorrect results like partial stack
13    /// traces being returned or a higher sampling error rate
14    pub blocking: LockingStrategy,
15
16    /// Whether or not to profile native extensions. Note: this option can not be
17    /// used with the nonblocking option, as we have to pause the process to collect
18    /// the native stack traces
19    pub native: bool,
20
21    // The following config options only apply when using py-spy as an application
22    #[doc(hidden)]
23    pub command: String,
24    #[doc(hidden)]
25    pub pid: Option<Pid>,
26    #[doc(hidden)]
27    pub python_program: Option<Vec<String>>,
28    #[doc(hidden)]
29    pub sampling_rate: u64,
30    #[doc(hidden)]
31    pub filename: Option<String>,
32    #[doc(hidden)]
33    pub format: Option<FileFormat>,
34    #[doc(hidden)]
35    pub show_line_numbers: bool,
36    #[doc(hidden)]
37    pub duration: RecordDuration,
38    #[doc(hidden)]
39    pub include_idle: bool,
40    #[doc(hidden)]
41    pub include_thread_ids: bool,
42    #[doc(hidden)]
43    pub subprocesses: bool,
44    #[doc(hidden)]
45    pub gil_only: bool,
46    #[doc(hidden)]
47    pub hide_progress: bool,
48    #[doc(hidden)]
49    pub capture_output: bool,
50    #[doc(hidden)]
51    pub dump_json: bool,
52    #[doc(hidden)]
53    pub dump_locals: u64,
54    #[doc(hidden)]
55    pub full_filenames: bool,
56    #[doc(hidden)]
57    pub lineno: LineNo,
58    #[doc(hidden)]
59    pub refresh_seconds: f64,
60    #[doc(hidden)]
61    pub core_filename: Option<String>,
62}
63
64#[allow(non_camel_case_types)]
65#[derive(ValueEnum, Debug, Copy, Clone, Eq, PartialEq)]
66pub enum FileFormat {
67    flamegraph,
68    raw,
69    speedscope,
70    chrometrace,
71}
72
73impl std::str::FromStr for FileFormat {
74    type Err = String;
75
76    fn from_str(s: &str) -> Result<Self, Self::Err> {
77        for variant in Self::value_variants() {
78            if variant.to_possible_value().unwrap().matches(s, false) {
79                return Ok(*variant);
80            }
81        }
82        Err(format!("Invalid fileformat: {s}"))
83    }
84}
85
86#[derive(Debug, Clone, Eq, PartialEq)]
87pub enum LockingStrategy {
88    NonBlocking,
89    #[allow(dead_code)]
90    AlreadyLocked,
91    Lock,
92}
93
94#[derive(Debug, Clone, Eq, PartialEq)]
95pub enum RecordDuration {
96    Unlimited,
97    Seconds(u64),
98}
99
100#[derive(Debug, Clone, Eq, PartialEq, Copy)]
101pub enum LineNo {
102    NoLine,
103    First,
104    LastInstruction,
105}
106
107impl Default for Config {
108    /// Initializes a new Config object with default parameters
109    #[allow(dead_code)]
110    fn default() -> Config {
111        Config {
112            pid: None,
113            python_program: None,
114            filename: None,
115            format: None,
116            command: String::from("top"),
117            blocking: LockingStrategy::Lock,
118            show_line_numbers: false,
119            sampling_rate: 100,
120            duration: RecordDuration::Unlimited,
121            native: false,
122            gil_only: false,
123            include_idle: false,
124            include_thread_ids: false,
125            hide_progress: false,
126            capture_output: true,
127            dump_json: false,
128            dump_locals: 0,
129            subprocesses: false,
130            full_filenames: false,
131            lineno: LineNo::LastInstruction,
132            refresh_seconds: 1.0,
133            core_filename: None,
134        }
135    }
136}
137
138impl Config {
139    /// Uses clap to set config options from commandline arguments
140    pub fn from_commandline() -> Config {
141        let args: Vec<String> = std::env::args().collect();
142        Config::from_args(&args).unwrap_or_else(|e| e.exit())
143    }
144
145    pub fn from_args(args: &[String]) -> clap::error::Result<Config> {
146        // pid/native/nonblocking/rate/python_program/subprocesses/full_filenames arguments can be
147        // used across various subcommand - define once here
148        let pid = Arg::new("pid")
149            .short('p')
150            .long("pid")
151            .value_name("pid")
152            .help("PID of a running python program to spy on, in decimal or hex")
153            .action(ArgAction::Set);
154
155        let mut native = Arg::new("native")
156            .short('n')
157            .long("native")
158            .help("Collect stack traces from native extensions written in Cython, C or C++")
159            .action(ArgAction::SetTrue);
160
161        // Only show `--native` on platforms where it's supported
162        if !cfg!(feature = "unwind") {
163            native = native.hide(true);
164        }
165
166        #[cfg(not(target_os="freebsd"))]
167        let nonblocking = Arg::new("nonblocking")
168                    .long("nonblocking")
169                    .help("Don't pause the python process when collecting samples. Setting this option will reduce \
170                          the performance impact of sampling, but may lead to inaccurate results")
171                    .action(ArgAction::SetTrue);
172
173        let rate = Arg::new("rate")
174            .short('r')
175            .long("rate")
176            .value_name("rate")
177            .help("The number of samples to collect per second")
178            .default_value("100")
179            .value_parser(value_parser!(u64))
180            .action(ArgAction::Set);
181
182        let subprocesses = Arg::new("subprocesses")
183            .short('s')
184            .long("subprocesses")
185            .help("Profile subprocesses of the original process")
186            .action(ArgAction::SetTrue);
187
188        let full_filenames = Arg::new("full_filenames")
189            .long("full-filenames")
190            .help("Show full Python filenames, instead of shortening to show only the package part")
191            .action(ArgAction::SetTrue);
192        let program = Arg::new("python_program")
193            .help("commandline of a python program to run")
194            .action(ArgAction::Append);
195
196        let idle = Arg::new("idle")
197            .short('i')
198            .long("idle")
199            .help("Include stack traces for idle threads")
200            .action(ArgAction::SetTrue);
201
202        let gil = Arg::new("gil")
203            .short('g')
204            .long("gil")
205            .help("Only include traces that are holding on to the GIL")
206            .action(ArgAction::SetTrue);
207
208        let top_delay = Arg::new("delay")
209            .long("delay")
210            .value_name("seconds")
211            .help("Delay between 'top' refreshes.")
212            .default_value("1.0")
213            .value_parser(clap::value_parser!(f64))
214            .action(ArgAction::Set);
215
216        let record = Command::new("record")
217            .about("Records stack trace information to a flamegraph, speedscope or raw file")
218            .arg(program.clone())
219            .arg(pid.clone().required_unless_present("python_program"))
220            .arg(full_filenames.clone())
221            .arg(
222                Arg::new("output")
223                    .short('o')
224                    .long("output")
225                    .value_name("filename")
226                    .help("Output filename")
227                    .action(ArgAction::Set)
228                    .required(false),
229            )
230            .arg(
231                Arg::new("format")
232                    .short('f')
233                    .long("format")
234                    .value_name("format")
235                    .help("Output file format")
236                    .action(ArgAction::Set)
237                    .value_parser(EnumValueParser::<FileFormat>::new())
238                    .ignore_case(true)
239                    .default_value("flamegraph"),
240            )
241            .arg(
242                Arg::new("duration")
243                    .short('d')
244                    .long("duration")
245                    .value_name("duration")
246                    .help("The number of seconds to sample for")
247                    .default_value("unlimited")
248                    .action(ArgAction::Set),
249            )
250            .arg(rate.clone())
251            .arg(subprocesses.clone())
252            .arg(Arg::new("function").short('F').long("function").help(
253                "Aggregate samples by function's first line number, instead of current line number",
254            ).action(ArgAction::SetTrue))
255            .arg(
256                Arg::new("nolineno")
257                    .long("nolineno")
258                    .help("Do not show line numbers")
259                    .action(ArgAction::SetTrue),
260            )
261            .arg(
262                Arg::new("threads")
263                    .short('t')
264                    .long("threads")
265                    .help("Show thread ids in the output")
266                    .action(ArgAction::SetTrue),
267            )
268            .arg(gil.clone())
269            .arg(idle.clone())
270            .arg(
271                Arg::new("capture")
272                    .long("capture")
273                    .hide(true)
274                    .help("Captures output from child process")
275                    .action(ArgAction::SetTrue),
276            )
277            .arg(
278                Arg::new("hideprogress")
279                    .long("hideprogress")
280                    .hide(true)
281                    .help("Hides progress bar (useful for showing error output on record)")
282                    .action(ArgAction::SetTrue),
283            );
284
285        let top = Command::new("top")
286            .about("Displays a top like view of functions consuming CPU")
287            .arg(program.clone())
288            .arg(pid.clone().required_unless_present("python_program"))
289            .arg(rate.clone())
290            .arg(subprocesses.clone())
291            .arg(full_filenames.clone())
292            .arg(gil.clone())
293            .arg(idle.clone())
294            .arg(top_delay.clone());
295
296        #[cfg(target_os = "linux")]
297        let dump_pid = pid.clone().required_unless_present("core");
298
299        #[cfg(not(target_os = "linux"))]
300        let dump_pid = pid.clone().required(true);
301
302        let dump = Command::new("dump")
303            .about("Dumps stack traces for a target program to stdout")
304            .arg(dump_pid);
305
306        #[cfg(target_os = "linux")]
307        let dump = dump.arg(
308            Arg::new("core")
309                .short('c')
310                .long("core")
311                .help("Filename of coredump to display python stack traces from")
312                .value_name("core")
313                .action(ArgAction::Set),
314        );
315
316        let dump = dump.arg(full_filenames.clone())
317            .arg(Arg::new("locals")
318                .short('l')
319                .long("locals")
320                .action(ArgAction::Count)
321                .help("Show local variables for each frame. Passing multiple times (-ll) increases verbosity"))
322            .arg(Arg::new("json")
323                .short('j')
324                .long("json")
325                .help("Format output as JSON")
326                .action(ArgAction::SetTrue))
327            .arg(subprocesses.clone());
328
329        let completions = Command::new("completions")
330            .about("Generate shell completions")
331            .hide(true)
332            .arg(
333                Arg::new("shell")
334                    .value_parser(value_parser!(clap_complete::Shell))
335                    .help("Shell type")
336                    .required(true)
337                    .action(ArgAction::Set),
338            );
339
340        let record = record.arg(native.clone());
341        let top = top.arg(native.clone());
342        let dump = dump.arg(native.clone());
343
344        // Nonblocking isn't an option for freebsd, remove
345        #[cfg(not(target_os = "freebsd"))]
346        let record = record.arg(nonblocking.clone());
347        #[cfg(not(target_os = "freebsd"))]
348        let top = top.arg(nonblocking.clone());
349        #[cfg(not(target_os = "freebsd"))]
350        let dump = dump.arg(nonblocking.clone());
351
352        let styles = Styles::styled()
353            .header(AnsiColor::Yellow.on_default())
354            .usage(AnsiColor::Yellow.on_default())
355            .literal(AnsiColor::Green.on_default())
356            .placeholder(AnsiColor::Green.on_default());
357
358        let mut app = Command::new(crate_name!())
359            .version(crate_version!())
360            .about(crate_description!())
361            .subcommand_required(true)
362            .infer_subcommands(true)
363            .arg_required_else_help(true)
364            .styles(styles)
365            .subcommand(record)
366            .subcommand(top)
367            .subcommand(dump)
368            .subcommand(completions);
369        let matches = app.clone().try_get_matches_from(args)?;
370        debug!("Command line args: {:?}", matches);
371
372        let mut config = Config::default();
373
374        let (subcommand, matches) = matches.subcommand().unwrap();
375
376        // Check if `--native` was used on an unsupported platform
377        if subcommand != "completions" && !cfg!(feature = "unwind") && matches.get_flag("native") {
378            eprintln!(
379                "Collecting stack traces from native extensions (`--native`) is not supported on your platform."
380            );
381            std::process::exit(1);
382        }
383
384        match subcommand {
385            "record" => {
386                config.sampling_rate = *matches.get_one("rate").unwrap();
387                config.duration = match matches.get_one::<String>("duration").map(|d| d.as_str()) {
388                    Some("unlimited") | None => RecordDuration::Unlimited,
389                    Some(seconds) => {
390                        RecordDuration::Seconds(seconds.parse().expect("invalid duration"))
391                    }
392                };
393                config.format = matches.get_one("format").copied();
394                config.filename = matches.get_one::<String>("output").cloned();
395                config.show_line_numbers = !matches.get_flag("nolineno");
396                config.lineno = if matches.get_flag("nolineno") {
397                    LineNo::NoLine
398                } else if matches.get_flag("function") {
399                    LineNo::First
400                } else {
401                    LineNo::LastInstruction
402                };
403                config.include_thread_ids = matches.get_flag("threads");
404                if matches.get_flag("nolineno") && matches.get_flag("function") {
405                    eprintln!("--function & --nolinenos can't be used together");
406                    std::process::exit(1);
407                }
408                config.hide_progress = matches.get_flag("hideprogress");
409            }
410            "top" => {
411                config.sampling_rate = *matches.get_one("rate").unwrap();
412                config.refresh_seconds = *matches.get_one::<f64>("delay").unwrap();
413            }
414            "dump" => {
415                config.dump_json = matches.get_flag("json");
416                config.dump_locals = matches.get_count("locals").into();
417
418                #[cfg(target_os = "linux")]
419                {
420                    config.core_filename = matches.get_one("core").cloned();
421                }
422            }
423            "completions" => {
424                let shell = matches.get_one::<clap_complete::Shell>("shell").unwrap();
425                let app_name = app.get_name().to_string();
426                clap_complete::generate(*shell, &mut app, app_name, &mut std::io::stdout());
427                std::process::exit(0);
428            }
429            _ => {}
430        }
431
432        match subcommand {
433            "record" | "top" => {
434                config.python_program = matches
435                    .get_many::<String>("python_program")
436                    .map(|vals| vals.map(|v| v.to_owned()).collect());
437                config.gil_only = matches.get_flag("gil");
438                config.include_idle = matches.get_flag("idle");
439            }
440            _ => {}
441        }
442
443        config.subprocesses = matches.get_flag("subprocesses");
444        config.command = subcommand.to_owned();
445
446        // options that can be shared between subcommands
447        config.pid = matches.get_one::<String>("pid").map(|p| {
448            // allow pid to be specified as a hexadecimal value
449            match p.to_lowercase().strip_prefix("0x") {
450                Some(prefix) => Pid::from_str_radix(prefix, 16).expect("invalid pid"),
451                None => p.parse().expect("invalid pid"),
452            }
453        });
454
455        config.full_filenames = matches.get_flag("full_filenames");
456        if cfg!(feature = "unwind") {
457            config.native = matches.get_flag("native");
458        }
459
460        config.capture_output = config.command != "record" || matches.get_flag("capture");
461        if !config.capture_output {
462            config.hide_progress = true;
463        }
464
465        if matches.get_flag("nonblocking") {
466            // disable native profiling if invalidly asked for
467            if config.native {
468                eprintln!("Can't get native stack traces with the --nonblocking option.");
469                std::process::exit(1);
470            }
471            config.blocking = LockingStrategy::NonBlocking;
472        }
473
474        #[cfg(windows)]
475        {
476            if config.native && config.subprocesses {
477                // the native extension profiling code relies on dbghelp library, which doesn't
478                // seem to work when connecting to multiple processes. disallow
479                eprintln!(
480                    "Can't get native stack traces with the ---subprocesses option on windows."
481                );
482                std::process::exit(1);
483            }
484        }
485
486        #[cfg(target_os = "freebsd")]
487        {
488            if config.pid.is_some() {
489                if std::env::var("PYSPY_ALLOW_FREEBSD_ATTACH").is_err() {
490                    eprintln!("On FreeBSD, running py-spy can cause an exception in the profiled process if the process \
491                        is calling 'socket.connect'.");
492                    eprintln!("While this is fixed in recent versions of python, you need to acknowledge the risk here by \
493                        setting an environment variable PYSPY_ALLOW_FREEBSD_ATTACH to run this command.");
494                    eprintln!(
495                        "\nSee https://github.com/benfred/py-spy/issues/147 for more information"
496                    );
497                    std::process::exit(-1);
498                }
499            }
500        }
501        info!("config {:#?}", config);
502        Ok(config)
503    }
504}
505
506#[cfg(test)]
507mod tests {
508    use super::*;
509    fn get_config(cmd: &str) -> clap::error::Result<Config> {
510        #[cfg(target_os = "freebsd")]
511        std::env::set_var("PYSPY_ALLOW_FREEBSD_ATTACH", "1");
512        let args: Vec<String> = cmd.split_whitespace().map(|x| x.to_owned()).collect();
513        Config::from_args(&args)
514    }
515
516    #[test]
517    fn test_parse_record_args() {
518        // basic use case
519        let config = get_config("py-spy record --pid 1234 --output foo").unwrap();
520        assert_eq!(config.pid, Some(1234));
521        assert_eq!(config.filename, Some(String::from("foo")));
522        assert_eq!(config.format, Some(FileFormat::flamegraph));
523        assert_eq!(config.command, String::from("record"));
524
525        // same command using short versions of everything
526        let short_config = get_config("py-spy r -p 1234 -o foo").unwrap();
527        assert_eq!(config, short_config);
528
529        // missing the --pid argument should fail
530        assert_eq!(
531            get_config("py-spy record -o foo").unwrap_err().kind(),
532            clap::error::ErrorKind::MissingRequiredArgument
533        );
534
535        // but should work when passed a python program
536        let program_config = get_config("py-spy r -o foo -- python test.py").unwrap();
537        assert_eq!(
538            program_config.python_program,
539            Some(vec![String::from("python"), String::from("test.py")])
540        );
541        assert_eq!(program_config.pid, None);
542
543        // passing an invalid file format should fail
544        assert_eq!(
545            get_config("py-spy r -p 1234 -o foo -f unknown")
546                .unwrap_err()
547                .kind(),
548            clap::error::ErrorKind::InvalidValue
549        );
550
551        // test out overriding these params by setting flags
552        assert_eq!(config.include_idle, false);
553        assert_eq!(config.gil_only, false);
554        assert_eq!(config.include_thread_ids, false);
555
556        let config_flags = get_config("py-spy r -p 1234 -o foo --idle --gil --threads").unwrap();
557        assert_eq!(config_flags.include_idle, true);
558        assert_eq!(config_flags.gil_only, true);
559        assert_eq!(config_flags.include_thread_ids, true);
560    }
561
562    #[test]
563    fn test_parse_dump_args() {
564        // basic use case
565        let config = get_config("py-spy dump --pid 1234").unwrap();
566        assert_eq!(config.pid, Some(1234));
567        assert_eq!(config.command, String::from("dump"));
568
569        // short version
570        let short_config = get_config("py-spy d -p 1234").unwrap();
571        assert_eq!(config, short_config);
572
573        // missing the --pid argument should fail
574        assert_eq!(
575            get_config("py-spy dump").unwrap_err().kind(),
576            clap::error::ErrorKind::MissingRequiredArgument
577        );
578    }
579
580    #[test]
581    fn test_parse_top_args() {
582        // basic use case
583        let config = get_config("py-spy top --pid 1234").unwrap();
584        assert_eq!(config.pid, Some(1234));
585        assert_eq!(config.command, String::from("top"));
586
587        // short version
588        let short_config = get_config("py-spy t -p 1234").unwrap();
589        assert_eq!(config, short_config);
590    }
591
592    #[test]
593    fn test_parse_args() {
594        assert_eq!(
595            get_config("py-spy dude").unwrap_err().kind(),
596            clap::error::ErrorKind::InvalidSubcommand
597        );
598    }
599}