cli_batteries/
lib.rs

1// TODO:
2// https://crates.io/crates/shadow-rs
3// https://crates.io/crates/argfile
4// https://docs.rs/wild/latest/wild/
5// https://crates.io/crates/clap_complete
6
7#![doc = include_str!("../Readme.md")]
8#![warn(clippy::all, clippy::pedantic, clippy::cargo, clippy::nursery)]
9
10mod allocator;
11mod build;
12mod heartbeat;
13mod metered_allocator;
14mod prometheus;
15mod rand;
16mod rayon;
17mod shutdown;
18mod trace;
19mod version;
20
21pub use crate::{
22    build::build_rs,
23    heartbeat::heartbeat,
24    shutdown::{await_shutdown, is_shutting_down, shutdown},
25    version::Version,
26};
27use clap::{Args, CommandFactory, FromArgMatches, Parser};
28use eyre::{Error as EyreError, Report, Result as EyreResult, WrapErr};
29use std::{future::Future, ptr::addr_of};
30use tokio::runtime;
31use tracing::{error, info};
32
33#[cfg(feature = "mock-shutdown")]
34pub use crate::shutdown::reset_shutdown;
35
36#[cfg(feature = "metered-allocator")]
37use crate::metered_allocator::MeteredAllocator;
38
39#[cfg(feature = "otlp")]
40pub use crate::trace::{trace_from_headers, trace_to_headers};
41
42/// Implement [`Default`] for a type that implements [`Parser`] and has
43/// default values set for all fields.
44#[macro_export]
45macro_rules! default_from_clap {
46    ($ty:ty) => {
47        impl ::std::default::Default for $ty {
48            fn default() -> Self {
49                use ::clap::Parser;
50                use ::std::ffi::OsString;
51                <Self as Parser>::parse_from::<Option<OsString>, OsString>(None)
52            }
53        }
54    };
55}
56
57// TODO: Use the new command / arg distinction from clap.
58#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Default, Parser)]
59#[group(skip)]
60struct Options<O: Args> {
61    #[clap(flatten)]
62    tracing: trace::Options,
63
64    #[cfg(feature = "rand")]
65    #[clap(flatten)]
66    rand: rand::Options,
67
68    #[cfg(feature = "rayon")]
69    #[clap(flatten)]
70    rayon: rayon::Options,
71
72    #[cfg(feature = "prometheus")]
73    #[clap(flatten)]
74    prometheus: prometheus::Options,
75
76    #[clap(flatten)]
77    app: O,
78}
79
80/// Run the program.
81#[allow(clippy::needless_pass_by_value)]
82pub fn run<A, O, F, E>(version: Version, app: A)
83where
84    A: FnOnce(O) -> F,
85    O: Args,
86    F: Future<Output = Result<(), E>>,
87    E: Into<Report> + Send + Sync + 'static,
88{
89    if let Err(report) = run_fallible(&version, app) {
90        error!(?report, "{}", report);
91        error!("Program terminating abnormally");
92        std::process::exit(1);
93    }
94}
95
96fn run_fallible<A, O, F, E>(version: &Version, app: A) -> EyreResult<()>
97where
98    A: FnOnce(O) -> F,
99    O: Args,
100    F: Future<Output = Result<(), E>>,
101    E: Into<Report> + Send + Sync + 'static,
102{
103    // Install panic handler
104    // TODO: write panics to log, like Err results.
105    color_eyre::config::HookBuilder::default()
106        .issue_url(format!("{}/issues/new", version.pkg_repo))
107        .add_issue_metadata(
108            "version",
109            format!("{} {}", version.pkg_name, version.long_version),
110        )
111        .install()
112        .map_err(|err| {
113            eprintln!("Error: {}", err);
114            err
115        })?;
116
117    // Parse CLI and handle help and version (which will stop the application).
118    let matches = Options::<O>::command()
119        .name(version.pkg_name)
120        .version(version.pkg_version)
121        .long_version(version.long_version)
122        .get_matches();
123    let options = Options::<O>::from_arg_matches(&matches)?;
124
125    // Start allocator metering (if enabled)
126    allocator::start_metering();
127
128    // TODO: Early logging to catch errors before we start the runtime.
129
130    // Launch Tokio runtime
131    // TODO: https://docs.rs/tokio/latest/tokio/runtime/struct.Builder.html#method.unhandled_panic
132    runtime::Builder::new_multi_thread()
133        .enable_all()
134        .build()
135        .wrap_err("Error creating Tokio runtime")?
136        .block_on(async {
137            // Start heartbeat
138            let heartbeat = tokio::spawn(heartbeat());
139
140            // Monitor for Ctrl-C
141            #[cfg(feature = "signals")]
142            shutdown::watch_signals();
143
144            // Start log system
145            let load_addr = addr_of!(app) as usize;
146            options.tracing.init(version, load_addr).map_err(|err| {
147                eprintln!("Error: {}", err);
148                err
149            })?;
150
151            #[cfg(feature = "rand")]
152            options.rand.init();
153
154            #[cfg(feature = "rayon")]
155            options.rayon.init()?;
156
157            // Start prometheus
158            #[cfg(feature = "prometheus")]
159            let prometheus = tokio::spawn(prometheus::main(options.prometheus));
160
161            // Start main
162            app(options.app).await.map_err(E::into)?;
163
164            // Initiate shutdown if main returns
165            shutdown::shutdown();
166
167            // Wait for prometheus to finish
168            #[cfg(feature = "prometheus")]
169            prometheus.await??;
170
171            // Submit remaining traces
172            trace::shutdown()?;
173
174            // Join heartbeat thread
175            heartbeat.await?;
176
177            Result::<(), EyreError>::Ok(())
178        })?;
179
180    // Terminate successfully
181    info!("Program terminating normally");
182    Ok(())
183}
184
185#[cfg(test)]
186pub mod test {
187    use tracing::{error, info, warn};
188    use tracing_test::traced_test;
189
190    #[test]
191    #[traced_test]
192    fn test_with_log_output() {
193        error!("logged on the error level");
194        assert!(logs_contain("logged on the error level"));
195    }
196
197    #[tokio::test]
198    #[traced_test]
199    #[allow(clippy::semicolon_if_nothing_returned)] // False positive
200    async fn async_test_with_log() {
201        // Local log
202        info!("This is being logged on the info level");
203
204        // Log from a spawned task (which runs in a separate thread)
205        tokio::spawn(async {
206            warn!("This is being logged on the warn level from a spawned task");
207        })
208        .await
209        .unwrap();
210
211        // Ensure that `logs_contain` works as intended
212        assert!(logs_contain("logged on the info level"));
213        assert!(logs_contain("logged on the warn level"));
214        assert!(!logs_contain("logged on the error level"));
215    }
216}