py_spy_for_datakit/
config.rs

1use clap::{ArgEnum, Arg, Command, crate_description, crate_name, crate_version, PossibleValue, value_parser};
2use remoteprocess::Pid;
3
4/// Options on how to collect samples from a python process
5#[derive(Debug, Clone, Eq, PartialEq)]
6pub struct Config {
7    /// Whether or not we should stop the python process when taking samples.
8    /// Setting this to false will reduce the performance impact on the target
9    /// python process, but can lead to incorrect results like partial stack
10    /// traces being returned or a higher sampling error rate
11    pub blocking: LockingStrategy,
12
13    /// Whether or not to profile native extensions. Note: this option can not be
14    /// used with the nonblocking option, as we have to pause the process to collect
15    /// the native stack traces
16    pub native: bool,
17
18    // The following config options only apply when using py-spy as an application
19    #[doc(hidden)]
20    pub command: String,
21    #[doc(hidden)]
22    pub pid: Option<Pid>,
23    #[doc(hidden)]
24    pub python_program: Option<Vec<String>>,
25    #[doc(hidden)]
26    pub sampling_rate: u64,
27    #[doc(hidden)]
28    pub filename: Option<String>,
29    #[doc(hidden)]
30    pub format: Option<FileFormat>,
31    #[doc(hidden)]
32    pub show_line_numbers: bool,
33    #[doc(hidden)]
34    pub duration: RecordDuration,
35    #[doc(hidden)]
36    pub include_idle: bool,
37    #[doc(hidden)]
38    pub include_thread_ids: bool,
39    #[doc(hidden)]
40    pub subprocesses: bool,
41    #[doc(hidden)]
42    pub gil_only: bool,
43    #[doc(hidden)]
44    pub hide_progress: bool,
45    #[doc(hidden)]
46    pub capture_output: bool,
47    #[doc(hidden)]
48    pub dump_json: bool,
49    #[doc(hidden)]
50    pub dump_locals: u64,
51    #[doc(hidden)]
52    pub full_filenames: bool,
53    #[doc(hidden)]
54    pub lineno: LineNo,
55
56    #[doc(hidden)]
57    pub host: String,
58    pub port: u32,
59    pub service: String,
60    pub env: String,
61    pub version: String,
62    pub loop_duration: u64,
63}
64
65#[allow(non_camel_case_types)]
66#[derive(ArgEnum, Debug, Copy, Clone, Eq, PartialEq)]
67pub enum FileFormat {
68    flamegraph,
69    raw,
70    speedscope
71}
72
73impl FileFormat {
74    pub fn possible_values() -> impl Iterator<Item = PossibleValue<'static>> {
75        FileFormat::value_variants()
76            .iter()
77            .filter_map(ArgEnum::to_possible_value)
78    }
79}
80
81impl std::str::FromStr for FileFormat {
82    type Err = String;
83
84    fn from_str(s: &str) -> Result<Self, Self::Err> {
85        for variant in Self::value_variants() {
86            if variant.to_possible_value().unwrap().matches(s, false) {
87                return Ok(*variant);
88            }
89        }
90        Err(format!("Invalid fileformat: {}", s))
91    }
92}
93
94
95
96#[derive(Debug, Clone, Eq, PartialEq)]
97pub enum LockingStrategy {
98    NonBlocking,
99    #[allow(dead_code)]
100    AlreadyLocked,
101    Lock
102}
103
104#[derive(Debug, Clone, Eq, PartialEq)]
105pub enum RecordDuration {
106    Unlimited,
107    Seconds(u64)
108}
109
110#[derive(Debug, Clone, Eq, PartialEq, Copy)]
111pub enum LineNo {
112    NoLine,
113    FirstLineNo,
114    LastInstruction
115}
116
117impl Default for Config {
118    /// Initializes a new Config object with default parameters
119    #[allow(dead_code)]
120    fn default() -> Config {
121        Config{pid: None, python_program: None, filename: None, format: None,
122               command: String::from("top"),
123               blocking: LockingStrategy::Lock, show_line_numbers: false, sampling_rate: 100,
124               duration: RecordDuration::Unlimited, native: false,
125               gil_only: false, include_idle: false, include_thread_ids: false,
126               hide_progress: false, capture_output: true, dump_json: false, dump_locals: 0, subprocesses: false,
127               full_filenames: false, lineno: LineNo::LastInstruction,
128               host: String::from("127.0.0.1"),
129               port: 9529,
130               service: String::from("unnamed-service"),
131               env: String::from("unnamed-env"),
132               version: String::from("unnamed-version"),
133               loop_duration: 0,
134        }
135    }
136}
137
138impl Config {
139    /// Uses clap to set config options from commandline arguments
140    pub fn from_commandline() -> Config {
141        let args: Vec<String> = std::env::args().collect();
142        Config::from_args(&args).unwrap_or_else( |e| e.exit() )
143    }
144
145    pub fn from_args(args: &[String]) -> clap::Result<Config> {
146        // pid/native/nonblocking/rate/python_program/subprocesses/full_filenames arguments can be
147        // used across various subcommand - define once here
148        let pid = Arg::new("pid")
149                    .short('p')
150                    .long("pid")
151                    .value_name("pid")
152                    .help("PID of a running python program to spy on")
153                    .takes_value(true);
154
155        #[cfg(unwind)]
156        let native = Arg::new("native")
157                    .short('n')
158                    .long("native")
159                    .help("Collect stack traces from native extensions written in Cython, C or C++");
160
161        #[cfg(not(target_os="freebsd"))]
162        let nonblocking = Arg::new("nonblocking")
163                    .long("nonblocking")
164                    .help("Don't pause the python process when collecting samples. Setting this option will reduce \
165                          the performance impact of sampling, but may lead to inaccurate results");
166
167        let rate = Arg::new("rate")
168                    .short('r')
169                    .long("rate")
170                    .value_name("rate")
171                    .help("The number of samples to collect per second")
172                    .default_value("100")
173                    .takes_value(true);
174
175        let subprocesses = Arg::new("subprocesses")
176                            .short('s')
177                            .long("subprocesses")
178                            .help("Profile subprocesses of the original process");
179
180        let full_filenames = Arg::new("full_filenames")
181                                .long("full-filenames")
182                                .help("Show full Python filenames, instead of shortening to show only the package part");
183        let program = Arg::new("python_program")
184                    .help("commandline of a python program to run")
185                    .multiple_values(true);
186
187        let idle = Arg::new("idle")
188                .short('i')
189                .long("idle")
190                .help("Include stack traces for idle threads");
191
192        let gil = Arg::new("gil")
193                .short('g')
194                .long("gil")
195                .help("Only include traces that are holding on to the GIL");
196        let datakit = Command::new("datakit")
197            .about("It is almost same with \"record\" subcommand except that py-spy will send the output to datakit at intervals instead of writing into a local file only once")
198            .arg(Arg::new("host")
199                .short('H')
200                .long("host")
201                .value_name("host")
202                .help("The target datakit host")
203                .takes_value(true)
204                .default_value("127.0.0.1"))
205            .arg(Arg::new("port")
206                .short('P')
207                .long("port")
208                .value_name("port")
209                .help("The target datakit port")
210                .takes_value(true)
211                .default_value("9529"))
212            .arg(Arg::new("service")
213                .short('S')
214                .long("service")
215                .value_name("service")
216                .help("Your service name")
217                .takes_value(true)
218                .default_value("unnamed-service"))
219            .arg(Arg::new("env")
220                .short('E')
221                .long("env")
222                .value_name("env")
223                .help("Your deployment env, eg: dev, testing, prod...")
224                .takes_value(true)
225                .default_value("unnamed-env"))
226            .arg(Arg::new("version")
227                .short('V')
228                .long("version")
229                .value_name("version")
230                .help("Your service version")
231                .takes_value(true)
232                .default_value("unnamed-version")
233            )
234            .arg(program.clone())
235            .arg(pid.clone().required_unless_present("python_program"))
236            .arg(Arg::new("duration")
237                .short('d')
238                .long("duration")
239                .value_name("duration")
240                .help("The number of seconds to sample for")
241                .default_value("60")
242                .takes_value(true))
243            .arg(Arg::new("loop")
244                .short('L')
245                .long("loop")
246                .value_name("loop")
247                .help("continuously run profiler in a loop within the specified seconds, 0 represents infinity")
248                .default_value("0")
249                .takes_value(true)
250            )
251            .arg(rate.clone())
252            .arg(subprocesses.clone())
253            .arg(gil.clone())
254            .arg(idle.clone())
255            .arg(Arg::new("capture")
256                .long("capture")
257                .hide(true)
258                .help("Captures output from child process"));
259        let record = Command::new("record")
260            .about("Records stack trace information to a flamegraph, speedscope or raw file")
261            .arg(program.clone())
262            .arg(pid.clone().required_unless_present("python_program"))
263            .arg(full_filenames.clone())
264            .arg(Arg::new("output")
265                .short('o')
266                .long("output")
267                .value_name("filename")
268                .help("Output filename")
269                .takes_value(true)
270                .required(false))
271            .arg(Arg::new("format")
272                .short('f')
273                .long("format")
274                .value_name("format")
275                .help("Output file format")
276                .takes_value(true)
277                .possible_values(FileFormat::possible_values())
278                .ignore_case(true)
279                .default_value("flamegraph"))
280            .arg(Arg::new("duration")
281                .short('d')
282                .long("duration")
283                .value_name("duration")
284                .help("The number of seconds to sample for")
285                .default_value("unlimited")
286                .takes_value(true))
287            .arg(rate.clone())
288            .arg(subprocesses.clone())
289            .arg(Arg::new("function")
290                .short('F')
291                .long("function")
292                .help("Aggregate samples by function's first line number, instead of current line number"))
293            .arg(Arg::new("nolineno")
294                .long("nolineno")
295                .help("Do not show line numbers"))
296            .arg(Arg::new("threads")
297                .short('t')
298                .long("threads")
299                .help("Show thread ids in the output"))
300            .arg(gil.clone())
301            .arg(idle.clone())
302            .arg(Arg::new("capture")
303                .long("capture")
304                .hide(true)
305                .help("Captures output from child process"))
306            .arg(Arg::new("hideprogress")
307                .long("hideprogress")
308                .hide(true)
309                .help("Hides progress bar (useful for showing error output on record)"));
310
311        let top = Command::new("top")
312            .about("Displays a top like view of functions consuming CPU")
313            .arg(program.clone())
314            .arg(pid.clone().required_unless_present("python_program"))
315            .arg(rate.clone())
316            .arg(subprocesses.clone())
317            .arg(full_filenames.clone())
318            .arg(gil.clone())
319            .arg(idle.clone());
320
321        let dump = Command::new("dump")
322            .about("Dumps stack traces for a target program to stdout")
323            .arg(pid.clone().required(true))
324            .arg(full_filenames.clone())
325            .arg(Arg::new("locals")
326                .short('l')
327                .long("locals")
328                .multiple_occurrences(true)
329                .help("Show local variables for each frame. Passing multiple times (-ll) increases verbosity"))
330            .arg(Arg::new("json")
331                .short('j')
332                .long("json")
333                .help("Format output as JSON"))
334            .arg(subprocesses.clone());
335
336        let completions = Command::new("completions")
337            .about("Generate shell completions")
338            .hide(true)
339            .arg(Arg::new("shell")
340                .value_parser(value_parser!(clap_complete::Shell))
341                .help("Shell type"));
342
343        // add native unwinding if appropriate
344        #[cfg(unwind)]
345        let record = record.arg(native.clone());
346        #[cfg(unwind)]
347        let datakit = datakit.arg(native.clone());
348        #[cfg(unwind)]
349        let top = top.arg(native.clone());
350        #[cfg(unwind)]
351        let dump = dump.arg(native.clone());
352
353        // Nonblocking isn't an option for freebsd, remove
354        #[cfg(not(target_os="freebsd"))]
355        let record = record.arg(nonblocking.clone());
356        #[cfg(not(target_os="freebsd"))]
357        let datakit = datakit.arg(nonblocking.clone());
358        #[cfg(not(target_os="freebsd"))]
359        let top = top.arg(nonblocking.clone());
360        #[cfg(not(target_os="freebsd"))]
361        let dump = dump.arg(nonblocking.clone());
362
363        let mut app = Command::new(crate_name!())
364            .version(crate_version!())
365            .about(crate_description!())
366            .subcommand_required(true)
367            .infer_subcommands(true)
368            .arg_required_else_help(true)
369            .global_setting(clap::AppSettings::DeriveDisplayOrder)
370            .subcommand(record)
371            .subcommand(datakit)
372            .subcommand(top)
373            .subcommand(dump)
374            .subcommand(completions);
375        let matches = app.clone().try_get_matches_from(args)?;
376        info!("Command line args: {:?}", matches);
377
378        let mut config = Config::default();
379
380        let (subcommand, matches) = matches.subcommand().unwrap();
381
382        match subcommand {
383            "record" | "datakit" => {
384                config.sampling_rate = matches.value_of_t("rate")?;
385                config.duration = match matches.value_of("duration") {
386                    Some("unlimited") | None => RecordDuration::Unlimited,
387                    Some(seconds) => RecordDuration::Seconds(seconds.parse().expect("invalid duration"))
388                };
389                if subcommand == "datakit" {
390                    config.duration = match matches.value_of("duration") {
391                        None => RecordDuration::Seconds(60),
392                        Some(seconds) => RecordDuration::Seconds(seconds.parse().expect("invalid duration"))
393                    };
394                    config.loop_duration = match matches.value_of("loop") {
395                        None => 0,
396                        Some(seconds) => seconds.parse().expect("invalid loop parameter")
397                    };
398                    config.host = match matches.value_of("host") {
399                        Some(host) => host.to_owned(),
400                        None => String::from("127.0.0.1")
401                    };
402                    config.port = matches.value_of_t("port")?;
403                    config.service = match matches.value_of("service") {
404                        Some(service) => service.to_owned(),
405                        None => String::from("unnamed-service")
406                    };
407                    config.env = match matches.value_of("env") {
408                        Some(env) => env.to_owned(),
409                        None => String::from("unnamed-env")
410                    };
411                    config.version = match matches.value_of("version") {
412                        Some(version) => version.to_owned(),
413                        None => String::from("unnamed-version")
414                    };
415                    config.show_line_numbers = true;
416                    config.lineno = LineNo::LastInstruction;
417                    config.include_thread_ids = true;
418                    config.hide_progress = true;
419                } else {
420                    if matches.occurrences_of("nolineno") > 0 && matches.occurrences_of("function") > 0 {
421                        eprintln!("--function & --nolinenos can't be used together");
422                        std::process::exit(1);
423                    }
424                    config.show_line_numbers = matches.occurrences_of("nolineno") == 0;
425                    config.lineno = if matches.occurrences_of("nolineno") > 0 { LineNo::NoLine } else if matches.occurrences_of("function") > 0 { LineNo::FirstLineNo } else { LineNo::LastInstruction };
426                    config.include_thread_ids = matches.occurrences_of("threads") > 0;
427                    config.format = Some(matches.value_of_t("format")?);
428                    config.filename = matches.value_of("output").map(|f| f.to_owned());
429                    config.hide_progress = matches.occurrences_of("hideprogress") > 0;
430                }
431            },
432            "top" => {
433                config.sampling_rate = matches.value_of_t("rate")?;
434            },
435            "dump" => {
436                config.dump_json = matches.occurrences_of("json") > 0;
437                config.dump_locals = matches.occurrences_of("locals");
438            },
439            "completions" => {
440                let shell = matches.get_one::<clap_complete::Shell>("shell").unwrap();
441                let app_name = app.get_name().to_string();
442                clap_complete::generate(*shell, &mut app, app_name, &mut std::io::stdout());
443                std::process::exit(0);
444            }
445            _ => {}
446        }
447
448        match subcommand {
449            "record" | "top" | "datakit" => {
450                config.python_program = matches.values_of("python_program").map(|vals| {
451                    vals.map(|v| v.to_owned()).collect()
452                });
453                config.gil_only = matches.occurrences_of("gil") > 0;
454                config.include_idle = matches.occurrences_of("idle") > 0;
455            },
456            _ => {}
457        }
458
459        config.subprocesses = matches.occurrences_of("subprocesses") > 0;
460        config.command = subcommand.to_owned();
461
462        // options that can be shared between subcommands
463        config.pid = matches.value_of("pid").map(|p| p.parse().expect("invalid pid"));
464        match subcommand {
465            "datakit" => config.full_filenames = true,
466            _ => config.full_filenames = matches.occurrences_of("full_filenames") > 0,
467        }
468
469        if cfg!(unwind) {
470            config.native = matches.occurrences_of("native") > 0;
471        }
472
473        config.capture_output = config.command != "record" || matches.occurrences_of("capture") > 0;
474        if !config.capture_output {
475            config.hide_progress = true;
476        }
477
478        if matches.occurrences_of("nonblocking") > 0 {
479            // disable native profiling if invalidly asked for
480            if config.native  {
481                eprintln!("Can't get native stack traces with the --nonblocking option.");
482                std::process::exit(1);
483            }
484            config.blocking = LockingStrategy::NonBlocking;
485        }
486
487        #[cfg(windows)]
488        {
489            if config.native && config.subprocesses {
490                // the native extension profiling code relies on dbghelp library, which doesn't
491                // seem to work when connecting to multiple processes. disallow
492                eprintln!("Can't get native stack traces with the ---subprocesses option on windows.");
493                std::process::exit(1);
494            }
495        }
496
497        #[cfg(target_os="freebsd")]
498        {
499           if config.pid.is_some() {
500               if std::env::var("PYSPY_ALLOW_FREEBSD_ATTACH").is_err() {
501                    eprintln!("On FreeBSD, running py-spy can cause an exception in the profiled process if the process \
502                        is calling 'socket.connect'.");
503                    eprintln!("While this is fixed in recent versions of python, you need to acknowledge the risk here by \
504                        setting an environment variable PYSPY_ALLOW_FREEBSD_ATTACH to run this command.");
505                    eprintln!("\nSee https://github.com/benfred/py-spy/issues/147 for more information");
506                    std::process::exit(-1);
507               }
508            }
509        }
510        Ok(config)
511    }
512}
513
514#[cfg(test)]
515mod tests {
516    use super::*;
517    fn get_config(cmd: &str) -> clap::Result<Config> {
518        #[cfg(target_os="freebsd")]
519        std::env::set_var("PYSPY_ALLOW_FREEBSD_ATTACH", "1");
520        let args: Vec<String> = cmd.split_whitespace().map(|x| x.to_owned()).collect();
521        Config::from_args(&args)
522    }
523
524    #[test]
525    fn test_parse_record_args() {
526        // basic use case
527        let config = get_config("py-spy-for-datakit record --pid 1234 --output foo").unwrap();
528        assert_eq!(config.pid, Some(1234));
529        assert_eq!(config.filename, Some(String::from("foo")));
530        assert_eq!(config.format, Some(FileFormat::flamegraph));
531        assert_eq!(config.command, String::from("record"));
532
533        // same command using short versions of everything
534        let short_config = get_config("py-spy-for-datakit r -p 1234 -o foo").unwrap();
535        assert_eq!(config, short_config);
536
537        // missing the --pid argument should fail
538        assert_eq!(get_config("py-spy-for-datakit record -o foo").unwrap_err().kind,
539                   clap::ErrorKind::MissingRequiredArgument);
540
541        // but should work when passed a python program
542        let program_config = get_config("py-spy-for-datakit r -o foo -- python test.py").unwrap();
543        assert_eq!(program_config.python_program, Some(vec![String::from("python"), String::from("test.py")]));
544        assert_eq!(program_config.pid, None);
545
546        // passing an invalid file format should fail
547        assert_eq!(get_config("py-spy-for-datakit r -p 1234 -o foo -f unknown").unwrap_err().kind,
548                   clap::ErrorKind::InvalidValue);
549
550        // test out overriding these params by setting flags
551        assert_eq!(config.include_idle, false);
552        assert_eq!(config.gil_only, false);
553        assert_eq!(config.include_thread_ids, false);
554
555        let config_flags = get_config("py-spy-for-datakit r -p 1234 -o foo --idle --gil --threads").unwrap();
556        assert_eq!(config_flags.include_idle, true);
557        assert_eq!(config_flags.gil_only, true);
558        assert_eq!(config_flags.include_thread_ids, true);
559    }
560
561    #[test]
562    fn test_parse_dump_args() {
563        // basic use case
564        let config = get_config("py-spy-for-datakit dump --pid 1234").unwrap();
565        assert_eq!(config.pid, Some(1234));
566        assert_eq!(config.command, String::from("dump"));
567
568        // short version
569        // let short_config = get_config("py-spy-for-datakit d -p 1234").unwrap();
570        // assert_eq!(config, short_config);
571
572        // missing the --pid argument should fail
573        assert_eq!(get_config("py-spy-for-datakit dump").unwrap_err().kind,
574                   clap::ErrorKind::MissingRequiredArgument);
575    }
576
577    #[test]
578    fn test_parse_top_args() {
579        // basic use case
580        let config = get_config("py-spy-for-datakit top --pid 1234").unwrap();
581        assert_eq!(config.pid, Some(1234));
582        assert_eq!(config.command, String::from("top"));
583
584        // short version
585        let short_config = get_config("py-spy-for-datakit t -p 1234").unwrap();
586        assert_eq!(config, short_config);
587    }
588
589    #[test]
590    fn test_parse_args() {
591        assert_eq!(get_config("py-spy-for-datakit dude").unwrap_err().kind,
592                   clap::ErrorKind::UnrecognizedSubcommand);
593    }
594}