htsget_config/config/
mod.rs

1//! Structs to serialize and deserialize the htsget-rs config options.
2//!
3
4use std::fmt::Debug;
5use std::io;
6use std::path::{Path, PathBuf};
7
8use crate::config::advanced::FormattingStyle;
9use crate::config::data_server::DataServerEnabled;
10use crate::config::location::{Location, LocationEither, Locations};
11use crate::config::parser::from_path;
12use crate::config::service_info::ServiceInfo;
13use crate::config::ticket_server::TicketServerConfig;
14use crate::error::Error::{ArgParseError, ParseError, TracingError};
15use crate::error::Result;
16use crate::storage::file::File;
17use crate::storage::Backend;
18use clap::{Args as ClapArgs, Command, FromArgMatches, Parser};
19use serde::{Deserialize, Serialize};
20use tracing::subscriber::set_global_default;
21use tracing_subscriber::fmt::{format, layer};
22use tracing_subscriber::layer::SubscriberExt;
23use tracing_subscriber::{EnvFilter, Registry};
24
25pub mod advanced;
26pub mod data_server;
27pub mod location;
28pub mod parser;
29pub mod service_info;
30pub mod ticket_server;
31
32/// The usage string for htsget-rs.
33pub const USAGE: &str = "To configure htsget-rs use a config file or environment variables. \
34See the documentation of the htsget-config crate for more information.";
35
36/// The command line arguments allowed for the htsget-rs executables.
37#[derive(Parser, Debug)]
38#[command(author, version, about, long_about = USAGE)]
39struct Args {
40  #[arg(
41    short,
42    long,
43    env = "HTSGET_CONFIG",
44    help = "Set the location of the config file"
45  )]
46  config: Option<PathBuf>,
47  #[arg(short, long, exclusive = true, help = "Print a default config file")]
48  print_default_config: bool,
49}
50
51/// Simplified config.
52#[derive(Serialize, Deserialize, Debug, Clone)]
53#[serde(default, deny_unknown_fields)]
54pub struct Config {
55  ticket_server: TicketServerConfig,
56  data_server: DataServerEnabled,
57  service_info: ServiceInfo,
58  locations: Locations,
59  formatting_style: FormattingStyle,
60}
61
62impl Config {
63  /// Create a config.
64  pub fn new(
65    formatting_style: FormattingStyle,
66    ticket_server: TicketServerConfig,
67    data_server: DataServerEnabled,
68    service_info: ServiceInfo,
69    locations: Locations,
70  ) -> Self {
71    Self {
72      formatting_style,
73      ticket_server,
74      data_server,
75      service_info,
76      locations,
77    }
78  }
79
80  /// Get the ticket server config.
81  pub fn formatting_style(&self) -> FormattingStyle {
82    self.formatting_style
83  }
84
85  /// Get the ticket server config.
86  pub fn ticket_server(&self) -> &TicketServerConfig {
87    &self.ticket_server
88  }
89
90  /// Get the data server config.
91  pub fn data_server(&self) -> &DataServerEnabled {
92    &self.data_server
93  }
94
95  /// Get the service info config.
96  pub fn service_info(&self) -> &ServiceInfo {
97    &self.service_info
98  }
99
100  /// Get a mutable instance of the service info config.
101  pub fn service_info_mut(&mut self) -> &mut ServiceInfo {
102    &mut self.service_info
103  }
104
105  /// Get the location.
106  pub fn locations(&self) -> &[LocationEither] {
107    self.locations.as_slice()
108  }
109
110  pub fn into_locations(self) -> Locations {
111    self.locations
112  }
113
114  /// Parse the command line arguments. Returns the config path, or prints the default config.
115  /// Augment the `Command` args from the `clap` parser. Returns an error if the
116  pub fn parse_args_with_command(augment_args: Command) -> Result<Option<PathBuf>> {
117    let args = Args::from_arg_matches(&Args::augment_args(augment_args).get_matches())
118      .map_err(|err| ArgParseError(err.to_string()))?;
119
120    if args.config.as_ref().is_some_and(|path| !path.exists()) {
121      return Err(ParseError("config file not found".to_string()));
122    }
123
124    Ok(Self::parse_with_args(args))
125  }
126
127  /// Parse the command line arguments. Returns the config path, or prints the default config.
128  pub fn parse_args() -> Option<PathBuf> {
129    Self::parse_with_args(Args::parse())
130  }
131
132  fn parse_with_args(args: Args) -> Option<PathBuf> {
133    if args.print_default_config {
134      println!(
135        "{}",
136        toml::ser::to_string_pretty(&Config::default()).unwrap()
137      );
138      None
139    } else {
140      Some(args.config.unwrap_or_else(|| "".into()))
141    }
142  }
143
144  /// Read a config struct from a TOML file.
145  pub fn from_path(path: &Path) -> io::Result<Self> {
146    let config: Self = from_path(path)?;
147    Ok(config.resolvers_from_data_server_config()?)
148  }
149
150  /// Setup tracing, using a global subscriber.
151  pub fn setup_tracing(&self) -> Result<()> {
152    let env_filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info"));
153
154    let subscriber = Registry::default().with(env_filter);
155
156    match self.formatting_style() {
157      FormattingStyle::Full => set_global_default(subscriber.with(layer())),
158      FormattingStyle::Compact => {
159        set_global_default(subscriber.with(layer().event_format(format().compact())))
160      }
161      FormattingStyle::Pretty => {
162        set_global_default(subscriber.with(layer().event_format(format().pretty())))
163      }
164      FormattingStyle::Json => {
165        set_global_default(subscriber.with(layer().event_format(format().json())))
166      }
167    }
168    .map_err(|err| TracingError(err.to_string()))?;
169
170    Ok(())
171  }
172
173  /// Set the local resolvers from the data server config.
174  pub fn resolvers_from_data_server_config(mut self) -> Result<Self> {
175    self
176      .locations
177      .as_mut_slice()
178      .iter_mut()
179      .map(|location| {
180        if let LocationEither::Simple(simple) = location {
181          // Fall through only if the backend is File and default
182          let file_location = if let Ok(location) = simple.backend().as_file() {
183            location
184          } else {
185            return Ok(());
186          };
187
188          if let DataServerEnabled::Some(ref data_server) = self.data_server {
189            let prefix = simple.prefix().to_string();
190
191            // Don't update the local path as that comes in from the config.
192            let file: File = data_server.try_into()?;
193            let file = file.set_local_path(file_location.local_path().to_string());
194
195            *location = LocationEither::Simple(Location::new(Backend::File(file), prefix));
196          }
197        }
198
199        Ok(())
200      })
201      .collect::<Result<Vec<()>>>()?;
202
203    Ok(self)
204  }
205}
206
207impl Default for Config {
208  fn default() -> Self {
209    Self {
210      formatting_style: FormattingStyle::Full,
211      ticket_server: Default::default(),
212      data_server: DataServerEnabled::Some(Default::default()),
213      service_info: Default::default(),
214      locations: Default::default(),
215    }
216  }
217}
218
219#[cfg(test)]
220pub(crate) mod tests {
221  use std::fmt::Display;
222
223  use super::*;
224  use crate::config::parser::from_str;
225  use crate::tls::tests::with_test_certificates;
226  use crate::types::Scheme;
227  use figment::Jail;
228  use http::uri::Authority;
229  #[cfg(feature = "url")]
230  use http::Uri;
231  use serde::de::DeserializeOwned;
232  use serde_json::json;
233
234  fn test_config<K, V, F>(contents: Option<&str>, env_variables: Vec<(K, V)>, test_fn: F)
235  where
236    K: AsRef<str>,
237    V: Display,
238    F: Fn(Config),
239  {
240    Jail::expect_with(|jail| {
241      let file = "test.toml";
242
243      if let Some(contents) = contents {
244        jail.create_file(file, contents)?;
245      }
246
247      for (key, value) in env_variables {
248        jail.set_env(key, value);
249      }
250
251      let path = Path::new(file);
252      test_fn(Config::from_path(path).map_err(|err| err.to_string())?);
253
254      test_fn(
255        from_path::<Config>(path)
256          .map_err(|err| err.to_string())?
257          .resolvers_from_data_server_config()
258          .map_err(|err| err.to_string())?,
259      );
260      test_fn(
261        from_str::<Config>(contents.unwrap_or(""))
262          .map_err(|err| err.to_string())?
263          .resolvers_from_data_server_config()
264          .map_err(|err| err.to_string())?,
265      );
266
267      Ok(())
268    });
269  }
270
271  pub(crate) fn test_config_from_env<K, V, F>(env_variables: Vec<(K, V)>, test_fn: F)
272  where
273    K: AsRef<str>,
274    V: Display,
275    F: Fn(Config),
276  {
277    test_config(None, env_variables, test_fn);
278  }
279
280  pub(crate) fn test_config_from_file<F>(contents: &str, test_fn: F)
281  where
282    F: Fn(Config),
283  {
284    test_config(Some(contents), Vec::<(&str, &str)>::new(), test_fn);
285  }
286
287  pub(crate) fn test_serialize_and_deserialize<T, D, F>(input: &str, expected: T, get_result: F)
288  where
289    T: Debug + PartialEq,
290    F: Fn(D) -> T,
291    D: DeserializeOwned + Serialize + Clone,
292  {
293    let config: D = toml::from_str(input).unwrap();
294    assert_eq!(expected, get_result(config.clone()));
295
296    let serialized = toml::to_string(&config).unwrap();
297    let deserialized = toml::from_str(&serialized).unwrap();
298    assert_eq!(expected, get_result(deserialized));
299  }
300
301  #[test]
302  fn config_ticket_server_addr_env() {
303    test_config_from_env(
304      vec![("HTSGET_TICKET_SERVER_ADDR", "127.0.0.1:8082")],
305      |config| {
306        assert_eq!(
307          config.ticket_server().addr(),
308          "127.0.0.1:8082".parse().unwrap()
309        );
310      },
311    );
312  }
313
314  #[test]
315  fn config_ticket_server_cors_allow_origin_env() {
316    test_config_from_env(
317      vec![("HTSGET_TICKET_SERVER_CORS_ALLOW_CREDENTIALS", true)],
318      |config| {
319        assert!(config.ticket_server().cors().allow_credentials());
320      },
321    );
322  }
323
324  #[test]
325  fn config_service_info_id_env() {
326    test_config_from_env(vec![("HTSGET_SERVICE_INFO", "{ id=id }")], |config| {
327      assert_eq!(config.service_info().as_ref().get("id"), Some(&json!("id")));
328    });
329  }
330
331  #[test]
332  fn config_data_server_addr_env() {
333    test_config_from_env(
334      vec![("HTSGET_DATA_SERVER_ADDR", "127.0.0.1:8082")],
335      |config| {
336        assert_eq!(
337          config
338            .data_server()
339            .clone()
340            .as_data_server_config()
341            .unwrap()
342            .addr(),
343          "127.0.0.1:8082".parse().unwrap()
344        );
345      },
346    );
347  }
348
349  #[test]
350  fn config_ticket_server_addr_file() {
351    test_config_from_file(r#"ticket_server.addr = "127.0.0.1:8082""#, |config| {
352      assert_eq!(
353        config.ticket_server().addr(),
354        "127.0.0.1:8082".parse().unwrap()
355      );
356    });
357  }
358
359  #[test]
360  fn config_ticket_server_cors_allow_origin_file() {
361    test_config_from_file(r#"ticket_server.cors.allow_credentials = true"#, |config| {
362      assert!(config.ticket_server().cors().allow_credentials());
363    });
364  }
365
366  #[test]
367  fn config_service_info_id_file() {
368    test_config_from_file(r#"service_info.id = "id""#, |config| {
369      assert_eq!(config.service_info().as_ref().get("id"), Some(&json!("id")));
370    });
371  }
372
373  #[test]
374  fn config_data_server_addr_file() {
375    test_config_from_file(r#"data_server.addr = "127.0.0.1:8082""#, |config| {
376      assert_eq!(
377        config
378          .data_server()
379          .clone()
380          .as_data_server_config()
381          .unwrap()
382          .addr(),
383        "127.0.0.1:8082".parse().unwrap()
384      );
385    });
386  }
387
388  #[test]
389  #[should_panic]
390  fn config_data_server_tls_no_cert() {
391    with_test_certificates(|path, _, _| {
392      let key_path = path.join("key.pem");
393
394      test_config_from_file(
395        &format!(
396          r#"
397        data_server.tls.key = "{}"
398        "#,
399          key_path.to_string_lossy().escape_default()
400        ),
401        |config| {
402          assert!(config
403            .data_server()
404            .clone()
405            .as_data_server_config()
406            .unwrap()
407            .tls()
408            .is_none());
409        },
410      );
411    });
412  }
413
414  #[test]
415  fn config_data_server_tls() {
416    with_test_certificates(|path, _, _| {
417      let key_path = path.join("key.pem");
418      let cert_path = path.join("cert.pem");
419
420      test_config_from_file(
421        &format!(
422          r#"
423          data_server.tls.key = "{}"
424          data_server.tls.cert = "{}"
425          "#,
426          key_path.to_string_lossy().escape_default(),
427          cert_path.to_string_lossy().escape_default()
428        ),
429        |config| {
430          assert!(config
431            .data_server()
432            .clone()
433            .as_data_server_config()
434            .unwrap()
435            .tls()
436            .is_some());
437        },
438      );
439    });
440  }
441
442  #[test]
443  fn config_data_server_tls_env() {
444    with_test_certificates(|path, _, _| {
445      let key_path = path.join("key.pem");
446      let cert_path = path.join("cert.pem");
447
448      test_config_from_env(
449        vec![
450          ("HTSGET_DATA_SERVER_TLS_KEY", key_path.to_string_lossy()),
451          ("HTSGET_DATA_SERVER_TLS_CERT", cert_path.to_string_lossy()),
452        ],
453        |config| {
454          assert!(config
455            .data_server()
456            .clone()
457            .as_data_server_config()
458            .unwrap()
459            .tls()
460            .is_some());
461        },
462      );
463    });
464  }
465
466  #[test]
467  #[should_panic]
468  fn config_ticket_server_tls_no_cert() {
469    with_test_certificates(|path, _, _| {
470      let key_path = path.join("key.pem");
471
472      test_config_from_file(
473        &format!(
474          r#"
475        ticket_server.tls.key = "{}"
476        "#,
477          key_path.to_string_lossy().escape_default()
478        ),
479        |config| {
480          assert!(config.ticket_server().tls().is_none());
481        },
482      );
483    });
484  }
485
486  #[test]
487  fn config_ticket_server_tls() {
488    with_test_certificates(|path, _, _| {
489      let key_path = path.join("key.pem");
490      let cert_path = path.join("cert.pem");
491
492      test_config_from_file(
493        &format!(
494          r#"
495        ticket_server.tls.key = "{}"
496        ticket_server.tls.cert = "{}"
497        "#,
498          key_path.to_string_lossy().escape_default(),
499          cert_path.to_string_lossy().escape_default()
500        ),
501        |config| {
502          assert!(config.ticket_server().tls().is_some());
503        },
504      );
505    });
506  }
507
508  #[test]
509  fn config_ticket_server_tls_env() {
510    with_test_certificates(|path, _, _| {
511      let key_path = path.join("key.pem");
512      let cert_path = path.join("cert.pem");
513
514      test_config_from_env(
515        vec![
516          ("HTSGET_TICKET_SERVER_TLS_KEY", key_path.to_string_lossy()),
517          ("HTSGET_TICKET_SERVER_TLS_CERT", cert_path.to_string_lossy()),
518        ],
519        |config| {
520          assert!(config.ticket_server().tls().is_some());
521        },
522      );
523    });
524  }
525
526  #[test]
527  fn locations_from_data_server_config() {
528    test_config_from_file(
529      r#"
530    data_server.addr = "127.0.0.1:8080"
531    data_server.local_path = "path"
532
533    [[locations]]
534    regex = "123"
535    backend.kind = "File"
536    backend.local_path = "path"
537    "#,
538      |config| {
539        assert_eq!(config.locations().len(), 1);
540        let config = config.locations.into_inner();
541        let regex = config[0].as_regex().unwrap();
542        assert!(matches!(regex.backend(),
543            Backend::File(file) if file.local_path() == "path" && file.scheme() == Scheme::Http && file.authority() == &Authority::from_static("127.0.0.1:8081")));
544      },
545    );
546  }
547
548  #[test]
549  fn simple_locations_env() {
550    test_config_from_env(
551      vec![
552        ("HTSGET_DATA_SERVER_ADDR", "127.0.0.1:8080"),
553        ("HTSGET_LOCATIONS", "[file://data/bam, file://data/cram]"),
554      ],
555      |config| {
556        assert_multiple(config);
557      },
558    );
559  }
560
561  #[test]
562  fn simple_locations() {
563    test_config_from_file(
564      r#"
565    data_server.addr = "127.0.0.1:8080"
566    data_server.local_path = "path"
567    
568    locations = "file://data"
569    "#,
570      |config| {
571        assert_eq!(config.locations().len(), 1);
572        let config = config.locations.into_inner();
573        let location = config[0].as_simple().unwrap();
574        assert_eq!(location.prefix(), "");
575        assert_file_location(location, "data");
576      },
577    );
578  }
579
580  #[cfg(feature = "aws")]
581  #[test]
582  fn simple_locations_s3() {
583    test_config_from_file(
584      r#"
585    locations = "s3://bucket"
586    "#,
587      |config| {
588        assert_eq!(config.locations().len(), 1);
589        let config = config.locations.into_inner();
590        let location = config[0].as_simple().unwrap();
591        assert_eq!(location.prefix(), "");
592        assert!(matches!(location.backend(),
593            Backend::S3(s3) if s3.bucket() == "bucket"));
594      },
595    );
596  }
597
598  #[cfg(feature = "url")]
599  #[test]
600  fn simple_locations_url() {
601    test_config_from_file(
602      r#"
603    locations = "https://example.com"
604    "#,
605      |config| {
606        assert_eq!(config.locations().len(), 1);
607        let config = config.locations.into_inner();
608        let location = config[0].as_simple().unwrap();
609        assert_eq!(location.prefix(), "");
610        assert!(matches!(location.backend(),
611            Backend::Url(url) if url.url() == &"https://example.com".parse::<Uri>().unwrap()));
612      },
613    );
614  }
615
616  #[test]
617  fn simple_locations_multiple() {
618    test_config_from_file(
619      r#"
620    data_server.addr = "127.0.0.1:8080"
621    locations = ["file://data/bam", "file://data/cram"]
622    "#,
623      |config| {
624        assert_multiple(config);
625      },
626    );
627  }
628
629  #[cfg(feature = "aws")]
630  #[test]
631  fn simple_locations_multiple_mixed() {
632    test_config_from_file(
633      r#"
634    data_server.addr = "127.0.0.1:8080"
635    data_server.local_path = "root"
636    locations = ["file://dir_one/bam", "file://dir_two/cram", "s3://bucket/vcf"]
637    "#,
638      |config| {
639        assert_eq!(config.locations().len(), 3);
640        let config = config.locations.into_inner();
641
642        let location = config[0].as_simple().unwrap();
643        assert_eq!(location.prefix(), "bam");
644        assert_file_location(location, "dir_one");
645
646        let location = config[1].as_simple().unwrap();
647        assert_eq!(location.prefix(), "cram");
648        assert_file_location(location, "dir_two");
649
650        let location = config[2].as_simple().unwrap();
651        assert_eq!(location.prefix(), "vcf");
652        assert!(matches!(location.backend(),
653            Backend::S3(s3) if s3.bucket() == "bucket"));
654      },
655    );
656  }
657
658  #[test]
659  fn no_data_server() {
660    test_config_from_file(
661      r#"
662      data_server = "None"
663    "#,
664      |config| {
665        assert!(config.data_server().as_data_server_config().is_err());
666      },
667    );
668  }
669
670  fn assert_multiple(config: Config) {
671    assert_eq!(config.locations().len(), 2);
672    let config = config.locations.into_inner();
673
674    println!("{:#?}", config);
675
676    let location = config[0].as_simple().unwrap();
677    assert_eq!(location.prefix(), "bam");
678    assert_file_location(location, "data");
679
680    let location = config[1].as_simple().unwrap();
681    assert_eq!(location.prefix(), "cram");
682    assert_file_location(location, "data");
683  }
684
685  fn assert_file_location(location: &Location, local_path: &str) {
686    assert!(matches!(location.backend(),
687            Backend::File(file) if file.local_path() == local_path && file.scheme() == Scheme::Http && file.authority() == &Authority::from_static("127.0.0.1:8080")));
688  }
689}