modality_plugin_utils/
lib.rs

1#![deny(warnings, clippy::all)]
2
3use clap::Parser;
4use modality_auth_token::{AuthToken, MODALITY_AUTH_TOKEN_ENV_VAR};
5use modality_reflector_config::{AttrKeyEqValuePair, ConfigLoadError, TopLevelIngest};
6use std::collections::BTreeMap;
7use std::future::Future;
8use std::path::{Path, PathBuf};
9use std::pin::Pin;
10use std::str::FromStr;
11use url::Url;
12
13pub const MODALITY_STORAGE_SERVICE_PORT_DEFAULT: u16 = 14182;
14
15pub const CLI_TEMPLATE: &str = "\
16            {about}\n\n\
17            USAGE:\n    {usage}\n\
18            \n\
19            {all-args}\
20        ";
21
22/// Handles boilerplate setup for:
23/// * tracing_subscriber configuration
24/// * Signal pipe fixup
25/// * Printing out errors
26/// * Exit code management
27///
28/// The server constructor function consumes config, custom cli args, and a shutdown signal future,
29/// then returns an indefinitely-running future that represents the server.
30///
31/// This function blocks waiting for either the constructed server future to finish
32/// or a CTRL+C style signal.
33///
34/// Returns the process's desired exit code.
35pub fn server_main<Opts, ServerFuture, ServerConstructor>(
36    server_constructor: ServerConstructor,
37) -> i32
38where
39    Opts: Parser,
40    Opts: BearingConfigFilePath,
41    ServerFuture: Future<Output = Result<(), Box<dyn std::error::Error + 'static>>> + 'static,
42    ServerConstructor: FnOnce(
43        modality_reflector_config::Config,
44        AuthToken,
45        Opts,
46        Pin<Box<dyn Future<Output = ()> + Send + 'static>>,
47    ) -> ServerFuture,
48{
49    let _ = reset_signal_pipe_handler();
50    let opts = match Opts::try_parse_from(std::env::args()) {
51        Ok(opts) => opts,
52        Err(e)
53            if e.kind() == clap::error::ErrorKind::DisplayHelp
54                || e.kind() == clap::error::ErrorKind::DisplayVersion =>
55        {
56            // Need to print to stdout for these command variants in support of manual generation
57            if let Err(e) = e.print() {
58                error_print(&e);
59                return exitcode::SOFTWARE;
60            }
61            return exitcode::OK;
62        }
63        Err(e) => {
64            error_print(&e);
65            return exitcode::SOFTWARE;
66        }
67    };
68
69    let config = if let Some(config_file) = opts.config_file_path() {
70        match modality_reflector_config::try_from_file(config_file) {
71            Ok(c) => c,
72            Err(config_load_error) => {
73                // N.B. tracing subscriber is not configured yet, this may disappear
74                tracing::error!(
75                    err = &config_load_error as &dyn std::error::Error,
76                    "Failed to load config file provided by command line args, exiting."
77                );
78                let exit_code = match &config_load_error {
79                    ConfigLoadError::Io(_) => exitcode::IOERR,
80                    _ => exitcode::CONFIG,
81                };
82                error_print(&config_load_error);
83                return exit_code;
84            }
85        }
86    } else if let Ok(config_file) = std::env::var(modality_reflector_config::CONFIG_ENV_VAR) {
87        match modality_reflector_config::try_from_file(&PathBuf::from(config_file)) {
88            Ok(c) => c,
89            Err(config_load_error) => {
90                // N.B. tracing subscriber is not configured yet, this may disappear
91                tracing::error!(
92                    err = &config_load_error as &dyn std::error::Error,
93                    "Failed to load config file provided by environment variable, exiting."
94                );
95                let exit_code = match &config_load_error {
96                    ConfigLoadError::Io(_) => exitcode::IOERR,
97                    _ => exitcode::CONFIG,
98                };
99                error_print(&config_load_error);
100                return exit_code;
101            }
102        }
103    } else {
104        // N.B. tracing subscriber is not configured yet, this may disappear
105        tracing::warn!("No config file specified, using default configuration.");
106        modality_reflector_config::Config::default()
107    };
108
109    // setup custom tracer including ModalityLayer
110    let maybe_modality = {
111        let mut modality_tracing_options = tracing_modality::Options::default();
112        let maybe_preferred_ingest_parent_socket = if let Some(ingest_parent_url) = config
113            .ingest
114            .as_ref()
115            .and_then(|ing| ing.protocol_parent_url.as_ref())
116        {
117            ingest_parent_url
118                .socket_addrs(|| Some(14182))
119                .ok()
120                .and_then(|sockets| sockets.into_iter().next())
121        } else {
122            None
123        };
124        if let Some(socket) = maybe_preferred_ingest_parent_socket {
125            modality_tracing_options = modality_tracing_options.with_server_address(socket);
126        }
127
128        use tracing_subscriber::layer::{Layer, SubscriberExt};
129
130        use tracing_subscriber::filter::{EnvFilter, LevelFilter};
131        let (disp, maybe_modality_ingest_handle) =
132            match tracing_modality::blocking::ModalityLayer::init_with_options(
133                modality_tracing_options,
134            ) {
135                Ok((modality_layer, modality_ingest_handle)) => {
136                    // Trace output through both the stdout formatter and modality's ingest pipeline
137                    (
138                        tracing::Dispatch::new(
139                            tracing_subscriber::Registry::default()
140                                .with(
141                                    modality_layer.with_filter(
142                                        EnvFilter::builder()
143                                            .with_default_directive(LevelFilter::INFO.into())
144                                            .from_env_lossy(),
145                                    ),
146                                )
147                                .with(
148                                    tracing_subscriber::fmt::Layer::default().with_filter(
149                                        EnvFilter::builder()
150                                            .with_default_directive(LevelFilter::INFO.into())
151                                            .from_env_lossy(),
152                                    ),
153                                ),
154                        ),
155                        Some(modality_ingest_handle),
156                    )
157                }
158                Err(modality_init_err) => {
159                    eprintln!("Modality tracing layer initialization error.");
160                    error_print(&modality_init_err);
161                    // Only do trace output through the stdout formatter
162                    (
163                        tracing::Dispatch::new(
164                            tracing_subscriber::Registry::default().with(
165                                tracing_subscriber::fmt::Layer::default().with_filter(
166                                    EnvFilter::builder()
167                                        .with_default_directive(LevelFilter::INFO.into())
168                                        .from_env_lossy(),
169                                ),
170                            ),
171                        ),
172                        None,
173                    )
174                }
175            };
176
177        tracing::dispatcher::set_global_default(disp).expect("set global tracer");
178
179        maybe_modality_ingest_handle
180    };
181
182    let auth_token = if let Ok(auth_token_env_str) = std::env::var(MODALITY_AUTH_TOKEN_ENV_VAR) {
183        match modality_auth_token::decode_auth_token_hex(auth_token_env_str.as_str()) {
184            Ok(at) => at,
185            Err(auth_token_deserialization_err) => {
186                tracing::error!(
187                    err = &auth_token_deserialization_err as &dyn std::error::Error,
188                    "Failed to interpret auth token provide by environment variable, exiting."
189                );
190                error_print(&auth_token_deserialization_err);
191                return exitcode::CONFIG;
192            }
193        }
194    } else {
195        tracing::warn!(
196            "No auth token provided by environment variable {}, falling back to empty auth token",
197            MODALITY_AUTH_TOKEN_ENV_VAR
198        );
199        AuthToken::from(vec![])
200    };
201
202    let runtime = tokio::runtime::Builder::new_multi_thread()
203        .enable_all()
204        .build()
205        .expect("Could not construct tokio runtime");
206
207    let ctrlc = tokio::signal::ctrl_c();
208    let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel();
209    let server_done = server_constructor(
210        config,
211        auth_token,
212        opts,
213        Box::pin(async {
214            let _ = shutdown_rx.await.map_err(|_recv_err| {
215                tracing::error!("Shutdown signal channel unexpectedly closed early.");
216            });
217        }),
218    );
219
220    let mut maybe_shutdown_tx = Some(shutdown_tx);
221    let out_exit_code = runtime.block_on(async {
222        tokio::select! {
223            signal_result = ctrlc => {
224                match signal_result {
225                    Ok(()) => {
226                        if let Some(shutdown_tx) = maybe_shutdown_tx.take() {
227                            let _ = shutdown_tx.send(());
228                        }
229                        tracing::info!("Received ctrl+c, exiting.");
230                        exitcode::OK
231                    },
232                    Err(io_err) => {
233                        if let Some(shutdown_tx) = maybe_shutdown_tx.take() {
234                            let _ = shutdown_tx.send(());
235                        }
236                        error_print(&io_err);
237                        tracing::error!("Failed to install ctrl+c handler, exiting.");
238                        exitcode::IOERR
239                    }
240                }
241            }
242            server_result = server_done => {
243                match server_result {
244                    Ok(()) => {
245                        tracing::info!("Done.");
246                        exitcode::OK
247                    },
248                    Err(e) => {
249                        tracing::error!("Server crashed early, exiting.");
250                        error_print(e.as_ref());
251                        exitcode::SOFTWARE
252                    }
253                }
254            }
255        }
256    });
257    // Drop the runtime a little ahead of function exit
258    // in order to ensure that the shutdown_tx side of
259    // the shutdown signal channel does not drop first.
260    std::mem::drop(runtime);
261    if let Some(modality_ingest_handle) = maybe_modality {
262        modality_ingest_handle.finish();
263    }
264    let _maybe_shutdown_tx = maybe_shutdown_tx;
265    out_exit_code
266}
267
268pub(crate) fn error_print(err: &dyn std::error::Error) {
269    fn print_err_node(err: &dyn std::error::Error) {
270        eprintln!("{err}");
271    }
272
273    print_err_node(err);
274
275    let mut cause = err.source();
276    while let Some(err) = cause {
277        eprint!("Caused by: ");
278        print_err_node(err);
279        cause = err.source();
280    }
281}
282
283// Used to prevent panics on broken pipes.
284// See:
285//   https://github.com/rust-lang/rust/issues/46016#issuecomment-605624865
286fn reset_signal_pipe_handler() -> Result<(), Box<dyn std::error::Error>> {
287    #[cfg(target_family = "unix")]
288    {
289        use nix::sys::signal;
290
291        unsafe {
292            signal::signal(signal::Signal::SIGPIPE, signal::SigHandler::SigDfl)?;
293        }
294    }
295
296    Ok(())
297}
298
299pub trait BearingConfigFilePath {
300    fn config_file_path(&self) -> Option<&Path>;
301}
302
303pub fn merge_ingest_protocol_parent_url(
304    cli_provided: Option<&Url>,
305    cfg: &modality_reflector_config::Config,
306) -> Url {
307    if let Some(parent_url) = cli_provided {
308        parent_url.clone()
309    } else if let Some(TopLevelIngest {
310        protocol_parent_url: Some(parent_url),
311        ..
312    }) = &cfg.ingest
313    {
314        parent_url.clone()
315    } else {
316        let fallback = Url::from_str("modality-ingest://127.0.0.1").unwrap();
317        tracing::warn!(
318            "Plugin falling back to an ingest protocol parent URL of {}",
319            &fallback
320        );
321        fallback
322    }
323}
324
325#[derive(Debug, thiserror::Error)]
326pub enum ProtocolParentError {
327    #[error("Failed to provide an ingest protocol parent URL.")]
328    IngestProtocolParentUrlMissing,
329
330    #[error("Failed to resolve ingest protocol parent URL to an address '{0}'.")]
331    IngestProtocolParentAddressResolution(Url),
332}
333
334pub fn merge_timeline_attrs(
335    cli_provided_attrs: &[AttrKeyEqValuePair],
336    cfg: &modality_reflector_config::Config,
337) -> BTreeMap<modality_reflector_config::AttrKey, modality_reflector_config::AttrVal> {
338    // Merge additional and override timeline attrs from cfg and opts
339    // TODO deal with conflicting reserved attrs in #2098
340    let mut timeline_attrs = BTreeMap::new();
341
342    use modality_reflector_config::AttrKey;
343    fn ensure_timeline_prefix(k: AttrKey) -> AttrKey {
344        if k.as_ref().starts_with("timeline.") {
345            k
346        } else if k.as_ref().starts_with('.') {
347            AttrKey::from("timeline".to_owned() + k.as_ref())
348        } else {
349            AttrKey::from("timeline.".to_owned() + k.as_ref())
350        }
351    }
352    if let Some(tli) = &cfg.ingest {
353        for kvp in tli
354            .timeline_attributes
355            .additional_timeline_attributes
356            .iter()
357            .cloned()
358        {
359            let _ = timeline_attrs.insert(ensure_timeline_prefix(kvp.0), kvp.1);
360        }
361        for kvp in tli
362            .timeline_attributes
363            .override_timeline_attributes
364            .iter()
365            .cloned()
366        {
367            let _ = timeline_attrs.insert(ensure_timeline_prefix(kvp.0), kvp.1);
368        }
369    }
370    // The CLI-provided attrs will take precedence over config
371    for kvp in cli_provided_attrs.iter().cloned() {
372        let _ = timeline_attrs.insert(ensure_timeline_prefix(kvp.0), kvp.1);
373    }
374    timeline_attrs
375}
376
377#[cfg(test)]
378mod tests {
379    #[test]
380    fn it_works() {
381        let result = 2 + 2;
382        assert_eq!(result, 4);
383    }
384}